Skip to content

Commit

Permalink
Merge 77526df into 876588d
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyanpeng committed Feb 28, 2020
2 parents 876588d + 77526df commit 5260354
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def arr(a, b):
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta

def marginals(self, scores, lengths=None, _autograd=True):
def marginals(self, scores, lengths=None, _autograd=True, _raw=False):
"""
Compute the marginals of a CFG using CKY.
Expand All @@ -92,6 +92,8 @@ def marginals(self, scores, lengths=None, _autograd=True):
spans: bxNxT terms, (bxNTx(NT+S)x(NT+S)) rules, bxNT roots
"""
assert not _raw, "top k > 1 with CKY is not supported."

terms, rules, roots = scores
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
Expand Down

0 comments on commit 5260354

Please sign in to comment.