From 67aa60deacc4c8a46d6e6f3598131f0c2d97e54e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Y=C3=A4n=2EPnG?= Date: Sun, 8 Mar 2020 21:21:58 +0000 Subject: [PATCH] add tests for CKY (#53) * minimize the CKY for debugging * add tests for the CKY * fix formatting issues --- torch_struct/cky.py | 15 +++-- torch_struct/test_cky.py | 133 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 torch_struct/test_cky.py diff --git a/torch_struct/cky.py b/torch_struct/cky.py index dc594771..54e57f99 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -5,7 +5,7 @@ class CKY(_Struct): - def _dp(self, scores, lengths=None, force_grad=False): + def _dp(self, scores, lengths=None, force_grad=False, cache=True): semiring = self.semiring @@ -26,7 +26,9 @@ def _dp(self, scores, lengths=None, force_grad=False): lengths = torch.LongTensor([N] * batch) # Charts - beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)] + beta = [ + Chart((batch, N, N, NT), rules, semiring, cache=cache) for _ in range(2) + ] span = [None for _ in range(N)] v = (ssize, batch) term_use = terms + 0.0 @@ -97,12 +99,11 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False): _, NT, _, _ = rules.shape v, (term_use, rule_use, root_use, spans), alpha = self._dp( - scores, lengths=lengths, force_grad=True + scores, lengths=lengths, force_grad=True, cache=not _raw ) - inputs = (rule_use, root_use, term_use) + tuple(spans) def marginal(obj, inputs): - obj = self.semiring.unconvert(v).sum(dim=0) + obj = self.semiring.unconvert(obj).sum(dim=0) marg = torch.autograd.grad( obj, inputs, create_graph=True, only_inputs=True, allow_unused=False, ) @@ -112,7 +113,8 @@ def marginal(obj, inputs): ) span_ls = marg[3:] for w in range(len(span_ls)): - spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(span_ls[w]) + x = span_ls[w].sum(dim=0, keepdim=True) + spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(x) rule_marg = self.semiring.unconvert(marg[0]).squeeze(1) root_marg = self.semiring.unconvert(marg[1]) @@ -123,6 +125,7 @@ def marginal(obj, inputs): assert rule_marg.shape == (batch, NT, NT + T, NT + T) return (term_marg, rule_marg, root_marg, spans_marg) + inputs = (rule_use, root_use, term_use) + tuple(spans) if _raw: paths = [] for k in range(v.shape[0]): diff --git a/torch_struct/test_cky.py b/torch_struct/test_cky.py new file mode 100644 index 00000000..d2f18e07 --- /dev/null +++ b/torch_struct/test_cky.py @@ -0,0 +1,133 @@ +import torch + +from torch_struct import SentCFG + + +def params_l3(): + """ + seq = x y z, t0, t1 & n0, n1, n2 + """ + terms = [[2, 1], [1, 2], [1, 1]] + # term4 = [[1, 1], [2, 1], [1, 2]] + roots = [1, 1, 1] + rule1 = [ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 5], + [1, 1, 1, 2, 1], + ] + rule2 = [ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 6], + [1, 1, 1, 2, 1], + ] + rule3 = [ + [1, 1, 1, 1, 1], + [1, 1, 1, 4, 5], + [1, 1, 1, 8, 9], + [1, 1, 1, 1, 4], + [1, 1, 1, 2, 1], + ] + rules = [rule1, rule2, rule3] + + terms = ( + torch.tensor(terms, dtype=torch.float64, requires_grad=True) + .unsqueeze(0) + .float() + ) + roots = ( + torch.tensor(roots, dtype=torch.float64, requires_grad=True) + .unsqueeze(0) + .float() + ) + rules = ( + torch.tensor(rules, dtype=torch.float64, requires_grad=True) + .unsqueeze(0) + .float() + ) + + length = torch.tensor([3]).long() + + # print('term:\n', terms, terms.shape) + # print('root:\n', roots, roots.shape) + # print('rule:\n', rules, rules.shape) + return ((terms, rules, roots), length) + + +def extract_parse(span, length): + tree = [(i, str(i)) for i in range(length)] + tree = dict(tree) + spans = [] + cover = (span > 0).float().nonzero() + for i in range(cover.shape[0]): + w, r, A = cover[i].tolist() + w = w + 1 + r = r + w + l = r - w + spans.append((l, r, A)) + span = "({} {})".format(tree[l], tree[r]) + tree[r] = tree[l] = span + return spans, tree[0] + + +def extract_topk(matrix, lengths): + batch, K, N = matrix.shape[:3] + spans = [] + trees = [] + for b in range(batch): + for k in range(K): + this_span = matrix[b][k] + span, tree = extract_parse(this_span, lengths[b]) + trees.append(tree) + spans.append(span) + # print(span) + # print(tree) + # break + return spans, trees + + +def extract_parses(matrix, lengths): + batch, K, N = matrix.shape[:3] + spans = [] + trees = [] + for b in range(batch): + span, tree = extract_parse(matrix[b], lengths[b]) + trees.append(tree) + spans.append(span) + # print(span, tree) + # break + return spans, trees + + +def test_l3_kbest(): + params, lengths = params_l3() + dist = SentCFG(params, lengths=lengths) + + _, _, _, spans = dist.argmax + spans, trees = extract_parses(spans, lengths) + best_trees = "((0 1) 2)" + best_spans = [(0, 1, 2), (0, 2, 2)] + assert spans[0] == best_spans + assert trees[0] == best_trees + + _, _, _, spans = dist.topk(4) + size = (1, 0) + tuple(range(2, spans.dim())) + spans = spans.permute(size) + spans, trees = extract_topk(spans, lengths) + best_trees = "((0 1) 2)" + best_spans = [ + [(0, 1, 2), (0, 2, 2)], + [(0, 1, 2), (0, 2, 2)], + [(0, 1, 1), (0, 2, 2)], + [(0, 1, 1), (0, 2, 2)], + ] + for i, (span, tree) in enumerate(zip(spans, trees)): + assert span == best_spans[i] + assert tree == best_trees + + +if __name__ == "__main__": + test_l3_kbest()