Skip to content

mit-han-lab/lite-transformer

master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
April 15, 2021 12:18
April 25, 2020 00:52
April 25, 2020 00:52
April 25, 2020 00:52
April 25, 2020 00:52
December 17, 2021 11:38
April 25, 2020 00:52
April 25, 2020 00:52
April 25, 2020 00:52
April 25, 2020 00:52
April 25, 2020 00:52

Lite Transformer with Long-Short Range Attention

@inproceedings{Wu2020LiteTransformer,
  title={Lite Transformer with Long-Short Range Attention},
  author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

Overview

We release the PyTorch code for the Lite Transformer. [Paper|Website|Slides]: overview

Consistent Improvement by Tradeoff Curves

tradeoff

Save 20000x Searching Cost of Evolved Transformer

et

Further Compress Transformer by 18.2x

compression

How to Use

Prerequisite

  • Python version >= 3.6
  • PyTorch version >= 1.0.0
  • configargparse >= 0.14
  • For training new models, you'll also need an NVIDIA GPU and NCCL

Installation

  1. Codebase

    To install fairseq from source and develop locally:

    pip install --editable .
  2. Costumized Modules

    We also need to build the lightconv and dynamicconv for GPU support.

    Lightconv_layer

    cd fairseq/modules/lightconv_layer
    python cuda_function_gen.py
    python setup.py install

    Dynamicconv_layer

    cd fairseq/modules/dynamicconv_layer
    python cuda_function_gen.py
    python setup.py install

Data Preparation

IWSLT'14 De-En

We follow the data preparation in fairseq. To download and preprocess the data, one can run

bash configs/iwslt14.de-en/prepare.sh

WMT'14 En-Fr

We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

bash configs/wmt14.en-fr/prepare.sh

WMT'16 En-De

We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run

bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]

WIKITEXT-103

As the language model task has many additional codes, we place it in another branch: language-model. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

git checkout language-model
bash configs/wikitext-103/prepare.sh

Testing

For example, to test the models on WMT'14 En-Fr, one can run

configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]

For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run

configs/wmt14.en-fr/test.sh embed496/ 0 test

We provide several pretrained models at the bottom. You can download the model and extract the file by

tar -xzvf [filename]

Training

We provided several examples to train Lite Transformer with this repo:

To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run

python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml

To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32

In general, to train a model, one can run

python train.py [path to the data binary] --configs [path to config file] [override options]

Note that --update-freq should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).

Distributed Training (optional)

To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.

# On host1
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=0 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8
# On host2
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=1 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8

Models

We provide the checkpoints for our Lite Transformer reported in the paper:

Dataset #Mult-Adds Test Score Model and Test Set
WMT'14 En-Fr 90M 35.3 download
360M 39.1 download
527M 39.6 download
WMT'16 En-De 90M 22.5 download
360M 25.6 download
527M 26.5 download
CNN / DailyMail 800M 38.3 (R-L) download
WIKITEXT-103 1147M 22.2 (PPL) download

About

[ICLR 2020] Lite Transformer with Long-Short Range Attention

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published