# Finetuning an LoRA of ESM-2 for Protein Binding Site Prediction

This tutorial is based on the Hugging Face Article by AmelieSchreiber (https://huggingface.co/blog/AmelieSchreiber/esmbind), where she discusses the low-rank adaptation (LoRA) of ESM-2 for predicting protein binding site. The curated dataset can either be downloaded from Hugging Face directly or generated using the `data_preprocessing_notebook.ipynb`. 

## ESM-2 pLM: 

The ESM-2 protein language model, along with ESMFold, offers a significant advancement in protein sequence analysis by eliminating the need for Multiple Sequence Alignment (MSA) in its predictions. This simplification leads to a more user-friendly experience, requiring less specialized knowledge and enabling quicker results. Remarkably, these models achieve performance on par with or superior to AlphaFold2, yet they operate up to 60 times faster. Additionally, ESM-2 does not depend on structural data, which is particularly advantageous given that many proteins lack characterized 3D structures. The rapid structural predictions from ESMFold are contributing to a growing repository of protein structures, as seen in the expansive Metagenomic Atlas. Despite their impressive speed and accuracy, ESM-2 and its counterparts have not yet reached the popularity of AlphaFold2, but their value is increasingly recognized. Users with varying levels of expertise in deep learning and protein science are encouraged to explore these models, potentially enhancing their performance by training their own LoRA models or refining data, especially those familiar with databases like UniProt.

## Low-Rank Adaptation of ESM-2:

Low-Rank Adaptation (LoRA) of ESM-2, which stands for Evolutionary Scale Modeling 2, represents a cutting-edge approach to fine-tuning large-scale language models specifically tailored for biological sequences. LoRA introduces a novel method to adapt pre-trained models by updating only a small subset of model parameters, thereby significantly reducing the computational overhead typically associated with training such expansive networks. By focusing on the low-rank decomposition of weight matrices within the transformer layers, LoRA strategically modifies the model to capture task-specific nuances without the need for extensive retraining. This technique is particularly beneficial for ESM-2, a model designed to understand and predict protein structure and function, as it allows researchers to efficiently adapt the model to new datasets or emerging biological problems. 

The application of LoRA to ESM-2 fine-tuning can also serve as a potent strategy to mitigate overfitting, a common challenge when adapting large models to specific tasks with limited data. By constraining updates to a low-rank subspace of the model's parameters, LoRA effectively regularizes the fine-tuning process, reducing the model's capacity to memorize training data and instead promoting the learning of generalizable patterns. This targeted update approach not only preserves the rich representations learned during pre-training but also ensures that the fine-tuning does not deviate excessively from the original parameter space, which is instrumental in maintaining the model's robustness to unseen data. Consequently, LoRA enables ESM-2 to maintain its predictive power across diverse biological sequences, making it a valuable tool for researchers aiming to leverage deep learning without succumbing to the pitfalls of overfitting in specialized domains.

In [12]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import wandb
import numpy as np
import torch
import torch.nn as nn
import pickle
import xml.etree.ElementTree as ET
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score, 
    matthews_corrcoef
)
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
from accelerate import Accelerator
# Imports specific to the custom peft lora model
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


## Defining metrics and losses:

The cell below outlines a series of steps for preparing and processing protein sequence data for predicting the binding site in a protein from sequence information alone, specifically token classification using a pre-trained model like ESM-2. Initially, it defines helper functions to truncate labels to a maximum length, compute various evaluation metrics (accuracy, precision, recall, F1 score, AUC, and MCC), and calculate custom loss during training. The code then proceeds to load protein sequence data and corresponding labels from pickle files, which are presumably chunked by protein family for a more organized dataset.

Tokenization is performed on the sequences using a tokenizer pre-trained on the ESM-2 model, with padding and truncation applied to ensure a consistent sequence length. The labels are truncated to match the tokenized sequences' length. These processed sequences and labels are then converted into datasets suitable for training and evaluation.

Furthermore, the code calculates class weights to address potential imbalances in the label distribution, which can be crucial for ensuring fair representation during model training. These weights are adjusted for the hardware accelerator in use, which could be a CPU or GPU, to optimize computational performance. The Accelerator class from the accelerate library is used to facilitate this hardware optimization, ensuring that the class weights are compatible with the device being used for training.

Download the .pkl dataset files from the [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family/tree/main).


In [7]:
# Helper Functions and Data Preparation
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    
    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    
    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)
    
    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    
    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions) 
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} 

def compute_loss(model, inputs):
    """Custom compute_loss function."""
    logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss(weight=class_weights)
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

# Load the data from pickle files (replace with your local paths)
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)

with open("test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)

with open("train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)

with open("test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# Tokenization
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
max_sequence_length = 1000

train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Directly truncate the entire list of labels
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

# Compute Class Weights
classes = [0, 1]  
flat_train_labels = [label for sublist in train_labels for label in sublist]
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
accelerator = Accelerator()
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)


## Define Custom Trainer Class:

In the context of protein sequence analysis, particularly for predicting binding site residues, datasets often exhibit a significant class imbalance where the number of non-binding residues vastly outnumbers the binding residues. This imbalance can skew the model's learning process, leading to suboptimal performance, as the model might become biased towards predicting the majority class. To address this issue, a custom `WeightedTrainer` class is defined, inheriting from the `Trainer` class. This custom trainer overrides the `compute_loss` method to incorporate a tailored loss function that takes into account the class weights computed from the dataset. By doing so, the model is penalized more for misclassifying the minority class, thereby encouraging the model to learn a more balanced representation of both binding and non-binding sites. This approach helps to improve the model's predictive accuracy on the underrepresented class, which is critical for tasks where identifying the less frequent, yet functionally significant, binding sites is the primary objective.

In [8]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = compute_loss(model, inputs)
        return (loss, outputs) if return_outputs else loss

## Training the LoRA ESM-2:

The training function is designed to fine-tune the ESM-2 protein language model for the task of token classification, specifically identifying binding sites within protein sequences. It leverages the Low-Rank Adaptation (LoRA) technique, which allows for the adjustment of key hyperparameters such as the rank (`r`) and scaling factor (`lora_alpha`). Users are encouraged to experiment with these settings, as well as the choice of weight matrices to apply LoRA to, in order to optimize the model's performance on the dataset. The function sets up a training environment with a pre-defined configuration, including learning rate, learning rate scheduler, and gradient clipping, among others. It uses a custom `WeightedTrainer` class that accounts for class imbalance by using a specialized loss function. The model, along with the datasets, is prepared for training using an accelerator to ensure compatibility with the available hardware. Training arguments are specified to control various aspects of the training process, such as the number of epochs, batch size, and evaluation strategy. The best model is saved based on the F1 score metric, and the tokenizer is also saved for future use. This setup not only aims to achieve high accuracy but also encourages reproducibility and ease of use by saving the trained model with a timestamped directory.

In [11]:
from accelerate import Accelerator
accelerator = Accelerator()

def train_function_no_sweeps(train_dataset, test_dataset):
    
    # Set the LoRA config
    config = {
        "lora_alpha": 1, #try 0.5, 1, 2, ..., 16
        "lora_dropout": 0.2,
        "lr": 5.701568055793089e-04,
        "lr_scheduler_type": "cosine",
        "max_grad_norm": 0.5,
        "num_train_epochs": 3,
        "per_device_train_batch_size": 4,
        "r": 2,
        "weight_decay": 0.2,
        # Add other hyperparameters as needed
    }
    # The base model you will train a LoRA on top of
    model_checkpoint = "facebook/esm2_t12_35M_UR50D"  
    
    # Define labels and model
    id2label = {0: "No binding site", 1: "Binding site"}
    label2id = {v: k for k, v in id2label.items()}
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)

    # Convert the model into a PeftModel
    peft_config = LoraConfig(
        task_type=TaskType.TOKEN_CLS, 
        inference_mode=False, 
        r=config["r"], 
        lora_alpha=config["lora_alpha"], 
        target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h"
        lora_dropout=config["lora_dropout"], 
        bias="none" # or "all" or "lora_only" 
    )
    model = get_peft_model(model, peft_config)

    # Use the accelerator
    model = accelerator.prepare(model)
    train_dataset = accelerator.prepare(train_dataset)
    test_dataset = accelerator.prepare(test_dataset)

    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Training setup
    training_args = TrainingArguments(
        output_dir=f"esm2_t12_35M-lora-binding-sites_{timestamp}",
        learning_rate=config["lr"],
        lr_scheduler_type=config["lr_scheduler_type"],
        gradient_accumulation_steps=1,
        max_grad_norm=config["max_grad_norm"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        per_device_eval_batch_size=config["per_device_train_batch_size"],
        num_train_epochs=config["num_train_epochs"],
        weight_decay=config["weight_decay"],
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=False,
        logging_dir=None,
        logging_first_step=False,
        logging_steps=200,
        save_total_limit=7,
        no_cuda=False,
        seed=8893,
        fp16=False
    )
        # report_to='wandb'
    # )

    # Initialize Trainer
    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
        compute_metrics=compute_metrics
    )

    # Train and Save Model
    trainer.train()
    save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)


In [None]:
train_function_no_sweeps(train_dataset, test_dataset)

## Evaluating the model performance:

After training the ESM-2 model with LoRA for protein binding site prediction, it's crucial to evaluate its performance to ensure that it generalizes well to new data and does not suffer from overfitting. By doing so, we can compute metrics for both the training and test datasets. Ideally, the metrics for these datasets should be comparable, indicating that the model performs consistently across both seen and unseen data. A significant discrepancy where the training metrics surpass the test metrics might suggest overfitting, meaning the model has learned to memorize the training data rather than generalize. Conversely, if the test metrics are better than the training metrics, it could indicate underfitting, where the model has not fully captured the patterns in the training data and could benefit from additional training. The code calculates a suite of metrics, including accuracy, precision, recall, F1 score, AUC, and MCC, to provide a comprehensive evaluation of the model's performance. These metrics are derived by first predicting labels for the datasets, then flattening the predictions and labels to compute the scores, ensuring that padding and special tokens are excluded from the analysis.

In [None]:
from sklearn.metrics import(
    matthews_corrcoef, 
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "nidhinthomas/esm2_t12_35M_lora_binding_sites" # "path/to/your/lora/model" Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics

## Model Inference:

now perform inference on new protein sequences of interest. To do this, you'll need to load your trained model and its corresponding tokenizer. The model should be set to evaluation mode to disable training-specific behaviors like dropout. You can then input a protein sequence, which the tokenizer will process into a format suitable for the model. The inference is carried out without updating the model's gradients, ensuring that the operation is purely for prediction. The model outputs logits, which are then converted into discrete predictions representing the likelihood of each token being a binding site. These predictions are mapped back to human-readable labels, distinguishing between binding and non-binding sites. By iterating over the sequence tokens and their associated predictions, you can examine the model's inference on a token-by-token basis, excluding any special padding or control tokens that are not part of the original sequence. This process allows you to leverage the power of the fine-tuned LoRA model to gain insights into the binding site propensities within your protein sequences of interest.

In [14]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
        
for token, prediction in zip(tokens, predictions[0].numpy()):
    if prediction == 1:
        print(token)
        

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('M', 'No binding site')
('A', 'No binding site')
('V', 'No binding site')
('P', 'No binding site')
('E', 'No binding site')
('T', 'No binding site')
('R', 'No binding site')
('P', 'No binding site')
('N', 'Binding site')
('H', 'Binding site')
('T', 'Binding site')
('I', 'Binding site')
('Y', 'Binding site')
('I', 'Binding site')
('N', 'Binding site')
('N', 'Binding site')
('L', 'Binding site')
('N', 'Binding site')
('E', 'Binding site')
('K', 'Binding site')
('I', 'No binding site')
('K', 'Binding site')
('K', 'No binding site')
('D', 'No binding site')
('E', 'No binding site')
('L', 'Binding site')
('K', 'No binding site')
('K', 'No binding site')
('S', 'No binding site')
('L', 'No binding site')
('H', 'No binding site')
('A', 'No binding site')
('I', 'No binding site')
('F', 'Binding site')
('S', 'No binding site')
('R', 'No binding site')
('F', 'Binding site')
('G', 'Binding site')
('Q', 'No binding site')
('I', 'No binding site')
('L', 'No binding site')
('D', 'No binding site')
(

In [None]:
from sklearn.metrics import(
    matthews_corrcoef, 
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3" # "path/to/your/lora/model" Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics


The structure of the aforementioned protein with the predicted binding site is shown below. The binding site is highlighted in green. 

![esm2_binding_site](esm2_binding_site.png)