In [4]:
cd code

[Errno 2] No such file or directory: 'code'
/mnt/aiongpfs/users/egomez/Projects/GSoC24/ESMbind/code


# I / O, helpers

In [5]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#import wandb
import numpy as np
import torch
import torch.nn as nn
import pickle
import xml.etree.ElementTree as ET
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    matthews_corrcoef
)
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
from accelerate import Accelerator
# Imports specific to the custom peft lora model
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


In [6]:

# Helper Functions and Data Preparation
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics_train(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    
    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    
    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)
    
    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    
    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions) 
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} 

def compute_loss(model, inputs):
    """Custom compute_loss function."""
    logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss(weight=class_weights)
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss


# Data load & embeddings

In [21]:

# Load the data from pickle files (replace with your local paths)
with open("../data/train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)

with open("../data/test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)

with open("../data/train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)

with open("../data/test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# make dataset smaller:
train_sequences = train_sequences[0:300]
test_sequences = test_sequences[0:100]
train_labels = train_labels[0:300]
test_labels = test_labels[0:100]

# Tokenization
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")  # esm2_t12_35M_UR50D
max_sequence_length = 500 # 1000

train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Directly truncate the entire list of labels
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

# Training

In [65]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=2)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [44]:
#!pip install evaluate
import evaluate
metric = evaluate.load("accuracy")

In [70]:
def compute_metrics_train(eval_pred):
    """Compute metrics for evaluation."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)
    
    # Remove padding (-100 labels)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    
    # Compute accuracy
    accuracy = accuracy_score(labels, predictions)
    
    # Compute precision, recall, F1 score, and AUC
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    
    # Compute MCC
    mcc = matthews_corrcoef(labels, predictions) 
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} 

In [66]:
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
training_args = TrainingArguments(
    output_dir=f"trained_models/esm2_t6_8M-binding-sites_{timestamp}", 
    eval_strategy="epoch",
    seed=8893,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True
)

In [67]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")  # esm2_t12_35M_UR50D


In [71]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics_train,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer)
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [72]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc,Mcc
1,No log,0.152388,0.9665,0.0,0.0,0.0,0.5,0.0
2,No log,0.157112,0.9665,0.0,0.0,0.0,0.5,0.0
3,No log,0.156422,0.9665,0.0,0.0,0.0,0.5,0.0


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TrainOutput(global_step=114, training_loss=0.03993859207421018, metrics={'train_runtime': 1325.1209, 'train_samples_per_second': 0.679, 'train_steps_per_second': 0.086, 'total_flos': 19977740100000.0, 'train_loss': 0.03993859207421018, 'epoch': 3.0})

In [74]:
# Save Model
save_path = os.path.join("../results/binding_sites", f"best_model_esm2_t6_8M_{timestamp}")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

('../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09/tokenizer_config.json',
 '../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09/special_tokens_map.json',
 '../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09/vocab.txt',
 '../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09/added_tokens.json')

In [76]:
# Define paths to the tuned and base models
base_model_path = "facebook/esm2_t6_8M_UR50D"  
tuned_model_path = "../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09" 

# Load the model
tuned_model = AutoModelForTokenClassification.from_pretrained(tuned_model_path)
accelerator = Accelerator()
model = accelerator.prepare(tuned_model)  # Prepare the model using the accelerator


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [77]:

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)


In [78]:
# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary


In [79]:
# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


({'accuracy': 0.9938003552301905,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'auc': 0.5,
  'mcc': 0.0},
 {'accuracy': 0.966500023361211,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'auc': 0.5,
  'mcc': 0.0})

# Inference

In [80]:
# Set paths and model if not loaded before
# Path to the saved tuned model and ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"  
tuned_model_path = "../results/binding_sites/best_model_esm2_t6_8M_2024-05-27_17-22-09" 

# Load the model
tuned_model = AutoModelForTokenClassification.from_pretrained(tuned_model_path)

# Ensure the model is in evaluation mode
tuned_model.eval()

# Load the tokenizer if not loaded already
# tokenizer = AutoTokenizer.from_pretrained(base_model_path)


EsmForTokenClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, el

In [81]:

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = tuned_model(**inputs).logits

# Get predictions
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))


('M', 'No binding site')
('A', 'No binding site')
('V', 'No binding site')
('P', 'No binding site')
('E', 'No binding site')
('T', 'No binding site')
('R', 'No binding site')
('P', 'No binding site')
('N', 'No binding site')
('H', 'No binding site')
('T', 'No binding site')
('I', 'No binding site')
('Y', 'No binding site')
('I', 'No binding site')
('N', 'No binding site')
('N', 'No binding site')
('L', 'No binding site')
('N', 'No binding site')
('E', 'No binding site')
('K', 'No binding site')
('I', 'No binding site')
('K', 'No binding site')
('K', 'No binding site')
('D', 'No binding site')
('E', 'No binding site')
('L', 'No binding site')
('K', 'No binding site')
('K', 'No binding site')
('S', 'No binding site')
('L', 'No binding site')
('H', 'No binding site')
('A', 'No binding site')
('I', 'No binding site')
('F', 'No binding site')
('S', 'No binding site')
('R', 'No binding site')
('F', 'No binding site')
('G', 'No binding site')
('Q', 'No binding site')
('I', 'No binding site')
