Skip to content

Commit

Permalink
Add some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Sep 2, 2019
1 parent e3837bb commit b2d7fcd
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 26 deletions.
3 changes: 3 additions & 0 deletions examples/supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torchtext

torchtext.datsets.UDPos
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
version="0.0.1",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct", ],
packages=["torch_struct"],
package_data={"torch_struct": []},
url="https://github.com/harvardnlp/pytorch_struct",
install_requires=["torch"],
setup_requires=["pytest-runner"],
tests_require=["pytest"]

tests_require=["pytest"],
)
15 changes: 11 additions & 4 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
A, B = 0, 1


def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None):
def cky_inside(
terms, rules, roots, semiring=LogSemiring, lengths=None, force_grad=False
):
"""
Compute the inside pass of a CFG using CKY.
Expand All @@ -23,9 +25,14 @@ def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None):
_, NT, _, _ = rules.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
beta = [_make_chart((batch, N, N, NT + T), rules, semiring) for _ in range(2)]

span = [_make_chart((batch, N, NT + T), rules, semiring) for _ in range(N)]
beta = [
_make_chart((batch, N, N, NT + T), rules, semiring, force_grad)
for _ in range(2)
]

span = [
_make_chart((batch, N, NT + T), rules, semiring, force_grad) for _ in range(N)
]
rule_use = [None for _ in range(N - 1)]
term_use = terms.requires_grad_(True)
beta[A][:, :, 0, NT:] = term_use
Expand Down
18 changes: 13 additions & 5 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _unconvert(logits):
A, B, R, C, L, I = 0, 1, 1, 1, 0, 0


def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None):
def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None, force_grad=False):
"""
Compute the inside pass of a projective dependency CRF.
Expand All @@ -49,10 +49,12 @@ def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None):
"""
arc_scores = _convert(arc_scores)
batch, N, _ = arc_scores.shape
batch, N, N2 = arc_scores.shape
assert N == N2, "Non-square potentials"
DIRS = 2
if lengths is None:
lengths = torch.LongTensor([N] * batch)
assert max(lengths) <= N, "Length longer than N"

def stack(a, b):
return torch.stack([a, b])
Expand All @@ -61,10 +63,16 @@ def sstack(a):
return torch.stack([a, a])

alpha = [
[_make_chart((DIRS, batch, N, N), arc_scores, semiring) for _ in [I, C]]
[
_make_chart((DIRS, batch, N, N), arc_scores, semiring, force_grad)
for _ in [I, C]
]
for _ in range(2)
]
arcs = [_make_chart((DIRS, batch, N), arc_scores, semiring) for _ in range(N)]
arcs = [
_make_chart((DIRS, batch, N), arc_scores, semiring, force_grad)
for _ in range(N)
]

# Inside step. assumes first token is root symbol
alpha[A][C][:, :, :, 0].data.fill_(semiring.one())
Expand Down Expand Up @@ -108,7 +116,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None):
"""
batch, N, _ = arc_scores.shape
N = N + 1
v, arcs = deptree_inside(arc_scores, semiring, lengths)
v, arcs = deptree_inside(arc_scores, semiring, lengths, force_grad=True)
grads = torch.autograd.grad(
v.sum(dim=0), arcs[1:], create_graph=True, only_inputs=True, allow_unused=False
)
Expand Down
4 changes: 2 additions & 2 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch


def _make_chart(size, potentials, semiring):
def _make_chart(size, potentials, semiring, force_grad):
return (
torch.zeros(*size)
.type_as(potentials)
.fill_(semiring.zero())
.requires_grad_(True)
.requires_grad_(force_grad and not potentials.requires_grad)
)
18 changes: 12 additions & 6 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .helpers import _make_chart


def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
def linearchain_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False):
"""
Compute the forward pass of a linear chain CRF.
Expand All @@ -18,10 +18,16 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
inside: list of N, b x C x C table
"""
batch, N, C, _ = edge.shape
batch, N, C, C2 = edge.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
alpha = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
assert max(lengths) <= N, "Length longer than edge scores"
assert C == C2, "Transition shape doesn't match"

alpha = [
_make_chart((batch, C), edge, semiring, force_grad=force_grad)
for n in range(N + 1)
]
edge_store = [None for _ in range(N)]
alpha[0].data.fill_(semiring.one())
for n in range(1, N + 1):
Expand All @@ -33,20 +39,20 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
return v, edge_store


def linearchain(edge, semiring=LogSemiring):
def linearchain(edge, semiring=LogSemiring, lengths=None):
"""
Compute the marginals of a linear chain CRF.
Parameters:
edge : b x N x C x C markov potentials
(t x z_t x z_{t-1})
semiring
lengths: None or b long tensor mask
Returns:
marginals: b x N x C x C table
"""
v, alpha = linearchain_forward(edge, semiring)
v, alpha = linearchain_forward(edge, semiring, force_grad=True)
marg = torch.autograd.grad(
v.sum(dim=0), alpha, create_graph=True, only_inputs=True, allow_unused=False
)
Expand Down
17 changes: 11 additions & 6 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .helpers import _make_chart


def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
def semimarkov_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False):
"""
Compute the forward pass of a semimarkov CRF.
Expand All @@ -17,12 +17,17 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
spans: list of N, b x K x C x C table
"""
batch, N, K, C, _ = edge.shape
batch, N, K, C, C2 = edge.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
assert max(lengths) <= N, "Length longer than edge scores"
assert C == C2, "Transition shape doesn't match"

spans = [None for _ in range(N)]
alpha = [_make_chart((batch, K, C), edge, semiring) for n in range(N + 1)]
beta = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
alpha = [
_make_chart((batch, K, C), edge, semiring, force_grad) for n in range(N + 1)
]
beta = [_make_chart((batch, C), edge, semiring, force_grad) for n in range(N + 1)]
beta[0].data.fill_(semiring.one())
for n in range(1, N + 1):
spans[n - 1] = semiring.times(
Expand All @@ -33,7 +38,7 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
f1 = torch.arange(n - 1, t, -1)
f2 = torch.arange(1, len(f1) + 1)
print(n - 1, f1, f2)
beta[n] = semiring.sum(
beta[n][:] = semiring.sum(
torch.stack([alpha[a][:, b] for a, b in zip(f1, f2)]), dim=0
)
v = semiring.sum(torch.stack([beta[l][i] for i, l in enumerate(lengths)]), dim=1)
Expand All @@ -52,7 +57,7 @@ def semimarkov(edge, semiring=LogSemiring, lengths=None):
marginals: b x N x K x C table
"""
v, spans = semimarkov_forward(edge, semiring, lengths)
v, spans = semimarkov_forward(edge, semiring, lengths, force_grad=True)
marg = torch.autograd.grad(
v.sum(dim=0), spans, create_graph=True, only_inputs=True, allow_unused=False
)
Expand Down
26 changes: 26 additions & 0 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def test_linearchain(batch, N, C):
assert torch.isclose(score.sum(), marginals.mul(vals).sum()).all()


@given(smint, smint, smint)
def test_params(batch, N, C):
vals = torch.ones(batch, N, C, C, requires_grad=True)
semiring = StdSemiring
alpha, _ = linearchain_forward(vals, semiring)
alpha.sum().backward()


def test_hmm():
C, V, batch, N = 5, 20, 2, 5
transition = torch.rand(C, C)
Expand Down Expand Up @@ -84,6 +92,13 @@ def test_dep(N):
assert torch.isclose(score.sum(), marginals.mul(scores).sum()).all()


def test_dep_params():
batch, N = 2, 2
scores = torch.rand(batch, N, N, requires_grad=True)
top, arcs = deptree_inside(scores)
top.sum().backward()


def test_dep_np():
N = 5
batch = 2
Expand Down Expand Up @@ -114,3 +129,14 @@ def test_cky(N, NT, T):
+ m_root.mul(roots).sum()
).sum(),
).all()


@given(smint, tint, tint)
@settings(max_examples=3)
def test_cky_params(N, NT, T):
batch = 2
terms = torch.rand(batch, N, T)
rules = torch.rand(batch, NT, (NT + T), (NT + T), requires_grad=True)
roots = torch.rand(batch, NT, requires_grad=True)
v, _ = cky_inside(terms, rules, roots)
v.sum().backward()

0 comments on commit b2d7fcd

Please sign in to comment.