Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Aug 30, 2019
1 parent e028d56 commit e3837bb
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 38 deletions.
37 changes: 20 additions & 17 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
A, B = 0, 1


def cky_inside(terms, rules, roots, semiring=LogSemiring):
def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None):
"""
Compute the inside pass of a CFG using CKY.
Expand All @@ -19,36 +19,37 @@ def cky_inside(terms, rules, roots, semiring=LogSemiring):
v: b tensor of total sum
spans: list of N, b x N x (NT+t)
"""
batch_size, N, T = terms.shape
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
beta = [_make_chart((batch, N, N, NT + T), rules, semiring) for _ in range(2)]

beta = [_make_chart((batch_size, N, N, NT + T), rules, semiring) for _ in range(2)]

span = [_make_chart((batch_size, N, NT + T), rules, semiring) for _ in range(N)]
span = [_make_chart((batch, N, NT + T), rules, semiring) for _ in range(N)]
rule_use = [None for _ in range(N - 1)]
term_use = terms.requires_grad_(True)
beta[A][:, :, 0, NT:] = term_use
beta[B][:, :, N - 1, NT:] = term_use

S = NT + T
for w in range(1, N):
Y = beta[A][:, : N - w, :w, :].view(batch_size, N - w, w, 1, S, 1)
Z = beta[B][:, w:, N - w :, :].view(batch_size, N - w, w, 1, 1, S)
X_Y_Z = rules.view(batch_size, 1, NT, S, S)
Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
X_Y_Z = rules.view(batch, 1, NT, S, S)
rule_use[w - 1] = semiring.times(
semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z
)
rulesmid = rule_use[w - 1].view(batch_size, N - w, NT, S * S)
rulesmid = rule_use[w - 1].view(batch, N - w, NT, S * S)
span[w] = semiring.sum(rulesmid, dim=3)
beta[A][:, : N - w, w, :NT] = span[w]
beta[B][:, w:N, N - w - 1, :NT] = beta[A][:, : N - w, w, :NT]

top = beta[A][:, 0, N - 1, :NT]
top = torch.stack([beta[A][i, 0, l - 1, :NT] for i, l in enumerate(lengths)])
log_Z = semiring.dot(top, roots)
return log_Z, (term_use, rule_use, top)


def cky(terms, rules, roots, semiring=LogSemiring):
def cky(terms, rules, roots, semiring=LogSemiring, lengths=None):
"""
Compute the marginals of a CFG using CKY.
Expand All @@ -63,10 +64,12 @@ def cky(terms, rules, roots, semiring=LogSemiring):
spans: bxNxT terms, (bxNxNxNTxSxS) rules, bxNT roots
"""
batch_size, N, T = terms.shape
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
S = NT + T
v, (term_use, rule_use, top) = cky_inside(terms, rules, roots, semiring=semiring)
v, (term_use, rule_use, top) = cky_inside(
terms, rules, roots, semiring=semiring, lengths=lengths
)
marg = torch.autograd.grad(
v.sum(dim=0),
tuple(rule_use) + (top, term_use),
Expand All @@ -76,19 +79,19 @@ def cky(terms, rules, roots, semiring=LogSemiring):
)

rule_use = marg[:-2]
rules = torch.zeros(batch_size, N, N, NT, S, S)
rules = torch.zeros(batch, N, N, NT, S, S)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = rule_use[w]
assert marg[-1].shape == (batch_size, N, T)
assert marg[-2].shape == (batch_size, NT)
assert marg[-1].shape == (batch, N, T)
assert marg[-2].shape == (batch, NT)
return (marg[-1], rules, marg[-2])


###### Test


def cky_check(terms, rules, roots, semiring=LogSemiring):
batch_size, N, T = terms.shape
batch, N, T = terms.shape
_, NT, _, _ = rules.shape

def enumerate(x, start, end):
Expand Down
28 changes: 17 additions & 11 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,25 @@ def _unconvert(logits):
A, B, R, C, L, I = 0, 1, 1, 1, 0, 0


def deptree_inside(arc_scores, semiring=LogSemiring):
def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None):
"""
Compute the inside pass of a projective dependency CRF.
Parameters:
arc_scores : b x N x N arc scores with root scores on diagonal.
semiring
lengths: None or b long tensor mask
Returns:
v: b tensor of total sum
arcs: list of N, LR x b x N table
"""
arc_scores = _convert(arc_scores)
batch_size, N, _ = arc_scores.shape
batch, N, _ = arc_scores.shape
DIRS = 2
if lengths is None:
lengths = torch.LongTensor([N] * batch)

def stack(a, b):
return torch.stack([a, b])
Expand All @@ -58,10 +61,10 @@ def sstack(a):
return torch.stack([a, a])

alpha = [
[_make_chart((DIRS, batch_size, N, N), arc_scores, semiring) for _ in [I, C]]
[_make_chart((DIRS, batch, N, N), arc_scores, semiring) for _ in [I, C]]
for _ in range(2)
]
arcs = [_make_chart((DIRS, batch_size, N), arc_scores, semiring) for _ in range(N)]
arcs = [_make_chart((DIRS, batch, N), arc_scores, semiring) for _ in range(N)]

# Inside step. assumes first token is root symbol
alpha[A][C][:, :, :, 0].data.fill_(semiring.one())
Expand All @@ -85,28 +88,31 @@ def sstack(a):
),
)
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]
return alpha[A][C][R, :, 0, N - 1], arcs
return (
torch.stack([alpha[A][C][R, i, 0, l - 1] for i, l in enumerate(lengths)]),
arcs,
)


def deptree(arc_scores, semiring=LogSemiring):
def deptree(arc_scores, semiring=LogSemiring, lengths=None):
"""
Compute the marginals of a projective dependency CRF.
Parameters:
arc_scores : b x N x N arc scores with root scores on diagonal.
semiring
lengths
Returns:
arc_marginals : b x N x N.
"""
batch_size, N, _ = arc_scores.shape
batch, N, _ = arc_scores.shape
N = N + 1
v, arcs = deptree_inside(arc_scores, semiring)
v, arcs = deptree_inside(arc_scores, semiring, lengths)
grads = torch.autograd.grad(
v.sum(dim=0), arcs[1:], create_graph=True, only_inputs=True, allow_unused=False
)
ret = torch.zeros(batch_size, N, N).cpu()
ret = torch.zeros(batch, N, N).cpu()
for k, grad in enumerate(grads, 1):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[0], f[1]] = grad[R].cpu()
Expand Down Expand Up @@ -163,7 +169,7 @@ def deptree_check(arc_scores, semiring=LogSemiring, non_proj=False):
parses = []
q = []
arc_scores = _convert(arc_scores)
batch_size, N, _ = arc_scores.shape
batch, N, _ = arc_scores.shape
for mid in itertools.product(range(N + 1), repeat=N - 1):
parse = [-1] + list(mid)
if not _is_spanning(parse):
Expand Down
10 changes: 7 additions & 3 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,34 @@
from .helpers import _make_chart


def linearchain_forward(edge, semiring=LogSemiring):
def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
"""
Compute the forward pass of a linear chain CRF.
Parameters:
edge : b x N x C x C markov potentials
(n-1 x z_n x z_{n-1})
semiring
lengths: None or b long tensor mask
Returns:
v: b tensor of total sum
inside: list of N, b x C x C table
"""
batch, N, C, _ = edge.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
alpha = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
edge_store = [None for _ in range(N)]
alpha[0].data.fill_(semiring.one())
for n in range(1, N + 1):
edge_store[n - 1] = semiring.times(
alpha[n - 1].view(batch, 1, C), edge[:, n - 1]
)
alpha[n] = semiring.sum(edge_store[n - 1])
return semiring.sum(alpha[N]), edge_store
alpha[n][:] = semiring.sum(edge_store[n - 1])
v = semiring.sum(torch.stack([alpha[l][i] for i, l in enumerate(lengths)]), dim=-1)
return v, edge_store


def linearchain(edge, semiring=LogSemiring):
Expand Down
13 changes: 8 additions & 5 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
from .helpers import _make_chart


def semimarkov_forward(edge, semiring=LogSemiring):
def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
"""
Compute the forward pass of a semimarkov CRF.
Parameters:
edge : b x N x K x C x C semimarkov potentials
semiring
lengths: None or b long tensor mask
Returns:
v: b tensor of total sum
spans: list of N, b x K x C x C table
"""
batch, N, K, C, _ = edge.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
spans = [None for _ in range(N)]
alpha = [_make_chart((batch, K, C), edge, semiring) for n in range(N + 1)]
beta = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
Expand All @@ -33,11 +36,11 @@ def semimarkov_forward(edge, semiring=LogSemiring):
beta[n] = semiring.sum(
torch.stack([alpha[a][:, b] for a, b in zip(f1, f2)]), dim=0
)
v = semiring.sum(torch.stack([beta[l][i] for i, l in enumerate(lengths)]), dim=1)
return v, spans

return semiring.sum(beta[N], dim=1), spans


def semimarkov(edge, semiring=LogSemiring):
def semimarkov(edge, semiring=LogSemiring, lengths=None):
"""
Compute the marginals of a semimarkov CRF.
Expand All @@ -49,7 +52,7 @@ def semimarkov(edge, semiring=LogSemiring):
marginals: b x N x K x C table
"""
v, spans = semimarkov_forward(edge, semiring)
v, spans = semimarkov_forward(edge, semiring, lengths)
marg = torch.autograd.grad(
v.sum(dim=0), spans, create_graph=True, only_inputs=True, allow_unused=False
)
Expand Down
5 changes: 3 additions & 2 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_simple(batch, N, C):
linearchain(vals, SampledSemiring)


@given(smint, smint, smint)
@settings(max_examples=50)
@given(smint, smint, tint)
@settings(max_examples=25)
def test_linearchain(batch, N, C):
for semiring in [LogSemiring, MaxSemiring]:
vals = torch.rand(batch, N, C, C)
Expand Down Expand Up @@ -69,6 +69,7 @@ def test_semimarkov(N, K, V, C):


@given(smint)
@settings(max_examples=25)
def test_dep(N):
batch = 2
for semiring in [LogSemiring, MaxSemiring]:
Expand Down

0 comments on commit e3837bb

Please sign in to comment.