Skip to content

Commit

Permalink
Merge 068880d into b8d0ee3
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 7, 2019
2 parents b8d0ee3 + 068880d commit 57faa9a
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 140 deletions.
158 changes: 131 additions & 27 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import torch
from .helpers import _Struct
from .semirings import LogSemiring
from torch.autograd import Function

A, B = 0, 1


class DPManual2(Function):
@staticmethod
def forward(ctx, obj, terms, rules, roots, lengths):
with torch.no_grad():
v, _, alpha = obj._dp((terms, rules, roots), lengths, False)
ctx.obj = obj
ctx.lengths = lengths
ctx.alpha = alpha
ctx.v = v
ctx.save_for_backward(terms, rules, roots)
return v

@staticmethod
def backward(ctx, grad_v):
terms, rules, roots = ctx.saved_tensors
with torch.no_grad():
marginals = ctx.obj._dp_backward(
(terms, rules, roots), ctx.lengths, ctx.alpha, ctx.v
)
return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None


class CKY(_Struct):
def sum(self, scores, lengths=None, force_grad=False):
def sum(self, scores, lengths=None, force_grad=False, _autograd=False):
"""
Compute the inside pass of a CFG using CKY.
Expand All @@ -18,7 +42,10 @@ def sum(self, scores, lengths=None, force_grad=False):
v: b tensor of total sum
spans: list of N, b x N x (NT+t)
"""
return self._dp(scores, lengths)[0]
if _autograd or self.semiring is not LogSemiring:
return self._dp(scores, lengths)[0]
else:
return DPManual2.apply(self, *scores, lengths)

def _dp(self, scores, lengths=None, force_grad=False):
terms, rules, roots = scores
Expand All @@ -40,7 +67,6 @@ def _dp(self, scores, lengths=None, force_grad=False):
term_use[:] = terms + 0.0
beta[A][:, :, 0, NT:] = term_use
beta[B][:, :, N - 1, NT:] = term_use

for w in range(1, N):
Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
Expand All @@ -56,9 +82,87 @@ def _dp(self, scores, lengths=None, force_grad=False):

top[:] = torch.stack([beta[A][i, 0, l - 1, :NT] for i, l in enumerate(lengths)])
log_Z = semiring.dot(top, roots)
return log_Z, (term_use, rule_use, top)
return log_Z, (term_use, rule_use, top), beta

def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
terms, rules, roots = scores
semiring = self.semiring
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
S = NT + T
if lengths is None:
lengths = torch.LongTensor([N] * batch)

beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]

ssum = semiring.sum
st = semiring.times
X_Y_Z = rules.view(batch, 1, NT, S, S)

for w in range(N - 1, -1, -1):
for b, l in enumerate(lengths):
beta[A][b, 0, l - 1, :NT] = roots[b]
beta[B][b, l - 1, N - (l), :NT] = roots[b]

def marginals(self, scores, lengths=None):
# LEFT
# all bigger on the left.
X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
batch, N - w - 1, N - w - 1, 1, 1, S
)
t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# sum out x and y
span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# RIGHT
X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
batch, N - w - 1, N - w - 1, 1, S, 1
)
t = st(ssum(st(X, Y), dim=2), X_Y_Z)

span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

beta[A][:, : N - w - 1, w, :] = span_l[w]
beta[A][:, 1 : N - w, w, :] = ssum(
torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
)
beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
for n in range(N):
term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
for b in range(batch):
root_marginals[b] = semiring.div_exp(
st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
)
edge_marginals = self._make_chart(
1, (batch, N, N, NT, S, S), terms, force_grad=False
)[0]
edge_marginals.fill_(0)
for w in range(1, N):
Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
score, v.view(batch, 1, 1, 1, 1)
)
edge_marginals = edge_marginals.transpose(1, 2)

return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=False):
"""
Compute the marginals of a CFG using CKY.
Expand All @@ -76,23 +180,26 @@ def marginals(self, scores, lengths=None):
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
S = NT + T
v, (term_use, rule_use, top) = self._dp(
v, (term_use, rule_use, top), alpha = self._dp(
scores, lengths=lengths, force_grad=True
)
marg = torch.autograd.grad(
v.sum(dim=0),
tuple(rule_use) + (top, term_use),
create_graph=True,
only_inputs=True,
allow_unused=False,
)
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = rule_use[w]
assert marg[-1].shape == (batch, N, T)
assert marg[-2].shape == (batch, NT)
return (marg[-1], rules, marg[-2])
if _autograd or self.semiring is not LogSemiring:
marg = torch.autograd.grad(
v.sum(dim=0),
tuple(rule_use) + (top, term_use),
create_graph=True,
only_inputs=True,
allow_unused=False,
)
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = rule_use[w]
assert marg[-1].shape == (batch, N, T)
assert marg[-2].shape == (batch, NT)
return (marg[-1], rules, marg[-2])
else:
return self._dp_backward(scores, lengths, alpha, v)

@staticmethod
def to_parts(spans, extra, lengths=None):
Expand Down Expand Up @@ -141,7 +248,6 @@ def from_parts(chart):
:, n, torch.arange(N - n - 1)
]
spans[:, torch.arange(N), torch.arange(N), NT:] = terms
print(rules.nonzero(), spans.nonzero())
return spans, (NT, S - NT)

###### Test
Expand All @@ -168,19 +274,17 @@ def enumerate(x, start, end):
[(x, start, w, end)] + y1 + z1,
)

# for nt in range(NT):
# print(list(enumerate(nt, 0, N)))
ls = []
for nt in range(NT):
ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)]
return semiring.sum(torch.stack(ls, dim=-1))

@staticmethod
def _rand():
batch = torch.randint(2, 4, (1,))
N = torch.randint(2, 4, (1,))
NT = torch.randint(2, 4, (1,))
T = torch.randint(2, 4, (1,))
batch = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
NT = torch.randint(2, 5, (1,))
T = torch.randint(2, 5, (1,))
terms = torch.rand(batch, N, T)
rules = torch.rand(batch, NT, (NT + T), (NT + T))
roots = torch.rand(batch, NT)
Expand Down
Loading

0 comments on commit 57faa9a

Please sign in to comment.