From 8fa0fce704df55680caa98e3c348ae120dde2315 Mon Sep 17 00:00:00 2001 From: Sasha Date: Sat, 7 Sep 2019 14:40:47 -0400 Subject: [PATCH 1/4] outside --- torch_struct/cky.py | 151 ++++++++++++++++++++++----- torch_struct/deptree.py | 178 ++++++++++++++++++++++++++++---- torch_struct/helpers.py | 23 +++++ torch_struct/linearchain.py | 65 ++++++++++-- torch_struct/semirings.py | 10 ++ torch_struct/test_algorithms.py | 32 +++++- 6 files changed, 400 insertions(+), 59 deletions(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 754b4398..885acf16 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -1,11 +1,31 @@ import torch -from .helpers import _Struct +from .helpers import _Struct, DPManual +from .semirings import LogSemiring +from torch.autograd import Function A, B = 0, 1 +class DPManual2(Function): + @staticmethod + def forward(ctx, obj, terms, rules, roots, lengths): + v, _, alpha = obj._dp((terms, rules, roots), lengths, False) + ctx.obj = obj + ctx.lengths = lengths + ctx.alpha = alpha + ctx.v = v + ctx.save_for_backward(terms, rules, roots) + return v + + @staticmethod + def backward(ctx, grad_v): + terms, rules, roots = ctx.saved_tensors + marginals = ctx.obj._dp_backward((terms, rules, roots), ctx.lengths, ctx.alphactx.v) + return None, marginals, None + + class CKY(_Struct): - def sum(self, scores, lengths=None, force_grad=False): + def sum(self, scores, lengths=None, force_grad=False, _autograd=False): """ Compute the inside pass of a CFG using CKY. @@ -18,7 +38,10 @@ def sum(self, scores, lengths=None, force_grad=False): v: b tensor of total sum spans: list of N, b x N x (NT+t) """ - return self._dp(scores, lengths)[0] + if _autograd or not self.semiring is LogSemiring: + return self._dp(scores, lengths)[0] + else: + return DPManual2.apply(self, *scores, lengths) def _dp(self, scores, lengths=None, force_grad=False): terms, rules, roots = scores @@ -40,7 +63,6 @@ def _dp(self, scores, lengths=None, force_grad=False): term_use[:] = terms + 0.0 beta[A][:, :, 0, NT:] = term_use beta[B][:, :, N - 1, NT:] = term_use - for w in range(1, N): Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) @@ -56,9 +78,82 @@ def _dp(self, scores, lengths=None, force_grad=False): top[:] = torch.stack([beta[A][i, 0, l - 1, :NT] for i, l in enumerate(lengths)]) log_Z = semiring.dot(top, roots) - return log_Z, (term_use, rule_use, top) + return log_Z, (term_use, rule_use, top), beta + + def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False): + terms, rules, roots = scores + semiring = self.semiring + batch, N, T = terms.shape + _, NT, _, _ = rules.shape + S = NT + T + if lengths is None: + lengths = torch.LongTensor([N] * batch) + + beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad) + span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad) + span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad) + top = self._make_chart(1, (batch, NT), rules, force_grad)[0] + term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0] + - def marginals(self, scores, lengths=None): + ssum = semiring.sum + st = semiring.times + X_Y_Z = rules.view(batch, 1, NT, S, S) + + for w in range(N-1, -1, -1): + for b, l in enumerate(lengths): + beta[A][b, 0, l-1, :NT] = roots[b] + beta[B][b, l-1, N-(l), :NT] = roots[b] + + # LEFT + # all bigger on the left. + X = beta[A][:, :N-w-1, w+1:, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1) + Z = alpha_in[A][:, w+1:N, 0:N-w-1].view(batch, N-w-1, N-w-1, 1, 1, S) + t = st(ssum(st(X, Z), dim=2), X_Y_Z) + # sum out x and y + span_l[w] = ssum(ssum(t, dim =-3), dim=-1) + + # RIGHT + X = beta[B][:, w+1:, :N-1-w, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1) + Y = alpha_in[B][:, :N-w-1, w+1:, :].view(batch, N-w-1, N-w-1, 1, S, 1) + t = st(ssum(st(X, Y), dim=2), X_Y_Z) + + span_r[w] = ssum(ssum(t, dim=-3), dim=-2) + + beta[A][:, :N-w-1, w, :] = span_l[w] + beta[A][:, 1:N-w, w, :] = ssum(torch.stack([span_r[w], + beta[A][:, 1:N-w, w, :]]), dim=0) + beta[B][:, w:, N-w-1, :] = beta[A][:, :N - w, w, :] + + + term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms) + term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0] + for n in range(N): + term_marginals[:, n] = semiring.div_exp(term_use[:, n], + v.view(batch, 1)) + + root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0] + for b in range(batch): + root_marginals[b] = semiring.div_exp(st(alpha_in[A][b, 0, lengths[b]-1, :NT], roots[b]), + v[b].view(1)) + edge_marginals = self._make_chart(1, (batch, N, N, NT, S, S), terms, force_grad=False)[0] + edge_marginals.fill_(0) + for w in range(1, N): + Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) + Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) + score = semiring.times( + semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z + ) + score = st(score, beta[A][:, :N-w, w, :NT].view(batch, N-w, NT, 1, 1)) + edge_marginals[:, :N-w, w-1] = semiring.div_exp(score, + v.view(batch, 1, 1, 1, 1)) + edge_marginals = edge_marginals.transpose(1, 2) + + + return (term_marginals, edge_marginals, root_marginals) + + + def marginals(self, scores, lengths=None, _autograd=False): """ Compute the marginals of a CFG using CKY. @@ -76,23 +171,26 @@ def marginals(self, scores, lengths=None): batch, N, T = terms.shape _, NT, _, _ = rules.shape S = NT + T - v, (term_use, rule_use, top) = self._dp( + v, (term_use, rule_use, top), alpha = self._dp( scores, lengths=lengths, force_grad=True ) - marg = torch.autograd.grad( - v.sum(dim=0), - tuple(rule_use) + (top, term_use), - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - rule_use = marg[:-2] - rules = torch.zeros(batch, N, N, NT, S, S) - for w in range(len(rule_use)): - rules[:, w, : N - w - 1] = rule_use[w] - assert marg[-1].shape == (batch, N, T) - assert marg[-2].shape == (batch, NT) - return (marg[-1], rules, marg[-2]) + if _autograd or not self.semiring is LogSemiring: + marg = torch.autograd.grad( + v.sum(dim=0), + tuple(rule_use) + (top, term_use), + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + rule_use = marg[:-2] + rules = torch.zeros(batch, N, N, NT, S, S) + for w in range(len(rule_use)): + rules[:, w, : N - w - 1] = rule_use[w] + assert marg[-1].shape == (batch, N, T) + assert marg[-2].shape == (batch, NT) + return (marg[-1], rules, marg[-2]) + else: + return self._dp_backward(edge, lengths, alpha, v) @staticmethod def to_parts(spans, extra, lengths=None): @@ -141,7 +239,6 @@ def from_parts(chart): :, n, torch.arange(N - n - 1) ] spans[:, torch.arange(N), torch.arange(N), NT:] = terms - print(rules.nonzero(), spans.nonzero()) return spans, (NT, S - NT) ###### Test @@ -168,8 +265,6 @@ def enumerate(x, start, end): [(x, start, w, end)] + y1 + z1, ) - # for nt in range(NT): - # print(list(enumerate(nt, 0, N))) ls = [] for nt in range(NT): ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)] @@ -177,10 +272,10 @@ def enumerate(x, start, end): @staticmethod def _rand(): - batch = torch.randint(2, 4, (1,)) - N = torch.randint(2, 4, (1,)) - NT = torch.randint(2, 4, (1,)) - T = torch.randint(2, 4, (1,)) + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + T = torch.randint(2, 5, (1,)) terms = torch.rand(batch, N, T) rules = torch.rand(batch, NT, (NT + T), (NT + T)) roots = torch.rand(batch, NT) diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 6c46ec7b..0e03491e 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,7 +1,7 @@ import torch import itertools -from .helpers import _Struct - +from .helpers import _Struct, DPManual +from .semirings import LogSemiring def _convert(logits): "move root arcs from diagonal" @@ -34,7 +34,7 @@ def _unconvert(logits): class DepTree(_Struct): - def sum(self, arc_scores, lengths=None): + def sum(self, arc_scores, lengths=None, _autograd=False): """ Compute the inside pass of a projective dependency CRF. @@ -48,7 +48,10 @@ def sum(self, arc_scores, lengths=None): arcs: list of N, LR x b x N table """ - return self._dp(arc_scores, lengths)[0] + if _autograd or not self.semiring is LogSemiring: + return self._dp(arc_scores, lengths)[0] + else: + return DPManual.apply(self, *scores, lengths) def _dp(self, arc_scores, lengths=None, force_grad=False): semiring = self.semiring @@ -105,9 +108,142 @@ def sstack(a): return ( torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]), arcs, + alpha ) - def marginals(self, arc_scores, lengths=None): + def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False): + semiring = self.semiring + arc_scores = _convert(arc_scores) + batch, N, N2 = arc_scores.shape + + assert N == N2, "Non-square potentials" + DIRS = 2 + if lengths is None: + lengths = torch.LongTensor([N - 1] * batch) + assert max(lengths) <= N, "Length longer than N" + for b in range(batch): + arc_scores[b, lengths[b] + 1 :, :] = semiring.zero() + arc_scores[b, :, lengths[b] + 1 :] = semiring.zero() + + def stack(*a): + return torch.stack(a) + + def sstack(a): + return torch.stack([a, a]) + + alpha = [ + self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) + for _ in range(2) + ] + arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad) + + + for k in range(N-1, -1, -1): + for b, l in enumerate(lengths): + alpha[A][C][R, b, 0, l] = semiring.one() + alpha[B][C][R, b, l, N-l-1] = semiring.one() + + a = semiring.dot( + alpha[A][I][R, :, :N - (k+1), (k+1):], + alpha_in[A][C][L,:, k+1:, : N-(k+1)], + arc_scores[:, :N-(k+1), k+1:] + ) + + b = semiring.dot( + alpha[A][I][L, :, :N - (k+1), k+1:], + alpha_in[A][C][L, :, k+1:, : N-(k+1)], + arc_scores.transpose(1,2)[:, :N-(k+1), k+1:] + ) + + + + alpha[A][C][R, :, :N-k-1, k] = semiring.sum(stack(a, b, + alpha[A][C][R, :, :N-k-1, k]), dim=0) + + a = semiring.dot( + alpha[B][I][L, :, k+1:, :N - k-1], + alpha_in[B][C][R, :, :N-k-1, k+1:], + arc_scores.transpose(1,2)[:, 1:N-(k), k+1:] + + ) + b = semiring.dot( + alpha[B][I][R, :, k+1:, :N-k-1], + alpha_in[B][C][R, :, :N-k-1, k+1:], + arc_scores[:, 1:N-(k), k+1:] + ) + + alpha[A][C][L, :, 1:N-k, k] = \ + semiring.sum(stack(a, b, + alpha[A][C][L, :, 1:N-k, k]), dim=0) + + + alpha[B][C][L, :, 1:N-k, N - k - 1] = alpha[A][C][L, :, 1: N - k, k] + alpha[B][C][R, :, :N-k-1, N - k - 1] = alpha[A][C][R, :, : N - k-1, k] + + print("C", k, alpha[A][C][:, 0, :, k].nonzero()) + + + + alpha[A][I][R][:, :N-k, k] = semiring.dot( + alpha[A][C][R, :, :N-k, k:], + alpha_in[A][C][R, :, k:, :N-k] + ) + + alpha[A][I][L][:, :N-k, k] = semiring.dot( + alpha[B][C][L, :, k:, :N - (k)], + alpha_in[B][C][L, :, :N-k, k:] + ) + + alpha[A][C][R][:, :N-k, k] = semiring.sum(stack(alpha[A][C][R][:, :N-k, k], + + semiring.dot( + alpha[B][C][R, :, k:, :N - k], + alpha_in[B][I][R, :, :N-k, k:] + )), dim =0) + + alpha[A][C][L][:, :N-k, k] =semiring.sum(stack( + alpha[A][C][L][:, :N-k, k], + semiring.dot( + alpha[A][C][L, :, :N - k, k:], + alpha_in[A][I][L, :, k:, : N-k] + )), dim=0) + + + alpha[B][C][:, :, :N-k, N - k - 1] = alpha[A][C][:, :, : N - k, k] + + alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] + print("IC", k, alpha[A][C][:, 0, :, k].nonzero()) + print("I", k, alpha[A][I][:, 0, :, k].nonzero()) + + + print("COMP", alpha[A][C][:, :, :, :].nonzero()) + v = alpha[A][C][R, 0, 1, 0] + print("finsh", v) + + + left = semiring.div_exp(semiring.times(alpha[A][I][L, :, :, :], + alpha_in[A][I][L, :, :, :]), + v.view(batch, 1, 1)) + + right = semiring.div_exp(semiring.times(alpha[A][I][R, :, :, :], + alpha_in[A][I][R, :, :, :]), + v.view(batch, 1, 1)) + ret = torch.zeros(batch, N, N) + print(left, right, alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) + for k in range(N): + for d in range(N-k): + ret[:, k+d, k] = left[:, k, d] + ret[:, k, k+d] = right[:, k, d] + return _unconvert(ret) + + + # return ( + # torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]), + # arcs, + # ) + + + def marginals(self, arc_scores, lengths=None, _autograd=False): """ Compute the marginals of a projective dependency CRF. @@ -121,20 +257,24 @@ def marginals(self, arc_scores, lengths=None): """ batch, N, _ = arc_scores.shape N = N + 1 - v, arcs = self._dp(arc_scores, lengths, force_grad=True) - grads = torch.autograd.grad( - v.sum(dim=0), - arcs[1:], - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - ret = torch.zeros(batch, N, N).cpu() - for k, grad in enumerate(grads, 1): - f = torch.arange(N - k), torch.arange(k, N) - ret[:, f[0], f[1]] = grad[R].cpu() - ret[:, f[1], f[0]] = grad[L].cpu() - return _unconvert(ret) + v, arcs, alpha = self._dp(arc_scores, lengths, force_grad=True) + + if _autograd or not self.semiring is LogSemiring: + grads = torch.autograd.grad( + v.sum(dim=0), + arcs[1:], + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + ret = torch.zeros(batch, N, N).cpu() + for k, grad in enumerate(grads, 1): + f = torch.arange(N - k), torch.arange(k, N) + ret[:, f[0], f[1]] = grad[R].cpu() + ret[:, f[1], f[0]] = grad[L].cpu() + return _unconvert(ret) + else: + return self._dp_backward(arc_scores, lengths, alpha) @staticmethod def to_parts(sequence, extra=None, lengths=None): diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index aab5ac90..f0095a6d 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -1,5 +1,28 @@ import torch from .semirings import LogSemiring +from torch.autograd import Function + +class DPManual(Function): + @staticmethod + def forward(ctx, obj, input, lengths): + v, _, alpha = obj._dp(input, lengths, False) + ctx.obj = obj + ctx.lengths = lengths + ctx.alpha = alpha + + if isinstance(input, tuple): + ctx.save_for_backward(*input) + else: + ctx.save_for_backward(input) + return v + + @staticmethod + def backward(ctx, grad_v): + input = ctx.saved_tensors + if len(input) == 1: + input = input[0] + marginals = ctx.obj._dp_backward(input, ctx.lengths, ctx.alpha) + return None, marginals, None class _Struct: diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index cf263084..69d1ce6b 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -1,5 +1,8 @@ import torch -from .helpers import _Struct +from .helpers import _Struct, DPManual +from torch.autograd import Function +from .semirings import LogSemiring + class LinearChain(_Struct): @@ -14,7 +17,6 @@ def _dp(self, edge, lengths=None, force_grad=False): assert C == C2, "Transition shape doesn't match" alpha = self._make_chart(N, (batch, C), edge, force_grad=force_grad) - edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=force_grad) alpha[0].data.fill_(semiring.one()) @@ -26,9 +28,43 @@ def _dp(self, edge, lengths=None, force_grad=False): v = semiring.sum( torch.stack([alpha[l - 1][i] for i, l in enumerate(lengths)]), dim=-1 ) - return v, edge_store + return v, edge_store, alpha - def sum(self, edge, lengths=None): + def _dp_backward(self, edge, lengths, alpha_in): + semiring = self.semiring + batch, N_1, C, C2 = edge.shape + N = N_1 + 1 + if lengths is None: + lengths = torch.LongTensor([N] * batch) + + alpha = self._make_chart(N, (batch, C), edge, force_grad=False) + edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=False) + + for n in range(N-1, 0, -1): + for b, l in enumerate(lengths): + alpha[l-1][b].data.fill_(semiring.one()) + + edge_store[n-1][:] = semiring.times( + alpha[n].view(batch, C, 1), edge[:, n-1] + ) + alpha[n - 1][:] = semiring.sum(edge_store[n-1], dim=-2) + v = semiring.sum( + torch.stack([alpha[0][i] for i, l in enumerate(lengths)]), dim=-1 + ) + edge_marginals = self._make_chart(1, (batch, N-1, C, C), edge, force_grad=False)[0] + + + + for n in range(N_1): + edge_marginals[:, n] = semiring.div_exp(semiring.times(alpha_in[n].view(batch, 1, C), + edge[:, n], + alpha[n+1].view(batch, C, 1)), + v.view(batch, 1, 1)) + + return edge_marginals + + + def sum(self, edge, lengths=None, _autograd=False): """ Compute the forward pass of a linear chain CRF. @@ -42,9 +78,12 @@ def sum(self, edge, lengths=None): inside: list of N, b x C x C table """ - return self._dp(edge, lengths)[0] + if _autograd or not self.semiring is LogSemiring: + return self._dp(edge, lengths)[0] + else: + return DPManual.apply(self, edge, lengths) - def marginals(self, edge, lengths=None): + def marginals(self, edge, lengths=None, _autograd=False): """ Compute the marginals of a linear chain CRF. @@ -56,11 +95,15 @@ def marginals(self, edge, lengths=None): marginals: b x (N-1) x C x C table """ - v, alpha = self._dp(edge, lengths=lengths, force_grad=True) - marg = torch.autograd.grad( - v.sum(dim=0), alpha, create_graph=True, only_inputs=True, allow_unused=False - ) - return torch.stack(marg, dim=1) + if _autograd or not self.semiring is LogSemiring: + v, edge, _ = self._dp(edge, lengths=lengths, force_grad=True) + marg = torch.autograd.grad( + v.sum(dim=0), edge, create_graph=True, only_inputs=True, allow_unused=False + ) + return torch.stack(marg, dim=1) + else: + v, _, alpha = self._dp(edge, lengths=lengths, force_grad=False) + return self._dp_backward(edge, lengths, alpha) # Adapters @staticmethod diff --git a/torch_struct/semirings.py b/torch_struct/semirings.py index 2860b3d5..a5518bfb 100644 --- a/torch_struct/semirings.py +++ b/torch_struct/semirings.py @@ -33,6 +33,8 @@ class StdSemiring(_Base): def sum(xs, dim=-1): return torch.sum(xs, dim=dim) + def div_exp(a, b): + return a.exp().div(b.exp()) class _BaseLog(Semiring): @staticmethod @@ -47,6 +49,10 @@ def zero(): def one(): return 0.0 + @staticmethod + def div_exp(a, b): + return a.exp().div(b.exp()) + class LogSemiring(_BaseLog): @staticmethod @@ -59,6 +65,10 @@ class MaxSemiring(_BaseLog): def sum(xs, dim=-1): return torch.max(xs, dim=dim)[0] + @staticmethod + def div_exp(a, b): + return a == b + class _SampledLogSumExp(torch.autograd.Function): @staticmethod diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 71665134..7a6b616b 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -21,6 +21,36 @@ def test_simple(batch, N, C): LinearChain(SampledSemiring).sum(vals) +def test_fb_m(): + vals = torch.rand(2, 4, 5, 5) + v, _, alpha = LinearChain(MaxSemiring)._dp(vals) + marginals = LinearChain(MaxSemiring)._dp_backward(vals, None, alpha) + +@given(data()) +def test_fb(data): + model = data.draw(sampled_from([DepTree])) + torch.manual_seed(1) + vals, (batch, N) = model._rand() + + + lengths = torch.tensor( + [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] + ) + vals, (batch, N) = torch.ones(1, 2, 2), (1, 2) + lengths = None + marginals2 = model().marginals(vals, lengths=lengths, _autograd=True) + v, _, alpha = model()._dp(vals, lengths=lengths) + print(v) + marginals = model()._dp_backward(vals, lengths, alpha, v) + + if isinstance(marginals, tuple): + for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:]) ): + print((torch.isclose(m1, m2) == False).nonzero()) + assert(torch.isclose(m1, m2).all()), (torch.isclose(m1, m2) == False).nonzero() + else: + assert(torch.isclose(marginals, marginals2).all()) + + @given(data()) @settings(max_examples=50, deadline=None) def test_generic(data): @@ -96,7 +126,7 @@ def test_params(data, seed): else: vals.requires_grad_(True) # torch.autograd.set_detect_anomaly(True) - semiring = StdSemiring + semiring = LogSemiring alpha = model(semiring).sum(vals) alpha.sum().backward() From 69846dfe43ebc08cf0f9d27e3820f8765d20f90a Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Sat, 7 Sep 2019 17:30:23 -0400 Subject: [PATCH 2/4] . --- torch_struct/cky.py | 11 +- torch_struct/deptree.py | 246 +++++++++++++------------------- torch_struct/helpers.py | 51 ++++++- torch_struct/linearchain.py | 83 ++++------- torch_struct/semirings.py | 5 + torch_struct/test_algorithms.py | 11 +- 6 files changed, 193 insertions(+), 214 deletions(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 885acf16..a1d17f72 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -1,5 +1,5 @@ import torch -from .helpers import _Struct, DPManual +from .helpers import _Struct from .semirings import LogSemiring from torch.autograd import Function @@ -8,7 +8,8 @@ class DPManual2(Function): @staticmethod def forward(ctx, obj, terms, rules, roots, lengths): - v, _, alpha = obj._dp((terms, rules, roots), lengths, False) + with torch.no_grad(): + v, _, alpha = obj._dp((terms, rules, roots), lengths, False) ctx.obj = obj ctx.lengths = lengths ctx.alpha = alpha @@ -19,8 +20,10 @@ def forward(ctx, obj, terms, rules, roots, lengths): @staticmethod def backward(ctx, grad_v): terms, rules, roots = ctx.saved_tensors - marginals = ctx.obj._dp_backward((terms, rules, roots), ctx.lengths, ctx.alphactx.v) - return None, marginals, None + with torch.no_grad(): + marginals = ctx.obj._dp_backward((terms, rules, roots), + ctx.lengths, ctx.alpha, ctx.v) + return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 0e03491e..77d93e91 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,6 +1,6 @@ import torch import itertools -from .helpers import _Struct, DPManual +from .helpers import _Struct, DPManual, roll from .semirings import LogSemiring def _convert(logits): @@ -34,39 +34,19 @@ def _unconvert(logits): class DepTree(_Struct): - def sum(self, arc_scores, lengths=None, _autograd=False): - """ - Compute the inside pass of a projective dependency CRF. - - Parameters: - arc_scores : b x N x N arc scores with root scores on diagonal. - semiring - lengths: None or b long tensor mask - - Returns: - v: b tensor of total sum - arcs: list of N, LR x b x N table - - """ - if _autograd or not self.semiring is LogSemiring: - return self._dp(arc_scores, lengths)[0] - else: - return DPManual.apply(self, *scores, lengths) + """ + A projective dependency CRF. + + Parameters: + arc_scores : b x N x N arc scores with root scores on diagonal. + """ def _dp(self, arc_scores, lengths=None, force_grad=False): semiring = self.semiring arc_scores = _convert(arc_scores) - batch, N, N2 = arc_scores.shape - - assert N == N2, "Non-square potentials" + batch, N, lengths = self._check_potentials(arc_scores, lengths) + DIRS = 2 - if lengths is None: - lengths = torch.LongTensor([N - 1] * batch) - assert max(lengths) <= N, "Length longer than N" - for b in range(batch): - arc_scores[b, lengths[b] + 1 :, :] = semiring.zero() - arc_scores[b, :, lengths[b] + 1 :] = semiring.zero() - def stack(a, b): return torch.stack([a, b]) @@ -105,19 +85,19 @@ def sstack(a): ), ) alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] + v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]) + print(v) return ( - torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]), - arcs, + v, + arcs[1:], alpha ) - def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False): + + def _check_potentials(self, arc_scores, lengths = None): semiring = self.semiring - arc_scores = _convert(arc_scores) batch, N, N2 = arc_scores.shape - assert N == N2, "Non-square potentials" - DIRS = 2 if lengths is None: lengths = torch.LongTensor([N - 1] * batch) assert max(lengths) <= N, "Length longer than N" @@ -125,6 +105,15 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False arc_scores[b, lengths[b] + 1 :, :] = semiring.zero() arc_scores[b, :, lengths[b] + 1 :] = semiring.zero() + return batch, N, lengths + + def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False): + + # This function is super complicated. + semiring = self.semiring + arc_scores = _convert(arc_scores) + batch, N, lengths = self._check_potentials(arc_scores, lengths) + DIRS = 2 def stack(*a): return torch.stack(a) @@ -137,145 +126,112 @@ def sstack(a): ] arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad) - for k in range(N-1, -1, -1): + # Initialize for b, l in enumerate(lengths): alpha[A][C][R, b, 0, l] = semiring.one() alpha[B][C][R, b, l, N-l-1] = semiring.one() - a = semiring.dot( - alpha[A][I][R, :, :N - (k+1), (k+1):], - alpha_in[A][C][L,:, k+1:, : N-(k+1)], - arc_scores[:, :N-(k+1), k+1:] - ) - b = semiring.dot( - alpha[A][I][L, :, :N - (k+1), k+1:], - alpha_in[A][C][L, :, k+1:, : N-(k+1)], - arc_scores.transpose(1,2)[:, :N-(k+1), k+1:] - ) + # R completes + #I -> C* C + a = semiring.dot(*roll(alpha[A][I][R], + alpha_in[A][C][L], N, k, 1)) + #I -> C* C + b = semiring.dot(*roll(alpha[A][I][L], + alpha_in[A][C][L], N, k, 1)) + #C -> I C* + c = semiring.dot(*roll(alpha_in[B][I][R], + alpha[B][C][R], N, k, 0)) - alpha[A][C][R, :, :N-k-1, k] = semiring.sum(stack(a, b, - alpha[A][C][R, :, :N-k-1, k]), dim=0) + alpha[A][C][R, :, :N-k-1, k] = \ + semiring.plus(a, b, alpha[A][C][R, :, :N-k-1, k]) - a = semiring.dot( - alpha[B][I][L, :, k+1:, :N - k-1], - alpha_in[B][C][R, :, :N-k-1, k+1:], - arc_scores.transpose(1,2)[:, 1:N-(k), k+1:] + alpha[A][C][R][:, :N-k, k] = \ + semiring.plus(alpha[A][C][R][:, :N-k, k], c) + + # L completes + #I -> C* C + f = torch.arange(N-k-1), torch.arange(k, N-1) + a = semiring.dot(*roll( + alpha_in[B][C][R], alpha[B][I][L], N, k, 1 + )) - ) + #I -> C* C + f = torch.arange(k+1, N), torch.arange(N-(k+1)) b = semiring.dot( alpha[B][I][R, :, k+1:, :N-k-1], - alpha_in[B][C][R, :, :N-k-1, k+1:], - arc_scores[:, 1:N-(k), k+1:] + alpha_in[B][C][R, :, :N-k-1, k+1:] ) - - alpha[A][C][L, :, 1:N-k, k] = \ - semiring.sum(stack(a, b, - alpha[A][C][L, :, 1:N-k, k]), dim=0) - - - alpha[B][C][L, :, 1:N-k, N - k - 1] = alpha[A][C][L, :, 1: N - k, k] - alpha[B][C][R, :, :N-k-1, N - k - 1] = alpha[A][C][R, :, : N - k-1, k] - - print("C", k, alpha[A][C][:, 0, :, k].nonzero()) - - - - alpha[A][I][R][:, :N-k, k] = semiring.dot( - alpha[A][C][R, :, :N-k, k:], - alpha_in[A][C][R, :, k:, :N-k] - ) - - alpha[A][I][L][:, :N-k, k] = semiring.dot( - alpha[B][C][L, :, k:, :N - (k)], - alpha_in[B][C][L, :, :N-k, k:] - ) - - alpha[A][C][R][:, :N-k, k] = semiring.sum(stack(alpha[A][C][R][:, :N-k, k], - - semiring.dot( - alpha[B][C][R, :, k:, :N - k], - alpha_in[B][I][R, :, :N-k, k:] - )), dim =0) - - alpha[A][C][L][:, :N-k, k] =semiring.sum(stack( - alpha[A][C][L][:, :N-k, k], - semiring.dot( + + #C -> I C* + c = semiring.dot( alpha[A][C][L, :, :N - k, k:], alpha_in[A][I][L, :, k:, : N-k] - )), dim=0) - + ) - alpha[B][C][:, :, :N-k, N - k - 1] = alpha[A][C][:, :, : N - k, k] - - alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] - print("IC", k, alpha[A][C][:, 0, :, k].nonzero()) - print("I", k, alpha[A][I][:, 0, :, k].nonzero()) - - - print("COMP", alpha[A][C][:, :, :, :].nonzero()) - v = alpha[A][C][R, 0, 1, 0] - print("finsh", v) + alpha[A][C][L, :, 1:N-k, k] = \ + semiring.sum(stack(a, b, + alpha[A][C][L, :, 1:N-k, k]), dim=0) + alpha[A][C][L][:, :N-k, k] = \ + semiring.sum(stack(c, + alpha[A][C][L][:, :N-k, k]), dim=0) - left = semiring.div_exp(semiring.times(alpha[A][I][L, :, :, :], - alpha_in[A][I][L, :, :, :]), - v.view(batch, 1, 1)) + alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] + + if k > 0: + f = torch.arange(N-k), torch.arange(k, N) + + # Incomplete + alpha[A][I][R][:, :N-k, k] = semiring.dot( + alpha[A][C][R, :, :N-k, k:], + alpha_in[A][C][R, :, k:, :N-k], + arc_scores[:, f[0], f[1]].unsqueeze(-1) + ) + + #C -> C I + alpha[A][I][L][:, :N-k, k] = semiring.dot( + alpha[B][C][L, :, k:, :N - (k)], + alpha_in[B][C][L, :, :N-k, k:], + arc_scores[:, f[1], f[0]].unsqueeze(-1) + ) + if k == 1: + print(alpha[B][C][L, 0, :, :].nonzero(), + alpha_in[B][C][L, 0, :, :].nonzero()) + + alpha[A][I][:, :, :] + alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] + + + v = alpha[A][C][R, :, 0, 0] + left = semiring.times(alpha[A][I][L, :, :, :], + alpha_in[A][I][L, :, :, :]) + right = semiring.times(alpha[A][I][R, :, :, :], + alpha_in[A][I][R, :, :, :]) - right = semiring.div_exp(semiring.times(alpha[A][I][R, :, :, :], - alpha_in[A][I][R, :, :, :]), - v.view(batch, 1, 1)) ret = torch.zeros(batch, N, N) print(left, right, alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) for k in range(N): for d in range(N-k): - ret[:, k+d, k] = left[:, k, d] - ret[:, k, k+d] = right[:, k, d] + ret[:, k+d, k] = semiring.div_exp(left[:, k, d] - arc_scores[:, k+d, k], v.view(batch)) + ret[:, k, k+d] = semiring.div_exp(right[:, k, d]- arc_scores[:, k, k+d], v.view(batch)) return _unconvert(ret) - # return ( - # torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]), - # arcs, - # ) - - - def marginals(self, arc_scores, lengths=None, _autograd=False): - """ - Compute the marginals of a projective dependency CRF. - - Parameters: - arc_scores : b x N x N arc scores with root scores on diagonal. - semiring - lengths - Returns: - arc_marginals : b x N x N. - """ - batch, N, _ = arc_scores.shape + def _arrange_marginals(self, grads): + batch, N = grads[0][0].shape N = N + 1 - v, arcs, alpha = self._dp(arc_scores, lengths, force_grad=True) - - if _autograd or not self.semiring is LogSemiring: - grads = torch.autograd.grad( - v.sum(dim=0), - arcs[1:], - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - ret = torch.zeros(batch, N, N).cpu() - for k, grad in enumerate(grads, 1): - f = torch.arange(N - k), torch.arange(k, N) - ret[:, f[0], f[1]] = grad[R].cpu() - ret[:, f[1], f[0]] = grad[L].cpu() - return _unconvert(ret) - else: - return self._dp_backward(arc_scores, lengths, alpha) - + ret = torch.zeros(batch, N, N).cpu() + for k, grad in enumerate(grads, 1): + f = torch.arange(N - k), torch.arange(k, N) + ret[:, f[0], f[1]] = grad[R].cpu() + ret[:, f[1], f[0]] = grad[L].cpu() + return _unconvert(ret) + @staticmethod def to_parts(sequence, extra=None, lengths=None): """ diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index f0095a6d..e03c3e20 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -2,10 +2,18 @@ from .semirings import LogSemiring from torch.autograd import Function + +def roll(a, b, N, k, gap=0): + return (a[:, :N - (k+gap), (k+gap):], \ + b[:, k+gap:, : N-(k+gap)]) + + + class DPManual(Function): @staticmethod def forward(ctx, obj, input, lengths): - v, _, alpha = obj._dp(input, lengths, False) + with torch.no_grad(): + v, _, alpha = obj._dp(input, lengths, False) ctx.obj = obj ctx.lengths = lengths ctx.alpha = alpha @@ -21,7 +29,8 @@ def backward(ctx, grad_v): input = ctx.saved_tensors if len(input) == 1: input = input[0] - marginals = ctx.obj._dp_backward(input, ctx.lengths, ctx.alpha) + with torch.no_grad(): + marginals = ctx.obj._dp_backward(input, ctx.lengths, ctx.alpha) return None, marginals, None @@ -43,3 +52,41 @@ def _make_chart(self, N, size, potentials, force_grad): ) for _ in range(N) ] + + + def sum(self, edge, lengths=None, _autograd=False): + """ + Compute the (semiring) sum over all structures model. + + Parameters: + params : generic params (see class) + lengths: None or b long tensor mask + + Returns: + v: b tensor of total sum + + """ + if _autograd or not self.semiring is LogSemiring and "_dp_backward" in dir(self): + return self._dp(edge, lengths)[0] + else: + return DPManual.apply(self, edge, lengths) + + def marginals(self, edge, lengths=None, _autograd=False): + """ + Compute the marginals of a structured model. + + Parameters: + params : generic params (see class) + lengths: None or b long tensor mask + Returns: + marginals: b x (N-1) x C x C table + + """ + v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) + if _autograd or not self.semiring is LogSemiring and "_dp_backward" in dir(self): + marg = torch.autograd.grad( + v.sum(dim=0), edge, create_graph=True, only_inputs=True, allow_unused=False + ) + return self._arrange_marginals(marg) + else: + return self._dp_backward(edge, lengths, alpha) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 69d1ce6b..fed87d0d 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -4,10 +4,21 @@ from .semirings import LogSemiring - class LinearChain(_Struct): - def _dp(self, edge, lengths=None, force_grad=False): - semiring = self.semiring + """ + Represents structured linear-chain CRFs, generalizing HMMs smoothing, tagging models, + and anything with chain-like dynamics. + + + Potentials are of the form: + + edge : b x (N-1) x C x C markov potentials + (n-1 x z_n x z_{n-1}) + + + """ + + def _check_potentials(self, edge, lengths = None): batch, N_1, C, C2 = edge.shape N = N_1 + 1 if lengths is None: @@ -15,9 +26,14 @@ def _dp(self, edge, lengths=None, force_grad=False): assert max(lengths) <= N, "Length longer than edge scores" assert max(lengths) == N, "One length must be at least N" assert C == C2, "Transition shape doesn't match" + return batch, N, C, lengths - alpha = self._make_chart(N, (batch, C), edge, force_grad=force_grad) - edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=force_grad) + def _dp(self, edge, lengths=None, force_grad=False): + semiring = self.semiring + batch, N, C, lengths = self._check_potentials(edge, lengths) + + alpha = self._make_chart(N, (batch, C), edge, force_grad) + edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad) alpha[0].data.fill_(semiring.one()) for n in range(1, N): @@ -25,17 +41,13 @@ def _dp(self, edge, lengths=None, force_grad=False): alpha[n - 1].view(batch, 1, C), edge[:, n - 1] ) alpha[n][:] = semiring.sum(edge_store[n - 1]) - v = semiring.sum( - torch.stack([alpha[l - 1][i] for i, l in enumerate(lengths)]), dim=-1 - ) + ret = [alpha[l - 1][i] for i, l in enumerate(lengths)] + v = semiring.sum(torch.stack(ret)) return v, edge_store, alpha - def _dp_backward(self, edge, lengths, alpha_in): + def _dp_backward(self, edge, lengths, alpha_in, v=None): semiring = self.semiring - batch, N_1, C, C2 = edge.shape - N = N_1 + 1 - if lengths is None: - lengths = torch.LongTensor([N] * batch) + batch, N, C, lengths = self._check_potentials(edge, lengths) alpha = self._make_chart(N, (batch, C), edge, force_grad=False) edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=False) @@ -53,9 +65,7 @@ def _dp_backward(self, edge, lengths, alpha_in): ) edge_marginals = self._make_chart(1, (batch, N-1, C, C), edge, force_grad=False)[0] - - - for n in range(N_1): + for n in range(N-1): edge_marginals[:, n] = semiring.div_exp(semiring.times(alpha_in[n].view(batch, 1, C), edge[:, n], alpha[n+1].view(batch, C, 1)), @@ -64,46 +74,9 @@ def _dp_backward(self, edge, lengths, alpha_in): return edge_marginals - def sum(self, edge, lengths=None, _autograd=False): - """ - Compute the forward pass of a linear chain CRF. - - Parameters: - edge : b x (N-1) x C x C markov potentials - (n-1 x z_n x z_{n-1}) - lengths: None or b long tensor mask - - Returns: - v: b tensor of total sum - inside: list of N, b x C x C table - - """ - if _autograd or not self.semiring is LogSemiring: - return self._dp(edge, lengths)[0] - else: - return DPManual.apply(self, edge, lengths) - def marginals(self, edge, lengths=None, _autograd=False): - """ - Compute the marginals of a linear chain CRF. - - Parameters: - edge : b x (N-1) x C x C markov potentials - (t x z_t x z_{t-1}) - lengths: None or b long tensor mask - Returns: - marginals: b x (N-1) x C x C table - - """ - if _autograd or not self.semiring is LogSemiring: - v, edge, _ = self._dp(edge, lengths=lengths, force_grad=True) - marg = torch.autograd.grad( - v.sum(dim=0), edge, create_graph=True, only_inputs=True, allow_unused=False - ) - return torch.stack(marg, dim=1) - else: - v, _, alpha = self._dp(edge, lengths=lengths, force_grad=False) - return self._dp_backward(edge, lengths, alpha) + def _arrange_marginals(self, marg): + return torch.stack(marg, dim=1) # Adapters @staticmethod diff --git a/torch_struct/semirings.py b/torch_struct/semirings.py index a5518bfb..56562836 100644 --- a/torch_struct/semirings.py +++ b/torch_struct/semirings.py @@ -9,6 +9,11 @@ def times(cls, *ls): cur = cls.mul(cur, l) return cur + @classmethod + def plus(cls, *ls): + return cls.sum(torch.stack(ls), dim=0) + + @classmethod def dot(cls, *ls): return cls.sum(cls.times(*ls)) diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 7a6b616b..45f65e18 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -28,24 +28,19 @@ def test_fb_m(): @given(data()) def test_fb(data): - model = data.draw(sampled_from([DepTree])) + model = data.draw(sampled_from([LinearChain, DepTree, CKY])) torch.manual_seed(1) vals, (batch, N) = model._rand() - lengths = torch.tensor( [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] ) - vals, (batch, N) = torch.ones(1, 2, 2), (1, 2) - lengths = None marginals2 = model().marginals(vals, lengths=lengths, _autograd=True) v, _, alpha = model()._dp(vals, lengths=lengths) - print(v) marginals = model()._dp_backward(vals, lengths, alpha, v) - + if isinstance(marginals, tuple): - for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:]) ): - print((torch.isclose(m1, m2) == False).nonzero()) + for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:])): assert(torch.isclose(m1, m2).all()), (torch.isclose(m1, m2) == False).nonzero() else: assert(torch.isclose(marginals, marginals2).all()) From fa308147376694087c3de40ca7c4163220817b2e Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Sat, 7 Sep 2019 17:53:38 -0400 Subject: [PATCH 3/4] . --- torch_struct/deptree.py | 76 ++++++++++++++------------------------ torch_struct/helpers.py | 5 ++- torch_struct/semimarkov.py | 39 ++++--------------- 3 files changed, 37 insertions(+), 83 deletions(-) diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 77d93e91..887ec625 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -114,17 +114,16 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False arc_scores = _convert(arc_scores) batch, N, lengths = self._check_potentials(arc_scores, lengths) DIRS = 2 - def stack(*a): - return torch.stack(a) - - def sstack(a): - return torch.stack([a, a]) alpha = [ self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) for _ in range(2) ] arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad) + def stack(a, b): + return torch.stack([a, b], dim=-1) + def sstack(a): + return torch.stack([a, a], dim=-1) for k in range(N-1, -1, -1): # Initialize @@ -135,51 +134,38 @@ def sstack(a): # R completes #I -> C* C - a = semiring.dot(*roll(alpha[A][I][R], - alpha_in[A][C][L], N, k, 1)) - #I -> C* C - b = semiring.dot(*roll(alpha[A][I][L], - alpha_in[A][C][L], N, k, 1)) - #C -> I C* + a = semiring.dot(*roll(stack(alpha[A][I][R], alpha[A][I][L]), + sstack(alpha_in[A][C][L]), N, k, 1)) + c = semiring.dot(*roll(alpha_in[B][I][R], alpha[B][C][R], N, k, 0)) - alpha[A][C][R, :, :N-k-1, k] = \ - semiring.plus(a, b, alpha[A][C][R, :, :N-k-1, k]) + alpha[A][C][R, :, :N-k-1, k] = semiring.plus(semiring.sum(a), + alpha[A][C][R, :, :N-k-1, k]) alpha[A][C][R][:, :N-k, k] = \ semiring.plus(alpha[A][C][R][:, :N-k, k], c) # L completes #I -> C* C - f = torch.arange(N-k-1), torch.arange(k, N-1) - a = semiring.dot(*roll( - alpha_in[B][C][R], alpha[B][I][L], N, k, 1 - )) - #I -> C* C - f = torch.arange(k+1, N), torch.arange(N-(k+1)) - b = semiring.dot( - alpha[B][I][R, :, k+1:, :N-k-1], - alpha_in[B][C][R, :, :N-k-1, k+1:] - ) - #C -> I C* - c = semiring.dot( - alpha[A][C][L, :, :N - k, k:], - alpha_in[A][I][L, :, k:, : N-k] - ) + a = semiring.dot(*roll(sstack(alpha_in[B][C][R]), + stack(alpha[B][I][L], alpha[B][I][R]), + N, k, 1)) - alpha[A][C][L, :, 1:N-k, k] = \ - semiring.sum(stack(a, b, - alpha[A][C][L, :, 1:N-k, k]), dim=0) + + c = semiring.dot(*roll(alpha[A][C][L], + alpha_in[A][I][L], N, k, 0)) + alpha[A][C][L, :, 1:N-k, k] = \ + semiring.plus(semiring.sum(a), alpha[A][C][L, :, 1:N-k, k]) alpha[A][C][L][:, :N-k, k] = \ - semiring.sum(stack(c, - alpha[A][C][L][:, :N-k, k]), dim=0) + semiring.plus(c, alpha[A][C][L][:, :N-k, k]) + # Compute reverses. alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] if k > 0: @@ -187,24 +173,18 @@ def sstack(a): # Incomplete alpha[A][I][R][:, :N-k, k] = semiring.dot( - alpha[A][C][R, :, :N-k, k:], - alpha_in[A][C][R, :, k:, :N-k], - arc_scores[:, f[0], f[1]].unsqueeze(-1) - ) + arc_scores[:, f[0], f[1]].unsqueeze(-1), + *roll(alpha[A][C][R], + alpha_in[A][C][R], N, k)) #C -> C I alpha[A][I][L][:, :N-k, k] = semiring.dot( - alpha[B][C][L, :, k:, :N - (k)], - alpha_in[B][C][L, :, :N-k, k:], - arc_scores[:, f[1], f[0]].unsqueeze(-1) - ) - if k == 1: - print(alpha[B][C][L, 0, :, :].nonzero(), - alpha_in[B][C][L, 0, :, :].nonzero()) - - alpha[A][I][:, :, :] - alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] + arc_scores[:, f[1], f[0]].unsqueeze(-1), + *roll(alpha_in[B][C][L], + alpha[B][C][L], N, k)) + # Compute reverses + alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] v = alpha[A][C][R, :, 0, 0] left = semiring.times(alpha[A][I][L, :, :, :], @@ -213,7 +193,6 @@ def sstack(a): alpha_in[A][I][R, :, :, :]) ret = torch.zeros(batch, N, N) - print(left, right, alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) for k in range(N): for d in range(N-k): ret[:, k+d, k] = semiring.div_exp(left[:, k, d] - arc_scores[:, k+d, k], v.view(batch)) @@ -221,7 +200,6 @@ def sstack(a): return _unconvert(ret) - def _arrange_marginals(self, grads): batch, N = grads[0][0].shape N = N + 1 diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index e03c3e20..8d951622 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -9,6 +9,7 @@ def roll(a, b, N, k, gap=0): + class DPManual(Function): @staticmethod def forward(ctx, obj, input, lengths): @@ -66,7 +67,7 @@ def sum(self, edge, lengths=None, _autograd=False): v: b tensor of total sum """ - if _autograd or not self.semiring is LogSemiring and "_dp_backward" in dir(self): + if _autograd or not self.semiring is LogSemiring or "_dp_backward" not in self.__dict__: return self._dp(edge, lengths)[0] else: return DPManual.apply(self, edge, lengths) @@ -83,7 +84,7 @@ def marginals(self, edge, lengths=None, _autograd=False): """ v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) - if _autograd or not self.semiring is LogSemiring and "_dp_backward" in dir(self): + if _autograd or not self.semiring is LogSemiring or "_dp_backward" not in self.__dict__: marg = torch.autograd.grad( v.sum(dim=0), edge, create_graph=True, only_inputs=True, allow_unused=False ) diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index 856ec561..68c877a8 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -3,21 +3,10 @@ class SemiMarkov(_Struct): - def sum(self, edge, lengths=None, force_grad=False): - """ - Compute the forward pass of a semimarkov CRF. - - Parameters: - edge : b x N x K x C x C semimarkov potentials - lengths: None or b long tensor mask - - Returns: - v: b tensor of total sum - spans: list of N, b x K x C x C table - - """ - return self._dp(edge, lengths)[0] - + """ + edge : b x N x K x C x C semimarkov potentials + """ + def _dp(self, edge, lengths=None, force_grad=False): semiring = self.semiring batch, N_1, K, C, C2 = edge.shape @@ -47,7 +36,7 @@ def _dp(self, edge, lengths=None, force_grad=False): v = semiring.sum( torch.stack([beta[l - 1][i] for i, l in enumerate(lengths)]), dim=1 ) - return v, spans + return v, spans, beta @staticmethod def _rand(): @@ -57,24 +46,10 @@ def _rand(): C = torch.randint(2, 4, (1,)) return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item()) - def marginals(self, edge, lengths=None): - """ - Compute the marginals of a semimarkov CRF. - - Parameters: - edge : b x N x K x C x C semimarkov potentials - semiring - - Returns: - marginals: b x N x K x C table - """ - v, spans = self._dp(edge, lengths, force_grad=True) - marg = torch.autograd.grad( - v.sum(dim=0), spans, create_graph=True, only_inputs=True, allow_unused=False - ) + def _arrange_marginals(self, marg): return torch.stack(marg, dim=1) - + @staticmethod def to_parts(sequence, extra, lengths=None): """ From 068880d1e910375a3fa88bcdef4029c8017dcc3b Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Sat, 7 Sep 2019 17:59:21 -0400 Subject: [PATCH 4/4] . --- torch_struct/cky.py | 74 +++++++++-------- torch_struct/deptree.py | 139 +++++++++++++++++--------------- torch_struct/helpers.py | 24 ++++-- torch_struct/linearchain.py | 46 +++++------ torch_struct/semimarkov.py | 5 +- torch_struct/semirings.py | 3 +- torch_struct/test_algorithms.py | 11 ++- 7 files changed, 165 insertions(+), 137 deletions(-) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index a1d17f72..7356082d 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -5,6 +5,7 @@ A, B = 0, 1 + class DPManual2(Function): @staticmethod def forward(ctx, obj, terms, rules, roots, lengths): @@ -21,12 +22,12 @@ def forward(ctx, obj, terms, rules, roots, lengths): def backward(ctx, grad_v): terms, rules, roots = ctx.saved_tensors with torch.no_grad(): - marginals = ctx.obj._dp_backward((terms, rules, roots), - ctx.lengths, ctx.alpha, ctx.v) + marginals = ctx.obj._dp_backward( + (terms, rules, roots), ctx.lengths, ctx.alpha, ctx.v + ) return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None - class CKY(_Struct): def sum(self, scores, lengths=None, force_grad=False, _autograd=False): """ @@ -41,7 +42,7 @@ def sum(self, scores, lengths=None, force_grad=False, _autograd=False): v: b tensor of total sum spans: list of N, b x N x (NT+t) """ - if _autograd or not self.semiring is LogSemiring: + if _autograd or self.semiring is not LogSemiring: return self._dp(scores, lengths)[0] else: return DPManual2.apply(self, *scores, lengths) @@ -95,66 +96,71 @@ def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False): beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad) span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad) span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad) - top = self._make_chart(1, (batch, NT), rules, force_grad)[0] term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0] - ssum = semiring.sum st = semiring.times X_Y_Z = rules.view(batch, 1, NT, S, S) - for w in range(N-1, -1, -1): + for w in range(N - 1, -1, -1): for b, l in enumerate(lengths): - beta[A][b, 0, l-1, :NT] = roots[b] - beta[B][b, l-1, N-(l), :NT] = roots[b] + beta[A][b, 0, l - 1, :NT] = roots[b] + beta[B][b, l - 1, N - (l), :NT] = roots[b] # LEFT # all bigger on the left. - X = beta[A][:, :N-w-1, w+1:, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1) - Z = alpha_in[A][:, w+1:N, 0:N-w-1].view(batch, N-w-1, N-w-1, 1, 1, S) + X = beta[A][:, : N - w - 1, w + 1 :, :NT].view( + batch, N - w - 1, N - w - 1, NT, 1, 1 + ) + Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view( + batch, N - w - 1, N - w - 1, 1, 1, S + ) t = st(ssum(st(X, Z), dim=2), X_Y_Z) # sum out x and y - span_l[w] = ssum(ssum(t, dim =-3), dim=-1) + span_l[w] = ssum(ssum(t, dim=-3), dim=-1) # RIGHT - X = beta[B][:, w+1:, :N-1-w, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1) - Y = alpha_in[B][:, :N-w-1, w+1:, :].view(batch, N-w-1, N-w-1, 1, S, 1) + X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view( + batch, N - w - 1, N - w - 1, NT, 1, 1 + ) + Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view( + batch, N - w - 1, N - w - 1, 1, S, 1 + ) t = st(ssum(st(X, Y), dim=2), X_Y_Z) span_r[w] = ssum(ssum(t, dim=-3), dim=-2) - beta[A][:, :N-w-1, w, :] = span_l[w] - beta[A][:, 1:N-w, w, :] = ssum(torch.stack([span_r[w], - beta[A][:, 1:N-w, w, :]]), dim=0) - beta[B][:, w:, N-w-1, :] = beta[A][:, :N - w, w, :] - + beta[A][:, : N - w - 1, w, :] = span_l[w] + beta[A][:, 1 : N - w, w, :] = ssum( + torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0 + ) + beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :] term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms) term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0] for n in range(N): - term_marginals[:, n] = semiring.div_exp(term_use[:, n], - v.view(batch, 1)) + term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1)) root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0] for b in range(batch): - root_marginals[b] = semiring.div_exp(st(alpha_in[A][b, 0, lengths[b]-1, :NT], roots[b]), - v[b].view(1)) - edge_marginals = self._make_chart(1, (batch, N, N, NT, S, S), terms, force_grad=False)[0] + root_marginals[b] = semiring.div_exp( + st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1) + ) + edge_marginals = self._make_chart( + 1, (batch, N, N, NT, S, S), terms, force_grad=False + )[0] edge_marginals.fill_(0) for w in range(1, N): Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1) Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S) - score = semiring.times( - semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z + score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z) + score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1)) + edge_marginals[:, : N - w, w - 1] = semiring.div_exp( + score, v.view(batch, 1, 1, 1, 1) ) - score = st(score, beta[A][:, :N-w, w, :NT].view(batch, N-w, NT, 1, 1)) - edge_marginals[:, :N-w, w-1] = semiring.div_exp(score, - v.view(batch, 1, 1, 1, 1)) edge_marginals = edge_marginals.transpose(1, 2) - - return (term_marginals, edge_marginals, root_marginals) - + return (term_marginals, edge_marginals, root_marginals) def marginals(self, scores, lengths=None, _autograd=False): """ @@ -177,7 +183,7 @@ def marginals(self, scores, lengths=None, _autograd=False): v, (term_use, rule_use, top), alpha = self._dp( scores, lengths=lengths, force_grad=True ) - if _autograd or not self.semiring is LogSemiring: + if _autograd or self.semiring is not LogSemiring: marg = torch.autograd.grad( v.sum(dim=0), tuple(rule_use) + (top, term_use), @@ -193,7 +199,7 @@ def marginals(self, scores, lengths=None, _autograd=False): assert marg[-2].shape == (batch, NT) return (marg[-1], rules, marg[-2]) else: - return self._dp_backward(edge, lengths, alpha, v) + return self._dp_backward(scores, lengths, alpha, v) @staticmethod def to_parts(spans, extra, lengths=None): diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 887ec625..c9ca3199 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,7 +1,7 @@ import torch import itertools -from .helpers import _Struct, DPManual, roll -from .semirings import LogSemiring +from .helpers import _Struct, roll + def _convert(logits): "move root arcs from diagonal" @@ -36,7 +36,7 @@ def _unconvert(logits): class DepTree(_Struct): """ A projective dependency CRF. - + Parameters: arc_scores : b x N x N arc scores with root scores on diagonal. """ @@ -45,8 +45,9 @@ def _dp(self, arc_scores, lengths=None, force_grad=False): semiring = self.semiring arc_scores = _convert(arc_scores) batch, N, lengths = self._check_potentials(arc_scores, lengths) - + DIRS = 2 + def stack(a, b): return torch.stack([a, b]) @@ -87,14 +88,9 @@ def sstack(a): alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]) print(v) - return ( - v, - arcs[1:], - alpha - ) - - - def _check_potentials(self, arc_scores, lengths = None): + return (v, arcs[1:], alpha) + + def _check_potentials(self, arc_scores, lengths=None): semiring = self.semiring batch, N, N2 = arc_scores.shape assert N == N2, "Non-square potentials" @@ -106,8 +102,8 @@ def _check_potentials(self, arc_scores, lengths = None): arc_scores[b, :, lengths[b] + 1 :] = semiring.zero() return batch, N, lengths - - def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False): + + def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False): # This function is super complicated. semiring = self.semiring @@ -119,87 +115,102 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) for _ in range(2) ] - arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad) + def stack(a, b): return torch.stack([a, b], dim=-1) + def sstack(a): return torch.stack([a, a], dim=-1) - for k in range(N-1, -1, -1): + for k in range(N - 1, -1, -1): # Initialize for b, l in enumerate(lengths): alpha[A][C][R, b, 0, l] = semiring.one() - alpha[B][C][R, b, l, N-l-1] = semiring.one() - + alpha[B][C][R, b, l, N - l - 1] = semiring.one() # R completes - #I -> C* C - #I -> C* C - #C -> I C* - a = semiring.dot(*roll(stack(alpha[A][I][R], alpha[A][I][L]), - sstack(alpha_in[A][C][L]), N, k, 1)) + # I -> C* C + # I -> C* C + # C -> I C* + a = semiring.dot( + *roll( + stack(alpha[A][I][R], alpha[A][I][L]), + sstack(alpha_in[A][C][L]), + N, + k, + 1, + ) + ) - c = semiring.dot(*roll(alpha_in[B][I][R], - alpha[B][C][R], N, k, 0)) + c = semiring.dot(*roll(alpha_in[B][I][R], alpha[B][C][R], N, k, 0)) - alpha[A][C][R, :, :N-k-1, k] = semiring.plus(semiring.sum(a), - alpha[A][C][R, :, :N-k-1, k]) + alpha[A][C][R, :, : N - k - 1, k] = semiring.plus( + semiring.sum(a), alpha[A][C][R, :, : N - k - 1, k] + ) + + alpha[A][C][R][:, : N - k, k] = semiring.plus( + alpha[A][C][R][:, : N - k, k], c + ) - alpha[A][C][R][:, :N-k, k] = \ - semiring.plus(alpha[A][C][R][:, :N-k, k], c) - # L completes - #I -> C* C - #I -> C* C - #C -> I C* - a = semiring.dot(*roll(sstack(alpha_in[B][C][R]), - stack(alpha[B][I][L], alpha[B][I][R]), - N, k, 1)) - - - c = semiring.dot(*roll(alpha[A][C][L], - alpha_in[A][I][L], N, k, 0)) - - alpha[A][C][L, :, 1:N-k, k] = \ - semiring.plus(semiring.sum(a), alpha[A][C][L, :, 1:N-k, k]) - alpha[A][C][L][:, :N-k, k] = \ - semiring.plus(c, alpha[A][C][L][:, :N-k, k]) + # I -> C* C + # I -> C* C + # C -> I C* + a = semiring.dot( + *roll( + sstack(alpha_in[B][C][R]), + stack(alpha[B][I][L], alpha[B][I][R]), + N, + k, + 1, + ) + ) + + c = semiring.dot(*roll(alpha[A][C][L], alpha_in[A][I][L], N, k, 0)) + + alpha[A][C][L, :, 1 : N - k, k] = semiring.plus( + semiring.sum(a), alpha[A][C][L, :, 1 : N - k, k] + ) + alpha[A][C][L][:, : N - k, k] = semiring.plus( + c, alpha[A][C][L][:, : N - k, k] + ) # Compute reverses. alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k] - + if k > 0: - f = torch.arange(N-k), torch.arange(k, N) - + f = torch.arange(N - k), torch.arange(k, N) + # Incomplete - alpha[A][I][R][:, :N-k, k] = semiring.dot( + alpha[A][I][R][:, : N - k, k] = semiring.dot( arc_scores[:, f[0], f[1]].unsqueeze(-1), - *roll(alpha[A][C][R], - alpha_in[A][C][R], N, k)) + *roll(alpha[A][C][R], alpha_in[A][C][R], N, k) + ) - #C -> C I - alpha[A][I][L][:, :N-k, k] = semiring.dot( + # C -> C I + alpha[A][I][L][:, : N - k, k] = semiring.dot( arc_scores[:, f[1], f[0]].unsqueeze(-1), - *roll(alpha_in[B][C][L], - alpha[B][C][L], N, k)) + *roll(alpha_in[B][C][L], alpha[B][C][L], N, k) + ) # Compute reverses alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k] v = alpha[A][C][R, :, 0, 0] - left = semiring.times(alpha[A][I][L, :, :, :], - alpha_in[A][I][L, :, :, :]) - right = semiring.times(alpha[A][I][R, :, :, :], - alpha_in[A][I][R, :, :, :]) + left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :]) + right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :]) ret = torch.zeros(batch, N, N) for k in range(N): - for d in range(N-k): - ret[:, k+d, k] = semiring.div_exp(left[:, k, d] - arc_scores[:, k+d, k], v.view(batch)) - ret[:, k, k+d] = semiring.div_exp(right[:, k, d]- arc_scores[:, k, k+d], v.view(batch)) + for d in range(N - k): + ret[:, k + d, k] = semiring.div_exp( + left[:, k, d] - arc_scores[:, k + d, k], v.view(batch) + ) + ret[:, k, k + d] = semiring.div_exp( + right[:, k, d] - arc_scores[:, k, k + d], v.view(batch) + ) return _unconvert(ret) - def _arrange_marginals(self, grads): batch, N = grads[0][0].shape N = N + 1 @@ -209,7 +220,7 @@ def _arrange_marginals(self, grads): ret[:, f[0], f[1]] = grad[R].cpu() ret[:, f[1], f[0]] = grad[L].cpu() return _unconvert(ret) - + @staticmethod def to_parts(sequence, extra=None, lengths=None): """ diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 8d951622..82c148ce 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -4,10 +4,7 @@ def roll(a, b, N, k, gap=0): - return (a[:, :N - (k+gap), (k+gap):], \ - b[:, k+gap:, : N-(k+gap)]) - - + return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)]) class DPManual(Function): @@ -54,7 +51,6 @@ def _make_chart(self, N, size, potentials, force_grad): for _ in range(N) ] - def sum(self, edge, lengths=None, _autograd=False): """ Compute the (semiring) sum over all structures model. @@ -67,7 +63,11 @@ def sum(self, edge, lengths=None, _autograd=False): v: b tensor of total sum """ - if _autograd or not self.semiring is LogSemiring or "_dp_backward" not in self.__dict__: + if ( + _autograd + or self.semiring is not LogSemiring + or "_dp_backward" not in self.__dict__ + ): return self._dp(edge, lengths)[0] else: return DPManual.apply(self, edge, lengths) @@ -84,9 +84,17 @@ def marginals(self, edge, lengths=None, _autograd=False): """ v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) - if _autograd or not self.semiring is LogSemiring or "_dp_backward" not in self.__dict__: + if ( + _autograd + or self.semiring is not LogSemiring + or "_dp_backward" not in self.__dict__ + ): marg = torch.autograd.grad( - v.sum(dim=0), edge, create_graph=True, only_inputs=True, allow_unused=False + v.sum(dim=0), + edge, + create_graph=True, + only_inputs=True, + allow_unused=False, ) return self._arrange_marginals(marg) else: diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index fed87d0d..2ca6221f 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -1,24 +1,20 @@ import torch -from .helpers import _Struct, DPManual -from torch.autograd import Function -from .semirings import LogSemiring +from .helpers import _Struct class LinearChain(_Struct): """ - Represents structured linear-chain CRFs, generalizing HMMs smoothing, tagging models, - and anything with chain-like dynamics. + Represents structured linear-chain CRFs, generalizing HMMs smoothing, tagging models, + and anything with chain-like dynamics. Potentials are of the form: edge : b x (N-1) x C x C markov potentials (n-1 x z_n x z_{n-1}) - - """ - def _check_potentials(self, edge, lengths = None): + def _check_potentials(self, edge, lengths=None): batch, N_1, C, C2 = edge.shape N = N_1 + 1 if lengths is None: @@ -31,7 +27,7 @@ def _check_potentials(self, edge, lengths = None): def _dp(self, edge, lengths=None, force_grad=False): semiring = self.semiring batch, N, C, lengths = self._check_potentials(edge, lengths) - + alpha = self._make_chart(N, (batch, C), edge, force_grad) edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad) @@ -52,29 +48,33 @@ def _dp_backward(self, edge, lengths, alpha_in, v=None): alpha = self._make_chart(N, (batch, C), edge, force_grad=False) edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=False) - for n in range(N-1, 0, -1): + for n in range(N - 1, 0, -1): for b, l in enumerate(lengths): - alpha[l-1][b].data.fill_(semiring.one()) + alpha[l - 1][b].data.fill_(semiring.one()) - edge_store[n-1][:] = semiring.times( - alpha[n].view(batch, C, 1), edge[:, n-1] + edge_store[n - 1][:] = semiring.times( + alpha[n].view(batch, C, 1), edge[:, n - 1] ) - alpha[n - 1][:] = semiring.sum(edge_store[n-1], dim=-2) + alpha[n - 1][:] = semiring.sum(edge_store[n - 1], dim=-2) v = semiring.sum( torch.stack([alpha[0][i] for i, l in enumerate(lengths)]), dim=-1 ) - edge_marginals = self._make_chart(1, (batch, N-1, C, C), edge, force_grad=False)[0] - - for n in range(N-1): - edge_marginals[:, n] = semiring.div_exp(semiring.times(alpha_in[n].view(batch, 1, C), - edge[:, n], - alpha[n+1].view(batch, C, 1)), - v.view(batch, 1, 1)) + edge_marginals = self._make_chart( + 1, (batch, N - 1, C, C), edge, force_grad=False + )[0] + + for n in range(N - 1): + edge_marginals[:, n] = semiring.div_exp( + semiring.times( + alpha_in[n].view(batch, 1, C), + edge[:, n], + alpha[n + 1].view(batch, C, 1), + ), + v.view(batch, 1, 1), + ) return edge_marginals - - def _arrange_marginals(self, marg): return torch.stack(marg, dim=1) diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index 68c877a8..bbaf2e85 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -6,7 +6,7 @@ class SemiMarkov(_Struct): """ edge : b x N x K x C x C semimarkov potentials """ - + def _dp(self, edge, lengths=None, force_grad=False): semiring = self.semiring batch, N_1, K, C, C2 = edge.shape @@ -46,10 +46,9 @@ def _rand(): C = torch.randint(2, 4, (1,)) return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item()) - def _arrange_marginals(self, marg): return torch.stack(marg, dim=1) - + @staticmethod def to_parts(sequence, extra, lengths=None): """ diff --git a/torch_struct/semirings.py b/torch_struct/semirings.py index 56562836..a3f6bd10 100644 --- a/torch_struct/semirings.py +++ b/torch_struct/semirings.py @@ -13,7 +13,6 @@ def times(cls, *ls): def plus(cls, *ls): return cls.sum(torch.stack(ls), dim=0) - @classmethod def dot(cls, *ls): return cls.sum(cls.times(*ls)) @@ -38,9 +37,11 @@ class StdSemiring(_Base): def sum(xs, dim=-1): return torch.sum(xs, dim=dim) + @staticmethod def div_exp(a, b): return a.exp().div(b.exp()) + class _BaseLog(Semiring): @staticmethod def mul(a, b): diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 45f65e18..ccdbb8ed 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -24,7 +24,8 @@ def test_simple(batch, N, C): def test_fb_m(): vals = torch.rand(2, 4, 5, 5) v, _, alpha = LinearChain(MaxSemiring)._dp(vals) - marginals = LinearChain(MaxSemiring)._dp_backward(vals, None, alpha) + LinearChain(MaxSemiring)._dp_backward(vals, None, alpha) + @given(data()) def test_fb(data): @@ -38,12 +39,14 @@ def test_fb(data): marginals2 = model().marginals(vals, lengths=lengths, _autograd=True) v, _, alpha = model()._dp(vals, lengths=lengths) marginals = model()._dp_backward(vals, lengths, alpha, v) - + if isinstance(marginals, tuple): for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:])): - assert(torch.isclose(m1, m2).all()), (torch.isclose(m1, m2) == False).nonzero() + assert torch.isclose(m1, m2).all(), ( + not torch.isclose(m1, m2) + ).nonzero() else: - assert(torch.isclose(marginals, marginals2).all()) + assert torch.isclose(marginals, marginals2).all() @given(data())