Skip to content

Commit

Permalink
add tests for CKY (#53)
Browse files Browse the repository at this point in the history
* minimize the CKY for debugging

* add tests for the CKY

* fix formatting issues
  • Loading branch information
zhaoyanpeng committed Mar 8, 2020
1 parent 8608b6c commit 67aa60d
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 6 deletions.
15 changes: 9 additions & 6 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class CKY(_Struct):
def _dp(self, scores, lengths=None, force_grad=False):
def _dp(self, scores, lengths=None, force_grad=False, cache=True):

semiring = self.semiring

Expand All @@ -26,7 +26,9 @@ def _dp(self, scores, lengths=None, force_grad=False):
lengths = torch.LongTensor([N] * batch)

# Charts
beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)]
beta = [
Chart((batch, N, N, NT), rules, semiring, cache=cache) for _ in range(2)
]
span = [None for _ in range(N)]
v = (ssize, batch)
term_use = terms + 0.0
Expand Down Expand Up @@ -97,12 +99,11 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False):
_, NT, _, _ = rules.shape

v, (term_use, rule_use, root_use, spans), alpha = self._dp(
scores, lengths=lengths, force_grad=True
scores, lengths=lengths, force_grad=True, cache=not _raw
)
inputs = (rule_use, root_use, term_use) + tuple(spans)

def marginal(obj, inputs):
obj = self.semiring.unconvert(v).sum(dim=0)
obj = self.semiring.unconvert(obj).sum(dim=0)
marg = torch.autograd.grad(
obj, inputs, create_graph=True, only_inputs=True, allow_unused=False,
)
Expand All @@ -112,7 +113,8 @@ def marginal(obj, inputs):
)
span_ls = marg[3:]
for w in range(len(span_ls)):
spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(span_ls[w])
x = span_ls[w].sum(dim=0, keepdim=True)
spans_marg[:, w, : N - w - 1] = self.semiring.unconvert(x)

rule_marg = self.semiring.unconvert(marg[0]).squeeze(1)
root_marg = self.semiring.unconvert(marg[1])
Expand All @@ -123,6 +125,7 @@ def marginal(obj, inputs):
assert rule_marg.shape == (batch, NT, NT + T, NT + T)
return (term_marg, rule_marg, root_marg, spans_marg)

inputs = (rule_use, root_use, term_use) + tuple(spans)
if _raw:
paths = []
for k in range(v.shape[0]):
Expand Down
133 changes: 133 additions & 0 deletions torch_struct/test_cky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch

from torch_struct import SentCFG


def params_l3():
"""
seq = x y z, t0, t1 & n0, n1, n2
"""
terms = [[2, 1], [1, 2], [1, 1]]
# term4 = [[1, 1], [2, 1], [1, 2]]
roots = [1, 1, 1]
rule1 = [
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 5],
[1, 1, 1, 2, 1],
]
rule2 = [
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 6],
[1, 1, 1, 2, 1],
]
rule3 = [
[1, 1, 1, 1, 1],
[1, 1, 1, 4, 5],
[1, 1, 1, 8, 9],
[1, 1, 1, 1, 4],
[1, 1, 1, 2, 1],
]
rules = [rule1, rule2, rule3]

terms = (
torch.tensor(terms, dtype=torch.float64, requires_grad=True)
.unsqueeze(0)
.float()
)
roots = (
torch.tensor(roots, dtype=torch.float64, requires_grad=True)
.unsqueeze(0)
.float()
)
rules = (
torch.tensor(rules, dtype=torch.float64, requires_grad=True)
.unsqueeze(0)
.float()
)

length = torch.tensor([3]).long()

# print('term:\n', terms, terms.shape)
# print('root:\n', roots, roots.shape)
# print('rule:\n', rules, rules.shape)
return ((terms, rules, roots), length)


def extract_parse(span, length):
tree = [(i, str(i)) for i in range(length)]
tree = dict(tree)
spans = []
cover = (span > 0).float().nonzero()
for i in range(cover.shape[0]):
w, r, A = cover[i].tolist()
w = w + 1
r = r + w
l = r - w
spans.append((l, r, A))
span = "({} {})".format(tree[l], tree[r])
tree[r] = tree[l] = span
return spans, tree[0]


def extract_topk(matrix, lengths):
batch, K, N = matrix.shape[:3]
spans = []
trees = []
for b in range(batch):
for k in range(K):
this_span = matrix[b][k]
span, tree = extract_parse(this_span, lengths[b])
trees.append(tree)
spans.append(span)
# print(span)
# print(tree)
# break
return spans, trees


def extract_parses(matrix, lengths):
batch, K, N = matrix.shape[:3]
spans = []
trees = []
for b in range(batch):
span, tree = extract_parse(matrix[b], lengths[b])
trees.append(tree)
spans.append(span)
# print(span, tree)
# break
return spans, trees


def test_l3_kbest():
params, lengths = params_l3()
dist = SentCFG(params, lengths=lengths)

_, _, _, spans = dist.argmax
spans, trees = extract_parses(spans, lengths)
best_trees = "((0 1) 2)"
best_spans = [(0, 1, 2), (0, 2, 2)]
assert spans[0] == best_spans
assert trees[0] == best_trees

_, _, _, spans = dist.topk(4)
size = (1, 0) + tuple(range(2, spans.dim()))
spans = spans.permute(size)
spans, trees = extract_topk(spans, lengths)
best_trees = "((0 1) 2)"
best_spans = [
[(0, 1, 2), (0, 2, 2)],
[(0, 1, 2), (0, 2, 2)],
[(0, 1, 1), (0, 2, 2)],
[(0, 1, 1), (0, 2, 2)],
]
for i, (span, tree) in enumerate(zip(spans, trees)):
assert span == best_spans[i]
assert tree == best_trees


if __name__ == "__main__":
test_l3_kbest()

0 comments on commit 67aa60d

Please sign in to comment.