In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, HeteroConv , GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore")


In [None]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
graph_data = torch.load(f'{path_work}/train_nn/graph_file.1107.pt')


In [None]:
# *****************************************************************************
# The model : Dot Classifier
class GNN(torch.nn.Module):
    def __init__(self, edge_type , hidden_channels, conv=SAGEConv): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False)
        self.hetero_conv = HeteroConv({edge_type: self.conv},aggr='mean')
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)  
        return x

class Classifier(torch.nn.Module):
    def forward(self, x_dict_A , x_dict_B1, edge_index):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][edge_index.edge_index_dict[edge_type][1]]
        edge_feat_B1 = x_dict_B1["B1"][edge_index.edge_index_dict[edge_type][0]]
        return (edge_feat_A * edge_feat_B1).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, out_channels , conv=SAGEConv):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") , out_channels) 
        self.second_layer_model = GNN(("B1", "infects", "A") , out_channels)
        self.classifier_dot = Classifier()
        
    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        updated_dict = {}
        updated_dict["A"], updated_dict["B2"] = graph_data.x_dict["A"], graph_data.x_dict["B2"]
        updated_dict["B1"] = b1_nodes["B1"]
        a_nodes = self.second_layer_model(updated_dict , graph_data.edge_index_dict)
        dot_product = self.classifier_dot(a_nodes ,b1_nodes , graph_data)
        
        return dot_product


In [None]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)


In [None]:
# *****************************************************************************
# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

def train(model, loader, optimizer, criterion, edge_type):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        edge_labels = data[edge_type].edge_label
        loss = criterion(out, edge_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader, criterion, edge_type):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []  # Collect output probabilities for AUC
    for data in loader:
        data = data.to(device)
        out = model(data)
        edge_labels = data[edge_type].edge_label
        val_loss = criterion(out, edge_labels)
        total_loss += val_loss.item()
        probs = torch.sigmoid(out)  # Convert output to probabilities
        pred_class = probs.round()  # Round to nearest integer to get class predictions
        all_preds.extend(pred_class.cpu().numpy())
        all_labels.extend(edge_labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())  # Collect output probabilities
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)  # Use probabilities, not class predictions
    return total_loss / len(loader), f1, precision, accuracy, auc

def main():
    hidden_channels = 1000
    model = Model(hidden_channels).to(device)
    # Due to lazy initialization, we need to run one model step so the number
    # of parameters can be inferred:
    eg_gratia_data = next(iter(val_loader))
    edge_type = ("B1", "infects", "A")
    with torch.no_grad():
        model(eg_gratia_data)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    for epoch in range(1000):
        train_loss = train(model, train_loader, optimizer, criterion, edge_type)
        if epoch % 10 == 0:
            test_loss, f1, precision, accuracy, auc = evaluate(model, test_loader, criterion, edge_type)
            print(f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Accuracy: {accuracy}, AUC: {auc}')
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/SAGEConv.dot.model.1707.pt")

if __name__ == "__main__":
    main()
