In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import optuna
from optuna.samplers import TPESampler
import warnings
import logging
import sys

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models


warnings.filterwarnings("ignore")
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn"

DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)
DF_info = DF_info.drop_duplicates(subset = ["Protein_name"])
DF_info = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]

df_prophages = DF_info.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_info = {row["Phage"] : {"prophage_strain" : row["prophage_id"] , "ancestor" : row["Infected_ancestor"], "KL_type" : row["KL_type_LCA"]} for _,row in df_prophages.iterrows()}

# *****************************************************************************
# The model : Classifier
class TropiGAT_small_module(torch.nn.Module):
    def __init__(self,hidden_channels, heads, edge_type = ("B2", "expressed", "B1") ,dropout = 0.2, conv = GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480 , 1))
        
    def forward(self, graph_data):
        x_B1_dict  = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1) 


In [None]:
# *****************************************************************************
# Pre-process data :
# First filtration step :
def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 :
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage)
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]
DF_info_lvl_0_final = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["KL_type_LCA"].str.contains("\\|")]


# Second filtration step :
duplicate_prophage = []
dico_kltype_duplica = {}
for kltype in DF_info_lvl_0_final["KL_type_LCA"].unique():
    df_kl = DF_info_lvl_0_final[DF_info_lvl_0_final["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
    prophages_tmp_list = df_kl["Phage"].unique().tolist()
    set_sets_depo = []
    duplicated = {}  
    for prophage_tmp in prophages_tmp_list: 
        set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
        for past_set in set_sets_depo:
            if past_set == set_depo:
                duplicated[past_set] = duplicated.get(past_set, 0) + 1
                duplicate_prophage.append(prophage_tmp)
                break
        else:
            set_sets_depo.append(set_depo)
            duplicated[set_depo] = 1
    dico_kltype_duplica[kltype] = duplicated
    
DF_info_lvl_0_final_ultrafiltered = DF_info_lvl_0_final[~DF_info_lvl_0_final["Phage"].isin(duplicate_prophage)]
DF_info_lvl_0 = DF_info_lvl_0_final_ultrafiltered.copy()

# Input graph: 
# graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
# graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = seed)



In [None]:
# *****************************************************************************
# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class TropiGAT_small_module(torch.nn.Module):
    def __init__(self, hidden_channels, heads, edge_type=("B2", "expressed", "B1"), dropout=0.2, conv=GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1, -1), hidden_channels, add_self_loops=False, heads=heads, dropout=dropout, shared_weights=True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers :
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480, 1))

    def forward(self, graph_data):
        x_B1_dict = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1)

class EarlyStopping:
    def __init__(self, patience=60, verbose=True, path='best_model.pt', delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

@torch.no_grad()
def evaluate(model, graph, criterion, mask):
    model.eval()
    out_eval = model(graph)
    pred = torch.sigmoid(out_eval[mask]).round()
    labels = graph["B1"].y[mask]
    val_loss = criterion(out_eval[mask], labels.float())

    # Calculate the metrics
    f1 = f1_score(labels.cpu(), pred.cpu(), average='binary')
    precision = precision_score(labels.cpu(), pred.cpu(), average='binary')
    recall = recall_score(labels.cpu(), pred.cpu(), average='binary')
    mcc = matthews_corrcoef(labels.cpu(), pred.cpu())
    accuracy = accuracy_score(labels.cpu(), pred.cpu())
    auc = roc_auc_score(labels.cpu(), out_eval[mask].cpu())

    return val_loss.item(), (f1, precision, recall, mcc, accuracy, auc)

def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data["B1"].train_mask], data["B1"].y[data["B1"].train_mask].float())
    loss.backward()
    optimizer.step()
    return loss.item()

def objective(trial):
    try:
        # Define the hyperparameters:
        lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
        weight_decay = trial.suggest_loguniform('weight_decay', 1e-6, 1e-4)
        heads = trial.suggest_int('heads', 1, 5, step=1)
        dropout = trial.suggest_uniform('dropout', 0, 0.5)
        # Fixed hidden channels:
        hidden_channels = 1280
        # Define the model:
        model = TropiGAT_small_module(hidden_channels=hidden_channels, heads=heads, dropout=dropout)
        # Input graph:
        graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
        graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = 243)
        # set up the training: 
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        early_stopping = EarlyStopping(patience=100, verbose=True, path=f"best_model_trial_{trial.number}_{KL_type}.pt")
        # training: 
        best_val_loss = float('inf')
        for epoch in range(500):
            train_loss = train(model, graph_data_kltype, optimizer, criterion)
            val_loss, metrics = evaluate(model, graph_data_kltype, criterion, graph_data_kltype["B1"].test_mask)

            scheduler.step(val_loss)
            early_stopping(val_loss, model)

            if val_loss < best_val_loss:
                best_val_loss = val_loss

            if early_stopping.early_stop:
                logging.info(f"Early stopping triggered at epoch {epoch}")
                break

        # Final evaluation
        model.load_state_dict(torch.load(f"best_model_trial_{trial.number}_{KL_type}.pt"))
        final_val_loss, metrics = evaluate(model, graph_data_kltype, criterion, graph_data_kltype["B1"].eval_mask)

        # Log the results
        logging.info(f"Trial {trial.number}: Val Loss: {final_val_loss:.4f}, MCC: {metrics[3]:.4f}, AUC: {metrics[5]:.4f}")

        return final_val_loss  # Return validation loss for minimization
    except Exception as e:
        logging.error(f"Error in trial {trial.number}: {str(e)}")
        raise optuna.exceptions.TrialPruned()

# Optimize
KL_type = "KL1"

logging.basicConfig(filename = f"{path_work}/GATv2Conv.{KL_type}.loss.1209.optuna.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')
logging.info("Starting hyperparameter optimization")
study = optuna.create_study(sampler=TPESampler(), direction='minimize')


try:
    study.optimize(objective, n_trials=100, n_jobs=-1, catch=(Exception,))

    if study.best_trial is not None:
        print(f"Best parameters: {study.best_params}")
        logging.info(f"Best parameters: {study.best_params}")
        best_trial = study.best_trial
        print(f"Best trial: Val Loss: {best_trial.value:.4f}")
        logging.info(f"Best trial: Val Loss: {best_trial.value:.4f}")
        # Optionally, you can retrain the model with the best parameters and evaluate on the test set
        best_model = TropiGAT_small_module(hidden_channels=1280, heads=best_trial.params['heads'], dropout=best_trial.params['dropout'])
        # ... (retrain with best parameters)
        # final_test_loss, final_metrics = evaluate(best_model, graph_data, criterion, graph_data["B1"].test_mask)
        # print(f"Final Test Results: Loss: {final_test_loss:.4f}, MCC: {final_metrics[3]:.4f}, AUC: {final_metrics[5]:.4f}")
    else:
        logging.warning("No trials were successfully completed.")
        print("No trials were successfully completed. Check the logs for more information.")

except Exception as e:
    logging.error(f"An error occurred during optimization: {str(e)}")
    print(f"An error occurred during optimization. Check the logs for more information.")

finally:
    # Print a summary of the study
    trial_data = study.trials_dataframe()
    if not trial_data.empty:
        print("\nStudy Summary:")
        print(f"Number of completed trials: {len(study.trials)}")
        print(f"Number of pruned trials: {len(study.get_trials(states=[optuna.trial.TrialState.PRUNED]))}")
        print(f"Number of completed trials: {len(study.get_trials(states=[optuna.trial.TrialState.COMPLETE]))}")
        print("\nBest 5 trials:")
        print(trial_data.sort_values('value').head())
    else:
        print("No trial data available.")

# Save the study results
study.trials_dataframe().to_csv(f"{path_work}/optuna_study_results.{KL_type}.1209.csv", index=False)
logging.info("Study results saved to 'optuna_study_results.{KL_type}.1209.csv'")
#best_model = TropiGAT_small_module(hidden_channels=1280, heads=best_trial.params['heads'], dropout=best_trial.params['dropout'])



In [None]:
#!/bin/bash
#BATCH --job-name=Optuna_1009_
#SBATCH --qos=short
#SBATCH --ntasks=1 
#SBATCH --cpus-per-task=32
#SBATCH --mem=120gb 
#SBATCH --time=2-00:00:00 
#SBATCH --output=Optuna_1009_%j.log 

module restore la_base
conda activate torch_geometric

python /home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/script_files/optuna_GATv2Conv_Hetero.loss.1009.py


In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/optuna_study_results.KL1.1209.csv \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/optuna_study_results.KL64.1209.csv \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023


rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/optuna_study_results.KL17.1209.csv \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

In [29]:
import pandas as pd
import os

path_df = "/media/concha-eloko/Linux/PPT_clean/ficheros_28032023"

kl_64_df = pd.read_csv(f"{path_df}/optuna_study_results.KL64.1209.csv", header = 0, sep = ",", index_col = ["number"])
kl_1_df = pd.read_csv(f"{path_df}/optuna_study_results.KL1.1209.csv", header = 0, sep = ",", index_col = ["number"])
kl_17_df = pd.read_csv(f"{path_df}/optuna_study_results.KL17.1209.csv", header = 0, sep = ",", index_col = ["number"]).fillna(1)


In [30]:
kl_64_df

Unnamed: 0_level_0,value,datetime_start,datetime_complete,duration,params_dropout,params_heads,params_lr,params_weight_decay,state
number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.214616,2024-09-11 13:13:15.012388,2024-09-11 16:02:22.856435,0 days 02:49:07.844047,0.063113,2,0.000246,0.000080,COMPLETE
1,0.288147,2024-09-11 13:13:15.015634,2024-09-11 16:02:19.889560,0 days 02:49:04.873926,0.387089,4,0.000377,0.000007,COMPLETE
2,0.402074,2024-09-11 13:13:15.017499,2024-09-11 15:40:39.095528,0 days 02:27:24.078029,0.200835,1,0.000033,0.000056,COMPLETE
3,0.341281,2024-09-11 13:13:15.021601,2024-09-11 16:06:48.181022,0 days 02:53:33.159421,0.078880,3,0.000106,0.000003,COMPLETE
4,0.453863,2024-09-11 13:13:15.022937,2024-09-11 15:18:59.106928,0 days 02:05:44.083991,0.135402,1,0.000014,0.000005,COMPLETE
...,...,...,...,...,...,...,...,...,...
95,0.291542,2024-09-11 15:28:52.372036,2024-09-11 16:42:31.864205,0 days 01:13:39.492169,0.485289,2,0.000534,0.000055,COMPLETE
96,0.292580,2024-09-11 15:32:57.953364,2024-09-11 16:44:26.652001,0 days 01:11:28.698637,0.484483,2,0.000561,0.000003,COMPLETE
97,0.267515,2024-09-11 15:34:28.631242,2024-09-11 16:43:48.094296,0 days 01:09:19.463054,0.494634,2,0.000528,0.000002,COMPLETE
98,0.295961,2024-09-11 15:40:39.597347,2024-09-11 16:43:01.275574,0 days 01:02:21.678227,0.484797,2,0.000446,0.000002,COMPLETE


In [12]:
kl_64_sorted = kl_64_df.sort_values(by='value', ascending=True)
kl_64_sorted

Unnamed: 0_level_0,value,datetime_start,datetime_complete,duration,params_dropout,params_heads,params_lr,params_weight_decay,state
number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.214616,2024-09-11 13:13:15.012388,2024-09-11 16:02:22.856435,0 days 02:49:07.844047,0.063113,2,0.000246,0.000080,COMPLETE
76,0.225334,2024-09-11 13:13:15.142941,2024-09-11 15:52:34.947279,0 days 02:39:19.804338,0.274048,5,0.000348,0.000001,COMPLETE
54,0.237554,2024-09-11 13:13:15.110488,2024-09-11 15:32:57.287727,0 days 02:19:42.177239,0.403687,1,0.000653,0.000001,COMPLETE
36,0.239207,2024-09-11 13:13:15.083388,2024-09-11 15:01:26.942961,0 days 01:48:11.859573,0.441065,1,0.000946,0.000007,COMPLETE
58,0.251839,2024-09-11 13:13:15.124498,2024-09-11 15:15:26.883346,0 days 02:02:11.758848,0.141927,1,0.000513,0.000023,COMPLETE
...,...,...,...,...,...,...,...,...,...
25,0.675084,2024-09-11 13:13:15.063330,2024-09-11 15:57:11.447594,0 days 02:43:56.384264,0.154847,4,0.000011,0.000015,COMPLETE
22,0.685594,2024-09-11 13:13:15.057566,2024-09-11 15:24:53.420901,0 days 02:11:38.363335,0.262334,4,0.000015,0.000025,COMPLETE
41,0.687717,2024-09-11 13:13:15.092231,2024-09-11 15:43:50.830644,0 days 02:30:35.738413,0.294077,5,0.000063,0.000021,COMPLETE
65,0.693526,2024-09-11 13:13:15.130973,2024-09-11 16:02:47.585835,0 days 02:49:32.454862,0.215736,3,0.000021,0.000070,COMPLETE


In [None]:
kl_64_sorted[""]

In [14]:
kl_1_sorted = kl_1_df.sort_values(by='value', ascending=True)
kl_1_sorted

Unnamed: 0_level_0,value,datetime_start,datetime_complete,duration,params_dropout,params_heads,params_lr,params_weight_decay,state
number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
82,0.322256,2024-09-11 15:02:17.661959,2024-09-11 16:28:02.713687,0 days 01:25:45.051728,0.007658,5,0.000942,0.000001,COMPLETE
84,0.356998,2024-09-11 15:04:22.007250,2024-09-11 16:30:23.946330,0 days 01:26:01.939080,0.015245,5,0.000976,0.000001,COMPLETE
17,0.360367,2024-09-11 13:13:39.797088,2024-09-11 15:12:12.441510,0 days 01:58:32.644422,0.466263,4,0.000855,0.000080,COMPLETE
21,0.366745,2024-09-11 13:13:39.801364,2024-09-11 15:04:22.811193,0 days 01:50:43.009829,0.159013,5,0.000293,0.000002,COMPLETE
10,0.370765,2024-09-11 13:13:39.787550,2024-09-11 15:04:21.406635,0 days 01:50:41.619085,0.044992,5,0.000749,0.000084,COMPLETE
...,...,...,...,...,...,...,...,...,...
27,0.705620,2024-09-11 13:13:39.809160,2024-09-11 14:33:09.197650,0 days 01:19:29.388490,0.308462,1,0.000013,0.000002,COMPLETE
69,,2024-09-11 14:30:24.441561,2024-09-11 15:29:20.796954,0 days 00:58:56.355393,0.151598,4,0.000075,0.000039,PRUNED
74,,2024-09-11 14:44:34.472092,2024-09-11 15:33:23.910163,0 days 00:48:49.438071,0.481098,2,0.000918,0.000095,PRUNED
81,,2024-09-11 15:00:58.560005,2024-09-11 15:34:56.648735,0 days 00:33:58.088730,0.003221,5,0.000815,0.000001,PRUNED


In [35]:
import numpy as np

kl_17_sorted = kl_17_df.sort_values(by='value', ascending=True)
#kl_17_sorted[kl_17_sorted["value"].astype(np.int64).isinteger()]
kl_17_sorted
print(kl_17_sorted.head(10))

           value              datetime_start           datetime_complete  \
number                                                                     
48      0.112556  2024-09-12 14:32:46.673197  2024-09-12 16:33:16.399220   
75      0.132413  2024-09-12 16:02:30.164488  2024-09-12 18:56:30.528288   
78      0.138315  2024-09-12 16:08:11.195608  2024-09-12 18:57:03.477956   
121     0.139090  2024-09-12 17:35:12.034846  2024-09-12 19:30:05.549491   
157     0.141805  2024-09-12 18:58:42.520549  2024-09-12 20:45:04.126730   
77      0.145535  2024-09-12 16:08:07.208407  2024-09-12 18:56:36.216733   
164     0.146187  2024-09-12 19:07:47.717052  2024-09-12 20:41:22.904269   
88      0.146418  2024-09-12 16:32:15.918852  2024-09-12 18:58:42.049945   
60      0.148787  2024-09-12 14:32:46.683049  2024-09-12 16:47:46.937569   
165     0.150669  2024-09-12 19:07:58.615268  2024-09-12 20:44:58.370366   

                      duration  params_dropout  params_heads  params_lr  \
number      