# Import and load data

In [None]:
import re
import numpy as np
import pandas as pd
import copy

from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmConfig, AutoConfig
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import torch

from scipy import stats
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import os

In [None]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and temporary sets (70% train, 30% temp)
# train_set, temp_set = train_test_split(combined, test_size=0.3, random_state=950806, stratify=combined['label'])

# Further split the temporary set into validation and test sets (50% validation, 50% test)
# valid_set, test_set = train_test_split(temp_set, test_size=0.5, random_state=950806, stratify=temp_set['label'])


In [None]:
train_set = pd.read_csv('/home/raylab/Zihao/BCR/ModelData/training_set.csv')
valid_set = pd.read_csv('/home/raylab/Zihao/BCR/ModelData/valid_set.csv')
test_set = pd.read_csv('/home/raylab/Zihao/BCR/ModelData/test_set.csv')


# Preprocessing and Load Model from Pretrained

In [None]:
train_seqs = train_set['sequence'].tolist()
train_labels = train_set['label'].tolist()
valid_seqs = valid_set['sequence'].tolist()
valid_labels = valid_set['label'].tolist()
test_seqs = test_set['sequence'].tolist()
test_labels = test_set['label'].tolist()

In [None]:
model_name = "facebook/esm2_t6_8M_UR50D"  # Example model; replace with your specific model
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
train_tokenized = tokenizer(train_seqs)
test_tokenized = tokenizer(test_seqs)
valid_tokenized = tokenizer(valid_seqs)

In [None]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)
valid_dataset = Dataset.from_dict(valid_tokenized)
train_dataset

In [None]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
valid_dataset = valid_dataset.add_column("labels", valid_labels)

train_dataset

In [None]:
valid_dataset

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Hyperparameter and Training

In [None]:
config = model.config
config.hidden_dropout_prob = 0.3
print(config)

In [None]:
training_args = TrainingArguments(
    "ESM2-finetuned-localization",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)# training_args

In [None]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

# Test

In [None]:
predictions = trainer.predict(test_dataset=test_dataset)

In [None]:
import numpy as np
from scipy.special import softmax

# Convert logits to probabilities
probabilities = softmax(predictions.predictions, axis=1)

# Determine the predicted class
predicted_classes = np.argmax(probabilities, axis=1)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

true_labels = predictions.label_ids

# Calculate metrics
accuracy = accuracy_score(true_labels, predicted_classes)
precision = precision_score(true_labels, predicted_classes)
recall = recall_score(true_labels, predicted_classes)
f1 = f1_score(true_labels, predicted_classes)

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

In [None]:
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report

# Compute ROC AUC
roc_auc = roc_auc_score(true_labels, probabilities[:, 1])

# Compute confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_classes)

# Print confusion matrix
print("Confusion Matrix:")
print(conf_matrix)

# Compute classification report (precision, recall, F1-score)
class_report = classification_report(true_labels, predicted_classes)

# Print classification report
print("\nClassification Report:")
print(class_report)

# Print ROC AUC
print("\nROC AUC:", roc_auc)

In [None]:
from sklearn.metrics import RocCurveDisplay

RocCurveDisplay.from_predictions(true_labels, probabilities[:, 1])

# BCR Prediction

In [None]:
# Load BCR repertoire 
heavy_chains = pd.read_csv('BCR/ModelData/heavy_vdj_airr.tsv', sep='\t')
# heavy_chains

In [None]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Create a list to store the predictions
predicted_labels = []
predicted_probabilities = []
# Iterate through each sequence in the DataFrame
for sequence in heavy_chains['sequence_aa']:
    # Tokenize the sequence using the tokenizer
    encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt", max_length=256, truncation=True)
    
    # Move the input tensors to the appropriate device
    input_ids = encoded_sequence['input_ids'].to(device)
    attention_mask = encoded_sequence['attention_mask'].to(device)
    
    # Make predictions for the current sequence
    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        
    # Softmax to get probabilities
    probability = torch.softmax(logits, dim=1).squeeze().tolist()[1]
    # Take the argmax to get the predicted label
    predicted_label = torch.argmax(logits, dim=1).item()
    
    # Append to lists
    predicted_labels.append(predicted_label)
    predicted_probabilities.append(probability)

# Add to df
heavy_chains['predicted_label'] = predicted_labels
heavy_chains['predicted_probability'] = predicted_probabilities

In [None]:
# Sort the DataFrame based on the maximum predicted probability in each list
heavy_chains_sorted = heavy_chains.sort_values(by='predicted_probability', ascending=False)

# Reset the index after sorting
heavy_chains_sorted.reset_index(drop=True, inplace=True)
heavy_chains_sorted.head(40)

In [None]:
plt.hist(heavy_chains['predicted_probability'], bins=20)  # Adjust the number of bins as needed
plt.xlabel('Predicted Probability')
plt.ylabel('Frequency')
plt.title('Frequency Distribution of predicted_probability')
plt.show()


In [None]:
plt.hist(heavy_chains['predicted_probability'], bins=20)  # Adjust the number of bins as needed
plt.xlabel('Predicted Probability')
plt.ylabel('Frequency')
plt.title('Frequency Distribution of predicted_probability')
plt.show()


# Save Model and Results

In [None]:
candidates = heavy_chains[heavy_chains['predicted_probability'] >=0.6]

In [None]:
model.save_pretrained("ESM_fintuned/lr2e-4")
tokenizer.save_pretrained("ESM_fintuned/lr2e-4")

In [None]:
candidates.to_csv('BCR/esm_candidates_low_threshold.csv')

# Load Model

In [None]:
model_name = 'ESM2-finetuned-localization/checkpoint-23410'
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
model.eval()
# Move model to the appropriate device (cuda if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)