In [None]:
import os
import json
import random
import warnings
from collections import Counter
from itertools import product
from multiprocessing.pool import ThreadPool

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import (accuracy_score, f1_score, matthews_corrcoef,
                             precision_score, recall_score, roc_auc_score)
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, label_binarize
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import GATv2Conv, HeteroConv, to_hetero
from torch_geometric.utils import negative_sampling
from tqdm import tqdm

import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# Constants
# **************************************************
ultrafiltration = False
# **************************************************

PATH_WORK = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DATE = "27_11_2024"

if ultrafiltration == False :
    ENSEMBLE_PATH = f"{PATH_WORK}/train_nn/TropiGAT_ensemble_{DATE}"
    ENSEMBLE_PATH_log = f"{PATH_WORK}/train_nn/TropiGAT_ensemble_{DATE}_log"
    OPTUNA_PATH = f"{PATH_WORK}/train_nn/ensemble_20112024_log_optimized_TropiGAT"
else :
    ENSEMBLE_PATH = f"{PATH_WORK}/train_nn/TropiGAT_ensemble_ultraF_{DATE}"
    ENSEMBLE_PATH_log = f"{PATH_WORK}/train_nn/TropiGAT_ensemble_ultraF_{DATE}_log"
    OPTUNA_PATH = f"{PATH_WORK}/train_nn/ensemble_20112024_log_optimized_TropiGAT_ultraF"
        
os.makedirs(ENSEMBLE_PATH, exist_ok=True)
os.makedirs(ENSEMBLE_PATH_log, exist_ok=True)


DICO_OPTUNA = {}    
for file in os.listdir(OPTUNA_PATH):
    kl_type = file.split("_")[0]
    best_parameters = json.load(open(f"{OPTUNA_PATH}/{file}").read())
    DICO_OPTUNA[kl_type] = best_parameters
    

def load_and_preprocess_data():
    """Load and preprocess the prophage data."""
    df_info = pd.read_csv(f"{PATH_WORK}/train_nn/TropiGATv2.final_df_v2.tsv", sep="\t", header=0)
    df_info = df_info.drop_duplicates(subset=["Protein_name"])
    
    df_prophages = df_info.drop_duplicates(subset=["Phage"], keep="first")
    dico_prophage_info = {row["Phage"]: {"prophage_strain": row["prophage_id"], "ancestor": row["Infected_ancestor"]} 
                          for _, row in df_prophages.iterrows()}
    
    return df_info, dico_prophage_info

def filter_prophages(df_info, dico_prophage_info):
    """Filter prophages to remove duplicates and ensure diversity."""
    def get_filtered_prophages(prophage):
        combinations = []
        to_exclude = set()
        to_keep = set()
        to_keep.add(prophage)
        df_prophage_group = df_info[
            (df_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & 
            (df_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])
        ]
        if len(df_prophage_group) == 1:
            return df_prophage_group, to_exclude, to_keep
        
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique():
            if prophage_tmp != prophage:
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set:
                    to_exclude.add(prophage_tmp)
                elif tmp_depo_set not in combinations:
                    to_keep.add(prophage_tmp)
                    combinations.append(tmp_depo_set)
                else:
                    to_exclude.add(prophage_tmp)
        return df_prophage_group, to_exclude, to_keep

    good_prophages = set()
    excluded_prophages = set()

    for prophage in tqdm(dico_prophage_info.keys()):
        if prophage not in excluded_prophages and prophage not in good_prophages:
            _, excluded_members, kept_members = get_filtered_prophages(prophage)
            good_prophages.update(kept_members)
            excluded_prophages.update(excluded_members)

    df_info_filtered = df_info[df_info["Phage"].isin(good_prophages)]
    df_info_final = df_info_filtered[~df_info_filtered["KL_type_LCA"].str.contains("\\|")]

    return df_info_final


def ultrafilter_prophages(df_info):
    """Perform ultra-filtration to remove duplicate prophages within KL types."""
    duplicate_prophage = []
    dico_kltype_duplica = {}

    for kltype in df_info["KL_type_LCA"].unique():
        df_kl = df_info[df_info["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
        prophages_tmp_list = df_kl["Phage"].unique().tolist()
        set_sets_depo = []
        duplicated = {}  
        for prophage_tmp in prophages_tmp_list: 
            set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
            for past_set in set_sets_depo:
                if past_set == set_depo:
                    duplicated[past_set] = duplicated.get(past_set, 0) + 1
                    duplicate_prophage.append(prophage_tmp)
                    break
            else:
                set_sets_depo.append(set_depo)
                duplicated[set_depo] = 1
        dico_kltype_duplica[kltype] = duplicated

    df_info_ultrafiltered = df_info[~df_info["Phage"].isin(duplicate_prophage)]
    return df_info_ultrafiltered


def prepare_kltypes(df_info):
    """Prepare KL types for training."""
    df_prophages = df_info.drop_duplicates(subset=["Phage"])
    dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))
    kltypes = [kltype for kltype, count in dico_prophage_count.items() if count >= 10]
    return kltypes, dico_prophage_count



def train_graph(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info, dico_prophage_count):
    """Train the graph neural network for a specific KL type."""
    for seed in range(1, 6):
        if os.path.isfile(f"{ENSEMBLE_PATH_log}/{kl_type}__{seed}__node_classification.{DATE}.log") == False :
            log_file = f"{ENSEMBLE_PATH_log}/{kl_type}__{seed}__node_classification.{DATE}.log"
            with open(log_file, "w") as log_outfile:
                try:
                    n_prophage = dico_prophage_count[kl_type]
                    graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(
                        graph_baseline, dico_prophage_kltype_associated, df_info, 
                        kl_type, 5, 0.7, 0.2, 0.1, seed=seed
                    )

                    model = TropiGAT_models.TropiGAT_small_module(
                        hidden_channels=1280, heads= DICO_OPTUNA[kl_type]["att_heads"], dropout = DICO_OPTUNA[kl_type]["dropout"]
                    )
                    model(graph_data_kltype)

                    optimizer = torch.optim.AdamW(
                        model.parameters(), 
                        lr=DICO_OPTUNA[kl_type]["lr"], 
                        weight_decay=DICO_OPTUNA[kl_type]["weight_decay"]
                    )
                    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, min_lr=1e-6)
                    criterion = torch.nn.BCEWithLogitsLoss()
                    early_stopping = TropiGAT_models.EarlyStopping(
                        patience=100, 
                        verbose=True, 
                        path=f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt", 
                        metric='MCC'
                    )

                    for epoch in range(500):
                        train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer, criterion)
                        if epoch % 5 == 0:
                            test_loss, metrics = TropiGAT_models.evaluate(
                                model, graph_data_kltype, criterion, graph_data_kltype["B1"].test_mask
                            )
                            log_outfile.write(f'Epoch: {epoch}\tTrain Loss: {train_loss:.4f}\t'
                                              f'Test Loss: {test_loss:.4f}\tMCC: {metrics[3]:.4f}\t'
                                              f'AUC: {metrics[5]:.4f}\tAccuracy: {metrics[4]:.4f}\n')
                            scheduler.step(test_loss)
                            early_stopping(metrics[3], model, epoch)
                            if early_stopping.early_stop:
                                log_outfile.write(f"Early stopping at epoch {epoch}\n")
                                break
                    else:
                        torch.save(model, f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt")

                    # Final evaluation
                    model_final = TropiGAT_models.TropiGAT_small_module(
                        hidden_channels=1280, heads = DICO_OPTUNA[kl_type]["att_heads"], dropout = DICO_OPTUNA[kl_type]["dropout"]
                    )
                    model_final.load_state_dict(torch.load(f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt"))
                    eval_loss, metrics = TropiGAT_models.evaluate(
                        model_final, graph_data_kltype, criterion, graph_data_kltype["B1"].eval_mask
                    )

                    with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                        metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t"
                                             f"{metrics[0]:.4f}\t{metrics[1]:.4f}\t{metrics[2]:.4f}\t"
                                             f"{metrics[3]:.4f}\t{metrics[4]:.4f}\t{metrics[5]:.4f}\n")

                    log_outfile.write(f"Final evaluation:\n"
                                      f"F1 Score: {metrics[0]:.4f}, Precision: {metrics[1]:.4f}, "
                                      f"Recall: {metrics[2]:.4f}, MCC: {metrics[3]:.4f}, "
                                      f"Accuracy: {metrics[4]:.4f}, AUC: {metrics[5]:.4f}")

                except Exception as e:
                    log_outfile.write(f"Error occurred: {str(e)}")
                    with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                        metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t***Error***\n")

def main():
    """Main function to orchestrate the TropiGAT workflow."""
    df_info, dico_prophage_info = load_and_preprocess_data()
    if ultrafiltration == True :
        df_info_final = ultrafilter_prophages(filter_prophages(df_info, dico_prophage_info))
    else :
        df_info_final = filter_prophages(df_info, dico_prophage_info)
    
    kltypes, dico_prophage_count = prepare_kltypes(df_info_filtered)
    graph_baseline, dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(df_info_final)
    
    with ThreadPool(5) as p:
        p.starmap(train_graph, [(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info_final, dico_prophage_count) for kl_type in kltypes])

if __name__ == '__main__':
    main()


***
### Move the best version of the model to local: 

In [15]:
import os 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

> Regular training :

In [17]:
average_metric_df = pd.read_csv(f"/media/concha-eloko/Linux/PPT_clean/ficheros_28032023/Metric_Report.review.GAT.F.tsv", sep = "\t", index_col = False, header = 0)

model_version = {kltype.split("_")[0] : kltype.split("_")[-1] for kltype in average_metric_df["model_version"]}
print(model_version)

{'KL1': '4', 'KL2': '3', 'KL3': '5', 'KL4': '5', 'KL5': '4', 'KL6': '3', 'KL7': '5', 'KL8': '1', 'KL9': '5', 'KL10': '4', 'KL12': '2', 'KL13': '1', 'KL14': '1', 'KL15': '1', 'KL16': '3', 'KL17': '2', 'KL18': '3', 'KL19': '5', 'KL20': '3', 'KL21': '3', 'KL22': '5', 'KL23': '5', 'KL24': '4', 'KL25': '4', 'KL26': '4', 'KL27': '4', 'KL28': '1', 'KL29': '4', 'KL30': '2', 'KL31': '1', 'KL34': '5', 'KL35': '3', 'KL36': '4', 'KL38': '5', 'KL39': '3', 'KL41': '3', 'KL43': '4', 'KL45': '5', 'KL46': '2', 'KL47': '5', 'KL48': '2', 'KL51': '2', 'KL52': '2', 'KL53': '2', 'KL54': '3', 'KL55': '5', 'KL56': '4', 'KL57': '2', 'KL60': '5', 'KL61': '3', 'KL62': '2', 'KL63': '3', 'KL64': '2', 'KL67': '4', 'KL70': '2', 'KL71': '3', 'KL74': '2', 'KL81': '1', 'KL102': '5', 'KL103': '3', 'KL105': '2', 'KL106': '3', 'KL107': '3', 'KL108': '4', 'KL109': '3', 'KL110': '1', 'KL111': '4', 'KL112': '5', 'KL114': '3', 'KL116': '2', 'KL117': '5', 'KL118': '1', 'KL122': '3', 'KL123': '2', 'KL124': '4', 'KL125': '3', 'K

In [None]:
import os 

path_work = f"/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn"
path_best_models = f"{path_work}/best_models"
path_models = f"{path_work}/TropiGAT_ensemble_27_11_2024"

index = "TropiGATv2.27_11_2024.pt"

model_version = {'KL1': '4', 'KL2': '3', 'KL3': '5', 'KL4': '5', 'KL5': '4', 'KL6': '3', 'KL7': '5', 'KL8': '1', 'KL9': '5', 'KL10': '4', 'KL12': '2', 'KL13': '1', 'KL14': '1', 'KL15': '1', 'KL16': '3', 'KL17': '2', 'KL18': '3', 'KL19': '5', 'KL20': '3', 'KL21': '3', 'KL22': '5', 'KL23': '5', 'KL24': '4', 'KL25': '4', 'KL26': '4', 'KL27': '4', 'KL28': '1', 'KL29': '4', 'KL30': '2', 'KL31': '1', 'KL34': '5', 'KL35': '3', 'KL36': '4', 'KL38': '5', 'KL39': '3', 'KL41': '3', 'KL43': '4', 'KL45': '5', 'KL46': '2', 'KL47': '5', 'KL48': '2', 'KL51': '2', 'KL52': '2', 'KL53': '2', 'KL54': '3', 'KL55': '5', 'KL56': '4', 'KL57': '2', 'KL60': '5', 'KL61': '3', 'KL62': '2', 'KL63': '3', 'KL64': '2', 'KL67': '4', 'KL70': '2', 'KL71': '3', 'KL74': '2', 'KL81': '1', 'KL102': '5', 'KL103': '3', 'KL105': '2', 'KL106': '3', 'KL107': '3', 'KL108': '4', 'KL109': '3', 'KL110': '1', 'KL111': '4', 'KL112': '5', 'KL114': '3', 'KL116': '2', 'KL117': '5', 'KL118': '1', 'KL122': '3', 'KL123': '2', 'KL124': '4', 'KL125': '3', 'KL127': '1', 'KL128': '1', 'KL136': '2', 'KL139': '2', 'KL140': '3', 'KL142': '2', 'KL143': '2', 'KL145': '3', 'KL149': '4', 'KL151': '5', 'KL152': '1', 'KL153': '4', 'KL155': '2', 'KL157': '4', 'KL166': '3', 'KL169': '3'}

for kltype, version in model_version.items(): 
    #os.system(f"mv {path_models}/{kltype}__{version}.{index} {path_best_models}/best_models_TropiGAT ")
    os.system(f"cp {path_best_models}/best_models_TropiGAT/{kltype}__{version}.{index} {path_models}")
    
    

> Ultrafiltration : 

In [16]:
average_metric_df = pd.read_csv(f"/media/concha-eloko/Linux/PPT_clean/ficheros_28032023/Metric_Report.review.GAT.UF.tsv", sep = "\t", index_col = False, header = 0)

model_version = {kltype.split("_")[0] : kltype.split("_")[-1] for kltype in average_metric_df["model_version"]}
print(model_version)

{'KL1': '3', 'KL2': '2', 'KL3': '5', 'KL4': '1', 'KL5': '1', 'KL6': '1', 'KL7': '5', 'KL8': '3', 'KL9': '3', 'KL10': '3', 'KL12': '4', 'KL13': '1', 'KL14': '3', 'KL15': '2', 'KL16': '4', 'KL17': '3', 'KL18': '5', 'KL19': '2', 'KL20': '3', 'KL21': '2', 'KL22': '4', 'KL23': '4', 'KL24': '2', 'KL25': '3', 'KL26': '4', 'KL27': '5', 'KL28': '5', 'KL29': '4', 'KL30': '3', 'KL31': '4', 'KL34': '5', 'KL35': '2', 'KL36': '5', 'KL38': '3', 'KL39': '3', 'KL43': '2', 'KL45': '3', 'KL46': '4', 'KL47': '4', 'KL48': '2', 'KL51': '3', 'KL52': '2', 'KL53': '4', 'KL55': '4', 'KL56': '2', 'KL57': '3', 'KL60': '4', 'KL62': '4', 'KL63': '3', 'KL64': '3', 'KL67': '4', 'KL70': '1', 'KL71': '1', 'KL74': '1', 'KL81': '2', 'KL102': '5', 'KL103': '3', 'KL105': '2', 'KL106': '5', 'KL107': '3', 'KL108': '5', 'KL109': '1', 'KL110': '1', 'KL111': '5', 'KL112': '5', 'KL114': '5', 'KL116': '2', 'KL117': '2', 'KL118': '4', 'KL122': '5', 'KL123': '4', 'KL124': '5', 'KL125': '3', 'KL127': '3', 'KL128': '4', 'KL136': '3',

In [None]:
import os 

path_work = f"/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn"
path_best_models = f"{path_work}/best_models"
path_models = f"{path_work}/TropiGAT_ensemble_ultraF_27_11_2024"

index = "TropiGATv2.27_11_2024.pt"

model_version = {'KL1': '3', 'KL2': '2', 'KL3': '5', 'KL4': '1', 'KL5': '1', 'KL6': '1', 'KL7': '5', 'KL8': '3', 'KL9': '3', 'KL10': '3', 'KL12': '4', 'KL13': '1', 'KL14': '3', 'KL15': '2', 'KL16': '4', 'KL17': '3', 'KL18': '5', 'KL19': '2', 'KL20': '3', 'KL21': '2', 'KL22': '4', 'KL23': '4', 'KL24': '2', 'KL25': '3', 'KL26': '4', 'KL27': '5', 'KL28': '5', 'KL29': '4', 'KL30': '3', 'KL31': '4', 'KL34': '5', 'KL35': '2', 'KL36': '5', 'KL38': '3', 'KL39': '3', 'KL43': '2', 'KL45': '3', 'KL46': '4', 'KL47': '4', 'KL48': '2', 'KL51': '3', 'KL52': '2', 'KL53': '4', 'KL55': '4', 'KL56': '2', 'KL57': '3', 'KL60': '4', 'KL62': '4', 'KL63': '3', 'KL64': '3', 'KL67': '4', 'KL70': '1', 'KL71': '1', 'KL74': '1', 'KL81': '2', 'KL102': '5', 'KL103': '3', 'KL105': '2', 'KL106': '5', 'KL107': '3', 'KL108': '5', 'KL109': '1', 'KL110': '1', 'KL111': '5', 'KL112': '5', 'KL114': '5', 'KL116': '2', 'KL117': '2', 'KL118': '4', 'KL122': '5', 'KL123': '4', 'KL124': '5', 'KL125': '3', 'KL127': '3', 'KL128': '4', 'KL136': '3', 'KL140': '4', 'KL142': '5', 'KL145': '2', 'KL149': '5', 'KL151': '1', 'KL153': '1', 'KL155': '2', 'KL157': '4', 'KL169': '5'}

for kltype, version in model_version.items(): 
    #os.system(f"mv {path_models}/{kltype}__{version}.{index} {path_best_models}/best_models_TropiGAT_UF ")
    os.system(f"cp {path_best_models}/best_models_TropiGAT_UF/{kltype}__{version}.{index} {path_models}")


In [None]:

rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/best_models \
/media/concha-eloko/Linux/PPT_clean/reviewed_models

