Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 4, 2019
1 parent 3eb35bb commit d39690c
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 93 deletions.
169 changes: 110 additions & 59 deletions examples/tree.py
@@ -1,39 +1,58 @@
# -*- coding: utf-8 -*-
# wandb login 7cd7ade39e2d850ec1cf4e914d9a148586a20900
from torch_struct import CKY_CRF, CKY, LogSemiring, MaxSemiring, SampledSemiring, EntropySemiring, SelfCritical
import torchtext.data as data
from torch_struct import (
CKY_CRF,
CKY,
LogSemiring,
MaxSemiring,
SampledSemiring,
EntropySemiring,
SelfCritical,
)
import torchtext.data as data
from torch_struct.data import ListOpsDataset, TokenBucket
from torch_struct.networks import NeuralCFG, TreeLSTM, SpanLSTM
import torch
import torch.nn as nn
import wandb
from torch_struct import MultiSampledSemiring

config = {"method": "reinforce", "baseline": "mean", "opt": "adadelta",
"lr_struct": 0.1, "lr_params": 1, "train_model":True,
"var_norm": False, "entropy": 0.001, "v": 3, "RL_K": 5,
"H" : 100, "train_len": 100, "div_ent" : 1
config = {
"method": "reinforce",
"baseline": "mean",
"opt": "adadelta",
"lr_struct": 0.1,
"lr_params": 1,
"train_model": True,
"var_norm": False,
"entropy": 0.001,
"v": 3,
"RL_K": 5,
"H": 100,
"train_len": 100,
"div_ent": 1,
}


NAME = "yoyo3"

wandb.init(project="pytorch-struct", config=config)


def clip(p):
torch.nn.utils.clip_grad_norm_(parameters = p,
max_norm = 0.5, norm_type=float("inf"))
torch.nn.utils.clip_grad_norm_(parameters=p, max_norm=0.5, norm_type=float("inf"))


def expand_spans(spans, words, K, V):
batch, N = words.shape
spans[:, torch.arange(N), torch.arange(N)] = 0
new_spans = torch.zeros(K, batch, N, N, 1 + V).cuda()
new_spans[:, :, :, :, :1] = spans.view(K, batch, N, N, 1)
new_spans[:, :, :, :, :1] = spans.view(K, batch, N, N, 1)
new_spans[:, :, torch.arange(N), torch.arange(N), :].fill_(0)
new_spans[:, :, torch.arange(N), torch.arange(N), 1:] = \
new_spans[:, :, torch.arange(N), torch.arange(N), 1:] = (
torch.nn.functional.one_hot(words, V).float().cuda().view(1, batch, N, V)
new_spans = new_spans.view(batch*K, N, N, 1 + V)
)
new_spans = new_spans.view(batch * K, N, N, 1 + V)
return new_spans


Expand All @@ -52,23 +71,26 @@ def tree_reward(spans):
g, labels, indices, topo = TreeLSTM.spans_to_dgl(new_spans)
_, am = tree_lstm(g, labels, indices, topo, lengths).max(-1)
return (label == am).sum(), label.shape[0]

words = words.cuda()
phi = model(words, lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax_tree = struct().from_parts(argmax.detach())[0]
score, tota = tree_reward(argmax_tree)
total += int(tota)
correct += score

if i == 25: break

if i == 25:
break
print(correct.item() / float(total), correct, total)
return correct.item() / float(total)


def run_train(train_iter, valid_iter, model, tree_lstm, V):
opt_struct = torch.optim.Adadelta(list(model.parameters()), lr=config["lr_struct"])
opt_params = torch.optim.Adadelta(list(tree_lstm.parameters()), lr=config["lr_params"])
opt_params = torch.optim.Adadelta(
list(tree_lstm.parameters()), lr=config["lr_params"]
)

model.train()
tree_lstm.train()
Expand All @@ -91,57 +113,76 @@ def run_train(train_iter, valid_iter, model, tree_lstm, V):
def tree_reward(spans, K):
new_spans = expand_spans(spans, words, K, V)
g, labels, indices, topo = TreeLSTM.spans_to_dgl(new_spans)
ret = tree_lstm(g, labels, indices, topo, torch.cat([lengths for _ in range(K)]))
ret = ret.view(K, batch, -1)
ret = tree_lstm(
g, labels, indices, topo, torch.cat([lengths for _ in range(K)])
)
ret = ret.view(K, batch, -1)
return -ret[:, torch.arange(batch), label].view(K, batch)

sc = SelfCritical(CKY_CRF, tree_reward)
phi = model(words, lengths)
structs, rewards, score, max_score = sc.forward(phi, lengths, K=config["RL_K"])

structs, rewards, score, max_score = sc.forward(
phi, lengths, K=config["RL_K"]
)

if config["train_model"]:
opt_params.zero_grad()
score.mean().backward()
clip(tree_lstm.parameters())
opt_params.step()
opt_params.zero_grad()

if config["method"] == "reinforce":
if config["method"] == "reinforce":
opt_struct.zero_grad()
log_partition, entropy = entropy_fn.sum(phi, lengths=lengths, _raw=True).unbind()
r = struct().score(phi.unsqueeze(0), structs, batch_dims=[0,1]) \
- log_partition.unsqueeze(0)
log_partition, entropy = entropy_fn.sum(
phi, lengths=lengths, _raw=True
).unbind()
r = struct().score(
phi.unsqueeze(0), structs, batch_dims=[0, 1]
) - log_partition.unsqueeze(0)
obj = rewards.mul(r).mean(-1).mean(-1)
policy = obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean()
policy = (
obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean()
)
policy.backward()
clip(model.parameters())
opt_struct.step()
opt_struct.step()
losses.append(-max_score.mean().detach())

# DEBUG
if i % 50 == 9:
if i % 50 == 9:
print(torch.tensor(losses).mean(), words.shape)
print("Round")
print("Entropy", entropy.mean().item())
print("Reward", rewards.mean().item())
if i % 1000 == 9:
valid_loss = valid_sup(valid_iter, model, tree_lstm, V)
fname = "/tmp/checkpoint.%s.%0d.%0d.%s"%(NAME, epoch, i, valid_loss)
fname = "/tmp/checkpoint.%s.%0d.%0d.%s" % (
NAME,
epoch,
i,
valid_loss,
)
torch.save((model, tree_lstm), fname)
wandb.save(fname)
trees = valid_show(valid_iter, model)
else:
print(valid_loss)

wandb.log({"entropy": entropy.mean(),
"valid_loss": valid_loss,
"reward": rewards.mean(),
"step": step,
"tree": trees,
"reward_var": rewards.var(),
"loss" : torch.tensor(losses).mean()})

wandb.log(
{
"entropy": entropy.mean(),
"valid_loss": valid_loss,
"reward": rewards.mean(),
"step": step,
"tree": trees,
"reward_var": rewards.var(),
"loss": torch.tensor(losses).mean(),
}
)
losses = []


def valid_show(valid_iter, model):
struct = CKY_CRF
table = wandb.Table(columns=["Sent", "Predicted Tree", "True Tree"])
Expand All @@ -151,15 +192,16 @@ def valid_show(valid_iter, model):
batch = label.shape[0]
words = words.cuda()
phi = model(words, lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax_tree = struct().from_parts(argmax.detach())[0].cpu()
for b in range(words.shape[0]):
out = [WORD.vocab.itos[w.item()] for w in words[b]]
sent = " ".join(out)

def show(tree):
output = ""
start = {}
end = {}
end = {}
for i, j, _ in tree.nonzero():
i = i.item()
j = j.item()
Expand All @@ -174,44 +216,54 @@ def show(tree):
for _ in range(end.get(i, 0)):
output += ") "
return output

predict_text = show(ex.tree[b].cpu())
true_text = show(argmax_tree[b].cpu())
table.add_data(sent, predict_text, true_text)
break
return table


WORD = None


def main():
global WORD
WORD = data.Field(include_lengths=True, batch_first=True, eos_token=None, init_token=None)
WORD = data.Field(
include_lengths=True, batch_first=True, eos_token=None, init_token=None
)
LABEL = data.Field(sequential=False, batch_first=True)
TREE = data.RawField(postprocessing=ListOpsDataset.tree_field(WORD))
TREE.is_target=False
train = ListOpsDataset("data/train_d20s.tsv", (("word", WORD), ("label", LABEL), ("tree", TREE)),
filter_pred=lambda x: 5 < len(x.word) < config["train_len"])
TREE.is_target = False
train = ListOpsDataset(
"data/train_d20s.tsv",
(("word", WORD), ("label", LABEL), ("tree", TREE)),
filter_pred=lambda x: 5 < len(x.word) < config["train_len"],
)
WORD.build_vocab(train)
LABEL.build_vocab(train)
valid = ListOpsDataset("data/test_d20s.tsv", (("word", WORD), ("label", LABEL), ("tree", TREE)),
filter_pred=lambda x: 5 < len(x.word) < 150)
valid = ListOpsDataset(
"data/test_d20s.tsv",
(("word", WORD), ("label", LABEL), ("tree", TREE)),
filter_pred=lambda x: 5 < len(x.word) < 150,
)

train_iter = TokenBucket(train,
batch_size=1500,
device="cuda:0", key=lambda x: len(x.word))
train_iter = TokenBucket(
train, batch_size=1500, device="cuda:0", key=lambda x: len(x.word)
)
train_iter.repeat = False
valid_iter = data.BucketIterator(train,
batch_size=50,
train=False,
sort=False,
device="cuda:0")
valid_iter = data.BucketIterator(
train, batch_size=50, train=False, sort=False, device="cuda:0"
)

NT = 1
T = len(WORD.vocab)
V = T

V = T

if True:
tree_lstm = TreeLSTM(config["H"],
len(WORD.vocab) + 100, len(LABEL.vocab)).cuda()
tree_lstm = TreeLSTM(
config["H"], len(WORD.vocab) + 100, len(LABEL.vocab)
).cuda()
for p in tree_lstm.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
Expand All @@ -229,7 +281,6 @@ def main():
print("loading")
model, tree_lstm = torch.load("cp.yoyo.model")
print(valid_sup(valid_iter, model, tree_lstm, V))
main()



main()
6 changes: 3 additions & 3 deletions requirements.dev.txt
@@ -1,6 +1,6 @@
pytest
pytest-runner
hypothesis==4.38
pytest - runner
hypothesis == 4.38
flake8
black
pep8-naming
pep8 - naming
2 changes: 1 addition & 1 deletion setup.cfg
@@ -1,2 +1,2 @@
[aliases]
test=pytest
test = pytest
2 changes: 1 addition & 1 deletion torch_struct/__init__.py
Expand Up @@ -36,7 +36,7 @@
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
SelfCritical
SelfCritical,
StructDistribution,
LinearChainCRF,
SemiMarkovCRF,
Expand Down
11 changes: 4 additions & 7 deletions torch_struct/data/trees.py
@@ -1,6 +1,7 @@
import torchtext.data as data
import torch


class ConllXDataset(data.Dataset):
def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
examples = []
Expand All @@ -23,10 +24,7 @@ def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
super(ConllXDataset, self).__init__(examples, fields, **kwargs)




class ListOpsDataset(data.Dataset):

@staticmethod
def tree_field(v):
def post(ls):
Expand All @@ -41,6 +39,7 @@ def post(ls):
else:
ret[b, i, j, 0] = 1
return ret.long()

return post

def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
Expand All @@ -50,19 +49,17 @@ def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
a, b = line.split("\t")
label = a
words = [w for w in b.split() if w not in "()"]

cur = 0
spans = []
stack = []
for w in b.split():
if w == "(":
stack.append(cur)
elif w == ")":
nt = last if stack[-1] == cur else "nt"
spans.append((stack[-1], cur-1, w))
spans.append((stack[-1], cur - 1, w))
stack = stack[:-1]
else:
last = w
spans.append((cur, cur, w))
cur += 1
examples.append(data.Example.fromlist((words, label, spans), fields))
Expand Down
7 changes: 1 addition & 6 deletions torch_struct/distributions.py
Expand Up @@ -5,12 +5,7 @@
from .semimarkov import SemiMarkov
from .deptree import DepTree
from .cky_crf import CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
)
from .semirings import LogSemiring, MaxSemiring, EntropySemiring, MultiSampledSemiring


class StructDistribution(Distribution):
Expand Down
5 changes: 1 addition & 4 deletions torch_struct/linearchain.py
Expand Up @@ -42,10 +42,7 @@ def _dp(self, edge, lengths=None, force_grad=False):
edge[:, :, n - 1].view(ssize, batch, C, C),
)
alpha[n][:] = semiring.sum(edge_store[n - 1])
ret = [
alpha[lengths[i] - 1][:, i]
for i in range(batch)
]
ret = [alpha[lengths[i] - 1][:, i] for i in range(batch)]
ret = torch.stack(ret, dim=1)
v = semiring.sum(ret)
return v, edge_store, alpha
Expand Down

0 comments on commit d39690c

Please sign in to comment.