/
basics.py
1307 lines (1149 loc) · 51 KB
/
basics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
import torch
from typing import List, Callable, Union, Optional
from ..communication import MPI
from .. import arithmetics
from .. import exponential
from ..dndarray import DNDarray
from .. import factories
from .. import manipulations
from .. import types
__all__ = ["dot", "matmul", "norm", "outer", "projection", "transpose", "tril", "triu"]
def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
"""
Returns the dot product of two ``DNDarrays``.
Specifically,
1. If both a and b are 1-D arrays, it is inner product of vectors.
2. If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or ``a@b`` is preferred.
3. If either a or b is 0-D (scalar), it is equivalent to multiply and using ``multiply(a, b)`` or ``a*b`` is preferred.
Parameters
----------
a : DNDarray
b : DNDarray
out : DNDarray, optional
place to put the result
"""
if isinstance(a, (float, int)) or isinstance(b, (float, int)) or a.ndim == 0 or b.ndim == 0:
# 3. If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
if out is not None:
out = a * b
return out
return a * b
elif a.ndim == 1 and b.ndim == 1:
# 1. If both a and b are 1-D arrays, it is inner product of vectors.
if a.split is None and b.split is None:
sl = slice(None)
else: # at least one of them is split
sl = a.comm.chunk(a.shape, a.split if a.split is not None else b.split)[2]
ret = torch.dot(a[sl]._DNDarray__array, b[sl]._DNDarray__array)
if a.is_distributed() or b.is_distributed():
a.comm.Allreduce(MPI.IN_PLACE, ret, MPI.SUM)
if out is not None:
out = ret.item()
return out
return ret.item()
elif a.ndim == 2 and b.ndim == 2:
# 2. If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred.
ret = matmul(a, b)
if out is not None:
if out is not None:
out._DNDarray__array = ret._DNDarray__array
out._DNDarray__dtype = ret.dtype
out._DNDarray__split = ret.split
out._DNDarray__device = ret.device
out._DNDarray__comm = ret.comm
return out
return ret
else:
raise NotImplementedError("ht.dot not implemented for N-D dot M-D arrays")
def matmul(a: DNDarray, b: DNDarray, allow_resplit: Optional[bool] = False) -> DNDarray:
"""
Matrix multiplication of two ``DNDarrays``: ``a@b=c`` or ``A@B=c``.
Returns a tensor with the result of ``a@b``. The split dimension of the returned array is
typically the split dimension of a. However, if ``a.split=None`` then the the ``c.split`` will be
set as the split dimension of ``b``. If both are ``None`` then ``c.split`` is also ``None``.
Parameters
----------
a : DNDarray
2 dimensional: :math:`L \\times P`
b : DNDarray
2 dimensional: :math:`P \\times Q`
allow_resplit : bool, optional
Flag for if to resplit the DNDarray ``a`` in the case that both ``a`` and ``b`` are not split.
- Default: if both are not split then both will remain not split.
- True: if both are not split then ``a`` will be split in-place along axis 0, i.e. the split axis of ``a`` will
become 0 and the ``DNDarray`` will be distributed in the standard fashion.
- The default case should be the most efficient case for large matrices.
Notes
-----
- If ``a`` is a split vector then the returned vector will be of shape (:math:`1xQ`) and will be split in the 1st dimension
- If ``b`` is a vector and either ``a`` or ``b`` is split, then the returned vector will be of shape (:math:`Lx1`) and will be split in the 0th dimension
References
----------
[1] R. Gu, et al., "Improving Execution Concurrency of Large-scale Matrix Multiplication on
Distributed Data-parallel Platforms," IEEE Transactions on Parallel and Distributed Systems,
vol 28, no. 9. 2017. \n
[2] S. Ryu and D. Kim, "Parallel Huge Matrix Multiplication on a Cluster with GPGPU
Accelerators," 2018 IEEE International Parallel and Distributed Processing Symposium
Workshops (IPDPSW), Vancouver, BC, 2018, pp. 877-882.
Example
-------
>>> a = ht.ones((n, m), split=1)
>>> a[0] = ht.arange(1, m + 1)
>>> a[:, -1] = ht.arange(1, n + 1)._DNDarray__array
[0/1] tensor([[1., 2.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
[1/1] tensor([[3., 1.],
[1., 2.],
[1., 3.],
[1., 4.],
[1., 5.]])
>>> b = ht.ones((j, k), split=0)
>>> b[0] = ht.arange(1, k + 1)
>>> b[:, 0] = ht.arange(1, j + 1)._DNDarray__array
[0/1] tensor([[1., 2., 3., 4., 5., 6., 7.],
[2., 1., 1., 1., 1., 1., 1.]])
[1/1] tensor([[3., 1., 1., 1., 1., 1., 1.],
[4., 1., 1., 1., 1., 1., 1.]])
>>> linalg.matmul(a, b)._DNDarray__array
[0/1] tensor([[18., 8., 9., 10.],
[14., 6., 7., 8.],
[18., 7., 8., 9.],
[22., 8., 9., 10.],
[26., 9., 10., 11.]])
[1/1] tensor([[11., 12., 13.],
[ 9., 10., 11.],
[10., 11., 12.],
[11., 12., 13.],
[12., 13., 14.]])
"""
if a.gshape[-1] != b.gshape[0]:
raise ValueError(
"If the last dimension of a ({}) is not the same size as the second-to-last dimension of b. ({})".format(
a.gshape[-1], b.gshape[-2]
)
)
# determine if a larger type is needed for c
c_type = types.promote_types(a.dtype, b.dtype)
if a.dtype != c_type:
a = c_type(a, device=a.device)
if b.dtype != c_type:
b = c_type(b, device=b.device)
if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
return factories.array(
torch.matmul(a._DNDarray__array, b._DNDarray__array), device=a.device
)
else:
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a._DNDarray__array @ b._DNDarray__array
c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c._DNDarray__array[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
return c
# if they are vectors they need to be expanded to be the proper dimensions
vector_flag = False # flag to run squeeze at the end of the function
if len(a.gshape) < 2 and len(b.gshape) < 2:
# make both split 0, do a local mm then a sum
a.resplit_(0)
b.resplit_(0)
res = a._DNDarray__array @ b._DNDarray__array
a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
return factories.array(res, split=None, device=a.device)
elif len(a.gshape) < 2:
a = manipulations.expand_dims(a, axis=0)
vector_flag = True
elif len(b.gshape) < 2:
b = manipulations.expand_dims(b, axis=1)
vector_flag = True
split_0_flag = False
split_1_flag = False
split_01_flag = False
split_10_flag = False
if (
(a.split == 0 and b.split is None) or (a.split is None and b.split == 1)
) and not vector_flag:
split = a.split if a.split is not None else b.split
split = split if not vector_flag else 0
c = factories.zeros((a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device)
c._DNDarray__array += a._DNDarray__array @ b._DNDarray__array
return c if not vector_flag else c.squeeze()
elif a.split == 1 and b.split is None:
c = torch.zeros(
(a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=a.device.torch_device
)
a_idx = a.comm.chunk(a.shape, a.split)[2]
c += (
a._DNDarray__array
@ b._DNDarray__array[a_idx[1].start : a_idx[1].start + a.lshape[-1], :]
)
a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
c = factories.array(c, split=a.split if b.gshape[1] > 1 else 0, device=a.device)
return c
elif a.split is None and b.split == 0:
c = torch.zeros(
(a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=a.device.torch_device
)
b_idx = b.comm.chunk(b.shape, b.split)[2]
c += (
a._DNDarray__array[:, b_idx[0].start : b_idx[0].start + b.lshape[0]]
@ b._DNDarray__array
)
b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
c = factories.array(c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device)
return c
elif (
a.split == 0 and b.split is None
): # this case and the one below will only be reaching if one of them is a vector
c = torch.zeros(
(a.gshape[-2], b.lshape[1]), dtype=c_type.torch_type(), device=a.device.torch_device
)
a_idx = a.comm.chunk(a.shape, a.split)[2]
c[a_idx[0]] += a._DNDarray__array @ b._DNDarray__array
a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
split = a.split if b.gshape[1] > 1 else 0
split = split if not vector_flag else 0
c = factories.array(c, split=split, device=a.device)
return c
elif a.split is None and b.split == 1:
c = torch.zeros(
(a.gshape[-2], b.lshape[1]), dtype=c_type.torch_type(), device=a.device.torch_device
)
c += a._DNDarray__array @ b._DNDarray__array
c = c if not vector_flag else c.squeeze()
split = b.split if a.gshape[1] > 1 else 0
split = split if not vector_flag else 0
c = factories.array(c, is_split=split, device=a.device)
return c
elif a.split == 0 and b.split == 0:
split_0_flag = True
elif a.split == 1 and b.split == 1:
split_1_flag = True
elif a.split == 0 and b.split == 1:
split_01_flag = True
elif a.split == 1 and b.split == 0:
split_10_flag = True
else:
raise NotImplementedError("splits > 1 not implemented")
# block sizes dont need to be the same. thy just need the same inner dimension (kB)
kB = 0
rem_a, rem_b = [0] * 2
if a.split == len(a.gshape) - 1 and b.split == len(a.gshape) - 2:
# if the split direction is the last dim in a and the first dim in b
# the max inner dim (kB) is the min value from the result of the integer division
# of the last dim of a/world size and the first dim of b/world size
kB = min([a.gshape[-1] // a.comm.size, b.gshape[0] // b.comm.size])
elif a.split == len(a.gshape) - 2 and b.split == len(a.gshape) - 1:
kB = a.gshape[-1]
elif a.split == len(a.gshape) - 1:
kB = a.gshape[-1] // a.comm.size
elif b.split == len(a.gshape) - 2:
kB = b.gshape[0] // b.comm.size
kB = kB if kB < a.gshape[-1] else a.gshape[-1]
if a.lshape[-1] % kB != 0 or (kB == 1 and a.lshape[-1] != 1):
rem_a = 1
if b.lshape[0] % kB != 0 or (kB == 1 and b.lshape[-2] != 1):
rem_b = 1
# get the lshape map to determine what needs to be sent where as well as M and N
# lshape map dims -> {node, a=0, b=1, lshape}
lshape_map = torch.zeros(
(a.comm.size, 2, len(a.gshape)), dtype=int, device=a.device.torch_device
)
lshape_map[a.comm.rank, 0, :] = torch.tensor(a.lshape, device=a.device.torch_device)
lshape_map[b.comm.rank, 1, :] = torch.tensor(b.lshape, device=a.device.torch_device)
a.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)
# find mB (first blocking dim for a) and nB (2nd blocking dim for b)
mB = lshape_map[:, 0, -2].min().item()
nB = lshape_map[:, 1, -1].min().item()
# check for remaining dims in the outside dimensions
rem_a_out, rem_b_out = 0, 0
if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1):
rem_a_out = 1
if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1):
rem_b_out = 1
# get the flags from all processes
# rem_map dims guide -> {process number, a/b (0/1), True/False (1/0)
# if there is a remainder in this dimension
rem_map = torch.zeros((a.comm.size, 2, 2))
rem_map[a.comm.rank, 0, :] = torch.tensor((rem_a_out, rem_a), device=a._DNDarray__array.device)
rem_map[a.comm.rank, 1, :] = torch.tensor((rem_b, rem_b_out), device=a._DNDarray__array.device)
rem_map_comm = a.comm.Iallreduce(MPI.IN_PLACE, rem_map, MPI.SUM)
# index_map dims guide -> {process number, a=0/b=1, relevent 1st index, 2nd index}
index_map = torch.zeros((a.comm.size, 2, 2, 2), dtype=int, device=b._DNDarray__array.device)
a_idx = a.comm.chunk(a.shape, a.split)[2]
index_map[a.comm.rank, 0, 0] = torch.tensor(
(a_idx[0].start, a_idx[0].stop), device=b._DNDarray__array.device
)
index_map[a.comm.rank, 0, 1] = torch.tensor(
(a_idx[1].start, a_idx[1].stop), device=b._DNDarray__array.device
)
b_idx = b.comm.chunk(b.shape, b.split)[2]
index_map[b.comm.rank, 1, 0] = torch.tensor(
(b_idx[0].start, b_idx[0].stop), device=b._DNDarray__array.device
)
index_map[b.comm.rank, 1, 1] = torch.tensor(
(b_idx[1].start, b_idx[1].stop), device=b._DNDarray__array.device
)
index_map_comm = a.comm.Iallreduce(MPI.IN_PLACE, index_map, MPI.SUM)
# for the communication scheme, the output array needs to be created
c_shape = (a.gshape[-2], b.gshape[1])
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device)
# get the index map for c
c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device)
c_idx = c.comm.chunk(c.shape, c.split)[2]
c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop)
c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop)
c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM)
if a.split == 0:
a_block_map = torch.zeros(
(a.comm.size, a.shape[-2] // mB // a.comm.size, a.shape[-1] // kB, 2),
dtype=torch.int,
device=a.device.torch_device,
)
elif a.split == 1:
a_block_map = torch.zeros(
(a.comm.size, a.shape[-2] // mB, a.shape[-1] // kB // a.comm.size, 2),
dtype=torch.int,
device=a.device.torch_device,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local
# below is to handle the edge case where there is only one element in one dimension of a
a_d0_1s_flag, a_d1_1s_flag = False, False
if any(lshape_map[:, 0, :][:, 0] == 1):
a_d0_1s_flag = True
if any(lshape_map[:, 0, :][:, 1] == 1):
a_d1_1s_flag = True
index_map_comm.wait()
for pr in range(a.comm.size):
start0 = index_map[pr, 0, 0, 0].item()
stop0 = index_map[pr, 0, 0, 1].item()
start1 = index_map[pr, 0, 1, 0].item()
stop1 = index_map[pr, 0, 1, 1].item()
for dim0 in range(
(stop0 - start0) // mB // a.comm.size if a_d0_1s_flag else (stop0 - start0) // mB
):
# loop over the number of blocks in the 0th dimension
for dim1 in range(
(stop1 - start1) // kB // a.comm.size if a_d1_1s_flag else (stop1 - start1) // kB
):
# loop over the number of blocks in the 1st dimension
a_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * mB, dim1 * kB), dtype=torch.int, device=a._DNDarray__array.device
)
rem_map_comm.wait()
if b.split == 0:
# the blocks are shifted in the 2nd dimension of A for as many remainders
# there are between the blocks in the first dim of B
cnt = 0
for r in rem_map[:, 1, 0]:
if r.item():
cnt += 1
a_block_map[:, :, cnt:, 1] += 1
if b.split == 0:
b_block_map = torch.zeros(
(b.comm.size, b.shape[-2] // kB // b.comm.size, b.shape[-1] // nB, 2),
dtype=torch.int,
device=b.device.torch_device,
)
elif b.split == 1:
b_block_map = torch.zeros(
(b.comm.size, b.shape[-2] // kB, b.shape[-1] // nB // b.comm.size, 2),
dtype=torch.int,
device=b.device.torch_device,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local
# below is to handle the edge case where there is only one element in one dimension of b
b_d0_1s_flag, b_d1_1s_flag = False, False
if any(lshape_map[:, 1, :][:, 0] == 1):
b_d0_1s_flag = True
if any(lshape_map[:, 1, :][:, 1] == 1):
b_d1_1s_flag = True
for pr in range(b.comm.size):
start0 = index_map[pr, 1, 0, 0].item()
stop0 = index_map[pr, 1, 0, 1].item()
start1 = index_map[pr, 1, 1, 0].item()
stop1 = index_map[pr, 1, 1, 1].item()
# loop over the number of blocks in the 0th dimension
for dim0 in range(
(stop0 - start0) // kB // b.comm.size if b_d0_1s_flag else (stop0 - start0) // kB
):
# loop over the number of blocks in the 1st dimension
for dim1 in range(
(stop1 - start1) // nB // b.comm.size if b_d1_1s_flag else (stop1 - start1) // nB
):
b_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * kB, dim1 * nB), dtype=torch.int, device=b._DNDarray__array.device
)
if a.split == 1:
cnt = 0
# this loop will push the blocks in B to adjust for the remainders in A
for r in rem_map[:, 0, 1]:
if r.item():
cnt += 1
b_block_map[:, cnt:, :, 0] += 1
# work loop: loop over all processes (also will incorporate the remainder calculations)
c_wait.wait()
if split_0_flag:
# need to send b here and not a
# the rows on 'a' are complete, and the columns of 'b' are split
# locations of the remainders in b
b_rem_locs0 = (rem_map[:, 1, 0] == 1).nonzero()
a_rem_locs0 = (rem_map[:, 0, 0] == 1).nonzero()
# remainders for a in the
a_node_rem_s0 = a._DNDarray__array[:mB, kB : (kB + 1) * b_rem_locs0.numel() : kB + 1]
b_rem = torch.empty(
b_rem_locs0.numel(),
b.lshape[-1],
dtype=a.dtype.torch_type(),
device=b.device.torch_device,
)
# this if/elif/else loop is for the handling of
if a.comm.rank in a_rem_locs0:
# if A is split in dim0 and the rank has a remainder in this direction
r = a._DNDarray__array[-1]
r_loc = index_map[a.comm.rank, 0, 0, 1] - index_map[a.comm.rank, 0, 0, 0] - 1
else:
r = None
r_loc = None
req = {}
b_lp_data = {}
for pr in range(b.comm.size):
# ibcast data on node first
if b.comm.rank == pr:
b_lp_data[pr] = b._DNDarray__array.clone()
else:
b_lp_data[pr] = torch.zeros(
(lshape_map[pr, 1, 0].item(), lshape_map[pr, 1, 1].item()),
dtype=b.dtype.torch_type(),
device=b.device.torch_device,
)
# sending a to all nodes for b to operate with
req[pr] = b.comm.Ibcast(b_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
# after receiving the last loop's bcast
__mm_c_block_setter(
b_proc=pr - 1,
a_proc=a.comm.rank,
a_data=a._DNDarray__array,
b_data=b_lp_data[pr - 1],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=b.split,
a_split=a.split,
mB=mB,
kB=kB,
nB=nB,
c=c._DNDarray__array,
)
# check if there is a remainder on b in the previous node
# this loop is intended to get the remainders of b since it is the one being passed
if pr - 1 in b_rem_locs0:
# takes care of the remainders in b as well as dim0 of a
b_rem[pr - 1] = b_lp_data[pr - 1][-1]
# this loop is to take care of the remainders in dim0 of A
if a_rem_locs0.nelement() != 0:
if r_loc is not None:
st = index_map[pr - 1, 1, 0, 0].item()
sp = index_map[pr - 1, 1, 0, 1].item()
c._DNDarray__array[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr - 1]
del b_lp_data[pr - 1]
# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
__mm_c_block_setter(
b_proc=pr,
a_proc=a.comm.rank,
a_data=a._DNDarray__array,
b_data=b_lp_data[pr],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=b.split,
a_split=a.split,
mB=mB,
kB=kB,
nB=nB,
c=c._DNDarray__array,
)
# check if there is a remainder on b on the last node (there shouldnt be)
if pr in b_rem_locs0:
# this is to save the data from B required by the remainders from dim1 of A
b_rem[pr] = b_lp_data[pr][-1]
# this loop is to take care of the remainders in the 0th dimension of A
if a_rem_locs0.nelement() != 0:
if r_loc is not None:
st = index_map[pr, 1, 0, 0].item()
sp = index_map[pr, 1, 0, 1].item()
if split_01_flag:
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c._DNDarray__array[r_loc.item(), st1:sp1] += r[st:sp] @ b_lp_data[pr]
else:
c._DNDarray__array[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr]
# set the final blocks on the last loop, then adjust for the
# the remainders which were collected in b_rem
if b_rem_locs0.numel():
c._DNDarray__array[: a_node_rem_s0.shape[0]] += a_node_rem_s0 @ b_rem
del b_lp_data[pr]
if vector_flag:
c_loc = c._DNDarray__array.squeeze()
if c_loc.nelement() == 1:
c = torch.tensor(c_loc, device=c._DNDarray__array.device)
c = factories.array(c_loc, is_split=0, device=a.device)
return c
elif split_1_flag:
# for this case, a is sent to b
# this is because 'b' has complete columns and the rows of 'a' are split
# locations of the remainders in b
b_rem_locs1 = (rem_map[:, 1, 1] == 1).nonzero()
a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
b_node_rem_s1 = b._DNDarray__array[
kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB
] # remainders for a in the
a_rem = torch.empty(
a.lshape[-2],
a_rem_locs1.numel(),
dtype=b.dtype.torch_type(),
device=a.device.torch_device,
)
# this if/elif/else loop is for the handling of
if b.comm.rank in b_rem_locs1:
# if b is split in dim1 and the rank has a remainder in this direction
r = b._DNDarray__array[:, -1]
r_loc = index_map[a.comm.rank, 1, 1, 1] - index_map[a.comm.rank, 1, 1, 0] - 1
else:
r = None
r_loc = None
req = {}
a_lp_data = {}
for pr in range(a.comm.size):
# ibcast data on node first
if a.comm.rank == pr:
a_lp_data[pr] = a._DNDarray__array.clone()
else:
a_lp_data[pr] = torch.zeros(
(lshape_map[pr, 0, 0].item(), lshape_map[pr, 0, 1].item()),
dtype=a.dtype.torch_type(),
device=a.device.torch_device,
)
# sending a to all nodes for b to operate with
req[pr] = a.comm.Ibcast(a_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
# after receiving the last loop's bcast
req[pr - 1].wait()
__mm_c_block_setter(
a_proc=pr - 1,
b_proc=b.comm.rank,
a_data=a_lp_data[pr - 1],
b_data=b._DNDarray__array,
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=b.split,
a_split=a.split,
mB=mB,
kB=kB,
nB=nB,
c=c._DNDarray__array,
)
# check if there is a remainder on b in the previous node
# this loop is intended to get the remainders of b since it is the one being passed
if pr - 1 in a_rem_locs1:
# takes care of the remainders in b as well as dim0 of a
a_rem[:, pr - 1] = a_lp_data[pr - 1][:, -1]
# this loop is to take care of the remainders in dim1 of B
if b_rem_locs1.nelement() != 0:
if r_loc is not None:
st = index_map[pr - 1, 0, 1, 0].item()
sp = index_map[pr - 1, 0, 1, 1].item()
c._DNDarray__array[:, r_loc.item()] += (
a_lp_data[pr - 1] @ r[st:sp, None]
).flatten()
del a_lp_data[pr - 1]
# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
__mm_c_block_setter(
a_proc=pr,
b_proc=a.comm.rank,
a_data=a_lp_data[pr],
b_data=b._DNDarray__array,
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=b.split,
a_split=a.split,
mB=mB,
kB=kB,
nB=nB,
c=c._DNDarray__array,
)
# check if there is a remainder on b on the last node (there shouldnt be)
if pr in a_rem_locs1:
# this is to save the data from B required by the remainders from dim1 of A
a_rem[:, pr] = a_lp_data[pr][:, -1]
# this loop is to take care of the remainders in the 0th dimension of A
if b_rem_locs1.nelement() != 0:
if r_loc is not None:
st = index_map[pr, 0, 1, 0].item()
sp = index_map[pr, 0, 1, 1].item()
c._DNDarray__array[:, r_loc.item()] += (
a_lp_data[pr] @ r[st:sp, None]
).flatten()
# set the final blocks on the last loop, then adjust for the the remainders which were collected in b_rem
if a_rem_locs1.numel():
c._DNDarray__array[:, : b_node_rem_s1.shape[1]] += a_rem @ b_node_rem_s1
del a_lp_data[pr]
c = (
c
if not vector_flag
else factories.array(c._DNDarray__array.squeeze(), is_split=0, device=a.device)
)
return c
elif split_01_flag:
# for this case there are no remainders which need to be taken care of
req = {}
b_lp_data = {}
for pr in range(a.comm.size):
# ibcast data on node first
if b.comm.rank == pr:
b_lp_data[pr] = b._DNDarray__array.clone()
else:
b_lp_data[pr] = torch.empty(
(lshape_map[pr, 1, 0].item(), lshape_map[pr, 1, 1].item()),
dtype=b.dtype.torch_type(),
device=b.device.torch_device,
)
# sending a to all nodes for b to operate with
req[pr] = b.comm.Ibcast(b_lp_data[pr], root=pr)
# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
# after receiving the last loop's bcast
st0 = index_map[pr - 1, 0, 0, 0].item()
sp0 = index_map[pr - 1, 0, 0, 1].item() + 1
st1 = index_map[pr - 1, 1, 1, 0].item()
sp1 = index_map[pr - 1, 1, 1, 1].item()
c._DNDarray__array[: sp0 - st0, st1:sp1] += a._DNDarray__array @ b_lp_data[pr - 1]
del b_lp_data[pr - 1]
if pr == b.comm.size - 1:
req[pr].wait()
st0 = index_map[pr, 0, 0, 0].item()
sp0 = index_map[pr, 0, 0, 1].item() + 1
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c._DNDarray__array[: sp0 - st0, st1:sp1] += a._DNDarray__array @ b_lp_data[pr]
del b_lp_data[pr]
c = (
c
if not vector_flag
else factories.array(c._DNDarray__array.squeeze(), is_split=0, device=a.device)
)
return c
elif split_10_flag:
# todo: this may create the full matrix on evey process, issue #360
# for this case, only a sum is needed at the end
a_rem_locs1 = (rem_map[:, 0, 1] == 1).nonzero()
# locations of the remainders in b
b_rem_locs0 = (rem_map[:, 1, 0] == 1).nonzero()
res = torch.zeros(
(a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=c.device.torch_device
)
for i in range(a.lshape[-1] // kB):
res += (
a._DNDarray__array[:mB, i * kB : i * kB + kB]
@ b._DNDarray__array[i * kB : i * kB + kB, :nB]
)
if a.comm.rank in a_rem_locs1 and b.comm.rank in b_rem_locs0 and kB > 1:
# these Nones are used to change the dims if the full process is not covered
res += a._DNDarray__array[:, -1, None] @ b._DNDarray__array[None, -1, :]
a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
split = a.split if b.gshape[1] > 1 else 0
split = split if not vector_flag else 0
res = res if not vector_flag else res.squeeze()
c = factories.array(res, split=split, device=a.device)
return c
DNDarray.__matmul__: Callable[
[DNDarray, DNDarray, Optional[bool]], DNDarray
] = lambda self, other=False: matmul(self, other)
DNDarray.__matmul__.__doc__ = matmul.__doc__
def norm(a: DNDarray) -> float:
"""
Returns the vector norm (Frobenius norm) of vector ``a``
Parameters
----------
a : DNDarray
Input vector
"""
if not isinstance(a, DNDarray):
raise TypeError("a must be of type ht.DNDarray, but was {}".format(type(a)))
d = a ** 2
for i in range(len(a.shape) - 1, -1, -1):
d = arithmetics.sum(d, axis=i)
return exponential.sqrt(d).item()
DNDarray.norm: Callable[[DNDarray], float] = lambda self: norm(self)
DNDarray.norm.__doc__ = norm.__doc__
def outer(
a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None, split: Optional[int] = None
) -> DNDarray:
"""
Compute the outer product of two 1-D DNDarrays: out[i, j] = a[i] * b[j].
Given two vectors, :math:`a = (a_0, a_1, ..., a_N)` and :math:`b = (b_0, b_1, ..., b_M)`, the outer product is:
.. math::
:nowrap:
\\begin{pmatrix}
a_0 \\cdot b_0 & a_0 \\cdot b_1 & . & . & a_0 \\cdot b_M \\
a_1 \\cdot b_0 & a_1 \\cdot b_1 & . & . & a_1 \\cdot b_M \\
. & . & . & . & . \\
a_N \\cdot b_0 & a_N \\cdot b_1 & . & . & a_N \\cdot b_M
\\end{pmatrix}
Parameters
----------
a : DNDarray
1-dimensional: :math: `N`
Will be flattened by default if more than 1-D.
b : DNDarray
1-dimensional: :math: `M`
Will be flattened by default if more than 1-D.
out : DNDarray, optional
2-dimensional: :math: `N \\times M`
A location where the result is stored
split : int, optional
Split dimension of the resulting DNDarray. Can be 0, 1, or None.
This is only relevant if the calculations are memory-distributed,
in which case default is ``split=0`` (see Note).
Notes
----------
Parallel implementation of outer product, arrays are dense.
In the classical (dense) case, one DNDarray stays put, the other one is passed around the ranks in
ring communication. The slice-by-slice outer product is calculated locally (here via torch.einsum()).
N.B.:
* If ``b`` is sent around, the resulting outer product is split along the rows dimension (``split = 0``).\n
* If ``a`` is sent around, the resulting outer product is split along the columns (``split = 1``).\n
So if ``split`` is not None, ``split`` defines which DNDarray stays put and which one is passed around. No
communication is needed beyond ring communication of one of the DNDarrays.
If ``split`` is None or unspecified, the result will be distributed along axis 0, i.e. by default ``b`` is
passed around, ``a`` stays put.
Examples
--------
>>> a = ht.arange(4)
>>> b = ht.arange(3)
>>> ht.outer(a, b)._DNDarray__array
(3 processes)
[0/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
[1/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
[2/2] tensor([[0, 0, 0],
[0, 1, 2],
[0, 2, 4],
[0, 3, 6]], dtype=torch.int32)
>>> a = ht.arange(4, split=0)
>>> b = ht.arange(3, split=0)
>>> ht.outer(a, b)._DNDarray__array
[0/2] tensor([[0, 0, 0],
[0, 1, 2]], dtype=torch.int32)
[1/2] tensor([[0, 2, 4]], dtype=torch.int32)
[2/2] tensor([[0, 3, 6]], dtype=torch.int32)
>>> ht.outer(a, b, split=1)._DNDarray__array
[0/2] tensor([[0],
[0],
[0],
[0]], dtype=torch.int32)
[1/2] tensor([[0],
[1],
[2],
[3]], dtype=torch.int32)
[2/2] tensor([[0],
[2],
[4],
[6]], dtype=torch.int32)
>>> a = ht.arange(5, dtype=ht.float32, split=0)
>>> b = ht.arange(4, dtype=ht.float64, split=0)
>>> out = ht.empty((5,4), dtype=ht.float64, split=1)
>>> ht.outer(a, b, split=1, out=out)
>>> out._DNDarray__array
[0/2] tensor([[0., 0.],
[0., 1.],
[0., 2.],
[0., 3.],
[0., 4.]], dtype=torch.float64)
[1/2] tensor([[0.],
[2.],
[4.],
[6.],
[8.]], dtype=torch.float64)
[2/2] tensor([[ 0.],
[ 3.],
[ 6.],
[ 9.],
[12.]], dtype=torch.float64)
"""
# sanitize input
if not isinstance(a, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(
"a, b must be of type ht.DNDarray, but were {}, {}".format(type(a), type(b))
)
# sanitize dimensions
# TODO move to sanitation module #468
if a.ndim > 1:
a = manipulations.flatten(a)
if b.ndim > 1:
b = manipulations.flatten(b)
if a.ndim == 0 or b.ndim == 0:
raise RuntimeError(
"a, b must be 1-D DNDarrays, but were {}-D and {}-D".format(a.ndim, b.ndim)
)
outer_gshape = (a.gshape[0], b.gshape[0])
t_a = a._DNDarray__array
t_b = b._DNDarray__array
t_outer_dtype = torch.promote_types(t_a.dtype, t_b.dtype)
t_a, t_b = t_a.type(t_outer_dtype), t_b.type(t_outer_dtype)
outer_dtype = types.canonical_heat_type(t_outer_dtype)
if out is not None:
if not isinstance(out, DNDarray):
raise TypeError("out must be of type ht.DNDarray, was {}".format(type(out)))
if out.dtype is not outer_dtype:
raise TypeError(
"Wrong datatype for out: expected {}, got {}".format(outer_dtype, out.dtype)
)
if out.gshape != outer_gshape:
raise ValueError("out must have shape {}, got {}".format(outer_gshape, out.gshape))
if out.split is not split:
raise ValueError(
"Split dimension mismatch for out: expected {}, got {}".format(split, out.split)
)
# distributed outer product, dense arrays (TODO: sparse, #384)
if a.comm.is_distributed() and split is not None or a.split is not None or b.split is not None:
# MPI coordinates
rank = a.comm.rank
size = a.comm.size
t_outer_slice = 2 * [slice(None, None, None)]
if a.split is None:
a.resplit_(axis=0)
t_a = a._DNDarray__array.type(t_outer_dtype)
if b.split is None:
b.resplit_(axis=0)
t_b = b._DNDarray__array.type(t_outer_dtype)
if split is None:
# Split semantics: default out.split = a.split
split = a.split
if out is not None and out.split is None:
out.resplit_(axis=split)
# calculate local slice of outer product
if split == 0:
lshape_map = b.create_lshape_map()
t_outer_shape = (a.lshape[0], b.gshape[0])
_, _, local_slice = b.comm.chunk(b.gshape, b.split)
t_outer_slice[1] = local_slice[0]
elif split == 1:
lshape_map = a.create_lshape_map()
t_outer_shape = (a.gshape[0], b.lshape[0])
_, _, local_slice = a.comm.chunk(a.gshape, a.split)
t_outer_slice[0] = local_slice[0]
t_outer = torch.zeros(t_outer_shape, dtype=t_outer_dtype, device=t_a.device)
if lshape_map[rank] != 0:
t_outer[t_outer_slice] = torch.einsum("i,j->ij", t_a, t_b)
# Ring: fill in missing slices of outer product
# allocate memory for traveling data
if split == 0:
t_b_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_a.device)
elif split == 1:
t_a_run = torch.empty(lshape_map[0], dtype=t_outer_dtype, device=t_b.device)
for p in range(size - 1):
# prepare for sending
dest_rank = rank + 1 if rank != size - 1 else 0
# prepare for receiving
origin_rank = rank - 1 if rank != 0 else size - 1
actual_origin = origin_rank - p
if origin_rank < p:
actual_origin += size
# blocking send and recv
if split == 0:
b.comm.Send(t_b, dest_rank)
b.comm.Recv(t_b_run, origin_rank)
# buffer from actual_origin could be smaller than allocated buffer
t_b = t_b_run[: lshape_map[actual_origin]]
_, _, remote_slice = b.comm.chunk(
b.gshape, b.split, rank=actual_origin, w_size=size
)
t_outer_slice[1] = remote_slice[0]
elif split == 1:
a.comm.Send(t_a, dest_rank)
a.comm.Recv(t_a_run, origin_rank)
# buffer from actual_origin could be smaller than allocated buffer
t_a = t_a_run[: lshape_map[actual_origin]]
_, _, remote_slice = a.comm.chunk(
a.gshape, a.split, rank=actual_origin, w_size=size