Skip to content
Fast, general, and tested differentiable structured prediction in PyTorch
Python Jupyter Notebook Cuda Other
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
README_files
docs Update model.ipynb Nov 1, 2019
examples Port RL Example to use new API (#16) Oct 8, 2019
kernels Rewrite the semiring classes and integrated in new kernels (#37) Nov 13, 2019
notebooks . Oct 28, 2019
torch_struct Rewrite the semiring classes and integrated in new kernels (#37) Nov 13, 2019
.gitignore Initial commit Aug 26, 2019
.travis.yml Update docs to describe models (#31) Oct 27, 2019
LICENSE Initial commit Aug 26, 2019
README.md Update README.md Nov 13, 2019
download.png Add files via upload Aug 28, 2019
github_deploy_key_harvardnlp_pytorch_struct.enc . Oct 6, 2019
requirements.dev.txt Update requirements.dev.txt Oct 28, 2019
requirements.txt . Aug 27, 2019
setup.cfg Add some utilities for RL (#14) Oct 4, 2019
setup.py Rewrite the semiring classes and integrated in new kernels (#37) Nov 13, 2019

README.md

Pytorch-Struct

Build Status Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

You can’t perform that action at this time.