In [None]:
# TropiGAT: A Graph Neural Network for Prophage Prediction
# This script processes prophage data, builds graphs, and trains models for prediction

import os
import random
import warnings
from collections import Counter
from itertools import product
from multiprocessing.pool import ThreadPool

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import (accuracy_score, f1_score, matthews_corrcoef,
                             precision_score, recall_score, roc_auc_score)
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, label_binarize
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import GATv2Conv, HeteroConv, to_hetero
from torch_geometric.utils import negative_sampling
from tqdm import tqdm

import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# Constants
PATH_WORK = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DATE = "1209"
ENSEMBLE_PATH = f"{PATH_WORK}/train_nn/ensemble_{DATE}2024_optimized"
ENSEMBLE_PATH_log = f"{PATH_WORK}/train_nn/ensemble_{DATE}2024_log_optimized"
os.makedirs(ENSEMBLE_PATH, exist_ok=True)
os.makedirs(ENSEMBLE_PATH_log, exist_ok=True)
# Hyperparameters
DICO_OPTUNA = {
    "KL64": {
        "para_heads": 5,  
        "para_lr": 0.000246, 
        "para_wd": 0.000080,
        "para_dropout": 0.063113}, 
    "KL1": {
        "para_heads": 5,  
        "para_lr":  0.0009415397708661039,
        "para_wd": 1.132790862878068e-06,
        "para_dropout": 0.007657626670776924},
    "KL10": {
        "para_heads": 2,  
        "para_lr": 0.0006633594884735811,
        "para_wd": 3.7430295738223034e-06,
        "para_dropout":0.4493747213067273 }, 
    "KL15": {
        "para_heads": 5,  
        "para_lr": 0.00017766142057218653,
        "para_wd": 5.245213610566463e-05,
        "para_dropout":0.15214512795626994 }, 
    "KL17": {
        "para_heads": 2,  
        "para_lr": 0.0002068133316219641,
        "para_wd": 5.303964308479191e-05,
        "para_dropout":0.4810681327179018 }, 
    "KL19": {
        "para_heads": 5,  
        "para_lr": 0.00028386856144729176,
        "para_wd": 6.667568504410857e-07,
        "para_dropout":0.4460345479421262 }, 
    "KL2": {
        "para_heads": 2,  
        "para_lr": 0.0006115983973072073,
        "para_wd": 3.521041854903662e-06,
        "para_dropout":0.16320044607028428 }, 
    "KL47": {
        "para_heads": 2,  
        "para_lr": 0.0007352151826846244,
        "para_wd": 8.666317429082471e-06,
        "para_dropout":0.1877399746783721 }, 
    "KL74": {
        "para_heads": 1,  
        "para_lr": 0.0004137122657073261,
        "para_wd": 3.5238343953806846e-05,
        "para_dropout":0.3464829958840639 }, 
}




def load_and_preprocess_data():
    """Load and preprocess the prophage data."""
    df_info = pd.read_csv(f"{PATH_WORK}/train_nn/TropiGATv2.final_df_v2.tsv", sep="\t", header=0)
    df_info = df_info.drop_duplicates(subset=["Protein_name"])
    
    df_prophages = df_info.drop_duplicates(subset=["Phage"], keep="first")
    dico_prophage_info = {row["Phage"]: {"prophage_strain": row["prophage_id"], "ancestor": row["Infected_ancestor"]} 
                          for _, row in df_prophages.iterrows()}
    
    return df_info, dico_prophage_info

def filter_prophages(df_info, dico_prophage_info):
    """Filter prophages to remove duplicates and ensure diversity."""
    def get_filtered_prophages(prophage):
        combinations = []
        to_exclude = set()
        to_keep = set()
        to_keep.add(prophage)
        df_prophage_group = df_info[
            (df_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & 
            (df_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])
        ]
        if len(df_prophage_group) == 1:
            return df_prophage_group, to_exclude, to_keep
        
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique():
            if prophage_tmp != prophage:
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set:
                    to_exclude.add(prophage_tmp)
                elif tmp_depo_set not in combinations:
                    to_keep.add(prophage_tmp)
                    combinations.append(tmp_depo_set)
                else:
                    to_exclude.add(prophage_tmp)
        return df_prophage_group, to_exclude, to_keep

    good_prophages = set()
    excluded_prophages = set()

    for prophage in tqdm(dico_prophage_info.keys()):
        if prophage not in excluded_prophages and prophage not in good_prophages:
            _, excluded_members, kept_members = get_filtered_prophages(prophage)
            good_prophages.update(kept_members)
            excluded_prophages.update(excluded_members)

    df_info_filtered = df_info[df_info["Phage"].isin(good_prophages)]
    df_info_final = df_info_filtered[~df_info_filtered["KL_type_LCA"].str.contains("\\|")]

    return df_info_final

def ultrafilter_prophages(df_info):
    """Perform ultra-filtration to remove duplicate prophages within KL types."""
    duplicate_prophage = []
    dico_kltype_duplica = {}

    for kltype in df_info["KL_type_LCA"].unique():
        df_kl = df_info[df_info["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
        prophages_tmp_list = df_kl["Phage"].unique().tolist()
        set_sets_depo = []
        duplicated = {}  
        for prophage_tmp in prophages_tmp_list: 
            set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
            for past_set in set_sets_depo:
                if past_set == set_depo:
                    duplicated[past_set] = duplicated.get(past_set, 0) + 1
                    duplicate_prophage.append(prophage_tmp)
                    break
            else:
                set_sets_depo.append(set_depo)
                duplicated[set_depo] = 1
        dico_kltype_duplica[kltype] = duplicated

    df_info_ultrafiltered = df_info[~df_info["Phage"].isin(duplicate_prophage)]
    return df_info_ultrafiltered

def prepare_kltypes(df_info):
    """Prepare KL types for training."""
    df_prophages = df_info.drop_duplicates(subset=["Phage"])
    dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))
    kltypes = [kltype for kltype, count in dico_prophage_count.items() if count >= 10]
    return kltypes, dico_prophage_count

def train_graph(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info, dico_prophage_count):
    """Train the graph neural network for a specific KL type."""
    for seed in range(1, 6):
        log_file = f"{ENSEMBLE_PATH_log}/{kl_type}__{seed}__node_classification.{DATE}.log"
        with open(log_file, "w") as log_outfile:
            try:
                n_prophage = dico_prophage_count[kl_type]
                graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(
                    graph_baseline, dico_prophage_kltype_associated, df_info, 
                    kl_type, 5, 0.7, 0.2, 0.1, seed=seed
                )
                
                model = TropiGAT_models.TropiGAT_small_module(
                    hidden_channels=1280, heads=DICO_OPTUNA[kl_type]["para_heads"], dropout = DICO_OPTUNA[kl_type]["para_dropout"]
                )
                model(graph_data_kltype)
                
                optimizer = torch.optim.AdamW(
                    model.parameters(), 
                    lr=DICO_OPTUNA[kl_type]["para_lr"], 
                    weight_decay=DICO_OPTUNA[kl_type]["para_wd"]
                )
                scheduler = ReduceLROnPlateau(optimizer, 'min')
                criterion = torch.nn.BCEWithLogitsLoss()
                early_stopping = TropiGAT_models.EarlyStopping(
                    patience=100, 
                    verbose=True, 
                    path=f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt", 
                    metric='MCC'
                )
                
                for epoch in range(500):
                    train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer, criterion)
                    if epoch % 5 == 0:
                        test_loss, metrics = TropiGAT_models.evaluate(
                            model, graph_data_kltype, criterion, graph_data_kltype["B1"].test_mask
                        )
                        log_outfile.write(f'Epoch: {epoch}\tTrain Loss: {train_loss:.4f}\t'
                                          f'Test Loss: {test_loss:.4f}\tMCC: {metrics[3]:.4f}\t'
                                          f'AUC: {metrics[5]:.4f}\tAccuracy: {metrics[4]:.4f}\n')
                        scheduler.step(test_loss)
                        early_stopping(metrics[3], model, epoch)
                        if early_stopping.early_stop:
                            log_outfile.write(f"Early stopping at epoch {epoch}\n")
                            break
                else:
                    torch.save(model, f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt")
                
                # Final evaluation
                model_final = TropiGAT_models.TropiGAT_small_module(
                    hidden_channels=1280, heads=DICO_OPTUNA[kl_type]["para_heads"], dropout =DICO_OPTUNA[kl_type]["para_dropout"]
                )
                model_final.load_state_dict(torch.load(f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiGATv2.{DATE}.pt"))
                eval_loss, metrics = TropiGAT_models.evaluate(
                    model_final, graph_data_kltype, criterion, graph_data_kltype["B1"].eval_mask
                )
                
                with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                    metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t"
                                         f"{metrics[0]:.4f}\t{metrics[1]:.4f}\t{metrics[2]:.4f}\t"
                                         f"{metrics[3]:.4f}\t{metrics[4]:.4f}\t{metrics[5]:.4f}\n")
                
                log_outfile.write(f"Final evaluation:\n"
                                  f"F1 Score: {metrics[0]:.4f}, Precision: {metrics[1]:.4f}, "
                                  f"Recall: {metrics[2]:.4f}, MCC: {metrics[3]:.4f}, "
                                  f"Accuracy: {metrics[4]:.4f}, AUC: {metrics[5]:.4f}")
            
            except Exception as e:
                log_outfile.write(f"Error occurred: {str(e)}")
                with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                    metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t***Error***\n")

def main():
    """Main function to orchestrate the TropiGAT workflow."""
    df_info, dico_prophage_info = load_and_preprocess_data()
    df_info_filtered = filter_prophages(df_info, dico_prophage_info)
    df_info_final = ultrafilter_prophages(df_info_filtered)
    
    kltypes, dico_prophage_count = prepare_kltypes(df_info_final)
    
    graph_baseline, dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(df_info_final)
    
    with ThreadPool(5) as p:
        p.starmap(train_graph, [(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info_final, dico_prophage_count) for kl_type in DICO_OPTUNA if os.path.isfile(f"{ENSEMBLE_PATH_log}/{kl_type}__1__node_classification.{DATE}.log")==False])

if __name__ == '__main__':
    main()
