# Fine-tuning ESM-2 with natively paired Ab sequences

#### architecture
[ESM-2](https://www.science.org/doi/10.1126/science.ade2574) is a state-of-the-art, general purpose protein LM that uses a modified BERT architecture. Aside from model size, the primary modification is the use of rotary position encoding ([RoPE](https://arxiv.org/abs/2104.09864)) rather than absolute position embedding. Due to compute constraints, we performed a full fine-tuning of the 650M-parameter variant of ESM-2.

#### dataset
We used all unique, productive paired sequences reported in [_Functional antibodies exhibit light chain coherence_](https://www.nature.com/articles/s41586-022-05371-z) (Jaffe et al, Nature 2022)
* [dataset DOI](https://plus.figshare.com/articles/dataset/Dataset_supporting_Functional_antibodies_exhibit_light_chain_coherence_/20338177) (figshare)
* the dataset was split in to train/eval/test subsets at a ratio of 90:5:5, which produced the following dataset sizes:
    * **train**: `1,202,269` paired sequences
    * **eval**: `66,791` paired sequences
    * **test**: `66,971` paired sequences  
    
Each input file contains a single paired antibody amino acid sequence per line, with the heavy and light chain sequences concatenated and separated by two `<cls>` tokens, like so:  
  > `HEAVY_CHAIN_AA_SEQUENCE<cls><cls>LIGHT_CHAIN_AA_SEQUENCE`

#### training
Full fine-tuning ESM-2 on eight NVIDIA A100 GPUs took approximately 10 days.

<br>  
  
## setup  

Fine-tuning ESM-2 requires several [huggingface](https://huggingface.co/) libraries. If they're not already installed, you can install them by uncommenting and running the following code block:

In [None]:
# !pip install transformers
# !pip install datasets
# !pip install accelerate

In [None]:
from datetime import date
import warnings
warnings.simplefilter('ignore')

from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

from datasets import load_dataset

## ESM-2 config

In [None]:
run_name = f"ESM-2_fine-tuning_{date.today().isoformat()}"
print(f"Run name: {run_name}")

esm_config = {
    "run_name": run_name,
    
    # training parameters
    "batch_size": 32,
    "max_steps": 150000,
    "warmup_steps": 30000,
    "save_steps": 50000,
    "logging_steps": 100,
    "eval_steps": 25000,
    "weight_decay": 0.01,
    "peak_learning_rate": 4e-4,
    "adam_epsilon": 1e-6,
    "adam_beta1": 0.9,
    "adam_beta2": 0.98,
    
    # outputs and logging
    "output_dir": f"./checkpoints/{run_name}",  # where the checkpoint data will be written
    "report_to": "wandb",  # enable logging to w&b
    "logging_dir": f"./logs/{run_name}",
}

<br>  
  
If you'd like to use [weights and biases](https://wandb.ai) for logging, uncomment and run the following code block:

In [None]:
# os.environ["WANDB_PROJECT"] = run_name
# balm_config["report_to"] = "wandb"

# import wandb
# wandb.login()

## model

In [None]:
model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [None]:
model_size = sum(p.numel() for p in model.parameters())
print(f"Model size: {model_size/1e6:.2f}M")

## load + tokenize data

In [None]:
%%bash
# download the train/eval/test data if it doesn't exist
if [ ! -d "./data/train-test-eval_paired" ]; then
    curl -o 'train-test-eval_paired.tar.gz' -L 'https://zenodo.org/record/8253367/files/train-test-eval_paired.tar.gz?download=1'
    tar xzvf 'train-test-eval_paired.tar.gz' -C ./data
    rm 'train-test-eval_paired.tar.gz'
fi

In [None]:
data_files = {
    "train": ['./data/train-test-eval_paired/train.txt'],
    "eval": ['./data/train-test-eval_paired/eval.txt'],
    "test": ['./data/train-test-eval_paired/test.txt']
}

dataset = load_dataset("text", data_files=data_files)

# reformat dataset so that HC and LC are seperated by <cls><cls> instead of </s>
dataset = dataset.map(lambda x: {"text": x["text"].replace("</s>", "<cls><cls>")})

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [None]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding="max_length",
        truncation=True,
        max_length=320,
        return_special_tokens_mask=True,
    ),
    remove_columns=["text"],
)

## data collator

In [None]:
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

## trainer

In [None]:
training_args = TrainingArguments(
    fp16=True,
    evaluation_strategy="steps",
    seed=42,
    per_device_train_batch_size=esm_config.get("batch_size", 32),
    per_device_eval_batch_size=esm_config.get("batch_size", 32),
    max_steps=esm_config.get("max_steps", 500000),
    save_steps=esm_config.get("save_steps", 50000),
    logging_steps=esm_config.get("logging_steps", 100),
    eval_steps=esm_config.get("eval_steps", 25000),
    adam_beta1=esm_config.get("adam_beta1", 0.9),
    adam_beta2=esm_config.get("adam_beta2", 0.98),
    adam_epsilon=esm_config.get("adam_epsilon", 1e-6),
    weight_decay=esm_config.get("weight_decay", 0.01),
    warmup_steps=esm_config.get("warmup_steps", 30000),
    learning_rate=esm_config.get("peak_learning_rate", 4e-4),
    gradient_accumulation_steps=esm_config.get("gradient_accumulation_steps", 1),
    
    # output and logging
    output_dir=esm_config.get("output_dir", f"./checkpoints/{run_name}"),
    overwrite_output_dir=True,
    logging_dir=esm_config.get("logging_dir", f"./logs/{run_name}"),
    report_to=esm_config.get("report_to", None),
    run_name=run_name,  # name of the W&B run
    logging_first_step=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"]
)

## train

In [None]:
trainer.train()

In [None]:
trainer.save_model(f"./models/{run_name}")

In [None]:
# wandb.finish()