Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 7, 2019
1 parent 26f2e42 commit 6f06de9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def sstack(a):
v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
print("here")

ret = torch.zeros(batch, N, N).type_as(left)
for k in range(N):
for d in range(N - k):
Expand Down
4 changes: 2 additions & 2 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _make_chart(self, N, size, potentials, force_grad):
for _ in range(N)
]

def sum(self, edge, lengths=None, _autograd=False):
def sum(self, edge, lengths=None, _autograd=True):
"""
Compute the (semiring) sum over all structures model.
Expand All @@ -72,7 +72,7 @@ def sum(self, edge, lengths=None, _autograd=False):
else:
return DPManual.apply(self, edge, lengths)

def marginals(self, edge, lengths=None, _autograd=False):
def marginals(self, edge, lengths=None, _autograd=True):
"""
Compute the marginals of a structured model.
Expand Down

0 comments on commit 6f06de9

Please sign in to comment.