In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

#import TropiGAT_functions 
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_2809"

In [None]:
def train_graph(KL_type) :
    with open(f"{path_work}/train_nn/ensemble_2709_log_files/{KL_type}__node_classification.2705.log" , "w") as log_outfile :
        n_prophage = dico_prophage_count[KL_type]
        graph_data_kltype = graph_dico[KL_type]
        model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
        model(graph_data_kltype)
        optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        criterion = torch.nn.BCEWithLogitsLoss()
        early_stopping = TropiGAT_models.EarlyStopping(patience=40, verbose=True, path=f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt", metric='MCC')
        try : 
            for epoch in range(200):
                train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer,criterion)
                if epoch % 5 == 0:
                    # Get all metrics
                    test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype,criterion, graph_data_kltype["B1"].test_mask)
                    info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}\n'
                    info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
                    log_outfile.write(info_training_concise)
                    print(info_training)
                    scheduler.step(test_loss)
                torch.save(model, f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt")
            # The final eval :
            print("Final evaluation ...")
            model_final = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
            model_final.load_state_dict(torch.load(f"{path_ensemble}/{KL_type}.TropiGATv2.2709.pt"))
            eval_loss, metrics = TropiGAT_models.evaluate(model_final, graph_data_kltype, criterion,graph_data_kltype["B1"].eval_mask)
            with open(f"{path_ensemble}/Metric_Report.2709.tsv", "a+") as metric_outfile :
                metric_outfile.write(f"{KL_type}\t{n_prophage}\t{metrics[0]}\t{metrics[1]}\t{metrics[2]}\t{metrics[3]}\t{metrics[4]}\t{metrics[5]}\n")
            info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            print(info_eval)
            log_outfile.write(f"Final evaluation ...\n{info_eval}")
        except Exception as e :
            log_outfile.write(f"***Issue here : {e}")





In [None]:
# Assuming you already have the model weights and new data
pretrained_model_path = "path_to_pretrained_model.pt"
new_data = graph_data_for_fine_tuning  # Replace with your new data

# Load the pre-trained model
model = TropiGAT_big_module(hidden_channels=1280, heads=1)
model.load_state_dict(torch.load(pretrained_model_path))

# Create an optimizer and scheduler
optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=0.000001)
scheduler = ReduceLROnPlateau(optimizer, 'min')

# Define the loss function
criterion = nn.BCEWithLogitsLoss()

# Training loop for fine-tuning
try:
    for epoch in range(200):
        # Training
        train_loss = TropiGAT_models.train(model, new_data, optimizer, criterion)

        if epoch % 5 == 0:
            # Validation
            test_loss, metrics = TropiGAT_models.evaluate(model, new_data, criterion, new_data["B1"].test_mask)
            
            info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}\n'
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            
            print(info_training)
            scheduler.step(test_loss)

    # Save the fine-tuned model
    torch.save(model.state_dict(), "fine_tuned_model.pt")

    # Final evaluation
    model_final = TropiGAT_big_module(hidden_channels=1280, heads=1)
    model_final.load_state_dict(torch.load("fine_tuned_model.pt"))
    eval_loss, metrics = TropiGAT_models.evaluate(model_final, new_data, criterion, new_data["B1"].eval_mask)

    info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
    print("Final evaluation ...")
    print(info_eval)

except Exception as e:
    print(f"***Issue here: {e}")