Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 7, 2019
1 parent c54576d commit c7b9bd5
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,11 @@ def sstack(a):
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
ret = torch.zeros(batch, N, N).type_as(left)
f = torch.arange(N - k), torch.arange(k, N)
ret[:, k, f[1]] = right[:, k, f[0]]
ret[:, k, f[0]] = left[:, k, f[1]]
for k in range(N):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[1], k] = left[:, k, f[0]]
ret[:, k, f[1]] = right[:, k, f[0]]

ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
return _unconvert(ret)

Expand Down

0 comments on commit c7b9bd5

Please sign in to comment.