### Updated PhageRBPdetect_v3_ESMfine_benchmark.py

In [1]:
"""
PhageRBPdetect (ESM2-fine) - benchmarking
@author: dimiboeckaerts
@date: 2023-12-19

Notes: 
You will probably want to run this script on a GPU-enabled machine (e.g. Google Colab or Kaggle).
The ESM-2 T12 model can run on a single GPU with 16GB of memory.
Taken from https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb#scrollTo=3d2edc14.
If you want to train the ESM-2 T33 model, you will need a machine with 32GB or more memory.
"""
#!pip install evaluate datasets

# 0 - SET THE PATHS
# ------------------------------------------
path = '.'

import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from evaluate import load
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split

GPU = 0
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
device = torch.device("cuda")

# 1 - TRAIN & TUNING THE MODEL
# ------------------------------------------
# define path & model checkpoint
model_checkpoint = "facebook/esm2_t12_35M_UR50D" # esm2_t12_35M_UR50D, esm2_t33_650M_UR50D, esm2_t36_3B_UR50D
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#os.environ["WANDB_DISABLED"] = "true"

# get data until SEPT 2021
RBPs = pd.read_csv('../2025_data/annotated_RBPs_2025-01.csv')
nonRBPs = pd.read_csv('../2025_data/annotated_nonRBPs_2025-01.csv')
nonRBPs_sub = nonRBPs.sample(n=10*RBPs.shape[0], random_state=42)
nonRBPs_sub = nonRBPs_sub.reset_index(drop=True)
months = ['OCT-2021', 'NOV-2021', 'DEC-2021']
to_delete_rbps = [i for i, date in enumerate(RBPs['RecordDate']) if any(x in date for x in months)]
rbps_upto2021 = RBPs.drop(to_delete_rbps)
rbps_upto2021 = rbps_upto2021.reset_index(drop=True)
to_delete_nonrbps = [i for i, date in enumerate(nonRBPs_sub['RecordDate']) if any(x in date for x in months)]
nonrbps_upto2021 = nonRBPs_sub.drop(to_delete_nonrbps)
nonrbps_upto2021 = nonrbps_upto2021.reset_index(drop=True)
RBPseqs = rbps_upto2021['ProteinSeq'].tolist()
nonRBPseqs = [seq[:2000] for seq in list(nonrbps_upto2021['ProteinSeq'])]
sequences = RBPseqs + nonRBPseqs
labels = [1]*rbps_upto2021.shape[0] + [0]*nonrbps_upto2021.shape[0]

# process the data
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.1, stratify=labels, random_state=42)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

# define function for metric
metric = load("f1")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

# finetune the model (takes around 1h per epoch on NVIDIA P100 GPU)
nlabels = len(set(labels))
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=nlabels)
batch_size = 2
args = TrainingArguments(
    'RBPdetect_ESM2finetune',
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()

# save the model & tokenizer
model_path = path+'/RBPdetect_v3_ESMfine_2025'
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)



  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,F1
1,0.1285,0.092918,0.873862


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [3]:
# 2 - BENCHMARKING
# ------------------------------------------
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from evaluate import load
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split

GPU = 2
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
device = torch.device("cuda")
path = "."

model_path = path+'/RBPdetect_v3_ESMfine_2025'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval().cuda() # or without cuda if not available

RBPs = pd.read_csv('../2025_data/annotated_RBPs_2025-01.csv')
nonRBPs = pd.read_csv('../2025_data/annotated_nonRBPs_2025-01.csv')
nonRBPs_sub = nonRBPs.sample(n=10*RBPs.shape[0], random_state=42)
nonRBPs_sub = nonRBPs_sub.reset_index(drop=True)
months = ['OCT-2021', 'NOV-2021', 'DEC-2021']
to_delete_rbps = [i for i, date in enumerate(RBPs['RecordDate']) if all(x not in date for x in months)]
rbps_2021 = RBPs.drop(to_delete_rbps)
rbps_2021 = rbps_2021.reset_index(drop=True)
to_delete_nonrbps = [i for i, date in enumerate(nonRBPs_sub['RecordDate']) if all(x not in date for x in months)]
nonrbps_2021 = nonRBPs_sub.drop(to_delete_nonrbps)
nonrbps_2021 = nonrbps_2021.reset_index(drop=True)
testdata = list(rbps_2021['ProteinSeq']) + list(nonrbps_2021['ProteinSeq'])
testlabels = [1]*rbps_2021.shape[0] + [0]*nonrbps_2021.shape[0]

predictions = []
scores = []
for sequence in tqdm(testdata):
    encoding = tokenizer(sequence, return_tensors="pt", truncation=True).to('cuda:0')#.to('mps:0') # or without cuda if not available
    with torch.no_grad():
        output = model(**encoding)
        predictions.append(int(output.logits.argmax(-1)))
        scores.append(float(output.logits.softmax(-1)[:, 1]))

# define function for metric
metric = load("f1")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)
# Compute the F1 score
f1_score = metric.compute(predictions=predictions, references=testlabels)

# Print the F1 score
print(f"Benchmarking F1 Score: {f1_score['f1']:.4f}")

esm_results = pd.concat([pd.DataFrame(predictions, columns=['preds']), 
                        pd.DataFrame(scores, columns=['score'])], axis=1)
results_path = './RBP_detection'
esm_results.to_csv(results_path+'/esm_finetuneT33_test_predictions.csv', index=False)


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.      | 0/3204 [00:00<?, ?it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3204/3204 [00:35<00:00, 90.19it/s]


Benchmarking F1 Score: 0.9170


In [9]:
from sklearn.metrics import f1_score, matthews_corrcoef, confusion_matrix

# Compute F1 score
f1 = f1_score(testlabels, predictions)

# Compute MCC
mcc = matthews_corrcoef(testlabels, predictions)

# Compute Confusion Matrix
tn, fp, fn, tp = confusion_matrix(testlabels, predictions).ravel()

# Compute Sensitivity (Recall for Positive Class)
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0

# Compute Specificity (Recall for Negative Class)
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

# Print the results
print(f"Benchmarking Metrics:")
print(f"  - F1 Score: {f1:.4f}")
print(f"  - MCC Score: {mcc:.4f}")
print(f"  - Sensitivity (Recall+): {sensitivity:.4f}")
print(f"  - Specificity (Recall-): {specificity:.4f}")


Benchmarking Metrics:
  - F1 Score: 0.9170
  - MCC Score: 0.9027
  - Sensitivity (Recall+): 0.9384
  - Specificity (Recall-): 0.9813


In [20]:
print(f"{round(((len(set(testdata).intersection(sequences)))/len(set(testdata)))*100,2)}% intersection was found between the test data and the train data.")

0.87% intersection was found between the test data and the train data.
