Skip to content

Commit

Permalink
Merge 551a042 into 67aadfe
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 22, 2019
2 parents 67aadfe + 551a042 commit 9c1a21b
Show file tree
Hide file tree
Showing 17 changed files with 909 additions and 244 deletions.
500 changes: 500 additions & 0 deletions examples/tree.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -5,7 +5,7 @@
version="0.0.1",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct", "torch_struct.data"],
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],
package_data={"torch_struct": []},
url="https://github.com/harvardnlp/pytorch_struct",
install_requires=["torch"],
Expand Down
13 changes: 12 additions & 1 deletion torch_struct/__init__.py
@@ -1,20 +1,31 @@
from .cky import CKY
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
from .semimarkov import SemiMarkov
from .semirings import LogSemiring, StdSemiring, SampledSemiring, MaxSemiring
from .semirings import (
LogSemiring,
StdSemiring,
SampledSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
)


version = "0.0.1"

# For flake8 compatibility.
__all__ = [
CKY,
CKY_CRF,
DepTree,
LinearChain,
SemiMarkov,
LogSemiring,
StdSemiring,
SampledSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
]
220 changes: 65 additions & 155 deletions torch_struct/cky.py
@@ -1,52 +1,11 @@
import torch
from .helpers import _Struct
from .semirings import LogSemiring
from torch.autograd import Function

A, B = 0, 1


class DPManual2(Function):
@staticmethod
def forward(ctx, obj, terms, rules, roots, lengths):
with torch.no_grad():
v, _, alpha = obj._dp((terms, rules, roots), lengths, False)
ctx.obj = obj
ctx.lengths = lengths
ctx.alpha = alpha
ctx.v = v
ctx.save_for_backward(terms, rules, roots)
return v

@staticmethod
def backward(ctx, grad_v):
terms, rules, roots = ctx.saved_tensors
with torch.no_grad():
marginals = ctx.obj._dp_backward(
(terms, rules, roots), ctx.lengths, ctx.alpha, ctx.v
)
return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None


class CKY(_Struct):
def sum(self, scores, lengths=None, force_grad=False, _autograd=True):
"""
Compute the inside pass of a CFG using CKY.
Parameters:
terms : b x n x T
rules : b x NT x (NT+T) x (NT+T)
root: b x NT
Returns:
v: b tensor of total sum
spans: list of N, b x N x (NT+t)
"""
if _autograd or self.semiring is not LogSemiring:
return self._dp(scores, lengths)[0]
else:
return DPManual2.apply(self, *scores, lengths=lengths)

def _dp(self, scores, lengths=None, force_grad=False):
terms, rules, roots = scores
semiring = self.semiring
Expand Down Expand Up @@ -111,84 +70,6 @@ def _dp(self, scores, lengths=None, force_grad=False):
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rule_use, top), beta

# def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
# terms, rules, roots = scores
# semiring = self.semiring
# batch, N, T = terms.shape
# _, NT, _, _ = rules.shape
# S = NT + T
# if lengths is None:
# lengths = torch.LongTensor([N] * batch)

# beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
# span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]

# ssum = semiring.sum
# st = semiring.times
# X_Y_Z = rules.view(batch, 1, NT, S, S)

# for w in range(N - 1, -1, -1):
# for b, l in enumerate(lengths):
# beta[A][b, 0, l - 1, :NT] = roots[b]
# beta[B][b, l - 1, N - (l), :NT] = roots[b]

# # LEFT
# # all bigger on the left.
# X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
# batch, N - w - 1, N - w - 1, 1, 1, S
# )
# t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# # sum out x and y
# span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# # RIGHT
# X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
# batch, N - w - 1, N - w - 1, 1, S, 1
# )
# t = st(ssum(st(X, Y), dim=2), X_Y_Z)

# span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

# beta[A][:, : N - w - 1, w, :] = span_l[w]
# beta[A][:, 1 : N - w, w, :] = ssum(
# torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
# )
# beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

# term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
# term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
# for n in range(N):
# term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

# root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
# for b in range(batch):
# root_marginals[b] = semiring.div_exp(
# st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
# )
# edge_marginals = self._make_chart(
# 1, (batch, N, N, NT, S, S), terms, force_grad=False
# )[0]
# edge_marginals.fill_(0)
# for w in range(1, N):
# Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
# Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
# score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
# score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
# edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
# score, v.view(batch, 1, 1, 1, 1)
# )
# edge_marginals = edge_marginals.transpose(1, 2)

# return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=True):
"""
Compute the marginals of a CFG using CKY.
Expand Down Expand Up @@ -219,7 +100,9 @@ def marginals(self, scores, lengths=None, _autograd=True):
allow_unused=False,
)
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S)
rules = torch.zeros(
batch, N, N, NT, S, S, dtype=scores[1].dtype, device=scores[1].device
)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = self.semiring.unconvert(rule_use[w])

Expand Down Expand Up @@ -263,17 +146,17 @@ def to_parts(spans, extra, lengths=None):
break
if j > i:
assert B is not None, "%s" % ((i, j, left[i], right[j], cover),)

rules[b, A, B, C] += 1
return terms, rules, roots

@staticmethod
def from_parts(chart):
terms, rules, roots = chart
batch, N, N, NT, S, S = rules.shape
spans = torch.zeros(batch, N, N, S).type_as(terms)
rules = rules.sum(dim=-1).sum(dim=-1)
assert terms.shape[1] == N

spans = torch.zeros(batch, N, N, S, dtype=rules.dtype, device=rules.device)
rules = rules.sum(dim=-1).sum(dim=-1)
for n in range(N):
spans[:, torch.arange(N - n - 1), torch.arange(n + 1, N), :NT] = rules[
:, n, torch.arange(N - n - 1)
Expand All @@ -284,42 +167,69 @@ def from_parts(chart):
@staticmethod
def _intermediary(spans):
batch, N = spans.shape[:2]
splits = []
for b in range(batch):
cover = spans[b].nonzero()
left = {i: [] for i in range(N)}
right = {i: [] for i in range(N)}
batch_split = {}
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
left[i].append((A, j, j - i + 1))
right[j].append((A, i, j - i + 1))
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
for B_p, k, a_span in left[i]:
for C_p, k_2, b_span in right[j]:
if k_2 == k + 1 and a_span + b_span == j - i + 1:
k_final = k
break
if j > i:
batch_split[(i, j)] = k_final
splits.append(batch_split)
splits = {}
cover = spans.nonzero()
left, right = {}, {}
for k in range(cover.shape[0]):
b, i, j, A = cover[k].tolist()
left.setdefault((b, i), [])
right.setdefault((b, j), [])
left[b, i].append((A, j, j - i + 1))
right[b, j].append((A, i, j - i + 1))

for x in range(cover.shape[0]):
b, i, j, A = cover[x].tolist()
if i == j:
continue
b_final = None
c_final = None
k_final = None
for B_p, k, a_span in left.get((b, i), []):
if k > j:
continue
for C_p, k_2, b_span in right.get((b, j), []):
if k_2 == k + 1 and a_span + b_span == j - i + 1:
k_final = k
b_final = B_p
c_final = C_p
break
if b_final is not None:
break
assert k_final is not None, "%s %s %s %s" % (b, i, j, spans[b].nonzero())
splits[(b, i, j)] = k_final, b_final, c_final
return splits

@classmethod
def to_networkx(cls, spans):
import networkx as nx

splits = cls._intermediary(spans)
G = nx.DiGraph()
batch, _ = spans.shape
for b in range(batch):
for n in spans[0].nonzero():
G.add_node((b, n[0], n[1]), label=n[2])
for k, v in splits[0].items():
G.add_edge(k, (b, k[0], v))
G.add_edge(k, (b, v + 1, k[1]))
return G
cur = 0
N = spans.shape[1]
n_nodes = int(spans.sum().item())
cover = spans.nonzero().cpu()
order = torch.argsort(cover[:, 2] - cover[:, 1])
left = {}
right = {}
ordered = cover[order]
label = ordered[:, 3]
a = []
b = []
topo = [[] for _ in range(N)]
for n in ordered:
batch, i, j, _ = n.tolist()
# G.add_node(cur, label=A)
if i - j != 0:
a.append(left[(batch, i)][0])
a.append(right[(batch, j)][0])
b.append(cur)
b.append(cur)
order = max(left[(batch, i)][1], right[(batch, j)][1]) + 1
else:
order = 0
left[(batch, i)] = (cur, order)
right[(batch, j)] = (cur, order)
topo[order].append(cur)
cur += 1
indices = left
return (n_nodes, a, b, label), indices, topo

###### Test

Expand Down

0 comments on commit 9c1a21b

Please sign in to comment.