In [None]:
# Torch geometric modules
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

# Torch modules
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# SKlearn modules
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")


In [None]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DF_info = pd.read_csv(f"{path_work}/train_nn/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

# Log file : 
path_ensemble = f"{path_work}/train_nn/ensemble_2709"

df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]


In [None]:
# *****************************************************************************
# Make graphs : 
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
graph_dico = {kltype : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated,DF_info_lvl_0, kltype, 5, 0.8, 0.1, 0.1) 
             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}



In [None]:
# *****************************************************************************
def train_graph(KL_type) :
    with open(f"{path_work}/train_nn/ensemble_2709_log_files/{KL_type}__node_classification.2705.log" , "w") as log_outfile :
        n_prophage = dico_prophage_count[KL_type]
        graph_data_kltype = graph_dico[KL_type]
        if n_prophage <= 125 : 
            model = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280, heads = 1)
            n = "small"
        else : 
            model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
            n = "big"
        model(graph_data_kltype)
        optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        criterion = torch.nn.BCEWithLogitsLoss()
        early_stopping = TropiGAT_models.EarlyStopping(patience=40, verbose=True, path=f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt", metric='MCC')
        try : 
            for epoch in range(200):
                train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer,criterion)
                if epoch % 5 == 0:
                    # Get all metrics
                    test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype,criterion, graph_data_kltype["B1"].test_mask)
                    info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}\n'
                    info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
                    log_outfile.write(info_training_concise)
                    print(info_training)
                    scheduler.step(test_loss)
                    early_stopping(metrics[3], model, epoch)
                    if early_stopping.early_stop:
                        log_outfile.write(f"Early stopping at epoch = {epoch}")
                        break
            else :
                torch.save(model, f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt")
            # The final eval :
            print("Final evaluation ...")
            if n == "small" : 
                model_final = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280, heads = 1)
            else :
                model_final = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
            model_final.load_state_dict(torch.load(f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt"))
            eval_loss, metrics = TropiGAT_models.evaluate(model_final, graph_data_kltype, criterion,graph_data_kltype["B1"].eval_mask)
            with open(f"{path_ensemble}/Metric_Report.2709.tsv", "a+") as metric_outfile :
                metric_outfile.write(f"{KL_type}\t{n_prophage}\t{metrics[0]}\t{metrics[1]}\t{metrics[2]}\t{metrics[3]}\t{metrics[4]}\t{metrics[5]}\n")
            info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            print(info_eval)
            log_outfile.write(f"Final evaluation ...\n{info_eval}")
        except Exception as e :
            log_outfile.write(f"***Issue here : {e}")
            with open(f"{path_ensemble}/Metric_Report.2709.tsv", "a+") as metric_outfile :
                n_prophage = dico_prophage_count[KL_type]
                metric_outfile.write(f"{KL_type}\t{n_prophage}\t***Issue***\n")
            


if __name__ == '__main__':
    with ThreadPool(15) as p:
        p.map(train_graph, KLtypes)


> V1.1

In [None]:
def train_graph(KL_type) :
    logging.basicConfig(filename = f"{path_work}/train_nn/ensemble_2509_log_files/{KL_type}__node_classification.2505.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')
    logging.info(f"***{KL_type}___")
    graph_data_kltype = graph_dico[KL_type]
    model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
    model(graph_data_kltype)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    criterion = torch.nn.BCEWithLogitsLoss()
    early_stopping = TropiGAT_models.EarlyStopping(patience=20, verbose=True, path=f"{path_ensemble}/{KL_type}.TropiGATv2.2509.pt", metric='MCC')
    for epoch in range(100):
        train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer,criterion)
        if epoch % 5 == 0:
            # Get all metrics
            test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype,criterion, graph_data_kltype["B1"].test_mask)
            info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}'
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            logging.info(info_training_concise)
            print(info_training)
            scheduler.step(test_loss)
            early_stopping(metrics[3], model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
    # The final eval :
    print("Final evaluation ...")
    model_final = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1))
    model_final.load_state_dict(torch.load(f"{path_ensemble}/{KL_type}.TropiGATv2.2509.pt")
    test_loss, metrics = TropiGAT_models.evaluate(model_final, graph_data_kltype, criterion,graph_data_kltype["B1"].eval_mask)
    with open(f"{path_ensemble}/Metric_Report.2609", "a+") as metric_outfile : 
        outfile.write(f"{KL_type}\t{metrics[0]}\t{metrics[1]}\t{metrics[2]}\t{metrics[3]}\t{metrics[4]}\t{metrics[5]}\n")
    info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
    print(info_eval)
    logging.info(f"Final evaluation ...\n{info_eval}")
    # Close the existing handlers
    handlers = logging.root.handlers[:]
    for handler in handlers:
        handler.close()
        logging.root.removeHandler(handler)


if __name__ == "__main__":
    wrong_kltypes = []
    with open(f"{path_ensemble}/report.2609", "a+") as outfile : 
        for kltype in graph_dico :
            try : 
                train_graph(kltype)
            except Exception as e :
                wrong_kltypes.append(kltype)
        outfile.write(f"Something wrong with those KLtypes: {str(wrong_kltypes)}")

> Original

In [None]:
def train_graph() :
    KL_type = "KL64"
    graph_data_kltype = graph_dico[KL_type]
    graph_data_kltype = TropiGAT_graph.build_graph_masking(graph_baseline , 
                                 dico_prophage_kltype_associated, 
                                 DF_info_lvl_0, 
                                 "KL64",
                                 5, 0.8, 0.1,0.1)
    #logging.info(f"Let's start the work with {conv}\t{hidden_channels}\t{dropout}\t{lr}\t{heads}")
    model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
    model(graph_data_kltype)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    criterion = torch.nn.BCEWithLogitsLoss()
    for epoch in range(100):
        train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer,criterion)
        if epoch % 5 == 0:
            # Get all metrics
            test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype,criterion, graph_data_kltype["B1"].test_mask)
            info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}'
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            logging.info(info_training_concise)
            print(info_training)
            scheduler.step(test_loss)
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/{KL_type}.TropiGATv2.2509.pt")
    # The final eval :
    print("Final evaluation ...")
    test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype, criterion,graph_data_kltype["B1"].eval_mask)
    info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
    print(info_eval)
    logging.info(f"Final evaluation ...\n{info_eval}")


if __name__ == "__main__":
    train_graph()
