Trellis Networks for Sequence Modeling
Clone or download
jerrybai1995 Merge pull request #1 from ip01/master
Transition to be pytorch v1.0 compatible
Latest commit 2a8fa3d Dec 7, 2018
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
TrellisNet Merge branch 'master' into master Dec 7, 2018
.gitignore Updated pre-trained weights loading code; README updated; version com… Nov 26, 2018
LICENSE Create MIT License Nov 1, 2018
README.md Minor edits to README Nov 26, 2018

README.md

TrellisNet for Sequence Modeling

This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico Kolter and Vladlen Koltun.

On the one hand, a trellis network is a temporal convolutional network with special structure, characterized by weight tying across depth and direct injection of the input into deep layers. On the other hand, we show that truncated recurrent networks are equivalent to trellis networks with special sparsity structure in their weight matrices. Thus trellis networks with general weight matrices generalize truncated recurrent networks. This allows trellis networks to serve as bridge between recurrent and convolutional architectures, benefitting from algorithmic and architectural techniques developed in either context. We leverage these relationships to design high-performing trellis networks that absorb ideas from both architectural families. Experiments demonstrate that trellis networks outperform the current state of the art on a variety of challenging benchmarks, including word-level language modeling on Penn Treebank and WikiText-103, character-level language modeling on Penn Treebank, and stress tests designed to evaluate long-term memory retention.

Our experiments were done in PyTorch. If you find our work, or this repository helpful, please consider citing our work:

@article{BaiTrellis2018,
	author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
	title     = {Trellis Networks for Sequence Modeling},
	journal   = {arXiv:1810.06682},
	year      = {2018},
}

Datasets

The code should be directly runnable with PyTorch 0.4.0 (although slight modifications may be needed for other versions). This repository contains the training script for the following tasks:

  • Sequential MNIST handwritten digit classification
  • Permuted Sequential MNIST that randomly permutes the pixel order in sequential MNIST
  • Sequential CIFAR-10 classification (more challenging, due to more intra-class variations, channel complexities and larger images)
  • Penn Treebank (PTB) word-level language modeling (with and without the mixture of softmax); vocabulary size 10K
  • Wikitext-103 (WT103) large-scale word-level language modeling; vocabulary size 268K
  • Penn Treebank medium-scale character-level language modeling

Note that these tasks are on very different scales, with unique properties that challenge sequence models in different ways. For example, word-level PTB is a small dataset that a typical model easily overfits, so judicious regularization is essential. WT103 is a hundred times larger, with less danger of overfitting, but with a vocabulary size of 268K that makes training more challenging (due to large embedding size).

Pre-trained Model(s)

We provide some reasonably good pre-trained weights here so that the users don't need to train from scratch. We'll update the table from time to time. Note: 1) it could take a while to load the weights; 2) if you train from scratch using different seeds (which we didn't explore), it's likely you will get better results :-)

Description Task Dataset Model
TrellisNet-LM Language Modeling Penn Treebank (PTB) download (.pkl)

To use the pre-trained weights, use the flag --load_weight [.pkl PATH] when starting the training script.

Usage

All tasks share the same underlying TrellisNet model, which is in file trellisnet.py (and the eventual models, including components like embedding layer, are in model.py). As discussed in the paper, TrellisNet is able to benefit significantly from techniques developed originally for RNNs as well as temporal convolutional networks (TCNs). Some of these techniques are also included in this repository. Each task is organized in the following structure:

[TASK_NAME] /
    data/
    logs/
    [TASK_NAME].py
    model.py
    utils.py
    data.py

where [TASK_NAME].py is the training script for the task (with argument flags; use -h to see the details).