# Fine-tuning ESM2 to Predict Mutant Effects on Binding

### Resources
- Data from this [paper](https://www.cell.com/cell/fulltext/S0092-8674(20)31003-5?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS0092867420310035%3Fshowall%3Dtrue#author-abstract), downloaded from this [repo](https://github.com/jbloomlab/SARS-CoV-2-RBD_DMS/tree/master)

As a benchmark, I'll first replicate the basic fine-tuning procedure with ESM2 to predict mutant effects on binding.

We will be predicting from a given sequence what the bind score is. This score is the change in log10K_D between wildtype and variant. Lower values mean lower affinity, higher values mean higher affinity. For example, -2 = 100-times weaker binding, 1 = 10-times stronger binding.

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
model_checkpoint = 'facebook/esm2_t12_35M_UR50D'

In [17]:
import pandas as pd

train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

In [18]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [19]:
train_tokenized = tokenizer(train.sequences.to_list())
test_tokenized = tokenizer(test.sequences.to_list())

In [20]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

In [21]:
train_dataset = train_dataset.add_column('labels', train.bind_score.to_list())
test_dataset = test_dataset.add_column('labels', test.bind_score.to_list())

In [13]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model_name = model_checkpoint.split('/')[-1]
batch_size = 8

args = TrainingArguments(
    f'{model_name}-finetuned-regression',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    weight_decay=0.01,
    seed=69,
    load_best_model_at_end=True,
    metric_for_best_model='spearmanr',
)

In [15]:
from evaluate import load

metric = load('spearmanr')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return metric.compute(predictions=predictions, references=labels)

In [22]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,Spearmanr
1,No log,1.896662,0.264815
2,1.941200,1.939587,0.345828
3,1.948200,3.275375,0.460322
4,1.732600,1.300373,0.570217
5,1.367500,1.437272,0.645552
6,1.008900,0.71746,0.689318
7,0.858100,0.671648,0.708898
8,0.853800,0.588366,0.7264
9,0.769100,0.552489,0.709654
10,0.578400,0.531976,0.790295


TrainOutput(global_step=9020, training_loss=0.7731567820794302, metrics={'train_runtime': 1005.936, 'train_samples_per_second': 71.615, 'train_steps_per_second': 8.967, 'total_flos': 2938135634478240.0, 'train_loss': 0.7731567820794302, 'epoch': 20.0})

Resume training

In [32]:
trainer.args.num_train_epochs = 50  
trainer.train(resume_from_checkpoint=True)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

Epoch,Training Loss,Validation Loss,Spearmanr
41,0.0815,0.390831,0.89103
42,0.1216,0.348801,0.894317
43,0.1107,0.293054,0.900565
44,0.1029,0.326584,0.897693
45,0.0944,0.395043,0.880987
46,0.0789,0.309577,0.906141
47,0.0777,0.305595,0.898411
48,0.0769,0.311686,0.903546
49,0.0544,0.304756,0.905195
50,0.058,0.29815,0.905757


TrainOutput(global_step=22550, training_loss=0.017094619882080348, metrics={'train_runtime': 459.7037, 'train_samples_per_second': 391.774, 'train_steps_per_second': 49.053, 'total_flos': 7345339086195600.0, 'train_loss': 0.017094619882080348, 'epoch': 50.0})

In [33]:
model_save_path = f'{model_name}_single_mut_effects_regression.model'
torch.save(model.state_dict(), model_save_path)

### Plots

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [34]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# set up model inference
model_checkpoint = 'facebook/esm2_t12_35M_UR50D'
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

model_name = model_checkpoint.split('/')[-1]
model_save_path = f'{model_name}_single_mut_effects_regression.model'
model.load_state_dict(torch.load(model_save_path))
model.eval()

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serializat

EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-0

In [35]:
import pandas as pd

test = pd.read_csv('test.csv')

In [36]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def predict_score(sequence):
    inputs = tokenizer(sequence, return_tensors="pt")
    with torch.no_grad():
        prediction = model(**inputs).logits.item()
    return prediction

predictions = []

for seq in test.sequences:
    predictions.append(predict_score(seq))

In [37]:
df_pred = pd.DataFrame(columns=['sequence','true_score','predicted_score'])
df_pred.sequence = test.sequences
df_pred.true_score = test.bind_score
df_pred.predicted_score = predictions
df_pred

Unnamed: 0,sequence,true_score,predicted_score
0,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-1.46,-0.384313
1,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-2.82,-3.164898
2,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.15,-0.060463
3,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.05,-0.004087
4,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,0.00,-0.005046
...,...,...,...
396,NITNLCPFGEVFNATRCASVYAWNRKRISNCVADYSVLYNSASFST...,-1.00,-0.462451
397,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-4.57,-3.280554
398,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-1.41,-2.781244
399,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.13,-0.037748


In [38]:
binary_bind = []
for score in df_pred.true_score:
    if score > 0:
        binary_bind.append(1)
    else:
        binary_bind.append(0)

df_pred['binary_bind'] = binary_bind
df_pred

Unnamed: 0,sequence,true_score,predicted_score,binary_bind
0,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-1.46,-0.384313,0
1,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-2.82,-3.164898,0
2,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.15,-0.060463,0
3,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.05,-0.004087,0
4,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,0.00,-0.005046,0
...,...,...,...,...
396,NITNLCPFGEVFNATRCASVYAWNRKRISNCVADYSVLYNSASFST...,-1.00,-0.462451,0
397,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-4.57,-3.280554,0
398,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-1.41,-2.781244,0
399,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,-0.13,-0.037748,0


In [39]:
import plotly.express as px

fig = px.scatter(df_pred, x='predicted_score', y='true_score', color='binary_bind', trendline='ols')
fig.show()