Skip to content

Commit

Permalink
Merge c282a3d into b6816a4
Browse files Browse the repository at this point in the history
  • Loading branch information
sustcsonglin committed Jul 27, 2020
2 parents b6816a4 + c282a3d commit b81cc0c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
52 changes: 29 additions & 23 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DepTree(_Struct):
"""

def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
multiroot = getattr(self, "multiroot")
if arc_scores_in.dim() not in (3, 4):
raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

Expand All @@ -72,34 +73,39 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]

BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
if multiroot:
start_idx = 0
else:
start_idx = 1

for k in range(1, N-start_idx):
f = torch.arange(start_idx, N - k), torch.arange(k+start_idx, N)
ACL = alpha[A][C][L][start_idx: N - k, :k]
ACR = alpha[A][C][R][start_idx: N - k, :k]
BCL = alpha[B][C][L][k+start_idx:, N - k :]
BCR = alpha[B][C][R][k+start_idx:, N - k :]
x = semiring.dot(ACR, BCL)

arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])

alpha[A][I][L][: N - k, k] = arcs_l
alpha[B][I][L][k:N, N - k - 1] = arcs_l

alpha[A][I][L][start_idx:N - k, k] = arcs_l
alpha[B][I][L][k+start_idx:N, N - k - 1] = arcs_l
arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
alpha[A][I][R][: N - k, k] = arcs_r
alpha[B][I][R][k:N, N - k - 1] = arcs_r

AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]

alpha[A][I][R][start_idx:N - k, k] = arcs_r
alpha[B][I][R][k+start_idx:N, N - k - 1] = arcs_r
AIR = alpha[A][I][R][start_idx: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k+start_idx:, N - k - 1 : N - 1]
new = semiring.dot(ACL, BIL)
alpha[A][C][L][: N - k, k] = new
alpha[B][C][L][k:N, N - k - 1] = new

alpha[A][C][L][start_idx: N - k, k] = new
alpha[B][C][L][k+start_idx:N, N - k - 1] = new
new = semiring.dot(AIR, BCR)
alpha[A][C][R][: N - k, k] = new
alpha[B][C][R][k:N, N - k - 1] = new
alpha[A][C][R][start_idx: N - k, k] = new
alpha[B][C][R][k+start_idx:N, N - k - 1] = new

if not multiroot:
root_incomplete_span = semiring.times(alpha[A][C][L][1, :N-1], arc_scores[:, :, 0, 1:])
for k in range(1,N):
AIR = root_incomplete_span[:, :, :k]
BCR = alpha[B][C][R][k, N-k:]
alpha[A][C][R][0, k] = semiring.dot(AIR, BCR)

final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
Expand Down
5 changes: 4 additions & 1 deletion torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ class DependencyCRF(StructDistribution):
"""

struct = DepTree
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
super(DependencyCRF, self).__init__(log_potentials, lengths, args)
self.struct = DepTree
setattr(self.struct, "multiroot", multiroot)


class TreeCRF(StructDistribution):
Expand Down

0 comments on commit b81cc0c

Please sign in to comment.