Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 9, 2019
1 parent a452764 commit 38403ba
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,46 +91,37 @@ def sstack(a):

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
(ACR2,
BCL2,
ACL2,
AIR2,
BIL2,
BCR2) = self._make_chart(6, (batch, N-k, k), arc_scores)

if k > 1:
ACR2[:, :, :k-1] = ACR[:, :N-k, :k-1]
ACR2[:, :, k-1] = ACR_next[:, :-1]
# ACR2[:, :, :k-1] = torch.cat([ACR[:, :N-k, :k-1],
# ACR2[:, :, k-1] = ACR_next[:, :-1]]
ACR2 = torch.cat([ACR[:, :N-k, :k-1], ACR_next[:, :-1].unsqueeze(-1)],
dim=2)
else:
ACR2[:] = alpha[A][C][R, :, :N - k, :k]
ACR2 = alpha[A][C][R, :, :N - k, :k]

if k > 1:
ACL2[:, :, :k-1] = ACL[:, :N-k, :k-1]
ACL2[:, :, k-1] = ACL_next[:, :-1]
ACL2 = torch.cat([ACL[:, :N-k, :k-1], ACL_next[:, :-1].unsqueeze(-1)],
dim=2)
else:
ACL2[:] = alpha[A][C][L, :, :N - k, :k]
ACL2 = alpha[A][C][L, :, :N - k, :k]

if k > 1:
BCL2[:, :, 1:] = BCL[:, 1:, :]
BCL2[:, :, 0] = ACL_next[:, 1:]
BCL2 = torch.cat([ACL_next[:, 1:].unsqueeze(-1), BCL[:, 1:, :]], dim=2)
else:
BCL2[:] = alpha[B][C][L, :, k:N, N - k:]
BCL2 = alpha[B][C][L, :, k:N, N - k:]

if k > 1:
BCR2[:, :, 1:] = BCR[:, 1:, :]
BCR2[:, :, 0] = ACR_next[:, 1:]
BCR2 = torch.cat([ACR_next[:, 1:].unsqueeze(-1), BCR[:, 1:, :]],dim=2)
else:
BCR2[:] = alpha[B][C][R, :, k:N, N-k:]
BCR2 = alpha[B][C][R, :, k:N, N-k:]

start = semiring.dot(BCL2, ACR2)
arcs[k] = stack(semiring.times(start, arc_scores[:, f[1], f[0]]),
semiring.times(start, arc_scores[:, f[0], f[1]]))

AIR2[:, :, :k-1] = AIR[:, :N-k, :k-1]
AIR2[:, :, k-1] = arcs[k][R]

BIL2[:, :, 1:] = BIL[:, 1:, :k-1]
BIL2[:, :, 0] = arcs[k][L]
AIR2 = torch.cat([AIR[:, :N-k, :k-1], arcs[k][R].unsqueeze(-1)], dim=2)
BIL2 = torch.cat([arcs[k][L].unsqueeze(-1), BIL[:, 1:, :k-1]], dim=2)

ACL_next = semiring.dot(ACL2, BIL2)
ACR_next = semiring.dot(AIR2, BCR2)
Expand Down

0 comments on commit 38403ba

Please sign in to comment.