From bc9707c24f5afe5bbfebf8411710424d244b3850 Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Sun, 8 Sep 2019 23:50:27 -0400 Subject: [PATCH] . --- torch_struct/deptree.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 9149ee7e..ead9f93c 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -83,14 +83,16 @@ def sstack(a): ends = [None] for k in range(1, N): - def tf(a): return torch.narrow(a, 2, 0, N-k) - def tb(a): return torch.narrow(a, 2, 1, N-k) + + def tf(a): + return torch.narrow(a, 2, 0, N - k) + + def tb(a): + return torch.narrow(a, 2, 1, N - k) f = torch.arange(N - k), torch.arange(k, N) if k > 1: - AC2 = torch.cat( - [tf(AC), tf(AC_next).unsqueeze(-1)], dim=3 - ) + AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=3) if k > 1: BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=3) @@ -98,15 +100,19 @@ def tb(a): return torch.narrow(a, 2, 1, N-k) # if k == 1: arcs[k] = stack( semiring.times(start, arc_scores[:, f[1], f[0]]), - semiring.times(start, arc_scores[:, f[0], f[1]]) + semiring.times(start, arc_scores[:, f[0], f[1]]), ) # else: # arcs[k] = stack(semiring.times(start), #, arc_scores[:, f[1], f[0]]), # semiring.times(start)) #, arc_scores[:, f[0], f[1]])) - AIR2 = torch.cat([torch.narrow(AIR, 1, 0, N-k), arcs[k][R].unsqueeze(-1)], dim=2) - BIL2 = torch.cat([arcs[k][L].unsqueeze(-1), torch.narrow(BIL, 1, 1, N-k)], dim=2) + AIR2 = torch.cat( + [torch.narrow(AIR, 1, 0, N - k), arcs[k][R].unsqueeze(-1)], dim=2 + ) + BIL2 = torch.cat( + [arcs[k][L].unsqueeze(-1), torch.narrow(BIL, 1, 1, N - k)], dim=2 + ) AC_next = stack(semiring.dot(AC2[L], BIL2), semiring.dot(AIR2, BC2[R])) ends.append(AC_next[R, :, 0])