<a href="https://colab.research.google.com/github/cai91/protein-language-models/blob/main/Protein_Language_Modeling_with_fairseq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Protein Language Modeling with fairseq

This is a quick walkthrough of how to set up a protein language modeling run using [pytorch/fairseq](https://github.com/pytorch/fairseq), an open-source langauge modeling toolkit developed by FAIR. Recently, fairseq was updated to allow directly training from fasta files, which makes setting up language modeling very easy!



## What you'll need

### Packages

*   numpy
*   pytorch
*   fairseq
*   biopython (used in the tutorial to process fasta files)
*   apex (optional, can speed up training)

Colab comes with numpy and pytorch installed, and we won't be using apex in the tutorial, so we're only going to install biopython and fairseq. If you're training on your own devices you may need to install numpy + pytorch separately.

In [None]:
%%bash
pushd . 
pip install biopython
git clone -q https://github.com/pytorch/fairseq
cd fairseq
pip install -e .
popd

/content /content
Obtaining file:///content/fairseq
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Installing backend dependencies: started
  Installing backend dependencies: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'
Collecting hydra-core
  Downloading https://files.pythonhosted.org/packages/79/03/fee705ef16675a103d8e929255f5fa0ee79432ac38bafad6935d6ad170f9/hydra_core-1.0.3-py3-none-any.whl (122kB)
Collecting sacrebleu>=1.4.12
  Downloading https://files.pythonhosted.org/packages/a3/c4/8e948f601a4f9609e8b2b58f31966cb13cf17b940b82aa3e767f01c42c52/sacrebleu-1.4.14-py3-none-any.whl (64kB)
Collecting omegaconf>=2.0.2
  Downloading https://files.pythonhosted.org/packages/29/08/a88210c2c1aa0a3f65f05d8a6c98939ccb84b6fb982aa6567dec4e6773f9/omega

#### Slightly Hacking Colab

There is an issue with Colab and pip editable installs (see [this issue](https://github.com/pytorch/fairseq/issues/2407)). To work around this, we'll add the fairseq directory directly to `sys.path`. You can remove this step if you are following this tutorial outside of a notebook.

In [None]:
import sys
sys.path.insert(0, "/content/fairseq")

### Data
For data, you're going to need a fasta file containing the set of sequences you want to train on. For the tutorial, we're going to use the 15051 sequences used to train [trRosetta](https://www.pnas.org/content/117/3/1496). This is a lot fewer sequences than you would actually want to use to train a language model, so this is just for demonstration purposes.

In [None]:
%%bash
pushd .
mkdir -p data && cd data
wget https://s3.amazonaws.com/proteindata/list15051.fasta -o /dev/null
popd

/content /content
/content


### Compute
Fairseq will automatically distribute to as many GPUs as you have available. If you want to limit the number of GPUs used, set the `CUDA_VISIBLE_DEVICES` environment variable. I often set only one visible device when debugging. 

With a bit more work it can also be set up to train on multiple nodes, but we won't be covering how to do that here.

## Setting Things Up

fairseq training requires three data files: `train.fasta`, `valid.fasta`, and `dict.txt`. The first two are simply training and validation splits of the full set of fasta sequences, while `dict.txt` is the dictionary that should be used to map tokens to integers. We'll set all three of these files up now.

### Training + Validation Splits

We're just going to use a 95/5 random-split training and validation set. Depending on your problem setting, you may want to use specific sequence identity thresholds for holdouts, structural holdouts, or some other form of splitting. 

If your files are very large, this particular method of generating training + validation splits may be slow + memory intensive.

In [None]:
from Bio import SeqIO
import numpy as np

all_records = list(SeqIO.parse("data/list15051.fasta", "fasta"))
np.random.shuffle(all_records)
valid_pct = 0.05
num_valid_records = int(len(all_records) * valid_pct)
valid_records = all_records[:num_valid_records]
train_records = all_records[num_valid_records:]

SeqIO.write(train_records, "data/train.fasta", "fasta")
SeqIO.write(valid_records, "data/valid.fasta", "fasta")

752

### The Dictionary

There are two ways to set up a dictionary, automatic and manual. Since these dictionaries are originally built for processing language, there are usually too many words to manually specify the dictionary. However in the case of protein sequences or nucleotide sequences, this is not necessarily true.

The code below will automatically parse through the records loaded in the previous section to add amino acids to the dictionary. In addition, it will also add a count of the frequency that each token appears. This count is not used, but is a nice piece of metadata to have.

In [None]:
from fairseq.data import Dictionary
from collections import Counter

token_counter = Counter()
for record in train_records:
  token_counter.update(record.seq)

dictionary = Dictionary()
for token, count in sorted(token_counter.items()):
  dictionary.add_symbol(token, count)

with open("data/dict.txt", "w") as f:
  dictionary.save(f)

## Training the Model

And now you're ready to train a protein language model! The command for fairseq training is, well, `fairseq-train`. You can see a list of all options via `fairseq-train --help`. Let's just walk through the ones we're going to use below.

### Anatomy of the fairseq-train command

#### Basic Arguments

There are a few basic arguments that you'll need to pass in to almost every `fairseq-train` command. 

*    `data`: This is actually a position argument, corresponding to the directory containing the training data files. In our case, this is going to be `./data`
*    `--dataset-impl`: This tells fairseq what type of dataset to load. In our case, this is going to be `fasta`.
*    `--task`: This is a fairseq construct that loads data, performs any necessary transformations, and logs metrics. To train with masked language modeling, we'll use the `masked_lm` task.
*    `--criterion`: This is a fairseq construct that computes the loss and metrics given model outputs and targets. This can sometimes be swapped around, depending on the task you're interested in, but here we're just going to use the `masked_lm` criterion as well.
*    `--arch`: This is the model architecture to use. Fairseq has several built-in model architectures, and you can also register your own. We're going to use the 12 layer `roberta_base` architecture.

Check out the fairseq docs for a list of the dataset implementations, tasks, criteria, and architectures built in!

#### Training Arguments

These are a set of common arguments that you'll want to set to train the model

*    `--max-tokens`: The maximum number of tokens to feed into the model on any given batch. Fairseq uses adaptive batch sizes to allow larger batches with smaller sequences and smaller batches for larger sequences. Divide `--max-tokens` by the average sequence length in your dataset to get an approximate batch size. I usually set this to the maximum power of 2 that will fit on GPU.
*    `--max-sentences`: This sets an actual maximum batch size - use this instead of `--max-tokens` if you want a fixed batch size.
*    `--update-freq`: How many forward passes to run before taking a backwards pass. This lets you simulate a larger batch size. I typically debug runs with this set to 1, then increase it to simulate the batch size I actually want when I'm ready to train.
*    `--lr`: The learning rate. Usually default to `1e-4`.
*    `--optimizer`: The optimizer. Usually default to `adam`.
*    `--lr-scheduler`: Learning rate scheduler. There are several built into fairseq, I usually use `inverse_sqrt`.
*    `--warmup-updates`: How many learning rate warmup steps to use. Usually default to `16000` for large models + lots of data.
*    `--max-positions`: The maximum number of positions to pass into the model. Will skip sequences longer than this.
*    `--skip-invalid-size-inputs-valid-test`: If you set max positions, you will need to set this as well to skip those sequences during validation, otherwise fairseq will complain.

#### Apex

These arguments require [nvidia/apex](https://github.com/nvidia/apex) to be installed.

*   `--fp16`: Use half precision (speeds up training if GPU supports it, reduces memory).
*   `--fp16-init-scale`: Initial fp16 loss scale. I usually default to 4.

#### Saving Arguments

*   `--validate-interval-updates`: Fairseq by default will run the validation pass at the end of each epoch, but this can be very infrequent if you have a large dataset. This allows you to also run the validation pass after a certain number of updates.
*   `--save-interval-updates`: Fairseq by default will save a checkpoint at the end of each epoch, but this can be very infrequent if you have a large dataset. This allows you to also save a checkpoint after a certain number of updates.


### Running the Training

Below is the full fairseq-train command with some defaults filled in.

In [None]:
!fairseq-train \

  # Basic Arguments
  ./data/ \
  --dataset-impl fasta \
  --task masked_lm \
  --criterion masked_lm \
  --arch roberta_base \

  # Training Arguments
  --max-tokens 8096 \
  --update-freq 4 \
  --lr 1e-4 \
  --optimizer adam \
  --lr-scheduler inverse_sqrt \
  --warmup-updates 16000 \
  --max-positions 1024 \
  --skip-invalid-size-inputs-valid-test \

  # Requires Apex
  # --fp16 \
  # --fp16-init-scale 4 \

  # Saving Arguments
  --validate-interval-updates 5000 \
  --save-interval-updates 5000 \
  

@hydra.main(strict) flag is deprecated and will removed in the next version.
See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/strict_mode_flag_deprecated
2020-11-03 20:13:55 | INFO | fairseq_cli.train | {'common': {'no_progress_bar': False, 'log_interval': 100, 'log_format': None, 'tensorboard_logdir': None, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'memory_efficient_bf16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 4, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'user_dir': None, 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False}, 'distributed_training': {'distributed_rank': 0, 'distributed_backend': 'nccl', 'distributed_init_method': None, 'distributed_port': -1, 'device_id': 0, 'local_rank': 0, 'distributed_no_spawn': False, 'ddp_backend': 'no_c10d', 'bucket_cap_