Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 25, 2019
1 parent db13392 commit 04ac887
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,36 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
semiring.one_(alpha[B][C][L].data[ :, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[ :, :, :, -1].data)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]
# for k in range(1, N):
# f = torch.arange(N - k), torch.arange(k, N)
# ACL = alpha[A][C][L][: N - k, :k]
# ACR = alpha[A][C][R][: N - k, :k]

BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
x = semiring.dot(ACR, BCL)
# BCL = alpha[B][C][L][k:, N - k :]
# BCR = alpha[B][C][R][k:, N - k :]
# x = semiring.dot(ACR, BCL)

arcs_l = semiring.times(
x, arc_scores[:, :, f[1], f[0]])
# arcs_l = semiring.times(
# x, arc_scores[:, :, f[1], f[0]])

# alpha[A][I][L][: N - k, k] = arcs_l
# alpha[B][I][L][k:N, N - k - 1] = arcs_l
# # alpha[A][I][L][: N - k, k] = arcs_l
# # alpha[B][I][L][k:N, N - k - 1] = arcs_l

arcs_r = semiring.times(
x , arc_scores[:, :, f[0], f[1]])
# alpha[A][I][R][: N - k, k] = arcs_r
# alpha[B][I][R][k:N, N - k - 1] = arcs_r
# arcs_r = semiring.times(
# x , arc_scores[:, :, f[0], f[1]])
# # alpha[A][I][R][: N - k, k] = arcs_r
# # alpha[B][I][R][k:N, N - k - 1] = arcs_r

AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][ k:, N - k - 1 : N - 1]
# AIR = alpha[A][I][R][: N - k, 1 : k + 1]
# BIL = alpha[B][I][L][ k:, N - k - 1 : N - 1]

new = semiring.dot(ACL, BIL)
# alpha[A][C][L][ : N - k, k] = new
# alpha[B][C][L][ k:N, N - k - 1] = new
# new = semiring.dot(ACL, BIL)
# # alpha[A][C][L][ : N - k, k] = new
# # alpha[B][C][L][ k:N, N - k - 1] = new

new = semiring.dot(AIR, BCR)
# alpha[A][C][R][ : N - k, k] = new
# alpha[B][C][R][ k:N, N - k - 1] = new
# new = semiring.dot(AIR, BCR)
# # alpha[A][C][R][ : N - k, k] = new
# # alpha[B][C][R][ k:N, N - k - 1] = new

final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
Expand Down

0 comments on commit 04ac887

Please sign in to comment.