Skip to content

Commit

Permalink
Merge pull request #8 from harvardnlp/cky
Browse files Browse the repository at this point in the history
Extensions
  • Loading branch information
srush committed Sep 10, 2019
2 parents 9aab50f + fbf64f5 commit b368540
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 278 deletions.
280 changes: 168 additions & 112 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ def sum(self, scores, lengths=None, force_grad=False, _autograd=True):
def _dp(self, scores, lengths=None, force_grad=False):
terms, rules, roots = scores
semiring = self.semiring
ssize = semiring.size()
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
S = NT + T

terms, rules, roots = (
semiring.convert(terms),
semiring.convert(rules),
semiring.convert(roots),
)
if lengths is None:
lengths = torch.LongTensor([N] * batch)
beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
Expand All @@ -65,117 +71,123 @@ def _dp(self, scores, lengths=None, force_grad=False):
top = self._make_chart(1, (batch, NT), rules, force_grad)[0]
term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]
term_use[:] = terms + 0.0
beta[A][:, :, 0, NT:] = term_use
beta[B][:, :, N - 1, NT:] = term_use
X_Y_Z = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, :NT]
X_Y_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, NT:]
X_Y1_Z = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, :NT]
X_Y1_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, NT:]
beta[A][:, :, :, 0, NT:] = term_use
beta[B][:, :, :, N - 1, NT:] = term_use
X_Y_Z = rules.view(ssize, batch, 1, NT, S, S)[:, :, :, :, :NT, :NT]
X_Y_Z1 = rules.view(ssize, batch, 1, NT, S, S)[:, :, :, :, :NT, NT:]
X_Y1_Z = rules.view(ssize, batch, 1, NT, S, S)[:, :, :, :, NT:, :NT]
X_Y1_Z1 = rules.view(ssize, batch, 1, NT, S, S)[:, :, :, :, NT:, NT:]

# here
for w in range(1, N):
Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1)
Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT)
rule_use[w - 1][:, :, :, :NT, :NT] = semiring.times(
semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z
Y = beta[A][:, :, : N - w, :w, :NT].view(ssize, batch, N - w, w, 1, NT, 1)
Z = beta[B][:, :, w:, N - w :, :NT].view(ssize, batch, N - w, w, 1, 1, NT)
rule_use[w - 1][:, :, :, :, :NT, :NT] = semiring.times(
semiring.sum(semiring.times(Y, Z), dim=3), X_Y_Z
)
Y = beta[A][:, : N - w, w - 1, :NT].view(batch, N - w, 1, NT, 1)
Z = beta[B][:, w:, N - 1, NT:].view(batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, :NT, NT:] = semiring.times(Y, Z, X_Y_Z1)
Y = beta[A][:, :, : N - w, w - 1, :NT].view(ssize, batch, N - w, 1, NT, 1)
Z = beta[B][:, :, w:, N - 1, NT:].view(ssize, batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, :, :NT, NT:] = semiring.times(Y, Z, X_Y_Z1)

Y = beta[A][:, : N - w, 0, NT:].view(batch, N - w, 1, T, 1)
Z = beta[B][:, w:, N - w, :NT].view(batch, N - w, 1, 1, NT)
rule_use[w - 1][:, :, :, NT:, :NT] = semiring.times(Y, Z, X_Y1_Z)
Y = beta[A][:, :, : N - w, 0, NT:].view(ssize, batch, N - w, 1, T, 1)
Z = beta[B][:, :, w:, N - w, :NT].view(ssize, batch, N - w, 1, 1, NT)
rule_use[w - 1][:, :, :, :, NT:, :NT] = semiring.times(Y, Z, X_Y1_Z)

if w == 1:
Y = beta[A][:, : N - w, w - 1, NT:].view(batch, N - w, 1, T, 1)
Z = beta[B][:, w:, N - w, NT:].view(batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, NT:, NT:] = semiring.times(Y, Z, X_Y1_Z1)

rulesmid = rule_use[w - 1].view(batch, N - w, NT, S * S)
span[w] = semiring.sum(rulesmid, dim=3)
beta[A][:, : N - w, w, :NT] = span[w]
beta[B][:, w:N, N - w - 1, :NT] = beta[A][:, : N - w, w, :NT]

top[:] = torch.stack([beta[A][i, 0, l - 1, :NT] for i, l in enumerate(lengths)])
Y = beta[A][:, :, : N - w, w - 1, NT:].view(
ssize, batch, N - w, 1, T, 1
)
Z = beta[B][:, :, w:, N - w, NT:].view(ssize, batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, :, NT:, NT:] = semiring.times(Y, Z, X_Y1_Z1)

rulesmid = rule_use[w - 1].view(ssize, batch, N - w, NT, S * S)
span[w] = semiring.sum(rulesmid, dim=4)
beta[A][:, :, : N - w, w, :NT] = span[w]
beta[B][:, :, w:N, N - w - 1, :NT] = beta[A][:, :, : N - w, w, :NT]

top[:] = torch.stack(
[beta[A][:, i, 0, l - 1, :NT] for i, l in enumerate(lengths)], dim=1
)
log_Z = semiring.dot(top, roots)
return log_Z, (term_use, rule_use, top), beta

def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
terms, rules, roots = scores
semiring = self.semiring
batch, N, T = terms.shape
_, NT, _, _ = rules.shape
S = NT + T
if lengths is None:
lengths = torch.LongTensor([N] * batch)

beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]

ssum = semiring.sum
st = semiring.times
X_Y_Z = rules.view(batch, 1, NT, S, S)

for w in range(N - 1, -1, -1):
for b, l in enumerate(lengths):
beta[A][b, 0, l - 1, :NT] = roots[b]
beta[B][b, l - 1, N - (l), :NT] = roots[b]

# LEFT
# all bigger on the left.
X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
batch, N - w - 1, N - w - 1, 1, 1, S
)
t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# sum out x and y
span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# RIGHT
X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
batch, N - w - 1, N - w - 1, 1, S, 1
)
t = st(ssum(st(X, Y), dim=2), X_Y_Z)

span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

beta[A][:, : N - w - 1, w, :] = span_l[w]
beta[A][:, 1 : N - w, w, :] = ssum(
torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
)
beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
for n in range(N):
term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
for b in range(batch):
root_marginals[b] = semiring.div_exp(
st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
)
edge_marginals = self._make_chart(
1, (batch, N, N, NT, S, S), terms, force_grad=False
)[0]
edge_marginals.fill_(0)
for w in range(1, N):
Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
score, v.view(batch, 1, 1, 1, 1)
)
edge_marginals = edge_marginals.transpose(1, 2)

return (term_marginals, edge_marginals, root_marginals)
return semiring.unconvert(log_Z), (term_use, rule_use, top), beta

# def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
# terms, rules, roots = scores
# semiring = self.semiring
# batch, N, T = terms.shape
# _, NT, _, _ = rules.shape
# S = NT + T
# if lengths is None:
# lengths = torch.LongTensor([N] * batch)

# beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
# span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]

# ssum = semiring.sum
# st = semiring.times
# X_Y_Z = rules.view(batch, 1, NT, S, S)

# for w in range(N - 1, -1, -1):
# for b, l in enumerate(lengths):
# beta[A][b, 0, l - 1, :NT] = roots[b]
# beta[B][b, l - 1, N - (l), :NT] = roots[b]

# # LEFT
# # all bigger on the left.
# X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
# batch, N - w - 1, N - w - 1, 1, 1, S
# )
# t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# # sum out x and y
# span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# # RIGHT
# X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
# batch, N - w - 1, N - w - 1, 1, S, 1
# )
# t = st(ssum(st(X, Y), dim=2), X_Y_Z)

# span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

# beta[A][:, : N - w - 1, w, :] = span_l[w]
# beta[A][:, 1 : N - w, w, :] = ssum(
# torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
# )
# beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

# term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
# term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
# for n in range(N):
# term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

# root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
# for b in range(batch):
# root_marginals[b] = semiring.div_exp(
# st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
# )
# edge_marginals = self._make_chart(
# 1, (batch, N, N, NT, S, S), terms, force_grad=False
# )[0]
# edge_marginals.fill_(0)
# for w in range(1, N):
# Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
# Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
# score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
# score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
# edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
# score, v.view(batch, 1, 1, 1, 1)
# )
# edge_marginals = edge_marginals.transpose(1, 2)

# return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=True):
"""
Expand Down Expand Up @@ -209,10 +221,14 @@ def marginals(self, scores, lengths=None, _autograd=True):
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = rule_use[w]
assert marg[-1].shape == (batch, N, T)
assert marg[-2].shape == (batch, NT)
return (marg[-1], rules, marg[-2])
rules[:, w, : N - w - 1] = self.semiring.unconvert(rule_use[w])

term_marg = self.semiring.unconvert(marg[-1])
root_marg = self.semiring.unconvert(marg[-2])

assert term_marg.shape == (batch, N, T)
assert root_marg.shape == (batch, NT)
return (term_marg, rules, root_marg)
else:
return self._dp_backward(scores, lengths, alpha, v)

Expand All @@ -231,7 +247,6 @@ def to_parts(spans, extra, lengths=None):
b, torch.arange(lengths[b]), torch.arange(lengths[b]), NT:
]
cover = spans[b].nonzero()

left = {i: [] for i in range(N)}
right = {i: [] for i in range(N)}
for i in range(cover.shape[0]):
Expand All @@ -248,14 +263,15 @@ def to_parts(spans, extra, lengths=None):
break
if j > i:
assert B is not None, "%s" % ((i, j, left[i], right[j], cover),)

rules[b, A, B, C] += 1
return terms, rules, roots

@staticmethod
def from_parts(chart):
terms, rules, roots = chart
batch, N, N, NT, S, S = rules.shape
spans = torch.zeros(batch, N, N, S)
spans = torch.zeros(batch, N, N, S).type_as(terms)
rules = rules.sum(dim=-1).sum(dim=-1)

for n in range(N):
Expand All @@ -265,6 +281,46 @@ def from_parts(chart):
spans[:, torch.arange(N), torch.arange(N), NT:] = terms
return spans, (NT, S - NT)

@staticmethod
def _intermediary(spans):
batch, N = spans.shape[:2]
splits = []
for b in range(batch):
cover = spans[b].nonzero()
left = {i: [] for i in range(N)}
right = {i: [] for i in range(N)}
batch_split = {}
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
left[i].append((A, j, j - i + 1))
right[j].append((A, i, j - i + 1))
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
for B_p, k, a_span in left[i]:
for C_p, k_2, b_span in right[j]:
if k_2 == k + 1 and a_span + b_span == j - i + 1:
k_final = k
break
if j > i:
batch_split[(i, j)] = k_final
splits.append(batch_split)
return splits

@classmethod
def to_networkx(cls, spans):
import networkx as nx

splits = cls._intermediary(spans)
G = nx.DiGraph()
batch, _ = spans.shape
for b in range(batch):
for n in spans[0].nonzero():
G.add_node((b, n[0], n[1]), label=n[2])
for k, v in splits[0].items():
G.add_edge(k, (b, k[0], v))
G.add_edge(k, (b, v + 1, k[1]))
return G

###### Test

def enumerate(self, scores):
Expand Down Expand Up @@ -292,7 +348,7 @@ def enumerate(x, start, end):
ls = []
for nt in range(NT):
ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)]
return semiring.sum(torch.stack(ls, dim=-1))
return semiring.sum(torch.stack(ls, dim=-1)), None

@staticmethod
def _rand():
Expand Down
6 changes: 3 additions & 3 deletions torch_struct/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def SubTokenizedField(tokenizer):
return FIELD


def TokenBucket(train, batch_size, device='cuda:0'):
def TokenBucket(train, batch_size, device='cuda:0', key=lambda x: max(len(x.word[0]), 5)):
def batch_size_fn(x, _, size):
return size + max(len(x.word[0]), 5)
return size + key(x)

return torchtext.data.BucketIterator(
train,
Expand All @@ -70,7 +70,7 @@ def batch_size_fn(x, _, size):
sort_within_batch=True,
shuffle=True,
batch_size=batch_size,
sort_key=lambda x: len(x.word[0]),
sort_key=lambda x: key(x),
repeat=True,
batch_size_fn=batch_size_fn,
device=device,
Expand Down
Loading

0 comments on commit b368540

Please sign in to comment.