Skip to content

Commit

Permalink
Merge 67e85bc into fd0f974
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 4, 2019
2 parents fd0f974 + 67e85bc commit 65e2b19
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 92 deletions.
69 changes: 35 additions & 34 deletions README.md
Expand Up @@ -17,13 +17,14 @@ A library of tested, GPU implementations of core structured prediction algorithm


```python
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct matplotlib
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
!pip install -q matplotlib
```


```python
import torch
from torch_struct import DepTree, LinearChain, MaxSemiring, SampledSemiring
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
```
Expand All @@ -34,8 +35,8 @@ def show(x): plt.imshow(x.detach())
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
vals = vals.log()
show(vals[0])
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])
```


Expand All @@ -45,8 +46,7 @@ show(vals[0])

```python
# Compute marginals
marginals = DepTree().marginals(vals)
show(marginals[0])
show(dist.marginals[0])
```


Expand All @@ -56,8 +56,7 @@ show(marginals[0])

```python
# Compute argmax
argmax = DepTree(MaxSemiring).marginals(vals)
show(argmax.detach()[0])
show(dist.argmax.detach()[0])
```


Expand All @@ -67,16 +66,14 @@ show(argmax.detach()[0])

```python
# Compute scoring and enumeration (forward / inside)
log_partition = DepTree().sum(vals)
max_score = DepTree(MaxSemiring).sum(vals)
max_score = DepTree().score(argmax, vals)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
```


```python
# Compute samples
sample = DepTree(SampledSemiring).marginals(vals)
show(sample.detach()[0])
show(dist.sample((1,)).detach()[0, 0])
```


Expand All @@ -86,12 +83,10 @@ show(sample.detach()[0])

```python
# Padding/Masking built into library.
marginals = DepTree().marginals(
vals,
lengths=torch.tensor([10, 7]))
show(marginals[0])
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(marginals[1])
show(dist.marginals[1])
```


Expand All @@ -112,8 +107,8 @@ chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

marginals = LinearChain().marginals(chain)
show(marginals.detach()[0].sum(-1))
dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))
```


Expand All @@ -122,36 +117,42 @@ show(marginals.detach()[0].sum(-1))

## Library

Current algorithms implemented:
Current distributions implemented:

* Linear Chain (CRF / HMM)
* Semi-Markov (CRF / HSMM)
* Dependency Parsing (Projective and Non-Projective)
* CKY (CFG, CKY_CRF)
* LinearChainCRF
* SemiMarkovCRF
* DependencyCRF
* TreeCRF

* Integration with `torchtext` and `pytorch-transformers`

Extensions:

* Integration with `torchtext`, `pytorch-transformers`, `dgl`
* Adapters for generative structured models (CFG / HMM / HSMM)
* Common tree structured parameterizations TreeLSTM / SpanLSTM

Design Strategy:

1) Minimal implementatations. Most are 10 lines.
1) Minimal efficient python implementatations.
2) Batched for GPU.
3) Code can be ported to other backends

Semirings:

## Low-level API:

Everything implemented through semiring dynamic programming.

* Log Marginals
* Max and MAP computation
* Sampling through specialized backprop
* Entropy
* Entropy and first-order semirings.



Networks:
* CKY CRF LSTM
* Tree-LSTM

## Examples

* BERT <a href="https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb">Part-of-Speech</a>
* BERT <a href="https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/BertDependencies.ipynb">Dependency Parsing</a>
* <a href="https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/Unsupervised_CFG.ipynb">Unsupervised Learning </a>
* Unsupervised Learning (to come)
* Structured VAE (to come)
* Structured attention (to come)
Binary file modified README_files/README_10_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/README_4_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/README_6_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/README_8_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/README_9_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/README_9_1.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 45 additions & 42 deletions notebooks/Examples.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name="torch_struct",
version="0.0.1",
version="0.2",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],
Expand Down
14 changes: 13 additions & 1 deletion torch_struct/__init__.py
@@ -1,4 +1,11 @@
from .cky import CKY
from .distributions import (
StructDistribution,
LinearChainCRF,
SemiMarkovCRF,
DependencyCRF,
TreeCRF,
)
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
Expand All @@ -13,7 +20,7 @@
)


version = "0.0.1"
version = "0.2"

# For flake8 compatibility.
__all__ = [
Expand All @@ -28,4 +35,9 @@
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
StructDistribution,
LinearChainCRF,
SemiMarkovCRF,
DependencyCRF,
TreeCRF,
]
107 changes: 107 additions & 0 deletions torch_struct/distributions.py
@@ -0,0 +1,107 @@
import torch
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
from .linearchain import LinearChain
from .semimarkov import SemiMarkov
from .deptree import DepTree
from .cky_crf import CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
)


class StructDistribution(Distribution):
has_enumerate_support = True

def __init__(self, log_potentials, lengths=None):
batch_shape = log_potentials.shape[:1]
event_shape = log_potentials.shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super().__init__(batch_shape=batch_shape, event_shape=event_shape)

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

# @constraints.dependent_property
# def support(self):
# pass

# @property
# def param_shape(self):
# return self._param.size()

@lazy_property
def partition(self):
return self.struct(LogSemiring).sum(self.log_potentials, self.lengths)

@property
def mean(self):
pass

@property
def variance(self):
pass

def sample(self, sample_shape=torch.Size()):
assert len(sample_shape) == 1
nsamples = sample_shape[0]
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self.struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
samples.append(tmp_sample)
return torch.stack(samples)

def log_prob(self, value):
d = value.dim()
batch_dims = range(d - len(self.event_shape))
v = self.struct().score(
self.log_potentials,
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)
return v - self.partition

@lazy_property
def entropy(self):
return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths)

@lazy_property
def argmax(self):
return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths)

@lazy_property
def marginals(self):
return self.struct(LogSemiring).marginals(self.log_potentials, self.lengths)

def enumerate_support(self, expand=True):
_, _, edges, enum_lengths = self.struct().enumerate(
self.log_potentials, self.lengths
)
# if expand:
# edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:])
return edges, enum_lengths


class LinearChainCRF(StructDistribution):
struct = LinearChain


class SemiMarkovCRF(StructDistribution):
struct = SemiMarkov


class DependencyCRF(StructDistribution):
struct = DepTree


class TreeCRF(StructDistribution):
struct = CKY_CRF
47 changes: 34 additions & 13 deletions torch_struct/linearchain.py
Expand Up @@ -37,15 +37,15 @@ def _dp(self, edge, lengths=None, force_grad=False):
semiring.one_(alpha[0].data)
BATCH_DIM, N_DIM = 1, 2

for n in torch.arange(1, N):
for n in range(1, N):
edge_store[n - 1][:] = semiring.times(
alpha[n - 1].view(ssize, batch, 1, C),
edge.index_select(N_DIM, n - 1).view(ssize, batch, C, C),
edge[:, :, n-1].view(ssize, batch, C, C),
)
alpha[n][:] = semiring.sum(edge_store[n - 1])
ret = [
alpha[lengths[i] - 1].index_select(BATCH_DIM, i)
for i in torch.arange(batch)
alpha[lengths[i] - 1][:, i]
for i in range(batch)
]
ret = torch.cat(ret, dim=1)
v = semiring.sum(ret)
Expand Down Expand Up @@ -166,28 +166,49 @@ def _rand():
return torch.rand(b, N, C, C), (b.item(), (N + 1).item())

### Tests
def enumerate(self, edge):

def enumerate(self, edge, lengths=None):
semiring = self.semiring
ssize = semiring.size()
edge, batch, N, C, lengths = self._check_potentials(edge, None)
chains = [([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]
edge, batch, N, C, lengths = self._check_potentials(edge, lengths)
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]

enum_lengths = torch.LongTensor(lengths.shape)
for n in range(1, N):
new_chains = []
for chain, score in chains:
for chain, score in chains[-1]:
for c in range(C):
new_chains.append(
(
chain + [c],
semiring.mul(score, edge[:, :, n - 1, c, chain[-1]]),
)
)
chains = new_chains
chains.append(new_chains)

edges = self.to_parts(torch.stack([torch.tensor(c) for (c, _) in chains]), C)
for b in range(lengths.shape[0]):
if lengths[b] == n + 1:
enum_lengths[b] = len(new_chains)

edges = self.to_parts(
torch.stack([torch.tensor(c) for (c, _) in chains[-1]]), C
)
# Sum out non-batch
a = torch.einsum("ancd,sbncd->sbancd", edges.float(), edge)
a = semiring.prod(a.view(*a.shape[:3] + (-1,)), dim=3)
a = semiring.sum(a, dim=2)
b = semiring.sum(torch.stack([s for (_, s) in chains], dim=1), dim=1)
assert torch.isclose(a, b).all(), "%s %s" % (a, b)
return semiring.unconvert(b), [s for (_, s) in chains]
ret = semiring.sum(torch.stack([s for (_, s) in chains[-1]], dim=1), dim=1)
assert torch.isclose(a, ret).all(), "%s %s" % (a, ret)

edges = torch.zeros(len(chains[-1]), batch, N - 1, C, C)
for b in range(lengths.shape[0]):
edges[: enum_lengths[b], b, : lengths[b] - 1] = self.to_parts(
torch.stack([torch.tensor(c) for (c, _) in chains[lengths[b] - 1]]), C
)

return (
semiring.unconvert(ret),
[s for (_, s) in chains[-1]],
edges,
enum_lengths,
)
1 change: 0 additions & 1 deletion torch_struct/test_algorithms.py
Expand Up @@ -28,7 +28,6 @@ def test_simple(batch, N, C):
alpha = LinearChain(semiring).sum(vals)
assert (alpha == pow(C, N + 1)).all()
LinearChain(SampledSemiring).marginals(vals)

LinearChain(MultiSampledSemiring).marginals(vals)


Expand Down

0 comments on commit 65e2b19

Please sign in to comment.