# The Depo Detection tool
>We finetuned the ESM2 model successfully (92% accuracy)<br>
>The goal now is to stack a RNN layer for a binary classification into Dpo or Not Dpo categories
***
## I. Load prebuilt model 
## II. Stack RNN layer
## III. Train Eval
## IV. Metrics
***

### I. Load the data

In [1]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer , AutoTokenizer

import torch 
from torch import nn 
from torch.utils.data import Dataset , DataLoader
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
from Bio import SeqIO
import os 
import pandas as pd 
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning) 

# Load the Prebuilt model :
path_work = "/home/conchae/PhageDepo_pdb"
model_path = f"{path_work}/script_files/esm2_t30_150M_UR50D-finetuned-depolymerase.1608.3_labels.final/checkpoint-6015"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at /media/concha-eloko/Linux/depolymerase_building/esm2_t12_35M_UR50D-finetuned-depolymerase/checkpoint-198/ were not used when initializing EsmForTokenClassification: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def get_labels(df) :
    labels_df = []
    for _,row in df.iterrows():
        info = row["Boundaries"]
        seq_length = len(row["Full_seq"])
        if info == "Negative" :
            label = 0
            labels_df.append(label)         
        else :
            label = 1
            labels_df.append(label)
    return labels_df

In [1]:
# Load the sequences :
df_depo = pd.read_csv(f"{path_work}/Phagedepo.Dataset.2007.tsv" , sep = "\t" , header = 0)

df_beta_helix = df_depo[df_depo["Fold"] == "right-handed beta-helix"]
df_beta_prope = df_depo[df_depo["Fold"] == "6-bladed beta-propeller"]
df_beta_triple =  df_depo[df_depo["Fold"] == "triple-helix"]
df_negative = df_depo[df_depo["Fold"] == "Negative"]

# The phage proteins associated with PL16
pl16_interpro = SeqIO.parse(f"{path_work}/PL_16.phage_proteins.fasta" , "fasta")
seq_pl16_interpro = [str(record.seq) for record in pl16_interpro]
labels_pl16 = [1]*len(seq_pl16_interpro)

# Beta-helix :
labels_beta_helix = get_labels(df_beta_helix)
seq_beta_helix = df_beta_helix["Full_seq"].to_list()

# Beta propeller : 
labels_beta_propeller = get_labels(df_beta_prope)
seq_beta_propeller = df_beta_prope["Full_seq"].to_list()

# Triple helix : 
labels_triple_helix = get_labels(df_beta_triple )
seq_triple_helix = df_beta_triple["Full_seq"].to_list()

# Negative :
labels_negative = get_labels(df_negative)
seq_negative = df_negative["Full_seq"].to_list()

In [3]:
# The input data :
sequences = seq_beta_helix + seq_beta_propeller + seq_triple_helix + seq_negative
labels = labels_beta_helix + labels_beta_propeller + labels_triple_helix + labels_negative

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.2, random_state = 243)
train_esm2 , train_CNV , esm2_labels , CNV_labels = train_test_split(train_sequences, train_labels, test_size=0.25, random_state = 243)

train_sequences_PL16, test_sequences_PL16, train_labels_PL16, test_labels_PL16 = train_test_split(seq_pl16_interpro, labels_pl16, test_size=0.5, random_state = 243)

train_CNV = train_CNV + train_sequences_PL16
CNV_labels = CNV_labels + train_labels_PL16

test_sequences = test_sequences + test_sequences_PL16
test_labels = test_labels + test_labels_PL16

Dataset_train_df = pd.DataFrame({"sequence" : train_CNV , "Label" : CNV_labels})
Dataset_test_df = pd.DataFrame({"sequence" : test_sequences , "Label" : test_labels})

In [7]:
#********************************************
class Dpo_Dataset(Dataset):
    def __init__(self, Dataset_df):
        self.sequence = Dataset_df.sequence.values
        self.labels = torch.tensor(Dataset_df["Label"].values, dtype=torch.long)  # Subtract 1 if labels start from 1
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item_domain1 = self.sequence[idx]
        item_domain2 = self.labels[idx]
        return item_domain1, item_domain2

train_singledata = Dpo_Dataset(Dataset_train_df)
class Dpo_Dataset(Dataset):
    def __init__(self, Dataset_df):
        self.sequence = Dataset_df.sequence.values
        self.labels = torch.tensor(Dataset_df["Label"].values, dtype=torch.long)  # Subtract 1 if labels start from 1
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item_domain1 = self.sequence[idx]
        item_domain2 = self.labels[idx]
        return item_domain1, item_domain2

In [5]:
train_singledata = Dpo_Dataset(Dataset_train_df)
test_singledata = Dpo_Dataset(Dataset_test_df)

train_single_loader = DataLoader(train_singledata, batch_size=12, shuffle=True, num_workers=4)
test_single_loader = DataLoader(test_singledata, batch_size=12, shuffle=True, num_workers=4)


In [6]:
class Dpo_classifier(nn.Module):
    def __init__(self, pretrained_model):
        super(Dpo_classifier, self).__init__()
        self.max_length = 1024
        self.pretrained_model = pretrained_model
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, stride=1)  # Convolutional layer
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=1)  # Convolutional layer
        self.fc1 = nn.Linear(128 * (self.max_length - 2 * (5 - 1)), 32)  # calculate the output shape after 2 conv layers
        self.classifier = nn.Linear(32, 1)  # Binary classification

    def make_prediction(self, fasta_txt):
        input_ids = tokenizer.encode(fasta_txt, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.pretrained_model(input_ids)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            token_probs, token_ids = torch.max(probs, dim=-1)
            tokens = token_ids.view(1, -1) # ensure 2D shape
            return tokens

    def pad_or_truncate(self, tokens):
        if tokens.size(1) < self.max_length:
            tokens = F.pad(tokens, (0, self.max_length - tokens.size(1)))
        elif tokens.size(1) > self.max_length:
            tokens = tokens[:, :self.max_length]
        return tokens

    def forward(self, sequences):
        batch_size = len(sequences)
        tokens_batch = []
        for seq in sequences:
            tokens = self.make_prediction(seq)
            tokens = self.pad_or_truncate(tokens)
            tokens_batch.append(tokens)

        outputs = torch.cat(tokens_batch).view(batch_size, 1, self.max_length)  # ensure 3D shape
        outputs = outputs.float()  

        out = F.relu(self.conv1(outputs))
        out = F.relu(self.conv2(out))
        out = out.view(batch_size, -1)  # Flatten the tensor
        out = F.relu(self.fc1(out))
        out = self.classifier(out)
        return out, outputs

In [7]:
# Initialize model
model_classifier = Dpo_classifier(model)
model_classifier.train()

optimizer = optim.Adam(model_classifier.parameters(), lr=0.001)  # Set learning rate
criterion = nn.BCEWithLogitsLoss()  # Set loss function

epochs = 10  # Number of training epochs

# Training loop
for epoch in range(epochs):
    # Training
    model_classifier.train()
    epoch_loss = 0
    epoch_correct = 0
    total_samples = 0
    for i, (sequences, labels) in enumerate(train_single_loader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs, _ = model_classifier(sequences)
        loss = criterion(outputs.view(-1), labels.float())  # Convert labels to float
        loss.backward()
        optimizer.step()
        predicted = (outputs > 0).float()  # Convert logits to predictions
        # Comipute accuracy
        #_, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        epoch_correct += (predicted == labels).sum().item()
        # Accumulate loss
        epoch_loss += loss.item()
    print(f'Epoch {epoch + 1}, Training Loss: {epoch_loss / len(train_single_loader):.4f}, Training Accuracy: {epoch_correct / total_samples:.4f}')
    # Evaluation
    model_classifier.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for sequences, labels in test_single_loader:
            outputs, _ = model_classifier(sequences)
            #_, predicted = torch.max(outputs.data, 1)
            predicted = (outputs > 0).float()
            y_true.extend(labels.numpy())
            y_pred.extend(predicted.numpy())            
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)  
    recall = recall_score(y_true, y_pred)  
    f1 = f1_score(y_true, y_pred)  
    print(f'Testing Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}')


('MALISQSIKNLKGGISQQPDILRYPDQGSRQVNGWSSETEGLQKRPPMVFIKTLGDRGALGQAPYIHLINRDENEQYYAVFTGNGIRVFDLAGNEKQVRYPNGSDYIKTSNPRNDLRMVTVADYTFVVNRNVAVQKNTTSVNLPNYNPKRDGLINVRGGQYGRELIVHINGKDVAKYKIPDGSQPAHVNNTDAQWLAEELAKQMRTNLSGWAVNVGQGFIHVAAPSGQQIDSFTTKDGYADQLINPVTHYAQSFSKLPPNAPNGYMVKVVGDASRSADQYYVRYDAERKVWVETLGWNTENQVRWETMPHALVRAADGNFDFKWLEWSPKSCGDIDTNPWPSFVGSSINDVFFFRNRLGFLSGENIILSRTAKYFNFYPASVANLSDDDPIDVAVSTNRISVLKYAVPFSEELLIWSDEAQFVLTASGTLTSKSVELNLTTQFDVQDRARPYGIGRNVYFASPRSSYTSIHRYYAVQDVSSVKNAEDITAHVPNYIPNGVFSICGSGTENFCSVLSHGDPSKIFMYKFLYLNEELRQQSWSHWDFGANVQVLACQSISSDMYVILRNEFNTFLTKISFTKNAIDLQGEPYRAFMDMKIRYTIPSGTYNDDTYNTSIHLPTIYGANFGRGRITVLEPDGKITVFEQPTAGWKSDPWLRLDGNLEGRMVYIGFNIDFVYEFSKFLIKQTADDGSSSTEDIGRLQLRRAWVNYENSGAFDIYVENQSSNWKYSMAGARLGSNTLRAGRLNLGTGQYRFPVVGNAKFNTVSILSDETTPLNIIGCGWEGNYLRRSSGI',
 tensor(1))

In [None]:
print('Finished Training')

torch.save(model_classifier.state_dict(), f"{path_work}/DepoDetection.esm2_t30_150M_UR50D.1608.final.model")
