In [None]:
import os
import json
import random
import warnings
from collections import Counter
from itertools import product
from multiprocessing.pool import ThreadPool

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import (accuracy_score, f1_score, matthews_corrcoef,
                             precision_score, recall_score, roc_auc_score)
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, label_binarize
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, HeteroConv, to_hetero
from torch_geometric.utils import negative_sampling
from tqdm import tqdm

import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# Constants
# **************************************************
ultrafiltration = False
# **************************************************

PATH_WORK = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DATE = "27_11_2024"
if ultrafiltration == False :
    ENSEMBLE_PATH = f"{PATH_WORK}/train_nn/TropiSAGE_ensemble_{DATE}"
    ENSEMBLE_PATH_log = f"{PATH_WORK}/train_nn/TropiSAGE_ensemble_{DATE}_log"
    OPTUNA_PATH = f"{PATH_WORK}/train_nn/ensemble_20112024_log_optimized_SAGE"
else :
    ENSEMBLE_PATH = f"{PATH_WORK}/train_nn/TropiSAGE_ensemble_ultraF_{DATE}"
    ENSEMBLE_PATH_log = f"{PATH_WORK}/train_nn/TropiSAGE_ensemble_ultraF_{DATE}_log"
    OPTUNA_PATH = f"{PATH_WORK}/train_nn/ensemble_20112024_log_optimized_SAGE_ultraF"

os.makedirs(ENSEMBLE_PATH, exist_ok=True)
os.makedirs(ENSEMBLE_PATH_log, exist_ok=True)

DICO_OPTUNA = {}    
for file in os.listdir(OPTUNA_PATH):
    kl_type = file.split("_")[0]
    best_parameters = json.load(open(f"{OPTUNA_PATH}/{file}").read())
    DICO_OPTUNA[kl_type] = best_parameters
    

def load_and_preprocess_data():
    """Load and preprocess the prophage data."""
    df_info = pd.read_csv(f"{PATH_WORK}/train_nn/TropiGATv2.final_df_v2.tsv", sep="\t", header=0)
    df_info = df_info.drop_duplicates(subset=["Protein_name"])
    
    df_prophages = df_info.drop_duplicates(subset=["Phage"], keep="first")
    dico_prophage_info = {row["Phage"]: {"prophage_strain": row["prophage_id"], "ancestor": row["Infected_ancestor"]} 
                          for _, row in df_prophages.iterrows()}
    
    return df_info, dico_prophage_info

def filter_prophages(df_info, dico_prophage_info):
    """Filter prophages to remove duplicates and ensure diversity."""
    def get_filtered_prophages(prophage):
        combinations = []
        to_exclude = set()
        to_keep = set()
        to_keep.add(prophage)
        df_prophage_group = df_info[
            (df_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & 
            (df_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])
        ]
        if len(df_prophage_group) == 1:
            return df_prophage_group, to_exclude, to_keep
        
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique():
            if prophage_tmp != prophage:
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set:
                    to_exclude.add(prophage_tmp)
                elif tmp_depo_set not in combinations:
                    to_keep.add(prophage_tmp)
                    combinations.append(tmp_depo_set)
                else:
                    to_exclude.add(prophage_tmp)
        return df_prophage_group, to_exclude, to_keep

    good_prophages = set()
    excluded_prophages = set()

    for prophage in tqdm(dico_prophage_info.keys()):
        if prophage not in excluded_prophages and prophage not in good_prophages:
            _, excluded_members, kept_members = get_filtered_prophages(prophage)
            good_prophages.update(kept_members)
            excluded_prophages.update(excluded_members)

    df_info_filtered = df_info[df_info["Phage"].isin(good_prophages)]
    df_info_final = df_info_filtered[~df_info_filtered["KL_type_LCA"].str.contains("\\|")]

    return df_info_final


def ultrafilter_prophages(df_info):
    """Perform ultra-filtration to remove duplicate prophages within KL types."""
    duplicate_prophage = []
    dico_kltype_duplica = {}

    for kltype in df_info["KL_type_LCA"].unique():
        df_kl = df_info[df_info["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
        prophages_tmp_list = df_kl["Phage"].unique().tolist()
        set_sets_depo = []
        duplicated = {}  
        for prophage_tmp in prophages_tmp_list: 
            set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
            for past_set in set_sets_depo:
                if past_set == set_depo:
                    duplicated[past_set] = duplicated.get(past_set, 0) + 1
                    duplicate_prophage.append(prophage_tmp)
                    break
            else:
                set_sets_depo.append(set_depo)
                duplicated[set_depo] = 1
        dico_kltype_duplica[kltype] = duplicated

    df_info_ultrafiltered = df_info[~df_info["Phage"].isin(duplicate_prophage)]
    return df_info_ultrafiltered


def prepare_kltypes(df_info):
    """Prepare KL types for training."""
    df_prophages = df_info.drop_duplicates(subset=["Phage"])
    dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))
    kltypes = [kltype for kltype, count in dico_prophage_count.items() if count >= 10]
    return kltypes, dico_prophage_count


class TropiGAT_small_sage_module(torch.nn.Module):
    def __init__(self,hidden_channels, edge_type = ("B2", "expressed", "B1") ,dropout = 0.2, conv = SAGEConv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480 , 1))
        
    def forward(self, graph_data):
        x_B1_dict  = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1)
    

def train_graph(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info, dico_prophage_count):
    """Train the graph neural network for a specific KL type."""
    for seed in range(1, 6):
        if os.path.isfile(f"{ENSEMBLE_PATH_log}/{kl_type}__{seed}__node_classification.{DATE}.log") == False :
            log_file = f"{ENSEMBLE_PATH_log}/{kl_type}__{seed}__node_classification.{DATE}.log"
            with open(log_file, "w") as log_outfile:
                try:
                    n_prophage = dico_prophage_count[kl_type]
                    graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(
                        graph_baseline, dico_prophage_kltype_associated, df_info, 
                        kl_type, 5, 0.7, 0.2, 0.1, seed=seed
                    )
                    model = TropiGAT_small_sage_module(
                        hidden_channels = 1280, 
                        dropout = DICO_OPTUNA[kl_type]["dropout"], 
                    )

                    model(graph_data_kltype)

                    optimizer = torch.optim.AdamW(
                        model.parameters(), 
                        lr = DICO_OPTUNA[kl_type]["lr"], 
                        weight_decay = DICO_OPTUNA[kl_type]["weight_decay"]
                    )
                    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, min_lr=1e-6)
                    criterion = torch.nn.BCEWithLogitsLoss()
                    early_stopping = TropiGAT_models.EarlyStopping(
                        patience=100, 
                        verbose=True, 
                        path=f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiSAGE.{DATE}.pt", 
                        metric='MCC'
                    )

                    for epoch in range(500):
                        train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer, criterion)
                        if epoch % 5 == 0:
                            test_loss, metrics = TropiGAT_models.evaluate(
                                model, graph_data_kltype, criterion, graph_data_kltype["B1"].test_mask
                            )
                            log_outfile.write(f'Epoch: {epoch}\tTrain Loss: {train_loss:.4f}\t'
                                              f'Test Loss: {test_loss:.4f}\tMCC: {metrics[3]:.4f}\t'
                                              f'AUC: {metrics[5]:.4f}\tAccuracy: {metrics[4]:.4f}\n')
                            scheduler.step(test_loss)
                            early_stopping(metrics[3], model, epoch)
                            if early_stopping.early_stop:
                                log_outfile.write(f"Early stopping at epoch {epoch}\n")
                                break
                    else:
                        torch.save(model, f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiSAGE.{DATE}.pt")

                    # Final evaluation
                    model_final = TropiGAT_small_sage_module(
                        hidden_channels = 1280, 
                        dropout = DICO_OPTUNA[kl_type]["dropout"], 
                    )
                    model_final.load_state_dict(torch.load(f"{ENSEMBLE_PATH}/{kl_type}__{seed}.TropiSAGE.{DATE}.pt"))
                    eval_loss, metrics = TropiGAT_models.evaluate(
                        model_final, graph_data_kltype, criterion, graph_data_kltype["B1"].eval_mask
                    )

                    with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                        metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t"
                                             f"{metrics[0]:.4f}\t{metrics[1]:.4f}\t{metrics[2]:.4f}\t"
                                             f"{metrics[3]:.4f}\t{metrics[4]:.4f}\t{metrics[5]:.4f}\n")

                    log_outfile.write(f"Final evaluation:\n"
                                      f"F1 Score: {metrics[0]:.4f}, Precision: {metrics[1]:.4f}, "
                                      f"Recall: {metrics[2]:.4f}, MCC: {metrics[3]:.4f}, "
                                      f"Accuracy: {metrics[4]:.4f}, AUC: {metrics[5]:.4f}")

                except Exception as e:
                    log_outfile.write(f"Error occurred: {str(e)}")
                    with open(f"{ENSEMBLE_PATH_log}/Metric_Report.{DATE}.tsv", "a+") as metric_outfile:
                        metric_outfile.write(f"{kl_type}__{seed}\t{n_prophage}\t***Error***\n")

def main():
    """Main function to orchestrate the TropiSAGE workflow."""
    df_info, dico_prophage_info = load_and_preprocess_data()
    if ultrafiltration == True :
        df_info_final = ultrafilter_prophages(filter_prophages(df_info, dico_prophage_info))
    else :
        df_info_final = filter_prophages(df_info, dico_prophage_info)
    
    kltypes, dico_prophage_count = prepare_kltypes(df_info_filtered)
    graph_baseline, dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(df_info_final)
    
    with ThreadPool(5) as p:
        p.starmap(train_graph, [(kl_type, graph_baseline, dico_prophage_kltype_associated, df_info_final, dico_prophage_count) for kl_type in kltypes])

if __name__ == '__main__':
    main()
