Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 9, 2019
1 parent cd555b6 commit bc9707c
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,36 @@ def sstack(a):

ends = [None]
for k in range(1, N):
def tf(a): return torch.narrow(a, 2, 0, N-k)
def tb(a): return torch.narrow(a, 2, 1, N-k)

def tf(a):
return torch.narrow(a, 2, 0, N - k)

def tb(a):
return torch.narrow(a, 2, 1, N - k)

f = torch.arange(N - k), torch.arange(k, N)
if k > 1:
AC2 = torch.cat(
[tf(AC), tf(AC_next).unsqueeze(-1)], dim=3
)
AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=3)
if k > 1:
BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=3)

start = semiring.dot(BC2[L], AC2[R])
# if k == 1:
arcs[k] = stack(
semiring.times(start, arc_scores[:, f[1], f[0]]),
semiring.times(start, arc_scores[:, f[0], f[1]])
semiring.times(start, arc_scores[:, f[0], f[1]]),
)

# else:
# arcs[k] = stack(semiring.times(start), #, arc_scores[:, f[1], f[0]]),
# semiring.times(start)) #, arc_scores[:, f[0], f[1]]))

AIR2 = torch.cat([torch.narrow(AIR, 1, 0, N-k), arcs[k][R].unsqueeze(-1)], dim=2)
BIL2 = torch.cat([arcs[k][L].unsqueeze(-1), torch.narrow(BIL, 1, 1, N-k)], dim=2)
AIR2 = torch.cat(
[torch.narrow(AIR, 1, 0, N - k), arcs[k][R].unsqueeze(-1)], dim=2
)
BIL2 = torch.cat(
[arcs[k][L].unsqueeze(-1), torch.narrow(BIL, 1, 1, N - k)], dim=2
)
AC_next = stack(semiring.dot(AC2[L], BIL2), semiring.dot(AIR2, BC2[R]))

ends.append(AC_next[R, :, 0])
Expand Down

0 comments on commit bc9707c

Please sign in to comment.