diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 7356082d..472a33df 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -29,7 +29,7 @@ def backward(ctx, grad_v): class CKY(_Struct): - def sum(self, scores, lengths=None, force_grad=False, _autograd=False): + def sum(self, scores, lengths=None, force_grad=False, _autograd=True): """ Compute the inside pass of a CFG using CKY. @@ -162,7 +162,7 @@ def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False): return (term_marginals, edge_marginals, root_marginals) - def marginals(self, scores, lengths=None, _autograd=False): + def marginals(self, scores, lengths=None, _autograd=True): """ Compute the marginals of a CFG using CKY. diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index d9db01a9..f6ff5818 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,6 +1,6 @@ import torch import itertools -from .helpers import _Struct, roll +from .helpers import _Struct, roll2 def _convert(logits): @@ -48,16 +48,35 @@ def _dp(self, arc_scores, lengths=None, force_grad=False): DIRS = 2 + alpha = [ + self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) + for _ in range(2) + ] + # Want to fix this slicing function. + # class MySlice(Function): + # @staticmethod + # def forward(ctx, alpha, beta, s1, s2, e, a, b, c, d): + # indices = torch.tensor([s1, s2, e, a, b, c, d]) + # ctx.save_for_backward(indices) + # return alpha[e, :, a:b, c:d] + + # @staticmethod + # def backward(ctx, grad_v): + # a, = ctx.saved_tensors + # itertools(a[2], :, a[3]:a[4], a[5]:a[6]) + # beta[a[0]][a[1]][a[2], :, a[3]:a[4], a[5]:a[6]] += grad_v + # return None, None, None, None, None, None, None, None, None + + # s = MySlice.apply + def s(input, e, a, b, c, d): + return input[e, :, a:b, c:d] + def stack(a, b): return torch.stack([a, b]) def sstack(a): return torch.stack([a, a]) - alpha = [ - self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) - for _ in range(2) - ] arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad) # Inside step. assumes first token is root symbol @@ -66,28 +85,32 @@ def sstack(a): for k in range(1, N): f = torch.arange(N - k), torch.arange(k, N) - arcs[k] = semiring.dot( - sstack(alpha[A][C][R, :, : N - k, :k]), - sstack(alpha[B][C][L, :, k:, N - k :]), - stack(arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]]).unsqueeze( - -1 + arcs[k] = semiring.times( + sstack( + semiring.sum( + semiring.times( + s(alpha[A][C], R, 0, N - k, 0, k), + s(alpha[B][C], L, k, N, N - k, N), + ) + ) ), + stack(arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]]), ) alpha[A][I][:, :, : N - k, k] = arcs[k] alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] alpha[A][C][:, :, : N - k, k] = semiring.dot( stack( - alpha[A][C][L, :, : N - k, :k], - alpha[A][I][R, :, : N - k, 1 : k + 1], + s(alpha[A][C], L, 0, N - k, 0, k), + s(alpha[A][I], R, 0, N - k, 1, k + 1), ), stack( - alpha[B][I][L, :, k:, N - k - 1 : N - 1], - alpha[B][C][R, :, k:, N - k :], + s(alpha[B][I], L, k, N, N - k - 1, N - 1), + s(alpha[B][C], R, k, N, N - k, N), ), ) alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] + v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]) - print(v) return (v, arcs[1:], alpha) def _check_potentials(self, arc_scores, lengths=None): @@ -116,63 +139,60 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False): for _ in range(2) ] - def stack(a, b): - return torch.stack([a, b], dim=-1) + def stack(a, b, dim=-1): + return torch.stack([a, b], dim=dim) def sstack(a): return torch.stack([a, a], dim=-1) for k in range(N - 1, -1, -1): # Initialize - for b, l in enumerate(lengths): - alpha[A][C][R, b, 0, l] = semiring.one() - alpha[B][C][R, b, l, N - l - 1] = semiring.one() - - # R completes - # I -> C* C - # I -> C* C - # C -> I C* - a = semiring.dot( - *roll( - stack(alpha[A][I][R], alpha[A][I][L]), - sstack(alpha_in[A][C][L]), - N, - k, - 1, + if N - k - 1 > 0: + # R completes + # I -> C* C + # I -> C* C + # C -> I C* + a = semiring.sum( + semiring.times( + *roll2( + stack( + stack(alpha[A][I][R], alpha[A][I][L]), + sstack(alpha_in[B][C][R]), + dim=0, + ), + stack( + sstack(alpha_in[A][C][L]), + stack(alpha[B][I][L], alpha[B][I][R]), + dim=0, + ), + N, + k, + 1, + ) + ).view(2, batch, N - k - 1, -1), + dim=-1, ) - ) - - c = semiring.dot(*roll(alpha_in[B][I][R], alpha[B][C][R], N, k, 0)) - - alpha[A][C][R, :, : N - k - 1, k] = semiring.plus( - semiring.sum(a), alpha[A][C][R, :, : N - k - 1, k] - ) + alpha[A][C][L, :, 1 : N - k, k] = a[1] + alpha[A][C][R, :, : N - k - 1, k] = a[0] - alpha[A][C][R][:, : N - k, k] = semiring.plus( - alpha[A][C][R][:, : N - k, k], c - ) - - # L completes - # I -> C* C - # I -> C* C - # C -> I C* - a = semiring.dot( - *roll( - sstack(alpha_in[B][C][R]), - stack(alpha[B][I][L], alpha[B][I][R]), - N, - k, - 1, + for b, l in enumerate(lengths): + if l == k: + alpha[A][C][R, b, 0, l] = semiring.one() + alpha[B][C][R, b, l, N - l - 1] = semiring.one() + + c = semiring.sum( + semiring.times( + *roll2( + stack(alpha[A][C][L], alpha_in[B][I][R], dim=0), + stack(alpha_in[A][I][L], alpha[B][C][R], dim=0), + N, + k, + 0, + ) ) ) - - c = semiring.dot(*roll(alpha[A][C][L], alpha_in[A][I][L], N, k, 0)) - - alpha[A][C][L, :, 1 : N - k, k] = semiring.plus( - semiring.sum(a), alpha[A][C][L, :, 1 : N - k, k] - ) - alpha[A][C][L][:, : N - k, k] = semiring.plus( - c, alpha[A][C][L][:, : N - k, k] + alpha[A][C][:, :, : N - k, k] = semiring.plus( + alpha[A][C][:, :, : N - k, k], c ) # Compute reverses. @@ -180,35 +200,29 @@ def sstack(a): if k > 0: f = torch.arange(N - k), torch.arange(k, N) - - # Incomplete - alpha[A][I][R][:, : N - k, k] = semiring.dot( - arc_scores[:, f[0], f[1]].unsqueeze(-1), - *roll(alpha[A][C][R], alpha_in[A][C][R], N, k) + alpha[A][I][:, :, : N - k, k] = semiring.dot( + stack( + arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0 + ).unsqueeze(-1), + *roll2( + stack(alpha_in[B][C][L], alpha[A][C][R], dim=0), + stack(alpha[B][C][L], alpha_in[A][C][R], dim=0), + N, + k, + ) ) - - # C -> C I - alpha[A][I][L][:, : N - k, k] = semiring.dot( - arc_scores[:, f[1], f[0]].unsqueeze(-1), - *roll(alpha_in[B][C][L], alpha[B][C][L], N, k) - ) - - # Compute reverses alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] v = alpha[A][C][R, :, 0, 0] left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :]) - ret = torch.zeros(batch, N, N).type_as(left) - for k in range(N): - for d in range(N - k): - ret[:, k + d, k] = semiring.div_exp( - left[:, k, d] - arc_scores[:, k + d, k], v.view(batch) - ) - ret[:, k, k + d] = semiring.div_exp( - right[:, k, d] - arc_scores[:, k, k + d], v.view(batch) - ) + for k in torch.arange(N): + f = torch.arange(N - k), torch.arange(k, N) + ret[:, f[1], k] = left[:, k, f[0]] + ret[:, k, f[1]] = right[:, k, f[0]] + + ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1)) return _unconvert(ret) def _arrange_marginals(self, grads): diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 77394f19..9d5cf58f 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -7,29 +7,10 @@ def roll(a, b, N, k, gap=0): return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)]) -class DPManual(Function): - @staticmethod - def forward(ctx, obj, input, lengths): - with torch.no_grad(): - v, _, alpha = obj._dp(input, lengths, False) - ctx.obj = obj - ctx.lengths = lengths - ctx.alpha = alpha - - if isinstance(input, tuple): - ctx.save_for_backward(*input) - else: - ctx.save_for_backward(input) - return v +def roll2(a, b, N, k, gap=0): + return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)]) + - @staticmethod - def backward(ctx, grad_v): - input = ctx.saved_tensors - if len(input) == 1: - input = input[0] - with torch.no_grad(): - marginals = ctx.obj._dp_backward(input, ctx.lengths, ctx.alpha) - return None, marginals, None class _Struct: @@ -61,8 +42,9 @@ def sum(self, edge, lengths=None, _autograd=True): Returns: v: b tensor of total sum - """ + + if ( _autograd or self.semiring is not LogSemiring @@ -70,8 +52,20 @@ def sum(self, edge, lengths=None, _autograd=True): ): return self._dp(edge, lengths)[0] else: - return DPManual.apply(self, edge, lengths) + v, _, alpha = self._dp(edge, lengths, False) + + class DPManual(Function): + @staticmethod + def forward(ctx, input): + return v + + @staticmethod + def backward(ctx, grad_v): + marginals = self._dp_backward(edge, lengths, alpha) + return marginals.mul( + grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))) + return DPManual.apply(edge) def marginals(self, edge, lengths=None, _autograd=True): """ Compute the marginals of a structured model. @@ -83,7 +77,7 @@ def marginals(self, edge, lengths=None, _autograd=True): marginals: b x (N-1) x C x C table """ - v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) + v, edges, alpha = self._dp(edge, lengths=lengths, force_grad=True) if ( _autograd or self.semiring is not LogSemiring @@ -91,7 +85,7 @@ def marginals(self, edge, lengths=None, _autograd=True): ): marg = torch.autograd.grad( v.sum(dim=0), - edge, + edges, create_graph=True, only_inputs=True, allow_unused=False, diff --git a/torch_struct/semirings.py b/torch_struct/semirings.py index a3f6bd10..65e5796c 100644 --- a/torch_struct/semirings.py +++ b/torch_struct/semirings.py @@ -57,7 +57,7 @@ def one(): @staticmethod def div_exp(a, b): - return a.exp().div(b.exp()) + return (a - b).exp() class LogSemiring(_BaseLog): diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index ccdbb8ed..43ed579e 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -42,9 +42,7 @@ def test_fb(data): if isinstance(marginals, tuple): for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:])): - assert torch.isclose(m1, m2).all(), ( - not torch.isclose(m1, m2) - ).nonzero() + assert torch.isclose(m1, m2).all(), (not torch.isclose(m1, m2)).nonzero() else: assert torch.isclose(marginals, marginals2).all() @@ -115,12 +113,12 @@ def test_generic_lengths(data, seed): @given(data(), integers(min_value=1, max_value=10)) def test_params(data, seed): - model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree, CKY])) + model = data.draw(sampled_from([DepTree]))#LinearChain, SemiMarkov, DepTree, CKY])) struct = model() torch.manual_seed(seed) vals, (batch, N) = struct._rand() if isinstance(vals, tuple): - vals = (v.requires_grad_(True) for v in vals) + vals = tuple((v.requires_grad_(True) for v in vals)) else: vals.requires_grad_(True) # torch.autograd.set_detect_anomaly(True) @@ -128,6 +126,14 @@ def test_params(data, seed): alpha = model(semiring).sum(vals) alpha.sum().backward() + if not isinstance(vals, tuple): + b = vals.grad.detach() + vals.grad.zero_() + alpha = model(semiring).sum(vals, _autograd=False) + alpha.sum().backward() + c = vals.grad.detach() + assert torch.isclose(b, c).all() + def test_hmm(): C, V, batch, N = 5, 20, 2, 5