# ESMB for Protein Binding Residue Prediction

**ESMBind** (or ESMB) is for predicting residues in protein sequences that are binding sites or active sites. The `ESMB_35M` series of models are Low Rank Adaptation (LoRA) finetuned versions of the protein language model `esm2_t12_35M_UR50D`. The models were finetuned on ~549K protein sequences and appear to achieve competative performance compared to current SOTA geometric/structural models, surpassing many on certain classes and for certain metrics. These models have an especially high recall score, meaning they are highly likely to discover most binding residues. 

However, they have relatively low precision, meaning they may return false positives as well. Their MCC, AUC, and accuracy values are also quite high. We hope that scaling the model and dataset size in a 1-to-1 fashion will achieve SOTA performace, but the simplicity of the models and training procedure already make them an attractive set of models as the domain knowledge and data preparation required to use them are very modest and make the barrier to entry low for using these models in practive. 

These models also predict binding residues from sequence alone. Since most proteins have yet to have their 3D folds and backbone structures predicted, we hope this resource will be valuable to the community for this reason as well. Before we proceed, we need to run a few pip install statements. 

In [None]:
!pip install transformers -q 
!pip install accelerate -q 
!pip install peft -q 
!pip install datasets -q 

## Running Inference 

To run inference, simply replace the protein sequence below with your own. Afterwards, you'll be guided through how to check the train/test metrics on the datasets yourself. 

In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference (replace with your own)
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # @param {type:"string"}

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))

## Train/Test Metrics

### Loading and Tokenizing the Datasets

To use this notebook to run the model on the train/test split and get the various metrics (accuracy, precision, recall, F1 score, AUC, and MCC) you will need to download the pickle files [found on Hugging Face here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). Navigate to the "Files and versions" and download the four pickle files (you can ignore the TSV files unless you want to preprocess the data in a different way yourself). Once you have downloaded the pickle files, change the four file pickle paths in the cell below to match the local paths of the pickle files on your machine, then run the cell. 

In [1]:
from datasets import Dataset
from transformers import AutoTokenizer
import pickle

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Function to truncate labels
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

# Set the maximum sequence length
max_sequence_length = 1000

# Load the data from pickle files (change to match your local paths)
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

train_dataset, test_dataset


(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 450330
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 113475
 }))

### Getting the Train/Test Metrics

Next, run the following cell. Depending on your hardware, this may take a while. There are ~549K protein sequences to process in total. The train dataset will obviously take much longer than the test dataset. Be patient and let both of them complete to see both the train and test metrics.

In [None]:
from sklearn.metrics import(
    matthews_corrcoef, 
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator

# Instantiate the accelerator
accelerator = Accelerator()

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3" # "path/to/your/lora/model"  # Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)/adapter_config.json:   0%|          | 0.00/457 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/307k [00:00<?, ?B/s]