Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 31, 2019
1 parent 6909467 commit 01cfcda
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
8 changes: 4 additions & 4 deletions torch_struct/alignment.py
Expand Up @@ -344,8 +344,8 @@ def merge2(xa, xb, size, rsize):
st2.append(torch.stack([semiring.zero_(right.clone()), right], dim=-2))
st = torch.cat([st, torch.stack(st2, dim=-1)], dim=-1)
return semiring.sum(st)
reporter = MemReporter()
reporter.report()
# reporter = MemReporter()
# reporter.report()

size = bin_MN // 2
rsize = 2
Expand All @@ -369,8 +369,8 @@ def merge2(xa, xb, size, rsize):
:, :, 0, M - N + (charta[-1].shape[3] // 2), N, Open, Open, Mid
]

reporter = MemReporter()
reporter.report()
# reporter = MemReporter()
# reporter.report()
return v, [log_potentials], None

@staticmethod
Expand Down
13 changes: 7 additions & 6 deletions torch_struct/semirings.py
Expand Up @@ -142,6 +142,7 @@ def unaccumulate_(a, b, grad_output, fn, step=10000):
a_grad = torch.tensor(0.0, device=a.device).set_(a.clone().storage(), a.storage_offset(), a.size(), a.stride()).fill_(0)

b_grad = torch.tensor(0.0, device=b.device).set_(b.clone().storage(), b.storage_offset(), b.size(), b.stride()).fill_(0)
print(b_grad.shape, a.shape)

print("chcek", a_grad.shape)
total = 1
Expand Down Expand Up @@ -268,9 +269,9 @@ def forward(ctx, a, b):
def backward(ctx, grad_output):

a, b = ctx.saved_tensors
print("backing out", a.shape)
reporter = MemReporter()
reporter.report()
# print("backing out", a.shape)
# reporter = MemReporter()
# reporter.report()

size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]

Expand All @@ -291,9 +292,9 @@ def backward(ctx, grad_output):
grad_a = back.sum(dim=asum, keepdim=True)
grad_b = back.sum(dim=bsum, keepdim=True)

print("backing out 2", a.shape)
reporter = MemReporter()
reporter.report()
# print("backing out 2", a.shape)
# reporter = MemReporter()
# reporter.report()


return grad_a, grad_b
Expand Down
4 changes: 3 additions & 1 deletion torch_struct/test_algorithms.py
Expand Up @@ -180,7 +180,9 @@ def test_dp_custom():

def test_align_custom():
model = Alignment
vals, _ = model._rand()
#vals, _ = model._rand()
vals = torch.rand(1, 10, 10, 3)

struct = Alignment(LogSemiring)
marginals = struct.marginals(vals)
s = struct.sum(vals)
Expand Down

0 comments on commit 01cfcda

Please sign in to comment.