You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description
Using binary operations on a tensor that has been expanded to calculate all permutations of the operation(t1,t2) along that new axis generates wrong output shapes on split tensors.
The new gshape gets calculated properly, but only the local elements of t2 (along split axis) are used along the new axis, so not all permutations are calculated.
To Reproduce
import heat as ht
a = ht.array([[0,1],[1,2]], split=0)
a = ht.expand_dims(a, 1)
b = ht.array([[1,1],[2,2],[3,3]], split=0)
c = ht.sub(a,b)
print('c', c._DNDarray__array, c.split, c.lshape, c.gshape, c.comm.rank)
with n=2 this returns
c 0 (1, 2, 2) (2, 3, 2) 0
c 0 (1, 1, 2) (2, 3, 2) 1
Note the different lshape along axis 1.
Expected behavior
c 0 (1, 3, 2) (2, 3, 2) 0
c 0 (1, 3, 2) (2, 3, 2) 1
Possible quick and dirty fix (although there is probably a better way through broadcasting shapes that avoids the resplit):
Insert into line 124 of operations.py:
if len(t2.shape) < len(output_shape) or t2.shape[t2.split] != output_shape[t2.split]:
if t2.shape[t2.split] > 1 and t2.comm.is_distributed():
t2 = resplit(t2)
elif len(t1.shape) < len(output_shape) or t1.shape[t1.split] != output_shape[t1.split]:
if t1.shape[t1.split] > 1 and t1.comm.is_distributed():
t1 = resplit(t1)
The text was updated successfully, but these errors were encountered:
Description
Using binary operations on a tensor that has been expanded to calculate all permutations of the operation(t1,t2) along that new axis generates wrong output shapes on split tensors.
The new gshape gets calculated properly, but only the local elements of t2 (along split axis) are used along the new axis, so not all permutations are calculated.
To Reproduce
with n=2 this returns
c 0 (1, 2, 2) (2, 3, 2) 0
c 0 (1, 1, 2) (2, 3, 2) 1
Note the different lshape along axis 1.
Expected behavior
c 0 (1, 3, 2) (2, 3, 2) 0
c 0 (1, 3, 2) (2, 3, 2) 1
Possible quick and dirty fix (although there is probably a better way through broadcasting shapes that avoids the resplit):
Insert into line 124 of operations.py:
The text was updated successfully, but these errors were encountered: