In [2]:
from datasets import Dataset
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from transformers import AutoModelForTokenClassification,AutoTokenizer, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from evaluate import load


import numpy as np
import pandas as pd
import os 
from Bio import SeqIO
from collections import Counter
import subprocess


In [None]:
# ***********************************************************
# Open the dataframes : 
path_work = "/home/conchae/PhageDepo_pdb"
path_tmp = f"{path_work}/low_cutoff_cluster"
os.makedirs(path_tmp, exist_ok=True)

df_depo = pd.read_csv(f"{path_work}/Phagedepo.Dataset.21032024.tsv" , sep = "\t" , header = 0)
df_depo = df_depo[df_depo["Fold"].isin(["Negative", "right-handed beta-helix", "6-bladed beta-propeller", "triple-helix"])]
df_depo = df_depo.drop_duplicates(subset = ["Full_seq"], keep = "first")
df_depo.reset_index(inplace = True)

model_checkpoint = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


***
### Make the clusters locally with psi-cd-hit : 

> Move the dataframe locally 

In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/PhageDepo_pdb/Phagedepo.Dataset.21032024.tsv \
/media/concha-eloko/Linux/PhageDEPOdetection



In [5]:
path_work = "/media/concha-eloko/Linux/PhageDEPOdetection"
path_tmp = f"{path_work}/tmp"
os.makedirs(path_tmp, exist_ok=True)

df_depo = pd.read_csv(f"{path_work}/Phagedepo.Dataset.21032024.tsv" , sep = "\t" , header = 0)
df_depo = df_depo[df_depo["Fold"].isin(["Negative", "right-handed beta-helix", "6-bladed beta-propeller", "triple-helix"])]
df_depo = df_depo.drop_duplicates(subset = ["Full_seq"], keep = "first")
df_depo.reset_index(inplace = True)

In [17]:
# 
# Locally make the clusters : 

path_tool = "/media/concha-eloko/Linux/conda_envs/blast_life/cdhit/psi-cd-hit"
thresholds = [0.3, 0.35,0.4, 0.45, 0.5]

def make_cdhit_cluster(threshold) :
    cdhit_command = f"{path_tool}/psi-cd-hit.pl -i {path_tmp}/training_sequences.fasta -o {path_tmp}/{threshold}__psicdhit.out -c {threshold} -G 0 -aL 0.8"
    cdhit_process = subprocess.Popen(cdhit_command, shell =True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 
    scan_out, scan_err = cdhit_process.communicate()
    print(scan_out, scan_err)


In [18]:
# ***********************************************************
# Generate multifasta : 
dico_seq_id = {}
with open(f"{path_tmp}/training_sequences.fasta", "w") as outfile :
    for index, seq in enumerate(df_depo["Full_seq"].tolist()) : 
        outfile.write(f">{index}\n{seq}\n")
        if seq not in dico_seq_id : 
            dico_seq_id[seq] = index

for c_value in thresholds :
    make_cdhit_cluster(c_value)

b'BLAST version:\nblastp: 2.9.0+\n Package: blast 2.9.0, build Dec  4 2019 10:13:08\n\n\nmkdir: impossible de cr\xc3\xa9er le r\xc3\xa9pertoire \xc2\xab/media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta-bl\xc2\xbb: Le fichier existe\nmkdir: impossible de cr\xc3\xa9er le r\xc3\xa9pertoire \xc2\xab/media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta-blm\xc2\xbb: Le fichier existe\nmkdir: impossible de cr\xc3\xa9er le r\xc3\xa9pertoire \xc2\xab/media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta-seq\xc2\xbb: Le fichier existe\n/media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta.122017-bl.sh: 8: .//media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta.122017-bl.pl: not found\n/media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta.122017-bl.sh: 8: .//media/concha-eloko/Linux/PhageDEPOdetection/tmp/training_sequences.fasta.122017-bl.pl: not found\n/media/concha-eloko/L

In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/PhageDepo_pdb/Phagedepo.Dataset.21032024.tsv \
/media/concha-eloko/Linux/PhageDEPOdetection

rsync -avzhe ssh \
/media/concha-eloko/Linux/PhageDEPOdetection/tmp \
conchae@garnatxa.srv.cpd:/home/conchae/PhageDepo_pdb/low_cutoff_cluster

***
### Server : 

In [None]:
# ***********************************************************
# Token classification task, n   t a s k   t rs : 
thresholds = [0.3,0.4, 0.45, 0.5]

def make_cdhit_cluster(threshold) :
    cdhit_command = f"psi-cd-hit.pl -i {path_tmp}/training_sequences.fasta -o {path_tmp}/{threshold}.out -c {threshold} -G 0 -aL 0.8"
    cdhit_process = subprocess.Popen(cdhit_command, shell =True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 
    scan_out, scan_err = cdhit_process.communicate()
    print(scan_out, scan_err)


def make_cluster_dico(cdhit_out) :
    import json
    dico_cluster = {}
    threshold = cdhit_out.split("/")[-1].split(".out")[0]
    cluster_file = f"{cdhit_out}.clstr"
    cluster_out = open(cluster_file).read().split(">Cluster")
    for index,cluster in enumerate(cluster_out[1:]) :
        tmp_dpo = []
        #id_cluster = f"Dpo_cdhit_{index}"
        id_cluster = index
        for _,line in enumerate(cluster.split("\n")[1:-1]) :
            dpo = line.split(">")[1].split(".")[0]
            tmp_dpo.append(dpo)
        dico_cluster[id_cluster] = tmp_dpo
    with open(f"{path_tmp}/dico_cluster.psicdhit__{threshold}.json", "w") as outfile:
        json.dump(dico_cluster, outfile)
    return dico_cluster , threshold


def reverse_dico(dico) : 
    r_dico = {}
    for key,values in dico.items() :
        for _,id in enumerate(values) : 
            r_dico[id] = key
    return r_dico


def make_list_group(list_seq, r_dico, id_dico) :
    list_group = []
    for _,seq in enumerate(list_seq) :
        idd_seq = str(id_dico[seq])
        list_group.append(r_dico[idd_seq])
    return list_group


def cvalue_to_list_group(threshold, df_depo) :
    dico_cluster, _ = make_cluster_dico(f"{path_tmp}/{threshold}__psicdhit.out")
    r_dico_cluster = reverse_dico(dico_cluster)
    list_groups = make_list_group(df_depo["Full_seq"].tolist(), r_dico_cluster, dico_seq_id)
    return list_groups


def get_labels(tuple_data ) :
    dico_labels = {"Negative" : 0,
                   "right-handed beta-helix" : 1,
                   "6-bladed beta-propeller" : 2, 
                   "triple-helix" : 3}
    labels_df = []
    for _,row in enumerate(tuple_data) :
        info = row[1]
        seq_length = len(row[0])
        fold = row[2]
        label = dico_labels[fold]
        if info == "Negative" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info == "full_protein" or info == "full" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info.count(":") > 0 : 
            start = int(info.split(":")[0])
            end = int(info.split(":")[1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
        else :
            start = int(info.split("_")[-2])
            end = int(info.split("_")[-1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
    return labels_df

def training_data(threshold): 
    # Split the data : 
    gss_token_class = GroupShuffleSplit(n_splits=1, train_size=0.7, test_size = 0.3, random_state=243)
    gss_seq_class = GroupShuffleSplit(n_splits=1, train_size=0.66, test_size = 0.34, random_state=243)
    list_group_1 = cvalue_to_list_group(threshold, df_depo)
    
    # First split :
    train_token_classification_indices = []
    Other_indices = []
    for i, (train_index, test_index) in enumerate(gss_token_class.split(df_depo["Full_seq"], df_depo["Fold"], list_group_1)):
        train_token_classification_indices.append(train_index)
        Other_indices.append(test_index)
    
    train_tok_seq = df_depo["Full_seq"][train_token_classification_indices[0]]
    train_tok_boundaries = df_depo["Boundaries"][train_token_classification_indices[0]]
    train_tok_fold = df_depo["Fold"][train_token_classification_indices[0]]
    
    training_data_token_class = tuple(zip(train_tok_seq, train_tok_boundaries, train_tok_fold))
    training_data_tok_labels = get_labels(training_data_token_class)
    
    # Intermediate DF : 
    df_depo_s2 = df_depo[df_depo.index.isin(Other_indices[0])]
    df_depo_s2.reset_index(inplace = True)
    df_depo_s2
    
    # Second split : 
    list_group_2 = cvalue_to_list_group(threshold,df_depo_s2)
    train_seq_classifiaction_indices = []
    eval_data_indices = []
    
    for i, (train_index, test_index) in enumerate(gss_seq_class.split(df_depo_s2["Full_seq"], df_depo_s2["Fold"], list_group_2)):
        train_seq_classifiaction_indices.append(train_index)
        eval_data_indices.append(test_index)
    
    train_seq_seq = df_depo_s2["Full_seq"][train_seq_classifiaction_indices[0]]
    train_seq_boundaries = df_depo_s2["Boundaries"][train_seq_classifiaction_indices[0]]
    train_seq_fold = df_depo_s2["Fold"][train_seq_classifiaction_indices[0]]
    
    # Sequence classification data :
    training_data_seq_class = tuple(zip(train_seq_seq, train_seq_boundaries, train_seq_fold))
    training_data_seq_labels = get_labels(training_data_seq_class)
    
    # Ealuation data :
    eval_seq = df_depo_s2["Full_seq"][eval_data_indices[0]]
    eval_seq_boundaries = df_depo_s2["Boundaries"][eval_data_indices[0]]
    eval_seq_fold = df_depo_s2["Fold"][eval_data_indices[0]]
    
    eval_data_token_class = tuple(zip(eval_seq, eval_seq_boundaries, eval_seq_fold))
    eval_data_token_labels = get_labels(eval_data_token_class)

    return train_tok_seq, training_data_tok_labels ,eval_seq , eval_data_token_labels

def mount_data(train_tok_seq, training_data_tok_labels ,eval_seq , eval_data_token_labels) :
    # ***********************************************************
    # Mount the data : 
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    train_tokenized = tokenizer(list(train_tok_seq))
    test_tokenized = tokenizer(list(eval_seq))
    
    train_dataset = Dataset.from_dict(train_tokenized)
    test_dataset = Dataset.from_dict(test_tokenized)
    
    train_dataset = train_dataset.add_column("labels", training_data_tok_labels)
    test_dataset = test_dataset.add_column("labels", eval_data_token_labels)

    return train_dataset, test_dataset

def compute_metrics(eval_pred):
    metric = load("accuracy")
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)


dico_seq_id = {}
with open(f"{path_tmp}/training_sequences.fasta", "w") as outfile :
    for index, seq in enumerate(df_depo["Full_seq"].tolist()) :
        outfile.write(f">{index}\n{seq}\n")
        if seq not in dico_seq_id :
            dico_seq_id[seq] = index



In [None]:
def training_model(c_value) : 
    # get training data :
    train_tok_seq, training_data_tok_labels ,eval_seq , eval_data_token_labels = training_data(c_value)
    train_dataset, test_dataset = mount_data(train_tok_seq, training_data_tok_labels ,eval_seq , eval_data_token_labels)
    
    # setup the training :
    num_labels = 4
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
    model_name = model_checkpoint.split("/")[-1]
    data_collator = DataCollatorForTokenClassification(tokenizer)
    if os.path.isdir(f"{model_name}__{c_value}__finetuneddepolymerase.2205.{num_labels}_labels") == False :
      # set args
        batch_size = 4
        args = TrainingArguments(
            f"{model_name}__{c_value}__finetuneddepolymerase.2205.{num_labels}_labels",
            evaluation_strategy = "epoch",
            save_strategy = "epoch",
            learning_rate=1e-5,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=5,
            weight_decay=0.001,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            logging_dir='./logs',
            push_to_hub=False,
        )
        # training :
        metric = load("accuracy")
        trainer = Trainer(
            model,
            args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
            data_collator=data_collator,
        )    
        trainer.train()


for c_value in thresholds[::-1] :
    training_model(c_value)

    

In [None]:
#!/bin/bash
#BATCH --job-name=token_class
#SBATCH --qos=medium 
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=20
#SBATCH --mem=75gb 
#SBATCH --time=2-00:00:00 
#SBATCH --output=token_class%j.log 

module restore la_base
conda activate embeddings

python /home/conchae/PhageDepo_pdb/script_files/esm2_finetuning.review_v2.py