Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 8, 2019
1 parent 1d63091 commit 8cade23
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions torch_struct/semirings/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ class _Check(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
if True:
return cls.dot(a, b)
else:
size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]
return accumulate_(a, b, size,
lambda a, b: cls.dot(a, b),
preserve=len(ret.shape),
step=max_size // a.shape[-1] + 2)
size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]
return accumulate_(a, b, size,
lambda a, b: cls.dot(a, b),
preserve=len(ret.shape),
step=max_size // a.shape[-1] + 2)

@staticmethod
def backward(ctx, grad_output):
Expand Down

0 comments on commit 8cade23

Please sign in to comment.