### Pre -work : 


In [None]:
rsync -avzhe ssh \
/media/concha-eloko/Linux/PPT_clean/in_vitro/in_vitro_DFs \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023

***

In [2]:
# Torch geometric modules
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

# Torch modules
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# SKlearn modules
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")


In [None]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DF_info = pd.read_csv(f"{path_work}/train_nn/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

# Log file : 
path_ensemble = f"{path_work}/train_nn/ensemble_2809"
path_finetuned = f"{path_work}/train_nn/fine_tuning/models"
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]


> Open the DF

In [3]:
path_project = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/in_vitro_DFs"

# ESM2 embedding dataframes :
# Ferriol 
dpo_embeddings = pd.read_csv(f"{path_project}/Dpo_domains_77.esm2.embedding.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings[0] = dpo_embeddings[0].apply(lambda x : x.split("_Dpo")[0])
dpo_embeddings.set_index([0], inplace = True)

# Beamud  
bea_embeddings = pd.read_csv(f"{path_project}/Bea_phages.esm2.embedding.csv", sep = "," , header = None)
bea_embeddings = bea_embeddings.drop([1281] , axis = 1)
bea_embeddings[0] = bea_embeddings[0].apply(lambda x : x.replace("_", "__"))
bea_embeddings.set_index([0], inplace = True)

# Towndsend  :
towndsend_embeddings = pd.read_csv(f"{path_project}/Townsed_phages.esm2.embedding.csv", sep = "," , header = None)
towndsend_embeddings = towndsend_embeddings.drop([1281] , axis = 1)
towndsend_embeddings[0] = towndsend_embeddings[0].apply(lambda x : x.replace("_", "__"))
towndsend_embeddings.set_index([0], inplace = True)

# ==> DF embeddings 
df_embeddings = pd.concat([towndsend_embeddings, bea_embeddings, dpo_embeddings], axis = 0)

# ************************************
# The matrices : 
# Beamud matrix :
bea_df = pd.read_csv(f"{path_project}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

# Ferriol matrix
ferriol_df = pd.read_csv(f"{path_project}/ferriol_fine_tuning.df", sep = "\t", header = 0)
ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

# Towndsend matrix :
towndsend_df = pd.read_csv(f"{path_project}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

# ==> dico Data
dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}

pools_kltypes = set()
pools_kltypes.update(pool_ferriol)
pools_kltypes.update(pool_bea)
pools_kltypes.update(pool_towndsend)
pools_kltypes = list(pools_kltypes)


> Functions :

In [None]:
# Functions : 
def finetune_kltype_df(kltype) :
    positive_lines , negative_lines = [], []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            for _,row in dico_matrices[author]["matrix"].iterrows() :
                if row["Target"].count("pass") == 0 :
                    targets = [kltype.strip() for kltype in row["Target"].split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0]
                    if kltype not in targets :
                        negative_lines.append(list(row.values))
                    else :
                        positive_lines.append(list(row.values))
    lines = positive_lines + negative_lines
    n_positives = len(set([item[0] for item in positive_lines]))
    n_negatives = len(set([item[0] for item in negative_lines]))
    df_kltype = pd.DataFrame(lines, columns = ["phage", "depo","KLtypes"])

    return df_kltype , n_positives, n_negatives
    
def build_graph_baseline(df_info, n_positives, n_negatives) : 
    # **************************************************************
    # initialize the graph
    graph_data = HeteroData()
    # Indexation process  
    indexation_nodes_A = [0]
    indexation_nodes_B1 = df_info["phage"].unique().tolist()
    indexation_nodes_B2 = df_info["depo"].unique().tolist() 
    #ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
    #ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
    ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    # **************************************************************
    # Make the node feature file : 
    embeddings_columns = [int(i) for i in range(1, 1281)]
    node_feature_A = torch.tensor([0], dtype=torch.float)
    node_feature_B1 = torch.zeros((len(ID_nodes_B1), 0), dtype=torch.float)
    node_feature_B2 = torch.tensor([df_embeddings[df_embeddings.index == depo][embeddings_columns].values[0].tolist() for depo in df_info["depo"]] , dtype=torch.float)
    # feed the graph
    graph_data["A"].x = node_feature_A
    graph_data["B1"].x = node_feature_B1
    graph_data["B2"].x = node_feature_B2
    # **************************************************************
    # Make edge file
    # Node B2 (depolymerase) - Node B1 (prophage) :
    edge_index_B2_B1 = []
    for phage in df_info.phage.unique() :
        all_data_phage = df_info[df_info["phage"] == phage]
        for _, row in all_data_phage.iterrows() :
            edge_index_B2_B1.append([ID_nodes_B2[row["depo"]], ID_nodes_B1[row["phage"]]])
    edge_index_B2_B1 = torch.tensor(edge_index_B2_B1 , dtype=torch.long)
    # feed the graph
    graph_data['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    # The labels : 
    labels = [1] * n_positives + [0] * n_negatives
    graph_data["B1"].y = torch.tensor(labels)
    # Training fraction :
    train_mask = [1]* len(labels)
    graph_data["B1"].train_mask = torch.tensor(train_mask)
    return graph_data 

def train_graph(KL_type, graph_data) :
    with open(f"{path_work}/train_nn/fine_tuning/log_files/{KL_type}__finetuned_node_classification.2111.log" , "w") as log_outfile :
        n_prophage = dico_prophage_count[KL_type]
        if n_prophage <= 125 : 
            model = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280, heads = 1)
            n = "small"
        else : 
            model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
            n = "big"
        model.load_state_dict(torch.load(f"{path_ensemble}/{KL_type}.TropiGATv2.2809.pt"))
        model(graph_data)
        # 
        optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        criterion = torch.nn.BCEWithLogitsLoss()
        #
        try : 
            for epoch in range(50):
                train_loss = TropiGAT_models.train(model, graph_data, optimizer,criterion)
                if epoch % 5 == 0:
                    # Save the model checkpoint
                    info = f'Epoch: {epoch}\tTrain Loss: {train_loss}\n'
                    log_outfile.write(info)
                    checkpoint_path = f"{path_finetuned}/{KL_type}.{epoch}.finetuned.TropiGATv2.2211.pt"
                    torch.save(model.state_dict(), checkpoint_path)
                    print(f"Checkpoint saved: {checkpoint_path}")
        except Exception as e :
            log_outfile.write(f"***Issue here : {e}")


def finetune_kltype(kltype) : 
    if kltype in KLtypes :
        df_kltype , n_positives, n_negatives = finetune_kltype_df(kltype)
        graph_data = build_graph_baseline(df_kltype , n_positives, n_negatives)
        train_graph(kltype , graph_data)
    else :
        with open(f"{path_work}/train_nn/fine_tuning/log_files/{KL_type}__finetuned_node_classification.2111.log" , "w") as log_outfile :
            log_outfile.write("Not in the TropiGAT system")



if __name__ == '__main__':
    with ThreadPool(10) as p:
        p.map(finetune_kltype, pools_kltypes)


In [None]:
#!/bin/bash
#BATCH --job-name=2211_finetuning__
#SBATCH --qos=short
#SBATCH --ntasks=1 
#SBATCH --cpus-per-task=16 
#SBATCH --mem=50gb 
#SBATCH --time=1-00:00:00 
#SBATCH --output=2211_finetuning__%j.log 

module restore la_base
source /storage/apps/ANACONDA/anaconda3/etc/profile.d/conda.sh
conda activate torch_geometric

python /home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/fine_tuning/script_files/finetuning.2211.py

In [None]:
***
