Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 8, 2019
1 parent 8cade23 commit 6d51297
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions torch_struct/semirings/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def forward(ctx, a, b):
size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]
return accumulate_(a, b, size,
lambda a, b: cls.dot(a, b),
preserve=len(ret.shape),
preserve=len(size),
step=max_size // a.shape[-1] + 2)

@staticmethod
Expand Down Expand Up @@ -57,8 +57,6 @@ def accumulate_(a, b, size, fn, preserve, step=10000):
if step > total:
return fn(a, b)

print("miss")

ret = torch.zeros(*size, dtype=a.dtype, device=a.device)
a_one, b_one = ones(a), ones(b)
indices = torch.tensor(np.mgrid[slices]).view(len(ret.shape[:preserve]), -1)
Expand Down

0 comments on commit 6d51297

Please sign in to comment.