Skip to content

Commit

Permalink
Merge af55cfe into b368540
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 12, 2019
2 parents b368540 + af55cfe commit 01c5407
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 75 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
version="0.0.1",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct", "torch_struct.data"],
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],
package_data={"torch_struct": []},
url="https://github.com/harvardnlp/pytorch_struct",
install_requires=["torch"],
Expand Down
11 changes: 10 additions & 1 deletion torch_struct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from .deptree import DepTree
from .linearchain import LinearChain
from .semimarkov import SemiMarkov
from .semirings import LogSemiring, StdSemiring, SampledSemiring, MaxSemiring
from .semirings import (
LogSemiring,
StdSemiring,
SampledSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
)


version = "0.0.1"
Expand All @@ -17,4 +24,6 @@
StdSemiring,
SampledSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
]
75 changes: 44 additions & 31 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def marginals(self, scores, lengths=None, _autograd=True):
allow_unused=False,
)
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S)
rules = torch.zeros(batch, N, N, NT, S, S, dtype=scores[1].dtype, device=scores[1].device)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = self.semiring.unconvert(rule_use[w])

Expand Down Expand Up @@ -263,14 +263,15 @@ def to_parts(spans, extra, lengths=None):
break
if j > i:
assert B is not None, "%s" % ((i, j, left[i], right[j], cover),)

rules[b, A, B, C] += 1
return terms, rules, roots

@staticmethod
def from_parts(chart):
terms, rules, roots = chart
batch, N, N, NT, S, S = rules.shape
assert terms.shape[1] == N

spans = torch.zeros(batch, N, N, S).type_as(terms)
rules = rules.sum(dim=-1).sum(dim=-1)

Expand All @@ -284,42 +285,54 @@ def from_parts(chart):
@staticmethod
def _intermediary(spans):
batch, N = spans.shape[:2]
splits = []
for b in range(batch):
cover = spans[b].nonzero()
left = {i: [] for i in range(N)}
right = {i: [] for i in range(N)}
batch_split = {}
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
left[i].append((A, j, j - i + 1))
right[j].append((A, i, j - i + 1))
for i in range(cover.shape[0]):
i, j, A = cover[i].tolist()
for B_p, k, a_span in left[i]:
for C_p, k_2, b_span in right[j]:
if k_2 == k + 1 and a_span + b_span == j - i + 1:
k_final = k
break
if j > i:
batch_split[(i, j)] = k_final
splits.append(batch_split)
splits = {}
cover = spans.nonzero()
left, right = {}, {}
for k in range(cover.shape[0]):
b, i, j, A = cover[k].tolist()
left.setdefault((b, i), [])
right.setdefault((b, j), [])
left[b, i].append((A, j, j - i + 1))
right[b, j].append((A, i, j - i + 1))

for x in range(cover.shape[0]):
b, i, j, A = cover[x].tolist()
if i == j:
continue
b_final = None
c_final = None
k_final = None
for B_p, k, a_span in left.get((b, i), []):
if k > j:
continue
for C_p, k_2, b_span in right.get((b, j), []):
if k_2 == k + 1 and a_span + b_span == j - i + 1:
k_final = k
b_final = B_p
c_final = C_p
break
if b_final is not None:
break
assert k_final is not None, "%s %s %s %s" % (b, i, j, spans[b].nonzero())
splits[(b, i, j)] = k_final, b_final, c_final
return splits

@classmethod
def to_networkx(cls, spans):
import networkx as nx

splits = cls._intermediary(spans)
splits = cls._intermediary(spans.cpu())
G = nx.DiGraph()
batch, _ = spans.shape
for b in range(batch):
for n in spans[0].nonzero():
G.add_node((b, n[0], n[1]), label=n[2])
for k, v in splits[0].items():
G.add_edge(k, (b, k[0], v))
G.add_edge(k, (b, v + 1, k[1]))
return G
cur = 0
indices = {}
for n in spans.nonzero():
indices[(n[0].item(), n[1].item(), n[2].item())] = cur
G.add_node(cur, label=n[3].item())
cur += 1
for k, v in splits.items():
G.add_edge(indices[(k[0], k[1], v[0])], indices[k])
G.add_edge(indices[(k[0], v[0] + 1, k[2])], indices[k])
return G, indices

###### Test

Expand Down
3 changes: 2 additions & 1 deletion torch_struct/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .data import SubTokenizedField, TokenBucket
from .trees import ConllXDataset, ListOpsDataset

__all__ = [SubTokenizedField, TokenBucket]
__all__ = [SubTokenizedField, TokenBucket, ConllXDataset, ListOpsDataset]
4 changes: 3 additions & 1 deletion torch_struct/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def SubTokenizedField(tokenizer):
return FIELD


def TokenBucket(train, batch_size, device='cuda:0', key=lambda x: max(len(x.word[0]), 5)):
def TokenBucket(
train, batch_size, device="cuda:0", key=lambda x: max(len(x.word[0]), 5)
):
def batch_size_fn(x, _, size):
return size + key(x)

Expand Down
36 changes: 36 additions & 0 deletions torch_struct/data/trees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torchtext.data as data


class ConllXDataset(data.Dataset):
def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
examples = []
columns = [[], []]
column_map = {1: 0, 6: 1}
with open(path, encoding=encoding) as input_file:
for line in input_file:
line = line.strip()
if line == "":
if columns:
examples.append(data.Example.fromlist(columns, fields))
columns = [[], []]
else:
for i, column in enumerate(line.split(separator)):
if i in column_map:
columns[column_map[i]].append(column)

if columns:
examples.append(data.Example.fromlist(columns, fields))
super(ConllXDataset, self).__init__(examples, fields, **kwargs)


class ListOpsDataset(data.Dataset):
def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
examples = []
with open(path, encoding=encoding) as input_file:
for line in input_file:
a, b = line.split("\t")
label = a
words = [w for w in b.split() if w not in "()"]

examples.append(data.Example.fromlist((words, label), fields))
super(ListOpsDataset, self).__init__(examples, fields, **kwargs)
66 changes: 66 additions & 0 deletions torch_struct/networks/NeuralCFG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn


# NeuralCFG From Kim et al
class Res(nn.Module):
def __init__(self, H):
super().__init__()
self.u1 = nn.Linear(H, H)
self.u2 = nn.Linear(H, H)

self.v1 = nn.Linear(H, H)
self.v2 = nn.Linear(H, H)
self.w = nn.Linear(H, H)

def forward(self, y):
y = self.w(y)
y = y + torch.relu(self.v1(torch.relu(self.u1(y))))
return y + torch.relu(self.v2(torch.relu(self.u2(y))))


class NeuralCFG(torch.nn.Module):
def __init__(self, V, T, NT, H):
super().__init__()
self.NT = NT
self.V = V
self.T = T
self.word_emb = nn.Parameter(torch.Tensor(V, H))
self.term_emb = nn.Parameter(torch.Tensor(T, H))
self.nonterm_emb = nn.Parameter(torch.Tensor(NT, H))
self.nonterm_emb_c = nn.Parameter(torch.Tensor(NT + T, NT + T, H))
self.root_emb = nn.Parameter(torch.Tensor(NT, H))
self.s_emb = nn.Parameter(torch.Tensor(1, H))
self.mlp1 = Res(H)
self.mlp2 = Res(H)
for p in self.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)

def forward(self, input):
T, NT = self.T, self.NT

def terms(words):
return torch.einsum(
"bnh,th->bnt", self.word_emb[words], self.mlp1(self.term_emb)
).log_softmax(-2)

def rules(b):
return (
torch.einsum("sh,tuh->stu", self.nonterm_emb, self.nonterm_emb_c)
.view(NT, -1)
.log_softmax(-1)
.view(1, NT, NT + T, NT + T)
.expand(b, NT, NT + T, NT + T)
)

def roots(b):
return (
torch.einsum("ah,th->t", self.s_emb, self.mlp2(self.root_emb))
.log_softmax(-1)
.view(1, NT)
.expand(b, NT)
)

batch = input.shape[0]
return terms(input), rules(batch), roots(batch)
6 changes: 0 additions & 6 deletions torch_struct/networks/TreeLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,16 @@ def message_func(self, edges):
return {"h": edges.src["h"], "c": edges.src["c"]}

def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
# equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_cat), "c": c}

def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
c = i * u + nodes.data["c"]
# equation (6)
h = o * th.tanh(c)
return {"h": h, "c": c}

Expand Down
5 changes: 3 additions & 2 deletions torch_struct/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .TreeLSTM import TreeLSTM
from .TreeLSTM import TreeLSTMCell
from .NeuralCFG import NeuralCFG

__all__ = [TreeLSTM]
__all__ = [TreeLSTMCell, NeuralCFG]
Loading

0 comments on commit 01c5407

Please sign in to comment.