Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 10, 2019
2 parents 2289c07 + 632c43e commit 9c014ca
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 45 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ show(dist.marginals.detach()[0].sum(-1))

## Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

* LinearChainCRF
Expand Down
47 changes: 17 additions & 30 deletions examples/tree.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# -*- coding: utf-8 -*-
# wandb login 7cd7ade39e2d850ec1cf4e914d9a148586a20900
from torch_struct import (
CKY_CRF,
CKY,
LogSemiring,
MaxSemiring,
SampledSemiring,
EntropySemiring,
SelfCritical,
)
from torch_struct import TreeCRF, SelfCritical
import torchtext.data as data
from torch_struct.data import ListOpsDataset, TokenBucket
from torch_struct.networks import NeuralCFG, TreeLSTM, SpanLSTM
from torch_struct.networks import TreeLSTM, SpanLSTM
import torch
import torch.nn as nn
import wandb
from torch_struct import MultiSampledSemiring


config = {
"method": "reinforce",
Expand Down Expand Up @@ -59,7 +51,7 @@ def expand_spans(spans, words, K, V):
def valid_sup(valid_iter, model, tree_lstm, V):
total = 0
correct = 0
struct = CKY_CRF
Dist = TreeCRF
for i, ex in enumerate(valid_iter):
words, lengths = ex.word
trees = ex.tree
Expand All @@ -74,8 +66,9 @@ def tree_reward(spans):

words = words.cuda()
phi = model(words, lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax_tree = struct().from_parts(argmax.detach())[0]
dist = TreeCRF(phi, lengths)
argmax = dist.argmax
argmax_tree = dist.struct.from_parts(argmax.detach())[0]
score, tota = tree_reward(argmax_tree)
total += int(tota)
correct += score
Expand All @@ -95,8 +88,7 @@ def run_train(train_iter, valid_iter, model, tree_lstm, V):
model.train()
tree_lstm.train()
losses = []
struct = CKY_CRF
entropy_fn = struct(EntropySemiring)
Dist = TreeCRF
step = 0
trees = None
for epoch in range(100):
Expand All @@ -119,11 +111,10 @@ def tree_reward(spans, K):
ret = ret.view(K, batch, -1)
return -ret[:, torch.arange(batch), label].view(K, batch)

sc = SelfCritical(CKY_CRF, tree_reward)
sc = SelfCritical(tree_reward)
phi = model(words, lengths)
structs, rewards, score, max_score = sc.forward(
phi, lengths, K=config["RL_K"]
)
dist = Dist(phi)
structs, rewards, score, max_score = sc.forward(dist, K=config["RL_K"])

if config["train_model"]:
opt_params.zero_grad()
Expand All @@ -134,12 +125,8 @@ def tree_reward(spans, K):

if config["method"] == "reinforce":
opt_struct.zero_grad()
log_partition, entropy = entropy_fn.sum(
phi, lengths=lengths, _raw=True
).unbind()
r = struct().score(
phi.unsqueeze(0), structs, batch_dims=[0, 1]
) - log_partition.unsqueeze(0)
entropy = dist.entropy
r = dist.log_prob(structs)
obj = rewards.mul(r).mean(-1).mean(-1)
policy = (
obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean()
Expand Down Expand Up @@ -184,16 +171,17 @@ def tree_reward(spans, K):


def valid_show(valid_iter, model):
struct = CKY_CRF
table = wandb.Table(columns=["Sent", "Predicted Tree", "True Tree"])
Dist = TreeCRF
for i, ex in enumerate(valid_iter):
words, lengths = ex.word
label = ex.label
batch = label.shape[0]
words = words.cuda()
phi = model(words, lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax_tree = struct().from_parts(argmax.detach())[0].cpu()
dist = Dist(phi)
argmax = dist.argmax
argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu()
for b in range(words.shape[0]):
out = [WORD.vocab.itos[w.item()] for w in words[b]]
sent = " ".join(out)
Expand Down Expand Up @@ -272,7 +260,6 @@ def main():
for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
struct = CKY_CRF

wandb.watch((model, tree_lstm))
print(wandb.config)
Expand Down
22 changes: 7 additions & 15 deletions torch_struct/rl.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
import torch
from .semirings import MultiSampledSemiring, MaxSemiring


class SelfCritical:
def __init__(self, struct, reward_fn):
self.struct = struct
def __init__(self, reward_fn):
self.reward_fn = reward_fn
self.max_fn = self.struct(MaxSemiring)
self.sample_fn = self.struct(MultiSampledSemiring)

def forward(self, phi, lengths, K=5):
sample = self.sample_fn.marginals(phi, lengths=lengths)
sample = sample.detach()
def forward(self, dist, K=5):
samples = dist.sample((K,))
trees = []
samples = []
for k in range(K):
tmp_sample = MultiSampledSemiring.to_discrete(sample, k + 1)
samples.append(tmp_sample)
sampled_tree = self.max_fn.from_parts(tmp_sample)[0].cpu()
sampled_tree = dist.struct.from_parts(samples[k])[0].cpu()
trees.append(sampled_tree)
structs = torch.stack(samples)
argmax = self.max_fn.marginals(phi, lengths=lengths)
argmax_tree = self.max_fn.from_parts(argmax.detach())[0].cpu()
structs = torch.stack(trees)
argmax = dist.argmax
argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu()
trees.append(argmax_tree)
sample_score = self.reward_fn(torch.cat(trees), K + 1)
total = sample_score[:-1].mean(dim=0)
Expand Down

0 comments on commit 9c014ca

Please sign in to comment.