Skip to content

Commit

Permalink
Merge e773067 into 876588d
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyanpeng committed Feb 28, 2020
2 parents 876588d + e773067 commit 15db24a
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def arr(a, b):
final = beta[A][0, :, NTs]
top = torch.stack([final[:, i, l - 1] for i, l in enumerate(lengths)], dim=1)
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta
return log_Z, (term_use, rules, roots, span[1:]), beta

def marginals(self, scores, lengths=None, _autograd=True):
def marginals(self, scores, lengths=None, _autograd=True, _raw=False):
"""
Compute the marginals of a CFG using CKY.
Expand All @@ -95,34 +95,46 @@ def marginals(self, scores, lengths=None, _autograd=True):
terms, rules, roots = scores
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
v, (term_use, rule_use, top, spans), alpha = self._dp(

v, (term_use, rule_use, root_use, spans), alpha = self._dp(
scores, lengths=lengths, force_grad=True
)
inputs = (rule_use, root_use, term_use) + tuple(spans)

marg = torch.autograd.grad(
v.sum(dim=0),
(rule_use, top, term_use) + tuple(spans),
create_graph=True,
only_inputs=True,
allow_unused=False,
)
def marginal(obj, inputs):
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, inputs, create_graph=True, only_inputs=True, allow_unused=False,
)

spans_marg = torch.zeros(
batch, N, N, NT, dtype=scores[1].dtype, device=scores[1].device
)
span_ls = marg[3:]
for w in range(len(span_ls)):
spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(
span_ls[w].squeeze(1)
spans_marg = torch.zeros(
batch, N, N, NT, dtype=scores[1].dtype, device=scores[1].device
)
rule_use = self.semiring.unconvert(marg[0]).squeeze(1)
term_marg = self.semiring.unconvert(marg[2])
root_marg = self.semiring.unconvert(marg[1])

assert term_marg.shape == (batch, N, T)
assert root_marg.shape == (batch, NT)
assert rule_use.shape == (batch, NT, NT + T, NT + T)
return (term_marg, rule_use, root_marg, spans_marg)
span_ls = marg[3:]
for w in range(len(span_ls)):
spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(span_ls[w])

rule_marg = self.semiring.unconvert(marg[0]).squeeze(1)
root_marg = self.semiring.unconvert(marg[1])
term_marg = self.semiring.unconvert(marg[2])

assert term_marg.shape == (batch, N, T)
assert root_marg.shape == (batch, NT)
assert rule_marg.shape == (batch, NT, NT + T, NT + T)
return (term_marg, rule_marg, root_marg, spans_marg)

if _raw:
paths = []
for k in range(v.shape[0]):
obj = v[k : k + 1]
marg = marginal(obj, inputs)
paths.append(marg[-1])
paths = torch.stack(paths, 0)
obj = v.sum(dim=0, keepdim=True)
term_marg, rule_marg, root_marg, _ = marginal(obj, inputs)
return term_marg, rule_marg, root_marg, paths
else:
return marginal(v, inputs)

def score(self, potentials, parts):
terms, rules, roots = potentials[:3]
Expand Down

0 comments on commit 15db24a

Please sign in to comment.