From 1b6f731f7d9c6442eae6a6483967fe3aa40f533c Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Mon, 9 Sep 2019 08:40:29 -0400 Subject: [PATCH 1/5] . --- torch_struct/cky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 472a33df..9c742779 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -70,7 +70,7 @@ def _dp(self, scores, lengths=None, force_grad=False): for w in range(1, N): Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) - Y, Z = Y.clone(), Z.clone() + # Y, Z = Y.clone(), Z.clone() X_Y_Z = rules.view(batch, 1, NT, S, S) rule_use[w - 1][:] = semiring.times( semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z From 606081ee362738480dd05211fd74c2542c8f65a1 Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Mon, 9 Sep 2019 08:42:51 -0400 Subject: [PATCH 2/5] . --- torch_struct/helpers.py | 3 +-- torch_struct/test_algorithms.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index d45efbb0..4113e342 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -22,8 +22,7 @@ def score(self, potentials, parts): def _make_chart(self, N, size, potentials, force_grad=False): return [ ( - torch.zeros(*size) - .type_as(potentials) + torch.zeros(*size, dtype=potentials.dtype, device=potentials.device) .fill_(self.semiring.zero()) .requires_grad_(force_grad and not potentials.requires_grad) ) diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index c435d561..fdb9f006 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -114,7 +114,7 @@ def test_generic_lengths(data, seed): @given(data(), integers(min_value=1, max_value=10)) def test_params(data, seed): model = data.draw( - sampled_from([DepTree]) + sampled_from([DepTree, CKY]) ) # LinearChain, SemiMarkov, DepTree, CKY])) struct = model() torch.manual_seed(seed) From 27ecf2a3072be4927fe4dd13d3b29e4ffcd7cab1 Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Mon, 9 Sep 2019 09:02:39 -0400 Subject: [PATCH 3/5] . --- torch_struct/cky.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 9c742779..7afab0e8 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -67,14 +67,40 @@ def _dp(self, scores, lengths=None, force_grad=False): term_use[:] = terms + 0.0 beta[A][:, :, 0, NT:] = term_use beta[B][:, :, N - 1, NT:] = term_use + X_Y_Z = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, :NT] + X_Y_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, NT:] + X_Y1_Z = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, :NT] + X_Y1_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, NT:] for w in range(1, N): - Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) - Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) - # Y, Z = Y.clone(), Z.clone() - X_Y_Z = rules.view(batch, 1, NT, S, S) - rule_use[w - 1][:] = semiring.times( + Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1) + Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT) + a = semiring.times( semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z ) + Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1) + Z = beta[B][:, w:, N - w :, NT:].view(batch, N - w, w, 1, 1, T) + b = semiring.times( + semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z1 + ) + + Y = beta[A][:, : N - w, :w, NT:].view(batch, N - w, w, 1, T, 1) + Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT) + c = semiring.times( + semiring.sum(semiring.times(Y, Z), dim=2), X_Y1_Z + ) + + Y = beta[A][:, : N - w, :w, NT:].view(batch, N - w, w, 1, T, 1) + Z = beta[B][:, w:, N - w :, NT:].view(batch, N - w, w, 1, 1, T) + d = semiring.times( + semiring.sum(semiring.times(Y, Z), dim=2), X_Y1_Z1 + ) + + # Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) + # Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) + + # Y, Z = Y.clone(), Z.clone() + + rule_use[w - 1][:] = semiring.sum(torch.stack([a, b, c, d]), dim=0) rulesmid = rule_use[w - 1].view(batch, N - w, NT, S * S) span[w] = semiring.sum(rulesmid, dim=3) beta[A][:, : N - w, w, :NT] = span[w] From 89bf0084b1c4047133639dc35fb975f76df668f6 Mon Sep 17 00:00:00 2001 From: srush Date: Mon, 9 Sep 2019 14:45:54 +0000 Subject: [PATCH 4/5] . --- torch_struct/cky.py | 35 ++---- torch_struct/deptree.py | 198 ++++++++++++++++---------------- torch_struct/helpers.py | 8 +- torch_struct/test_algorithms.py | 10 +- 4 files changed, 120 insertions(+), 131 deletions(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 7afab0e8..95d00d99 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -45,7 +45,7 @@ def sum(self, scores, lengths=None, force_grad=False, _autograd=True): if _autograd or self.semiring is not LogSemiring: return self._dp(scores, lengths)[0] else: - return DPManual2.apply(self, *scores, lengths) + return DPManual2.apply(self, *scores, lengths=lengths) def _dp(self, scores, lengths=None, force_grad=False): terms, rules, roots = scores @@ -74,33 +74,22 @@ def _dp(self, scores, lengths=None, force_grad=False): for w in range(1, N): Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1) Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT) - a = semiring.times( + rule_use[w - 1][:, :, :, :NT, :NT] = semiring.times( semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z ) - Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1) - Z = beta[B][:, w:, N - w :, NT:].view(batch, N - w, w, 1, 1, T) - b = semiring.times( - semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z1 - ) - - Y = beta[A][:, : N - w, :w, NT:].view(batch, N - w, w, 1, T, 1) - Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT) - c = semiring.times( - semiring.sum(semiring.times(Y, Z), dim=2), X_Y1_Z - ) - - Y = beta[A][:, : N - w, :w, NT:].view(batch, N - w, w, 1, T, 1) - Z = beta[B][:, w:, N - w :, NT:].view(batch, N - w, w, 1, 1, T) - d = semiring.times( - semiring.sum(semiring.times(Y, Z), dim=2), X_Y1_Z1 - ) + Y = beta[A][:, : N - w, w - 1, :NT].view(batch, N - w, 1, NT, 1) + Z = beta[B][:, w:, N - 1, NT:].view(batch, N - w, 1, 1, T) + rule_use[w - 1][:, :, :, :NT, NT:] = semiring.times(Y, Z, X_Y_Z1) - # Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) - # Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) + Y = beta[A][:, : N - w, 0, NT:].view(batch, N - w, 1, T, 1) + Z = beta[B][:, w:, N - w, :NT].view(batch, N - w, 1, 1, NT) + rule_use[w - 1][:, :, :, NT:, :NT] = semiring.times(Y, Z, X_Y1_Z) - # Y, Z = Y.clone(), Z.clone() + if w == 1: + Y = beta[A][:, : N - w, w - 1, NT:].view(batch, N - w, 1, T, 1) + Z = beta[B][:, w:, N - w, NT:].view(batch, N - w, 1, 1, T) + rule_use[w - 1][:, :, :, NT:, NT:] = semiring.times(Y, Z, X_Y1_Z1) - rule_use[w - 1][:] = semiring.sum(torch.stack([a, b, c, d]), dim=0) rulesmid = rule_use[w - 1].view(batch, N - w, NT, S * S) span[w] = semiring.sum(rulesmid, dim=3) beta[A][:, : N - w, w, :NT] = span[w] diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index f0987c3f..e796bb63 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,6 +1,6 @@ import torch import itertools -from .helpers import _Struct, roll2 +from .helpers import _Struct def _convert(logits): @@ -139,104 +139,104 @@ def _check_potentials(self, arc_scores, lengths=None): return batch, N, lengths - def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False): - - # This function is super complicated. - semiring = self.semiring - arc_scores = _convert(arc_scores) - batch, N, lengths = self._check_potentials(arc_scores, lengths) - DIRS = 2 - - alpha = [ - self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) - for _ in range(2) - ] - - def stack(a, b, dim=-1): - return torch.stack([a, b], dim=dim) - - def sstack(a): - return torch.stack([a, a], dim=-1) - - for k in range(N - 1, -1, -1): - # Initialize - if N - k - 1 > 0: - # R completes - # I -> C* C - # I -> C* C - # C -> I C* - a = semiring.sum( - semiring.times( - *roll2( - stack( - stack(alpha[A][I][R], alpha[A][I][L]), - sstack(alpha_in[B][C][R]), - dim=0, - ), - stack( - sstack(alpha_in[A][C][L]), - stack(alpha[B][I][L], alpha[B][I][R]), - dim=0, - ), - N, - k, - 1, - ) - ).view(2, batch, N - k - 1, -1), - dim=-1, - ) - alpha[A][C][L, :, 1 : N - k, k] = a[1] - alpha[A][C][R, :, : N - k - 1, k] = a[0] - - for b, l in enumerate(lengths): - if l == k: - alpha[A][C][R, b, 0, l] = semiring.one() - alpha[B][C][R, b, l, N - l - 1] = semiring.one() - - c = semiring.sum( - semiring.times( - *roll2( - stack(alpha[A][C][L], alpha_in[B][I][R], dim=0), - stack(alpha_in[A][I][L], alpha[B][C][R], dim=0), - N, - k, - 0, - ) - ) - ) - alpha[A][C][:, :, : N - k, k] = semiring.plus( - alpha[A][C][:, :, : N - k, k], c - ) - - # Compute reverses. - alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] - - if k > 0: - f = torch.arange(N - k), torch.arange(k, N) - alpha[A][I][:, :, : N - k, k] = semiring.dot( - stack( - arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0 - ).unsqueeze(-1), - *roll2( - stack(alpha_in[B][C][L], alpha[A][C][R], dim=0), - stack(alpha[B][C][L], alpha_in[A][C][R], dim=0), - N, - k, - ) - ) - alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] - - v = alpha[A][C][R, :, 0, 0] - left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) - right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :]) - ret = torch.zeros(batch, N, N, dtype=left.dtype) - for k in torch.arange(N): - f = torch.arange(N - k), torch.arange(k, N) - ret[:, f[1], k] = left[:, k, f[0]] - ret[:, k, f[1]] = right[:, k, f[0]] - - ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1)) - return _unconvert(ret) + # def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False): + + # # This function is super complicated and was just too slow to include + # semiring = self.semiring + # arc_scores = _convert(arc_scores) + # batch, N, lengths = self._check_potentials(arc_scores, lengths) + # DIRS = 2 + + # alpha = [ + # self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) + # for _ in range(2) + # ] + + # def stack(a, b, dim=-1): + # return torch.stack([a, b], dim=dim) + + # def sstack(a): + # return torch.stack([a, a], dim=-1) + + # for k in range(N - 1, -1, -1): + # # Initialize + # if N - k - 1 > 0: + # # R completes + # # I -> C* C + # # I -> C* C + # # C -> I C* + # a = semiring.sum( + # semiring.times( + # *roll2( + # stack( + # stack(alpha[A][I][R], alpha[A][I][L]), + # sstack(alpha_in[B][C][R]), + # dim=0, + # ), + # stack( + # sstack(alpha_in[A][C][L]), + # stack(alpha[B][I][L], alpha[B][I][R]), + # dim=0, + # ), + # N, + # k, + # 1, + # ) + # ).view(2, batch, N - k - 1, -1), + # dim=-1, + # ) + # alpha[A][C][L, :, 1 : N - k, k] = a[1] + # alpha[A][C][R, :, : N - k - 1, k] = a[0] + + # for b, l in enumerate(lengths): + # if l == k: + # alpha[A][C][R, b, 0, l] = semiring.one() + # alpha[B][C][R, b, l, N - l - 1] = semiring.one() + + # c = semiring.sum( + # semiring.times( + # *roll2( + # stack(alpha[A][C][L], alpha_in[B][I][R], dim=0), + # stack(alpha_in[A][I][L], alpha[B][C][R], dim=0), + # N, + # k, + # 0, + # ) + # ) + # ) + # alpha[A][C][:, :, : N - k, k] = semiring.plus( + # alpha[A][C][:, :, : N - k, k], c + # ) + + # # Compute reverses. + # alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] + + # if k > 0: + # f = torch.arange(N - k), torch.arange(k, N) + # alpha[A][I][:, :, : N - k, k] = semiring.dot( + # stack( + # arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0 + # ).unsqueeze(-1), + # *roll2( + # stack(alpha_in[B][C][L], alpha[A][C][R], dim=0), + # stack(alpha[B][C][L], alpha_in[A][C][R], dim=0), + # N, + # k, + # ) + # ) + # alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] + + # v = alpha[A][C][R, :, 0, 0] + # left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) + # right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :]) + # ret = torch.zeros(batch, N, N, dtype=left.dtype) + # for k in torch.arange(N): + # f = torch.arange(N - k), torch.arange(k, N) + # ret[:, f[1], k] = left[:, k, f[0]] + # ret[:, k, f[1]] = right[:, k, f[0]] + + # ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1)) + # return _unconvert(ret) def _arrange_marginals(self, grads): batch, N = grads[0][0].shape diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 4113e342..7d582097 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -3,12 +3,12 @@ from torch.autograd import Function -def roll(a, b, N, k, gap=0): - return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)]) +# def roll(a, b, N, k, gap=0): +# return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)]) -def roll2(a, b, N, k, gap=0): - return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)]) +# def roll2(a, b, N, k, gap=0): +# return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)]) class _Struct: diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index fdb9f006..5a8f1dca 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -59,11 +59,11 @@ def test_generic_a(data): print(alpha, count) assert torch.isclose(count[0], alpha[0]) - # vals, _ = model._rand() - # struct = model(MaxSemiring) - # score = struct.sum(vals) - # marginals = struct.marginals(vals) - # assert torch.isclose(score, struct.score(vals, marginals)).all() + vals, _ = model._rand() + struct = model(MaxSemiring) + score = struct.sum(vals) + marginals = struct.marginals(vals) + assert torch.isclose(score, struct.score(vals, marginals)).all() @given(data(), integers(min_value=1, max_value=10)) From a990699301c69114a99abc779c4d3f0bf1abbabc Mon Sep 17 00:00:00 2001 From: srush Date: Mon, 9 Sep 2019 14:53:01 +0000 Subject: [PATCH 5/5] . --- torch_struct/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 7d582097..a7110b27 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -75,12 +75,12 @@ def marginals(self, edge, lengths=None, _autograd=True): marginals: b x (N-1) x C x C table """ - v, edges, alpha = self._dp(edge, lengths=lengths, force_grad=True) if ( _autograd or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward") ): + v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True) marg = torch.autograd.grad( v.sum(dim=0), edges, @@ -90,4 +90,5 @@ def marginals(self, edge, lengths=None, _autograd=True): ) return self._arrange_marginals(marg) else: + v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True) return self._dp_backward(edge, lengths, alpha)