In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import optuna
from optuna.samplers import TPESampler
import warnings
import logging
import sys

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models


warnings.filterwarnings("ignore")
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn"

DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)
DF_info = DF_info.drop_duplicates(subset = ["Protein_name"])
DF_info = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]

df_prophages = DF_info.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_info = {row["Phage"] : {"prophage_strain" : row["prophage_id"] , "ancestor" : row["Infected_ancestor"], "KL_type" : row["KL_type_LCA"]} for _,row in df_prophages.iterrows()}

# *****************************************************************************
# The model : Classifier
class TropiGAT_small_module(torch.nn.Module):
    def __init__(self,hidden_channels, heads, edge_type = ("B2", "expressed", "B1") ,dropout = 0.2, conv = GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480 , 1))
        
    def forward(self, graph_data):
        x_B1_dict  = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1) 

# *****************************************************************************
# Pre-process data :
# First filtration step :
def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 :
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage)
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]
DF_info_lvl_0_final = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["KL_type_LCA"].str.contains("\\|")]


# Second filtration step :
duplicate_prophage = []
dico_kltype_duplica = {}
for kltype in DF_info_lvl_0_final["KL_type_LCA"].unique():
    df_kl = DF_info_lvl_0_final[DF_info_lvl_0_final["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
    prophages_tmp_list = df_kl["Phage"].unique().tolist()
    set_sets_depo = []
    duplicated = {}  
    for prophage_tmp in prophages_tmp_list: 
        set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
        for past_set in set_sets_depo:
            if past_set == set_depo:
                duplicated[past_set] = duplicated.get(past_set, 0) + 1
                duplicate_prophage.append(prophage_tmp)
                break
        else:
            set_sets_depo.append(set_depo)
            duplicated[set_depo] = 1
    dico_kltype_duplica[kltype] = duplicated
    
DF_info_lvl_0_final_ultrafiltered = DF_info_lvl_0_final[~DF_info_lvl_0_final["Phage"].isin(duplicate_prophage)]
DF_info_lvl_0 = DF_info_lvl_0_final_ultrafiltered.copy()

# Input graph: 
# graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
# graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = seed)

# *****************************************************************************
# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class TropiGAT_small_module(torch.nn.Module):
    def __init__(self, hidden_channels, heads, edge_type=("B2", "expressed", "B1"), dropout=0.2, conv=GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1, -1), hidden_channels, add_self_loops=False, heads=heads, dropout=dropout, shared_weights=True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers :
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480, 1))

    def forward(self, graph_data):
        x_B1_dict = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1)

class EarlyStopping:
    def __init__(self, patience=60, verbose=True, path='best_model.pt', delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

@torch.no_grad()
def evaluate(model, graph, criterion, mask):
    model.eval()
    out_eval = model(graph)
    pred = torch.sigmoid(out_eval[mask]).round()
    labels = graph["B1"].y[mask]
    val_loss = criterion(out_eval[mask], labels.float())

    # Calculate the metrics
    f1 = f1_score(labels.cpu(), pred.cpu(), average='binary')
    precision = precision_score(labels.cpu(), pred.cpu(), average='binary')
    recall = recall_score(labels.cpu(), pred.cpu(), average='binary')
    mcc = matthews_corrcoef(labels.cpu(), pred.cpu())
    accuracy = accuracy_score(labels.cpu(), pred.cpu())
    auc = roc_auc_score(labels.cpu(), out_eval[mask].cpu())

    return val_loss.item(), (f1, precision, recall, mcc, accuracy, auc)

def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data["B1"].train_mask], data["B1"].y[data["B1"].train_mask].float())
    loss.backward()
    optimizer.step()
    return loss.item()

def objective(trial):
    try:
        # Define the hyperparameters
        lr = trial.suggest_loguniform('lr', 1e-6, 1e-3)
        weight_decay = trial.suggest_loguniform('weight_decay', 1e-7, 1e-4)
        heads = trial.suggest_int('heads', 1, 6, step=1)
        dropout = trial.suggest_uniform('dropout', 0, 0.5)

        # Fixed hidden channels
        hidden_channels = 1280

        # Define the model
        model = TropiGAT_small_module(hidden_channels=hidden_channels, heads=heads, dropout=dropout)
        # Input graph:
        graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
        graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = 243)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        early_stopping = EarlyStopping(patience=100, verbose=True, path=f"best_model_trial_{trial.number}_{KL_type}.pt")

        best_val_loss = float('inf')

        for epoch in range(500):
            train_loss = train(model, graph_data_kltype, optimizer, criterion)
            val_loss, metrics = evaluate(model, graph_data_kltype, criterion, graph_data_kltype["B1"].test_mask)

            scheduler.step(val_loss)
            early_stopping(val_loss, model)

            if val_loss < best_val_loss:
                best_val_loss = val_loss

            if early_stopping.early_stop:
                logging.info(f"Early stopping triggered at epoch {epoch}")
                break

        # Final evaluation
        model.load_state_dict(torch.load(f"best_model_trial_{trial.number}_{KL_type}.pt"))
        final_val_loss, metrics = evaluate(model, graph_data_kltype, criterion, graph_data_kltype["B1"].eval_mask)

        # Log the results
        logging.info(f"Trial {trial.number}: Val Loss: {final_val_loss:.4f}, MCC: {metrics[3]:.4f}, AUC: {metrics[5]:.4f}")

        return final_val_loss  # Return validation loss for minimization
    except Exception as e:
        logging.error(f"Error in trial {trial.number}: {str(e)}")
        raise optuna.exceptions.TrialPruned()

# Optimize
KL_type = "KL10"

logging.basicConfig(filename = f"{path_work}/GATv2Conv.{KL_type}.loss.1209.optuna.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')
logging.info("Starting hyperparameter optimization")
study = optuna.create_study(sampler=TPESampler(), direction='minimize')


try:
    study.optimize(objective, n_trials=200, n_jobs=-1, catch=(Exception,))

    if study.best_trial is not None:
        print(f"Best parameters: {study.best_params}")
        logging.info(f"Best parameters: {study.best_params}")
        best_trial = study.best_trial
        print(f"Best trial: Val Loss: {best_trial.value:.4f}")
        logging.info(f"Best trial: Val Loss: {best_trial.value:.4f}")

        # Optionally, you can retrain the model with the best parameters and evaluate on the test set
        best_model = TropiGAT_small_module(hidden_channels=1280, heads=best_trial.params['heads'], dropout=best_trial.params['dropout'])
        # ... (retrain with best parameters)
        # final_test_loss, final_metrics = evaluate(best_model, graph_data, criterion, graph_data["B1"].test_mask)
        # print(f"Final Test Results: Loss: {final_test_loss:.4f}, MCC: {final_metrics[3]:.4f}, AUC: {final_metrics[5]:.4f}")
    else:
        logging.warning("No trials were successfully completed.")
        print("No trials were successfully completed. Check the logs for more information.")

except Exception as e:
    logging.error(f"An error occurred during optimization: {str(e)}")
    print(f"An error occurred during optimization. Check the logs for more information.")

finally:
    # Print a summary of the study
    trial_data = study.trials_dataframe()
    if not trial_data.empty:
        print("\nStudy Summary:")
        print(f"Number of completed trials: {len(study.trials)}")
        print(f"Number of pruned trials: {len(study.get_trials(states=[optuna.trial.TrialState.PRUNED]))}")
        print(f"Number of completed trials: {len(study.get_trials(states=[optuna.trial.TrialState.COMPLETE]))}")
        print("\nBest 5 trials:")
        print(trial_data.sort_values('value').head())
    else:
        print("No trial data available.")

# Save the study results
study.trials_dataframe().to_csv(f"{path_work}/optuna_study_results.{KL_type}.1209.csv", index=False)
logging.info("Study results saved to 'optuna_study_results.{KL_type}.1209.csv'")
#best_model = TropiGAT_small_module(hidden_channels=1280, heads=best_trial.params['heads'], dropout=best_trial.params['dropout'])

