Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 9, 2019
1 parent 7cec16c commit 47fc6b2
Showing 1 changed file with 14 additions and 32 deletions.
46 changes: 14 additions & 32 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,51 +70,33 @@ def sstack(a):

AIR = alpha[A][I][R, :, :N-k, 1:k]
BIL = alpha[B][I][L, :, k:N, N-k:N-1]

k = 1
AC2 = alpha[A][C][:, :, :N - k, :k]
BC2 = alpha[B][C][:, :, k:, N - k:]
ends = [None]
for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)

if k > 1:
ACR2 = torch.cat([ACR[:, :-1], ACR_next[:, :-1].unsqueeze(-1)],
dim=2)
else:
ACR2 = alpha[A][C][R, :, :N - k, :k]

AC2 = torch.cat([AC[:, :, :-1], AC_next[:, :, :-1].unsqueeze(-1)],
dim=3)
if k > 1:
ACL2 = torch.cat([ACL[:, :-1], ACL_next[:, :-1].unsqueeze(-1)],
dim=2)
else:
ACL2 = alpha[A][C][L, :, :N - k, :k]
BC2 = torch.cat([AC_next[:, :, 1:].unsqueeze(-1), BC[:, :, 1:]], dim=3)

if k > 1:
BCL2 = torch.cat([ACL_next[:, 1:].unsqueeze(-1), BCL[:, 1:]], dim=2)
else:
BCL2 = alpha[B][C][L, :, k:, N - k:]

if k > 1:
BCR2 = torch.cat([ACR_next[:, 1:].unsqueeze(-1), BCR[:, 1:]],dim=2)
else:
BCR2 = alpha[B][C][R, :, k:, N-k:]

start = semiring.dot(BCL2, ACR2)
start = semiring.dot(BC2[L], AC2[R])
arcs[k] = stack(semiring.times(start, arc_scores[:, f[1], f[0]]),
semiring.times(start, arc_scores[:, f[0], f[1]]))

AIR2 = torch.cat([AIR[:, :-1], arcs[k][R].unsqueeze(-1)], dim=2)
BIL2 = torch.cat([arcs[k][L].unsqueeze(-1), BIL[:, 1:]], dim=2)
AC_next = stack(semiring.dot(AC2[L], BIL2), semiring.dot(AIR2, BC2[R]))

ACL_next = semiring.dot(ACL2, BIL2)
ACR_next = semiring.dot(AIR2, BCR2)

alpha[A][C][R, :, : N - k, k] = ACR_next
ACR = ACR2
BCL = BCL2
ACL = ACL2
ends.append(AC_next[R, :, 0])
AC = AC2
BC = BC2
AIR = AIR2
BIL = BIL2
BCR = BCR2

v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
v = torch.stack([ends[l][i] for i, l in enumerate(lengths)])
# v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
return (v, arcs[1:], alpha)

def _check_potentials(self, arc_scores, lengths=None):
Expand Down

0 comments on commit 47fc6b2

Please sign in to comment.