diff --git a/README.md b/README.md index 8c82ccc9..aaba4ee6 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ show(dist.marginals.detach()[0].sum(-1)) ## Library +Full docs: http://nlp.seas.harvard.edu/pytorch-struct/ + Current distributions implemented: * LinearChainCRF diff --git a/examples/tree.py b/examples/tree.py index 744c7a8c..eac27951 100644 --- a/examples/tree.py +++ b/examples/tree.py @@ -1,21 +1,13 @@ # -*- coding: utf-8 -*- # wandb login 7cd7ade39e2d850ec1cf4e914d9a148586a20900 -from torch_struct import ( - CKY_CRF, - CKY, - LogSemiring, - MaxSemiring, - SampledSemiring, - EntropySemiring, - SelfCritical, -) +from torch_struct import TreeCRF, SelfCritical import torchtext.data as data from torch_struct.data import ListOpsDataset, TokenBucket -from torch_struct.networks import NeuralCFG, TreeLSTM, SpanLSTM +from torch_struct.networks import TreeLSTM, SpanLSTM import torch import torch.nn as nn import wandb -from torch_struct import MultiSampledSemiring + config = { "method": "reinforce", @@ -59,7 +51,7 @@ def expand_spans(spans, words, K, V): def valid_sup(valid_iter, model, tree_lstm, V): total = 0 correct = 0 - struct = CKY_CRF + Dist = TreeCRF for i, ex in enumerate(valid_iter): words, lengths = ex.word trees = ex.tree @@ -74,8 +66,9 @@ def tree_reward(spans): words = words.cuda() phi = model(words, lengths) - argmax = struct(MaxSemiring).marginals(phi, lengths=lengths) - argmax_tree = struct().from_parts(argmax.detach())[0] + dist = TreeCRF(phi, lengths) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0] score, tota = tree_reward(argmax_tree) total += int(tota) correct += score @@ -95,8 +88,7 @@ def run_train(train_iter, valid_iter, model, tree_lstm, V): model.train() tree_lstm.train() losses = [] - struct = CKY_CRF - entropy_fn = struct(EntropySemiring) + Dist = TreeCRF step = 0 trees = None for epoch in range(100): @@ -119,11 +111,10 @@ def tree_reward(spans, K): ret = ret.view(K, batch, -1) return -ret[:, torch.arange(batch), label].view(K, batch) - sc = SelfCritical(CKY_CRF, tree_reward) + sc = SelfCritical(tree_reward) phi = model(words, lengths) - structs, rewards, score, max_score = sc.forward( - phi, lengths, K=config["RL_K"] - ) + dist = Dist(phi) + structs, rewards, score, max_score = sc.forward(dist, K=config["RL_K"]) if config["train_model"]: opt_params.zero_grad() @@ -134,12 +125,8 @@ def tree_reward(spans, K): if config["method"] == "reinforce": opt_struct.zero_grad() - log_partition, entropy = entropy_fn.sum( - phi, lengths=lengths, _raw=True - ).unbind() - r = struct().score( - phi.unsqueeze(0), structs, batch_dims=[0, 1] - ) - log_partition.unsqueeze(0) + entropy = dist.entropy + r = dist.log_prob(structs) obj = rewards.mul(r).mean(-1).mean(-1) policy = ( obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean() @@ -184,16 +171,17 @@ def tree_reward(spans, K): def valid_show(valid_iter, model): - struct = CKY_CRF table = wandb.Table(columns=["Sent", "Predicted Tree", "True Tree"]) + Dist = TreeCRF for i, ex in enumerate(valid_iter): words, lengths = ex.word label = ex.label batch = label.shape[0] words = words.cuda() phi = model(words, lengths) - argmax = struct(MaxSemiring).marginals(phi, lengths=lengths) - argmax_tree = struct().from_parts(argmax.detach())[0].cpu() + dist = Dist(phi) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu() for b in range(words.shape[0]): out = [WORD.vocab.itos[w.item()] for w in words[b]] sent = " ".join(out) @@ -272,7 +260,6 @@ def main(): for p in model.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) - struct = CKY_CRF wandb.watch((model, tree_lstm)) print(wandb.config) diff --git a/torch_struct/rl.py b/torch_struct/rl.py index 9e0a6b01..49db79a8 100644 --- a/torch_struct/rl.py +++ b/torch_struct/rl.py @@ -1,27 +1,19 @@ import torch -from .semirings import MultiSampledSemiring, MaxSemiring class SelfCritical: - def __init__(self, struct, reward_fn): - self.struct = struct + def __init__(self, reward_fn): self.reward_fn = reward_fn - self.max_fn = self.struct(MaxSemiring) - self.sample_fn = self.struct(MultiSampledSemiring) - def forward(self, phi, lengths, K=5): - sample = self.sample_fn.marginals(phi, lengths=lengths) - sample = sample.detach() + def forward(self, dist, K=5): + samples = dist.sample((K,)) trees = [] - samples = [] for k in range(K): - tmp_sample = MultiSampledSemiring.to_discrete(sample, k + 1) - samples.append(tmp_sample) - sampled_tree = self.max_fn.from_parts(tmp_sample)[0].cpu() + sampled_tree = dist.struct.from_parts(samples[k])[0].cpu() trees.append(sampled_tree) - structs = torch.stack(samples) - argmax = self.max_fn.marginals(phi, lengths=lengths) - argmax_tree = self.max_fn.from_parts(argmax.detach())[0].cpu() + structs = torch.stack(trees) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu() trees.append(argmax_tree) sample_score = self.reward_fn(torch.cat(trees), K + 1) total = sample_score[:-1].mean(dim=0)