Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 2, 2019
1 parent e92d0b7 commit 6be5ca2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torch_struct/semirings.py
Expand Up @@ -332,7 +332,7 @@ def forward(ctx, a, b, band, o1, o2):
def backward(ctx, grad_output):
a, b, opt = ctx.saved_tensors
band, o1, o2 = opt.tolist()
print("backing out banded", a.shape, b)
print("backing out banded", a.shape, band)
reporter = MemReporter()
reporter.report()

Expand All @@ -347,7 +347,7 @@ def inner(a, b):
p = p.transpose(-1, -3).transpose(-2, -3)
return p.mul(g.unsqueeze(-1)).sum(-1)

if True:
if False:
asum, bsum = [], []
for i, (x, y) in enumerate(zip(a.shape, b.shape)):
if x == 1:
Expand All @@ -357,7 +357,12 @@ def inner(a, b):
back = fn(a, b, grad_output)
grad_a = back.sum(dim=asum, keepdim=True)
grad_b = back.sum(dim=bsum, keepdim=True)

else:
grad_a, grad_b = unaccumulate_(
a, b, grad_output, fn,
step=max_size // a.shape[-1] + 2
)

print("backing out banded 2",
grad_a.shape, grad_b.shape, a.shape, b)
reporter = MemReporter()
Expand Down

0 comments on commit 6be5ca2

Please sign in to comment.