Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 31, 2019
1 parent 3e87121 commit e14047e
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torch_struct/semirings.py
@@ -1,7 +1,7 @@
import torch
import torch.distributions
import numpy as np

from pytorch_memlab import MemReporter

class Semiring:
"""
Expand Down Expand Up @@ -265,6 +265,9 @@ def backward(ctx, grad_output):

a, b = ctx.saved_tensors
print("backing out", a.shape)
reporter = MemReporter()
reporter.report()

size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]

fn = lambda a, b, g: torch.softmax(a + b, dim=-1).mul(g.unsqueeze(-1))
Expand All @@ -283,7 +286,12 @@ def backward(ctx, grad_output):
back = fn(a, b, grad_output)
grad_a = back.sum(dim=asum, keepdim=True)
grad_b = back.sum(dim=bsum, keepdim=True)

print("backing out 2", a.shape)
reporter = MemReporter()
reporter.report()


return grad_a, grad_b


Expand Down

0 comments on commit e14047e

Please sign in to comment.