From b2d7fcd2397346257fa1fd692c54bf867da00d0e Mon Sep 17 00:00:00 2001 From: Sasha Date: Mon, 2 Sep 2019 15:37:45 -0400 Subject: [PATCH] Add some more tests --- examples/supervised.py | 3 +++ setup.py | 5 ++--- torch_struct/cky.py | 15 +++++++++++---- torch_struct/deptree.py | 18 +++++++++++++----- torch_struct/helpers.py | 4 ++-- torch_struct/linearchain.py | 18 ++++++++++++------ torch_struct/semimarkov.py | 17 +++++++++++------ torch_struct/test_algorithms.py | 26 ++++++++++++++++++++++++++ 8 files changed, 80 insertions(+), 26 deletions(-) diff --git a/examples/supervised.py b/examples/supervised.py index e69de29b..12198361 100644 --- a/examples/supervised.py +++ b/examples/supervised.py @@ -0,0 +1,3 @@ +import torchtext + +torchtext.datsets.UDPos diff --git a/setup.py b/setup.py index 7e6796d9..f295c1a4 100644 --- a/setup.py +++ b/setup.py @@ -5,11 +5,10 @@ version="0.0.1", author="Alexander Rush", author_email="arush@cornell.edu", - packages=["torch_struct", ], + packages=["torch_struct"], package_data={"torch_struct": []}, url="https://github.com/harvardnlp/pytorch_struct", install_requires=["torch"], setup_requires=["pytest-runner"], - tests_require=["pytest"] - + tests_require=["pytest"], ) diff --git a/torch_struct/cky.py b/torch_struct/cky.py index b1be8535..8d4ed270 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -5,7 +5,9 @@ A, B = 0, 1 -def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None): +def cky_inside( + terms, rules, roots, semiring=LogSemiring, lengths=None, force_grad=False +): """ Compute the inside pass of a CFG using CKY. @@ -23,9 +25,14 @@ def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None): _, NT, _, _ = rules.shape if lengths is None: lengths = torch.LongTensor([N] * batch) - beta = [_make_chart((batch, N, N, NT + T), rules, semiring) for _ in range(2)] - - span = [_make_chart((batch, N, NT + T), rules, semiring) for _ in range(N)] + beta = [ + _make_chart((batch, N, N, NT + T), rules, semiring, force_grad) + for _ in range(2) + ] + + span = [ + _make_chart((batch, N, NT + T), rules, semiring, force_grad) for _ in range(N) + ] rule_use = [None for _ in range(N - 1)] term_use = terms.requires_grad_(True) beta[A][:, :, 0, NT:] = term_use diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 5cd13130..82973f35 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -34,7 +34,7 @@ def _unconvert(logits): A, B, R, C, L, I = 0, 1, 1, 1, 0, 0 -def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None): +def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None, force_grad=False): """ Compute the inside pass of a projective dependency CRF. @@ -49,10 +49,12 @@ def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None): """ arc_scores = _convert(arc_scores) - batch, N, _ = arc_scores.shape + batch, N, N2 = arc_scores.shape + assert N == N2, "Non-square potentials" DIRS = 2 if lengths is None: lengths = torch.LongTensor([N] * batch) + assert max(lengths) <= N, "Length longer than N" def stack(a, b): return torch.stack([a, b]) @@ -61,10 +63,16 @@ def sstack(a): return torch.stack([a, a]) alpha = [ - [_make_chart((DIRS, batch, N, N), arc_scores, semiring) for _ in [I, C]] + [ + _make_chart((DIRS, batch, N, N), arc_scores, semiring, force_grad) + for _ in [I, C] + ] for _ in range(2) ] - arcs = [_make_chart((DIRS, batch, N), arc_scores, semiring) for _ in range(N)] + arcs = [ + _make_chart((DIRS, batch, N), arc_scores, semiring, force_grad) + for _ in range(N) + ] # Inside step. assumes first token is root symbol alpha[A][C][:, :, :, 0].data.fill_(semiring.one()) @@ -108,7 +116,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None): """ batch, N, _ = arc_scores.shape N = N + 1 - v, arcs = deptree_inside(arc_scores, semiring, lengths) + v, arcs = deptree_inside(arc_scores, semiring, lengths, force_grad=True) grads = torch.autograd.grad( v.sum(dim=0), arcs[1:], create_graph=True, only_inputs=True, allow_unused=False ) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 182744aa..faaa66dd 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -1,10 +1,10 @@ import torch -def _make_chart(size, potentials, semiring): +def _make_chart(size, potentials, semiring, force_grad): return ( torch.zeros(*size) .type_as(potentials) .fill_(semiring.zero()) - .requires_grad_(True) + .requires_grad_(force_grad and not potentials.requires_grad) ) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index c8a2a734..8e13d32c 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -3,7 +3,7 @@ from .helpers import _make_chart -def linearchain_forward(edge, semiring=LogSemiring, lengths=None): +def linearchain_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False): """ Compute the forward pass of a linear chain CRF. @@ -18,10 +18,16 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None): inside: list of N, b x C x C table """ - batch, N, C, _ = edge.shape + batch, N, C, C2 = edge.shape if lengths is None: lengths = torch.LongTensor([N] * batch) - alpha = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)] + assert max(lengths) <= N, "Length longer than edge scores" + assert C == C2, "Transition shape doesn't match" + + alpha = [ + _make_chart((batch, C), edge, semiring, force_grad=force_grad) + for n in range(N + 1) + ] edge_store = [None for _ in range(N)] alpha[0].data.fill_(semiring.one()) for n in range(1, N + 1): @@ -33,7 +39,7 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None): return v, edge_store -def linearchain(edge, semiring=LogSemiring): +def linearchain(edge, semiring=LogSemiring, lengths=None): """ Compute the marginals of a linear chain CRF. @@ -41,12 +47,12 @@ def linearchain(edge, semiring=LogSemiring): edge : b x N x C x C markov potentials (t x z_t x z_{t-1}) semiring - + lengths: None or b long tensor mask Returns: marginals: b x N x C x C table """ - v, alpha = linearchain_forward(edge, semiring) + v, alpha = linearchain_forward(edge, semiring, force_grad=True) marg = torch.autograd.grad( v.sum(dim=0), alpha, create_graph=True, only_inputs=True, allow_unused=False ) diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index 28777a5b..6138e343 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -3,7 +3,7 @@ from .helpers import _make_chart -def semimarkov_forward(edge, semiring=LogSemiring, lengths=None): +def semimarkov_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False): """ Compute the forward pass of a semimarkov CRF. @@ -17,12 +17,17 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None): spans: list of N, b x K x C x C table """ - batch, N, K, C, _ = edge.shape + batch, N, K, C, C2 = edge.shape if lengths is None: lengths = torch.LongTensor([N] * batch) + assert max(lengths) <= N, "Length longer than edge scores" + assert C == C2, "Transition shape doesn't match" + spans = [None for _ in range(N)] - alpha = [_make_chart((batch, K, C), edge, semiring) for n in range(N + 1)] - beta = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)] + alpha = [ + _make_chart((batch, K, C), edge, semiring, force_grad) for n in range(N + 1) + ] + beta = [_make_chart((batch, C), edge, semiring, force_grad) for n in range(N + 1)] beta[0].data.fill_(semiring.one()) for n in range(1, N + 1): spans[n - 1] = semiring.times( @@ -33,7 +38,7 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None): f1 = torch.arange(n - 1, t, -1) f2 = torch.arange(1, len(f1) + 1) print(n - 1, f1, f2) - beta[n] = semiring.sum( + beta[n][:] = semiring.sum( torch.stack([alpha[a][:, b] for a, b in zip(f1, f2)]), dim=0 ) v = semiring.sum(torch.stack([beta[l][i] for i, l in enumerate(lengths)]), dim=1) @@ -52,7 +57,7 @@ def semimarkov(edge, semiring=LogSemiring, lengths=None): marginals: b x N x K x C table """ - v, spans = semimarkov_forward(edge, semiring, lengths) + v, spans = semimarkov_forward(edge, semiring, lengths, force_grad=True) marg = torch.autograd.grad( v.sum(dim=0), spans, create_graph=True, only_inputs=True, allow_unused=False ) diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 102d3533..04a06965 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -41,6 +41,14 @@ def test_linearchain(batch, N, C): assert torch.isclose(score.sum(), marginals.mul(vals).sum()).all() +@given(smint, smint, smint) +def test_params(batch, N, C): + vals = torch.ones(batch, N, C, C, requires_grad=True) + semiring = StdSemiring + alpha, _ = linearchain_forward(vals, semiring) + alpha.sum().backward() + + def test_hmm(): C, V, batch, N = 5, 20, 2, 5 transition = torch.rand(C, C) @@ -84,6 +92,13 @@ def test_dep(N): assert torch.isclose(score.sum(), marginals.mul(scores).sum()).all() +def test_dep_params(): + batch, N = 2, 2 + scores = torch.rand(batch, N, N, requires_grad=True) + top, arcs = deptree_inside(scores) + top.sum().backward() + + def test_dep_np(): N = 5 batch = 2 @@ -114,3 +129,14 @@ def test_cky(N, NT, T): + m_root.mul(roots).sum() ).sum(), ).all() + + +@given(smint, tint, tint) +@settings(max_examples=3) +def test_cky_params(N, NT, T): + batch = 2 + terms = torch.rand(batch, N, T) + rules = torch.rand(batch, NT, (NT + T), (NT + T), requires_grad=True) + roots = torch.rand(batch, NT, requires_grad=True) + v, _ = cky_inside(terms, rules, roots) + v.sum().backward()