diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 4a9f5810..acb3364c 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -61,36 +61,36 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False): semiring.one_(alpha[B][C][L].data[ :, :, :, -1].data) semiring.one_(alpha[B][C][R].data[ :, :, :, -1].data) - for k in range(1, N): - f = torch.arange(N - k), torch.arange(k, N) - ACL = alpha[A][C][L][: N - k, :k] - ACR = alpha[A][C][R][: N - k, :k] + # for k in range(1, N): + # f = torch.arange(N - k), torch.arange(k, N) + # ACL = alpha[A][C][L][: N - k, :k] + # ACR = alpha[A][C][R][: N - k, :k] - BCL = alpha[B][C][L][k:, N - k :] - BCR = alpha[B][C][R][k:, N - k :] - x = semiring.dot(ACR, BCL) + # BCL = alpha[B][C][L][k:, N - k :] + # BCR = alpha[B][C][R][k:, N - k :] + # x = semiring.dot(ACR, BCL) - arcs_l = semiring.times( - x, arc_scores[:, :, f[1], f[0]]) + # arcs_l = semiring.times( + # x, arc_scores[:, :, f[1], f[0]]) - # alpha[A][I][L][: N - k, k] = arcs_l - # alpha[B][I][L][k:N, N - k - 1] = arcs_l + # # alpha[A][I][L][: N - k, k] = arcs_l + # # alpha[B][I][L][k:N, N - k - 1] = arcs_l - arcs_r = semiring.times( - x , arc_scores[:, :, f[0], f[1]]) - # alpha[A][I][R][: N - k, k] = arcs_r - # alpha[B][I][R][k:N, N - k - 1] = arcs_r + # arcs_r = semiring.times( + # x , arc_scores[:, :, f[0], f[1]]) + # # alpha[A][I][R][: N - k, k] = arcs_r + # # alpha[B][I][R][k:N, N - k - 1] = arcs_r - AIR = alpha[A][I][R][: N - k, 1 : k + 1] - BIL = alpha[B][I][L][ k:, N - k - 1 : N - 1] + # AIR = alpha[A][I][R][: N - k, 1 : k + 1] + # BIL = alpha[B][I][L][ k:, N - k - 1 : N - 1] - new = semiring.dot(ACL, BIL) - # alpha[A][C][L][ : N - k, k] = new - # alpha[B][C][L][ k:N, N - k - 1] = new + # new = semiring.dot(ACL, BIL) + # # alpha[A][C][L][ : N - k, k] = new + # # alpha[B][C][L][ k:N, N - k - 1] = new - new = semiring.dot(AIR, BCR) - # alpha[A][C][R][ : N - k, k] = new - # alpha[B][C][R][ k:N, N - k - 1] = new + # new = semiring.dot(AIR, BCR) + # # alpha[A][C][R][ : N - k, k] = new + # # alpha[B][C][R][ k:N, N - k - 1] = new final = alpha[A][C][R][(0,)] v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)