### I. Load the data

In [None]:
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer , AutoTokenizer

import torch 
from torch import nn 
from torch.utils.data import Dataset , DataLoader
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
from Bio import SeqIO
import os 
import pandas as pd 
import numpy as np
import warnings
import subprocess

warnings.filterwarnings("ignore", category=RuntimeWarning) 

# ***********************************************************
# Open the dataframes : 
path_work = "/home/conchae/PhageDepo_pdb
path_tmp = f"{path_work}/low_cutoff_cluster"
os.makedirs(path_tmp, exist_ok=True)

df_depo = pd.read_csv(f"{path_work}/Phagedepo.Dataset.21032024.tsv" , sep = "\t" , header = 0)
df_depo = df_depo[df_depo["Fold"].isin(["Negative", "right-handed beta-helix", "6-bladed beta-propeller", "triple-helix"])]
df_depo = df_depo.drop_duplicates(subset = ["Full_seq"], keep = "first")
df_depo.reset_index(inplace = True)

path_models = f"{path_work}/script_files"
dico_path_models = {0.3 : f"{path_models}/esm2_t12_35M_UR50D__0.3__finetuneddepolymerase.2205.4_labels/checkpoint-1185"}
                    #0.35 : f"{path_models}/esm2_t12_35M_UR50D__0.7__finetuneddepolymerase.2103.4_labels/checkpoint-1945",
                    #0.4 : f"{path_models}/esm2_t12_35M_UR50D__0.75__finetuneddepolymerase.2103.4_labels/checkpoint-1995",
                    #0.45 : f"{path_models}/esm2_t12_35M_UR50D__0.8__finetuneddepolymerase.2103.4_labels/checkpoint-1980",
                    #0.5 : f"{path_models}/esm2_t12_35M_UR50D__0.85__finetuneddepolymerase.2103.4_labels/checkpoint-1990"}


In [None]:
# ***********************************************************
thresholds = [0.3]

def make_cluster_dico(cdhit_out) :
    import json
    dico_cluster = {}
    threshold = cdhit_out.split("/")[-1].split(".out")[0]
    cluster_file = f"{cdhit_out}.clstr"
    cluster_out = open(cluster_file).read().split(">Cluster")
    for index,cluster in enumerate(cluster_out[1:]) :
        tmp_dpo = []
        #id_cluster = f"Dpo_cdhit_{index}"
        id_cluster = index
        for _,line in enumerate(cluster.split("\n")[1:-1]) :
            dpo = line.split(">")[1].split(".")[0]
            tmp_dpo.append(dpo)
        dico_cluster[id_cluster] = tmp_dpo
    with open(f"{path_tmp}/dico_cluster.psicdhit__{threshold}.json", "w") as outfile:
        json.dump(dico_cluster, outfile)
    return dico_cluster , threshold


def reverse_dico(dico) : 
    r_dico = {}
    for key,values in dico.items() :
        for _,id in enumerate(values) : 
            r_dico[id] = key
    return r_dico


def make_list_group(list_seq, r_dico, id_dico) :
    list_group = []
    for _,seq in enumerate(list_seq) :
        idd_seq = str(id_dico[seq])
        list_group.append(r_dico[idd_seq])
    return list_group


def cvalue_to_list_group(threshold, df_depo) :
    dico_cluster, _ = make_cluster_dico(f"{path_tmp}/{threshold}__psicdhit.out")
    r_dico_cluster = reverse_dico(dico_cluster)
    list_groups = make_list_group(df_depo["Full_seq"].tolist(), r_dico_cluster, dico_seq_id)
    return list_groups


def get_labels(tuple_data ) :
    dico_labels = {"Negative" : 0,
                   "right-handed beta-helix" : 1,
                   "6-bladed beta-propeller" : 2, 
                   "triple-helix" : 3}
    labels_df = []
    for _,row in enumerate(tuple_data) :
        info = row[1]
        seq_length = len(row[0])
        fold = row[2]
        label = dico_labels[fold]
        if info == "Negative" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info == "full_protein" or info == "full" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info.count(":") > 0 : 
            start = int(info.split(":")[0])
            end = int(info.split(":")[1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
        else :
            start = int(info.split("_")[-2])
            end = int(info.split("_")[-1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
    return labels_df

def get_labels_seq(df) :
    labels_df = []
    for _,row in enumerate(df):
        info = row[2]
        seq_length = len(row[0])
        if info == "Negative" :
            label = 0
            labels_df.append(label)         
        else :
            label = 1
            labels_df.append(label)
    return labels_df

def training_data(threshold): 
    # Split the data : 
    gss_token_class = GroupShuffleSplit(n_splits=1, train_size=0.7, test_size = 0.3, random_state=243)
    gss_seq_class = GroupShuffleSplit(n_splits=1, train_size=0.66, test_size = 0.34, random_state=243)
    list_group_1 = cvalue_to_list_group(threshold, df_depo)
    
    # First split :
    train_token_classification_indices = []
    Other_indices = []
    for i, (train_index, test_index) in enumerate(gss_token_class.split(df_depo["Full_seq"], df_depo["Fold"], list_group_1)):
        train_token_classification_indices.append(train_index)
        Other_indices.append(test_index)
    
    #train_tok_seq = df_depo["Full_seq"][train_token_classification_indices[0]]
    #train_tok_boundaries = df_depo["Boundaries"][train_token_classification_indices[0]]
    #train_tok_fold = df_depo["Fold"][train_token_classification_indices[0]]
    
    #training_data_token_class = tuple(zip(train_tok_seq, train_tok_boundaries, train_tok_fold))
    #training_data_tok_labels = get_labels(training_data_token_class)
    
    # Intermediate DF : 
    df_depo_s2 = df_depo[df_depo.index.isin(Other_indices[0])]
    df_depo_s2.reset_index(inplace = True)
    df_depo_s2
    
    # Second split : 
    list_group_2 = cvalue_to_list_group(threshold,df_depo_s2)
    train_seq_classifiaction_indices = []
    eval_data_indices = []
    
    for i, (train_index, test_index) in enumerate(gss_seq_class.split(df_depo_s2["Full_seq"], df_depo_s2["Fold"], list_group_2)):
        train_seq_classifiaction_indices.append(train_index)
        eval_data_indices.append(test_index)
    
    train_cnv_seq = df_depo_s2["Full_seq"][train_seq_classifiaction_indices[0]]
    train_cnv_boundaries = df_depo_s2["Boundaries"][train_seq_classifiaction_indices[0]]
    train_cnv_fold = df_depo_s2["Fold"][train_seq_classifiaction_indices[0]]
    
    # Sequence classification data :
    training_data_cnv_class = tuple(zip(train_cnv_seq, train_cnv_boundaries, train_cnv_fold))
    training_data_cnv_labels = get_labels_seq(training_data_cnv_class)
    
    # Ealuation data :
    eval_seq = df_depo_s2["Full_seq"][eval_data_indices[0]]
    eval_seq_boundaries = df_depo_s2["Boundaries"][eval_data_indices[0]]
    eval_seq_fold = df_depo_s2["Fold"][eval_data_indices[0]]
    
    eval_data_token_class = tuple(zip(eval_seq, eval_seq_boundaries, eval_seq_fold))
    eval_data_token_labels = get_labels_seq(eval_data_token_class)

    return train_cnv_seq, training_data_cnv_labels ,eval_seq , eval_data_token_labels

def data_to_tensor(train_cnv_seq, training_data_cnv_labels ,eval_seq , eval_data_token_labels) :
    Dataset_train_df = pd.DataFrame({"sequence" : list(train_cnv_seq) , "Label" : list(training_data_cnv_labels)})
    Dataset_test_df = pd.DataFrame({"sequence" : list(eval_seq)  , "Label" : list(eval_data_token_labels)})
    train_singledata = Dpo_Dataset(Dataset_train_df)
    test_singledata = Dpo_Dataset(Dataset_test_df)
    
    train_single_loader = DataLoader(train_singledata, batch_size=12, shuffle=True, num_workers=4)
    test_single_loader = DataLoader(test_singledata, batch_size=12, shuffle=True, num_workers=4)

    return train_single_loader, test_single_loader

class Dpo_Dataset(Dataset):
    def __init__(self, Dataset_df):
        self.sequence = Dataset_df.sequence.values
        self.labels = torch.tensor(Dataset_df["Label"].values, dtype=torch.long) 
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item_domain1 = self.sequence[idx]
        item_domain2 = self.labels[idx]
        return item_domain1, item_domain2

class Dpo_classifier(nn.Module):
    def __init__(self, pretrained_model, tokenizer):
        super(Dpo_classifier, self).__init__()
        self.max_length = 1024
        self.pretrained_model = pretrained_model
        self.tokenizer = tokenizer
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, stride=1)  # Convolutional layer
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=1)  # Convolutional layer
        self.fc1 = nn.Linear(128 * (self.max_length - 2 * (5 - 1)), 32)  # calculate the output shape after 2 conv layers
        self.classifier = nn.Linear(32, 1)  # Binary classification

    def make_prediction(self,fasta_txt):
        input_ids = self.tokenizer.encode(fasta_txt, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.pretrained_model(input_ids)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            token_probs, token_ids = torch.max(probs, dim=-1)
            tokens = token_ids.view(1, -1) # ensure 2D shape
            return tokens

    def pad_or_truncate(self, tokens):
        if tokens.size(1) < self.max_length:
            tokens = F.pad(tokens, (0, self.max_length - tokens.size(1)))
        elif tokens.size(1) > self.max_length:
            tokens = tokens[:, :self.max_length]
        return tokens

    def forward(self, sequences):
        batch_size = len(sequences)
        tokens_batch = []
        for seq in sequences:
            tokens = self.make_prediction(seq)
            tokens = self.pad_or_truncate(tokens)
            tokens_batch.append(tokens)

        outputs = torch.cat(tokens_batch).view(batch_size, 1, self.max_length)  # ensure 3D shape
        outputs = outputs.float()  

        out = F.relu(self.conv1(outputs))
        out = F.relu(self.conv2(out))
        out = out.view(batch_size, -1)  # Flatten the tensor
        out = F.relu(self.fc1(out))
        out = self.classifier(out)
        return out, outputs


def train_cnv(c_value, train_single_loader, test_single_loader) :
    # get model
    model_path = dico_path_models[c_value]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    # Initialize model
    model_classifier = Dpo_classifier(model,tokenizer)
    model_classifier.train()
    optimizer = optim.Adam(model_classifier.parameters(), lr=0.001) 
    criterion = nn.BCEWithLogitsLoss() 
    epochs = 5 
    # Training loop
    for epoch in range(epochs):
        model_classifier.train()
        epoch_loss = 0
        epoch_correct = 0
        total_samples = 0
        for i, (sequences, labels) in enumerate(train_single_loader):
            # Zero the parameter gradients
            optimizer.zero_grad()
            # Forward pass
            outputs, _ = model_classifier(sequences)
            loss = criterion(outputs.view(-1), labels.float()) 
            loss.backward()
            optimizer.step()
            predicted = (outputs > 0).float() 
            # Comipute accuracy
            #_, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            epoch_correct += (predicted == labels).sum().item()
            # Accumulate loss
            epoch_loss += loss.item()
        print(f'Epoch {epoch + 1}, Training Loss: {epoch_loss / len(train_single_loader):.4f}, Training Accuracy: {epoch_correct / total_samples:.4f}')
        # Evaluation
        model_classifier.eval()
        y_true = []
        y_pred = []
        with torch.no_grad():
            for sequences, labels in test_single_loader:
                outputs, _ = model_classifier(sequences)
                predicted = (outputs > 0).float()
                y_true.extend(labels.numpy())
                y_pred.extend(predicted.numpy())            
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        # Calculate metrics
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)  
        recall = recall_score(y_true, y_pred)  
        f1 = f1_score(y_true, y_pred)  
        print(f'Testing Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}')
        torch.save(model_classifier.state_dict(), f"{path_work}/Deposcope__{c_value}__.esm2_t12_35M_UR50D.2205.review.model")


In [None]:
# ***********************************************************
# Generate multifasta : 
dico_seq_id = {}
for index, seq in enumerate(df_depo["Full_seq"].tolist()) : 
    if seq not in dico_seq_id : 
        dico_seq_id[seq] = index

In [None]:
def full_training(c_value) : 
    # get training data :
    model_path = dico_path_models[c_value]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    train_cnv_seq, training_data_cnv_labels ,eval_seq , eval_data_token_labels = training_data(c_value)
    train_single_loader, test_single_loader = data_to_tensor(train_cnv_seq, training_data_cnv_labels ,eval_seq , eval_data_token_labels)
    train_cnv(c_value, train_single_loader, test_single_loader)

for c_value in thresholds :
    full_training(c_value)


In [None]:
#!/bin/bash
#BATCH --job-name=cnv_review
#SBATCH --qos=medium 
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=30
#SBATCH --mem=40gb 
#SBATCH --time=2-00:00:00 
#SBATCH --output=cnv_review%j.log 

module restore la_base
conda activate embeddings

python /home/conchae/PhageDepo_pdb/script_files/cnv_training.review_v2.py