diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 070697a8..6d95f619 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -78,7 +78,7 @@ def arr(a, b): log_Z = semiring.dot(top, roots) return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta - def marginals(self, scores, lengths=None, _autograd=True): + def marginals(self, scores, lengths=None, _autograd=True, _raw=False): """ Compute the marginals of a CFG using CKY. @@ -92,6 +92,8 @@ def marginals(self, scores, lengths=None, _autograd=True): spans: bxNxT terms, (bxNTx(NT+S)x(NT+S)) rules, bxNT roots """ + assert not _raw, "top k > 1 with CKY is not supported." + terms, rules, roots = scores batch, N, T = terms.shape _, NT, _, _ = rules.shape