# Fine-tuning ESM2 to Predict Mutant Effects on Binding

### Resources
- Data from this [paper](https://www.cell.com/cell/fulltext/S0092-8674(20)31003-5?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS0092867420310035%3Fshowall%3Dtrue#author-abstract), downloaded from this [repo](https://github.com/jbloomlab/SARS-CoV-2-RBD_DMS/tree/master)

As a benchmark, I'll first replicate the basic fine-tuning procedure with ESM2 to predict mutant effects on binding.

In [15]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model_checkpoint = 'facebook/esm2_t12_35M_UR50D'

In [5]:
import pandas as pd

train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [7]:
train_tokenized = tokenizer(train.sequences.to_list())
test_tokenized = tokenizer(test.sequences.to_list())

In [8]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

In [9]:
train_dataset = train_dataset.add_column('labels', train.bind_score.to_list())
test_dataset = test_dataset.add_column('labels', test.bind_score.to_list())

In [10]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1)

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
model_name = model_checkpoint.split('/')[-1]
batch_size = 8

args = TrainingArguments(
    f'{model_name}-finetuned-regression',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    weight_decay=0.01,
    seed=69,
    load_best_model_at_end=True,
    metric_for_best_model='spearmanr',
)

In [19]:
from evaluate import load

metric = load('spearmanr')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return metric.compute(predictions=predictions, references=labels)

In [20]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,Spearmanr
1,No log,1.896662,0.264815
2,1.941200,1.939587,0.345828
3,1.948200,3.275375,0.460322
4,1.732600,1.300373,0.570217
5,1.367500,1.437272,0.645552
6,1.008900,0.71746,0.689318
7,0.858100,0.671648,0.708898
8,0.853800,0.588366,0.7264
9,0.769100,0.552489,0.709654
10,0.578400,0.531976,0.790295


TrainOutput(global_step=9020, training_loss=0.7731567820794302, metrics={'train_runtime': 1005.936, 'train_samples_per_second': 71.615, 'train_steps_per_second': 8.967, 'total_flos': 2938135634478240.0, 'train_loss': 0.7731567820794302, 'epoch': 20.0})

In [22]:
model_save_path = f'{model_name}_single_mut_effects_regression.model'
torch.save(model.state_dict(), model_save_path)