# training BALM-unpaired

#### architecture
BALM-unpaired is built on the RoBERTa-large architecture, with the following hyperparameter modifications:
* max input length of **256**, which was selected to be suitable for paired sequences and to be 0.5x the size of BALM-paired
* per-GPU batch size of **64**, which is 2x the batch size of BALM-paired to equalize training steps between the two models
* **500k** training steps, which should be roughly 200 epochs when trained on eight GPUs for a total batch size of **512**

#### dataset
We used all unique, productive paired sequences reported in [_Functional antibodies exhibit light chain coherence_](https://www.nature.com/articles/s41586-022-05371-z) (Jaffe et al, Nature 2022)
* [dataset DOI](https://plus.figshare.com/articles/dataset/Dataset_supporting_Functional_antibodies_exhibit_light_chain_coherence_/20338177) (figshare)
* the dataset was split in to train/eval/test subsets at a ratio of 90:5:5, which produced the following dataset sizes:
    * **train**: `2,404,538` unpaired sequences
    * **eval**: `133,582` unpaired sequences
    * **test**: `133,582` unpaired sequences  

Each input file contains a single antibody heavy or light chain amino acid sequence per line.

#### training
Training BALM-unpaired on eight NVIDIA A100 GPUs took approximately 5 days.

<br>  
  
## setup  

Training BALM-unpaired requires several [huggingface](https://huggingface.co/) libraries. If they're not already installed, you can install them by uncommenting and running the following code block:

In [None]:
# !pip install transformers
# !pip install datasets
# !pip install accelerate

In [None]:
from datetime import date
import os

from transformers import (
    RobertaConfig,
    RobertaTokenizer,
    RobertaForMaskedLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

from datasets import load_dataset

## BALM config

In [None]:
run_name = f"BALM-unpaired_lc-coherence-data_90-5-5-split_{date.today().isoformat()}"
print(f"Run name: {run_name}")

balm_config = {
    "run_name": run_name,
    
    # model architecture
    "num_hidden_layers": 24,
    "num_attention_heads": 16,
    "hidden_size": 1024,
    "intermediate_size": 4096,
    "vocab_size": 25,
    "max_len": 256,
    "max_position_embeddings": 258,
    
    # tokenizer
    "padding": "max_length",
    "truncate": True,
    "return_special_tokens_mask": True,
    
    # training parameters
    "batch_size": 64,
    "max_steps": 500000,
    "warmup_steps": 30000,
    "weight_decay": 0.01,
    "peak_learning_rate": 4e-4,
    "adam_epsilon": 1e-6,
    "adam_beta1": 0.9,
    "adam_beta2": 0.98,
    "type_vocab_size": 1,  # this should be 2 for paired/mixed models, 1 for unpaired models
    "fp16": True,
    "evaluation_strategy": "steps",
    "seed": 42,
    
    # outputs and logging
    "save_steps": 100000,
    "eval_steps": 25000,
    "output_dir": f"./checkpoints/{run_name}",  # where the checkpoint data will be written
    "logging_dir": f"./logs/{run_name}",
    "logging_steps": 100,
    "overwrite_output_dir": True,
    "logging_first_step": True,
}

<br>  
  
If you'd like to use [weights and biases](https://wandb.ai) for logging, uncomment and run the following code block:

In [None]:
# os.environ["WANDB_PROJECT"] = run_name
# balm_config["report_to"] = "wandb"

# import wandb
# wandb.login()

## model

In [None]:
# initialize the model using the BALM config dictionary
# defaults are based on what was used in the paper
model_config = RobertaConfig(
    vocab_size=balm_config.get("vocab_size", 25),
    hidden_size=balm_config.get("hidden_size", 1024),
    intermediate_size=balm_config.get("intermediate_size", 4096),
    max_position_embeddings=balm_config.get("max_position_embeddings", 256),
    num_hidden_layers=balm_config.get("num_hidden_layers", 24),
    num_attention_heads=balm_config.get("num_attention_heads", 16),
    type_vocab_size=balm_config.get("type_vocab_size", 1),
)
    
model = RobertaForMaskedLM(model_config)

In [None]:
model_size = sum(p.numel() for p in model.parameters())
print(f"Model size: {model_size/1e6:.2f}M")

## load data

In [None]:
%%bash
# download the train/eval/test data if it doesn't exist
if [ ! -d "./data/train-test-eval_unpaired" ]; then
    curl -o 'train-test-eval_unpaired.tar.gz' -L 'https://zenodo.org/record/8253367/files/train-test-eval_unpaired.tar.gz?download=1'
    tar xzvf 'train-test-eval_unpaired.tar.gz' -C ./data
    rm 'train-test-eval_unpaired.tar.gz'
fi

In [None]:
# load the tran, eval, and test data
data_files = {
    "train": ['./data/train-test-eval_unpaired/train.txt'],
    "eval": ['./data/train-test-eval_unpaired/eval.txt'],
    "test": ['./data/train-test-eval_unpaired/test.txt']
}

dataset = load_dataset("text", data_files=data_files)

## data tokenization

In [None]:
tokenizer = RobertaTokenizer.from_pretrained(
    "tokenizer"
)

In [None]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=balm_config.get("padding", "max_length"),
        truncation=balm_config.get("truncation", True),
        max_length=balm_config.get("max_len", 256),
        return_special_tokens_mask=balm_config.get("return_special_tokens_mask", True),
    ),
    remove_columns=["text"],
)

## data collator

In [None]:
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

## trainer

In [None]:
training_args = TrainingArguments(
    fp16=balm_config.get("fp16", True),
    evaluation_strategy=balm_config.get("evaluation_strategy", "steps"),
    seed=balm_config.get("seed", 42),
    per_device_train_batch_size=balm_config.get("batch_size", 64),
    per_device_eval_batch_size=balm_config.get("batch_size", 64),
    max_steps=balm_config.get("max_steps", 500000),
    save_steps=balm_config.get("save_steps", 100000),
    logging_steps=balm_config.get("logging_steps", 100),
    eval_steps=balm_config.get("eval_steps", 25000),
    adam_beta1=balm_config.get("adam_beta1", 0.9),
    adam_beta2=balm_config.get("adam_beta2", 0.98),
    adam_epsilon=balm_config.get("adam_epsilon", 1e-6),
    weight_decay=balm_config.get("weight_decay", 0.01),
    warmup_steps=balm_config.get("warmup_steps", 30000),
    learning_rate=balm_config.get("peak_learning_rate", 4e-4),
    gradient_accumulation_steps=balm_config.get("gradient_accumulation_steps", 1),
    
    # output and logging
    run_name=balm_config.get("run_name", None),
    output_dir=balm_config.get("output_dir", f"./checkpoints/{run_name}"),
    overwrite_output_dir=balm_config.get("overwrite_output_dir", True),
    logging_dir=balm_config.get("logging_dir", f"./logs/{run_name}"),
    report_to=balm_config.get("report_to", None),
    logging_first_step=balm_config.get("logging_first_step", True),
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"]
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(f"../models/{run_name}")

In [None]:
wandb.finish()