From cf99d99802751cd5213629e992b1c41896ffebea Mon Sep 17 00:00:00 2001 From: srush Date: Tue, 8 Oct 2019 13:33:29 -0400 Subject: [PATCH 1/3] Port RL Example to use new API (#16) ListOps now uses distribution API --- examples/tree.py | 47 +++++++++++++++++----------------------------- torch_struct/rl.py | 22 +++++++--------------- 2 files changed, 24 insertions(+), 45 deletions(-) diff --git a/examples/tree.py b/examples/tree.py index 744c7a8c..eac27951 100644 --- a/examples/tree.py +++ b/examples/tree.py @@ -1,21 +1,13 @@ # -*- coding: utf-8 -*- # wandb login 7cd7ade39e2d850ec1cf4e914d9a148586a20900 -from torch_struct import ( - CKY_CRF, - CKY, - LogSemiring, - MaxSemiring, - SampledSemiring, - EntropySemiring, - SelfCritical, -) +from torch_struct import TreeCRF, SelfCritical import torchtext.data as data from torch_struct.data import ListOpsDataset, TokenBucket -from torch_struct.networks import NeuralCFG, TreeLSTM, SpanLSTM +from torch_struct.networks import TreeLSTM, SpanLSTM import torch import torch.nn as nn import wandb -from torch_struct import MultiSampledSemiring + config = { "method": "reinforce", @@ -59,7 +51,7 @@ def expand_spans(spans, words, K, V): def valid_sup(valid_iter, model, tree_lstm, V): total = 0 correct = 0 - struct = CKY_CRF + Dist = TreeCRF for i, ex in enumerate(valid_iter): words, lengths = ex.word trees = ex.tree @@ -74,8 +66,9 @@ def tree_reward(spans): words = words.cuda() phi = model(words, lengths) - argmax = struct(MaxSemiring).marginals(phi, lengths=lengths) - argmax_tree = struct().from_parts(argmax.detach())[0] + dist = TreeCRF(phi, lengths) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0] score, tota = tree_reward(argmax_tree) total += int(tota) correct += score @@ -95,8 +88,7 @@ def run_train(train_iter, valid_iter, model, tree_lstm, V): model.train() tree_lstm.train() losses = [] - struct = CKY_CRF - entropy_fn = struct(EntropySemiring) + Dist = TreeCRF step = 0 trees = None for epoch in range(100): @@ -119,11 +111,10 @@ def tree_reward(spans, K): ret = ret.view(K, batch, -1) return -ret[:, torch.arange(batch), label].view(K, batch) - sc = SelfCritical(CKY_CRF, tree_reward) + sc = SelfCritical(tree_reward) phi = model(words, lengths) - structs, rewards, score, max_score = sc.forward( - phi, lengths, K=config["RL_K"] - ) + dist = Dist(phi) + structs, rewards, score, max_score = sc.forward(dist, K=config["RL_K"]) if config["train_model"]: opt_params.zero_grad() @@ -134,12 +125,8 @@ def tree_reward(spans, K): if config["method"] == "reinforce": opt_struct.zero_grad() - log_partition, entropy = entropy_fn.sum( - phi, lengths=lengths, _raw=True - ).unbind() - r = struct().score( - phi.unsqueeze(0), structs, batch_dims=[0, 1] - ) - log_partition.unsqueeze(0) + entropy = dist.entropy + r = dist.log_prob(structs) obj = rewards.mul(r).mean(-1).mean(-1) policy = ( obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean() @@ -184,16 +171,17 @@ def tree_reward(spans, K): def valid_show(valid_iter, model): - struct = CKY_CRF table = wandb.Table(columns=["Sent", "Predicted Tree", "True Tree"]) + Dist = TreeCRF for i, ex in enumerate(valid_iter): words, lengths = ex.word label = ex.label batch = label.shape[0] words = words.cuda() phi = model(words, lengths) - argmax = struct(MaxSemiring).marginals(phi, lengths=lengths) - argmax_tree = struct().from_parts(argmax.detach())[0].cpu() + dist = Dist(phi) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu() for b in range(words.shape[0]): out = [WORD.vocab.itos[w.item()] for w in words[b]] sent = " ".join(out) @@ -272,7 +260,6 @@ def main(): for p in model.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) - struct = CKY_CRF wandb.watch((model, tree_lstm)) print(wandb.config) diff --git a/torch_struct/rl.py b/torch_struct/rl.py index 9e0a6b01..49db79a8 100644 --- a/torch_struct/rl.py +++ b/torch_struct/rl.py @@ -1,27 +1,19 @@ import torch -from .semirings import MultiSampledSemiring, MaxSemiring class SelfCritical: - def __init__(self, struct, reward_fn): - self.struct = struct + def __init__(self, reward_fn): self.reward_fn = reward_fn - self.max_fn = self.struct(MaxSemiring) - self.sample_fn = self.struct(MultiSampledSemiring) - def forward(self, phi, lengths, K=5): - sample = self.sample_fn.marginals(phi, lengths=lengths) - sample = sample.detach() + def forward(self, dist, K=5): + samples = dist.sample((K,)) trees = [] - samples = [] for k in range(K): - tmp_sample = MultiSampledSemiring.to_discrete(sample, k + 1) - samples.append(tmp_sample) - sampled_tree = self.max_fn.from_parts(tmp_sample)[0].cpu() + sampled_tree = dist.struct.from_parts(samples[k])[0].cpu() trees.append(sampled_tree) - structs = torch.stack(samples) - argmax = self.max_fn.marginals(phi, lengths=lengths) - argmax_tree = self.max_fn.from_parts(argmax.detach())[0].cpu() + structs = torch.stack(trees) + argmax = dist.argmax + argmax_tree = dist.struct.from_parts(argmax.detach())[0].cpu() trees.append(argmax_tree) sample_score = self.reward_fn(torch.cat(trees), K + 1) total = sample_score[:-1].mean(dim=0) From 0a1d34c5eff81a04462ffeca5caaa6c5e37075c9 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 10 Oct 2019 07:36:22 -0700 Subject: [PATCH 2/3] Add documents for web (#17) --- docs/requirements.txt | 1 + docs/source/conf.py | 5 +- docs/source/index.rst | 130 ++++++++++++++++++++- requirements.dev.txt | 5 +- torch_struct/distributions.py | 175 ++++++++++++++++++++++++----- torch_struct/linearchain.py | 6 - torch_struct/networks/NeuralCFG.py | 4 +- torch_struct/networks/SpanLSTM.py | 3 + torch_struct/networks/TreeLSTM.py | 4 + 9 files changed, 290 insertions(+), 43 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 74af9144..2404a798 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ sphinx sphinx-jinja +sphinxcontrib-bibtex sphinx-rtd-theme recommonmark diff --git a/docs/source/conf.py b/docs/source/conf.py index f0e9f3f0..eb69279d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -29,6 +29,7 @@ 'sphinx.ext.githubpages', 'sphinx.ext.napoleon', 'sphinxcontrib.jinja', + 'sphinxcontrib.bibtex', 'sphinx.ext.intersphinx' ] @@ -47,8 +48,8 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ -] +# extensions = [ +# ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/source/index.rst b/docs/source/index.rst index 60b4d652..71bab2da 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,15 +1,135 @@ -.. pytorch-struct documentation master file, created by - sphinx-quickstart on Sun Oct 6 13:13:53 2019. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. -Welcome to pytorch-struct's documentation! +PyTorch-Struct ========================================== .. toctree:: :maxdepth: 2 :caption: Contents: +Introduction +============ + +A library for structured prediction. + + + +Distributional Interface +======================== + + +The main interface is through a structured distribution objects. Each +of these implement a conditional random field over a class of +structures. Roughly, these represent specialized softmax's over +exponentially sized spaces. Each distribution object takes in +log_potentials (generalized logits) and can return properties of the +distribution. The properties of interest are, + +* Partition (e.g. logsumexp) +* Marginals (e.g. softmax) +* Argmax +* Entropy +* Samples +* to_event / from_event (adapters) + + +.. autoclass:: torch_struct.StructDistribution + :members: + +Linear Chain +-------------- + +.. autoclass:: torch_struct.LinearChainCRF + + +Semi-Markov +-------------- + +.. autoclass:: torch_struct.SemiMarkovCRF + + +Dependency Tree +---------------- + + +.. autoclass:: torch_struct.DependencyCRF + + +Binary Tree +-------------- + +.. autoclass:: torch_struct.TreeCRF + +Context-Free Grammar +--------------------- + +.. autoclass:: torch_struct.SentCFG + + + + + +Networks +=========== + +Common structured networks. + + +.. autoclass:: torch_struct.networks.TreeLSTM + +.. autoclass:: torch_struct.networks.NeuralCFG + +.. autoclass:: torch_struct.networks.SpanLSTM + + +Data +==== + +Datasets for common structured prediction tasks. + +.. autoclass:: torch_struct.data.ConllXDataset +.. autoclass:: torch_struct.data.ListOpsDataset + + +Advanced Usage: Semirings +========================= + +All of the distributional code is implemented through a series of +semiring objects. These are passed through dynamic programming +backends to compute the distributions. + + +Standard Semirings +------------------ + +.. autoclass:: torch_struct.LogSemiring +.. autoclass:: torch_struct.StdSemiring +.. autoclass:: torch_struct.MaxSemiring + +Higher-Order Semirings +---------------------- +.. autoclass:: torch_struct.EntropySemiring + +Sampling Semirings +---------------------- + +.. autoclass:: torch_struct.SampledSemiring +.. autoclass:: torch_struct.MultiSampledSemiring + + +Dynamic Programming +------------------- + +.. autoclass:: torch_struct.LinearChain +.. autoclass:: torch_struct.SemiMarkov +.. autoclass:: torch_struct.DepTree +.. autoclass:: torch_struct.CKY + + + +References +========== + +.. bibliography:: refs.bib Indices and tables diff --git a/requirements.dev.txt b/requirements.dev.txt index ea3be805..8d4b4759 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,6 +1,7 @@ pytest -pytest - runner +pytest-runner hypothesis == 4.38 flake8 black -pep8 - naming +pep8-naming +dgl diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 1ff903e8..e5b06e36 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -10,6 +10,16 @@ class StructDistribution(Distribution): + r""" + Base structured distribution class. + + Dynamic distribution for length N of structures :math:`p(z)`. + + Parameters: + log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi` + lengths (long tensor, batch_shape) : integers for length masking + """ + has_enumerate_support = True def __init__(self, log_potentials, lengths=None): @@ -22,6 +32,53 @@ def __init__(self, log_potentials, lengths=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) + def log_prob(self, value): + """ + Compute log probability over values :math:`p(z)`. + + Parameters: + value (tensor): sample_sample x batch_shape x event_shapesss + """ + + d = value.dim() + batch_dims = range(d - len(self.event_shape)) + v = self.struct().score( + self.log_potentials, + value.type_as(self.log_potentials), + batch_dims=batch_dims, + ) + return v - self.partition + + @lazy_property + def entropy(self): + """ + Compute entropy for distribution :math:`H[z]`. + + Returns: + entropy - batch_shape + """ + return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths) + + @lazy_property + def argmax(self): + r""" + Compute an argmax for distribution :math:`\\arg\max p(z)`. + + Returns: + argmax (*batch_shape x event_shape*) + """ + return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths) + + @lazy_property + def marginals(self): + """ + Compute marginals for distribution :math:`p(z_t)`. + + Returns: + marginals (*batch_shape x event_shape*) + """ + return self.struct(LogSemiring).marginals(self.log_potentials, self.lengths) + # @constraints.dependent_property # def support(self): # pass @@ -32,17 +89,19 @@ def _new(self, *args, **kwargs): @lazy_property def partition(self): + "Compute the partition function." return self.struct(LogSemiring).sum(self.log_potentials, self.lengths) - @property - def mean(self): - pass + def sample(self, sample_shape=torch.Size()): + r""" + Compute structured samples from the distribution :math:`z \sim p(z)`. - @property - def variance(self): - pass + Parameters: + sample_shape (int): number of samples - def sample(self, sample_shape=torch.Size()): + Returns: + samples - sample_shape x batch_shape x event_shape + """ assert len(sample_shape) == 1 nsamples = sample_shape[0] samples = [] @@ -56,29 +115,21 @@ def sample(self, sample_shape=torch.Size()): samples.append(tmp_sample) return torch.stack(samples) - def log_prob(self, value): - d = value.dim() - batch_dims = range(d - len(self.event_shape)) - v = self.struct().score( - self.log_potentials, - value.type_as(self.log_potentials), - batch_dims=batch_dims, - ) - return v - self.partition - - @lazy_property - def entropy(self): - return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths) + def to_event(self, sequence, extra, lengths=None): + "Convert simple representation to event." + return self.struct.to_parts(sequence, extra, lengths=None) - @lazy_property - def argmax(self): - return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths) - - @lazy_property - def marginals(self): - return self.struct(LogSemiring).marginals(self.log_potentials, self.lengths) + def from_event(self, event): + "Convert event to simple representation." + return self.struct.from_parts(event) def enumerate_support(self, expand=True): + """ + Compute the full exponential enumeration set. + + Returns: + (enum, enum_lengths) - tuple cardinality x batch_shape x event_shape + """ _, _, edges, enum_lengths = self.struct().enumerate( self.log_potentials, self.lengths ) @@ -88,22 +139,92 @@ def enumerate_support(self, expand=True): class LinearChainCRF(StructDistribution): + r""" + Represents structured linear-chain CRFs with C classes. + + Event shape is of the form: + + Parameters: + log_potentials (tensor) : event shape ((N-1) x C x C ) e.g. + :math:`\phi(n, z_{n+1}, z_{n})` + lengths (long tensor) : batch_shape integers for length masking. + + + Compact representation: N long tensor in [0, ..., C-1] + """ + struct = LinearChain class SemiMarkovCRF(StructDistribution): + r""" + Represents a semi-markov or segmental CRF with C classes of max width K + + Event shape is of the form: + + Parameters: + log_potentials : event shape (N x K x C x C) e.g. + :math:`\phi(n, k, z_{n+1}, z_{n})` + lengths (long tensor) : batch shape integers for length masking. + + Compact representation: N long tensor in [-1, 0, ..., C-1] + """ + struct = SemiMarkov class DependencyCRF(StructDistribution): + r""" + Represents a projective dependency CRF. + + Event shape is of the form: + + Parameters: + log_potentials (tensor) : event shape (N x N) head, child with + arc scores with root scores on diagonal e.g. + :math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i). + lengths (long tensor) : batch shape integers for length masking. + + + Compact representation: N long tensor in [0, N] (indexing is +1) + """ + struct = DepTree class TreeCRF(StructDistribution): + r""" + Represents a 0th-order span parser with NT nonterminals. + + Event shape is of the form: + + Parameters: + log_potentials (tensor) : event_shape N x N x NT, e.g. + :math:`\phi(i, j, nt)` + lengths (long tensor) : batch shape integers for length masking. + + Compact representation: N x N x NT long tensor (Same) + """ struct = CKY_CRF class SentCFG(StructDistribution): + """ + Represents a full generative context-free grammar with + non-terminals NT and terminals T. + + Event shape is of the form: + + Parameters: + log_potentials (tuple) : event tuple with event shapes + terms (N x T) + rules (NT x (NT+T) x (NT+T)) + root (NT) + lengths (long tensor) : batch shape integers for length masking. + + Compact representation: N x N x NT long tensor + """ + struct = CKY def __init__(self, log_potentials, lengths=None): diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index e93906c4..786306eb 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -6,12 +6,6 @@ class LinearChain(_Struct): """ Represents structured linear-chain CRFs, generalizing HMMs smoothing, tagging models, and anything with chain-like dynamics. - - - Potentials are of the form: - - edge : b x (N-1) x C x C markov potentials - (n-1 x z_n x z_{n-1}) """ def _check_potentials(self, edge, lengths=None): diff --git a/torch_struct/networks/NeuralCFG.py b/torch_struct/networks/NeuralCFG.py index 25986ebf..a1a78b16 100644 --- a/torch_struct/networks/NeuralCFG.py +++ b/torch_struct/networks/NeuralCFG.py @@ -2,7 +2,6 @@ import torch.nn as nn -# NeuralCFG From Kim et al class Res(nn.Module): def __init__(self, H): super().__init__() @@ -20,6 +19,9 @@ def forward(self, y): class NeuralCFG(torch.nn.Module): + """ + NeuralCFG From Kim et al + """ def __init__(self, V, T, NT, H): super().__init__() self.NT = NT diff --git a/torch_struct/networks/SpanLSTM.py b/torch_struct/networks/SpanLSTM.py index 29b792ce..3fee45e7 100644 --- a/torch_struct/networks/SpanLSTM.py +++ b/torch_struct/networks/SpanLSTM.py @@ -19,6 +19,9 @@ def forward(self, y): class SpanLSTM(torch.nn.Module): + """ + SpanLSTM model. + """ def __init__(self, NT, V, H): super().__init__() self.H = H diff --git a/torch_struct/networks/TreeLSTM.py b/torch_struct/networks/TreeLSTM.py index 98a0c7be..afef4ced 100644 --- a/torch_struct/networks/TreeLSTM.py +++ b/torch_struct/networks/TreeLSTM.py @@ -52,6 +52,10 @@ def run(cell, graph, iou, h, c, topo=None): class TreeLSTM(torch.nn.Module): + """ + TreeLSTM from DGL. + """ + def __init__(self, hidden, in_size, out_size): super().__init__() self.emb = torch.nn.Embedding(in_size, hidden) From 632c43ea79a274db2d67310399541a43f715b5b1 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 10 Oct 2019 13:49:06 -0400 Subject: [PATCH 3/3] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 8c82ccc9..aaba4ee6 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ show(dist.marginals.detach()[0].sum(-1)) ## Library +Full docs: http://nlp.seas.harvard.edu/pytorch-struct/ + Current distributions implemented: * LinearChainCRF