In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, HeteroConv , GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore")


In [None]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
graph_data = torch.load(f'{path_work}/train_nn/graph_file.1107.pt')


In [None]:
# *****************************************************************************
# The model : dot product
class GNN(torch.nn.Module):
    def __init__(self, edge_type , hidden_channels, conv=GATConv): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = 3, dropout = 0.1)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)
        return x

# Dot product :
class Classifier(torch.nn.Module):
    def forward(self, x_dict_A , x_dict_B1, edge_index):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][edge_index[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][edge_index[edge_type].edge_label_index[0]]
        return (edge_feat_A * edge_feat_B1).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") , out_channels)
        self.second_layer_model = GNN(("B1", "infects", "A") , out_channels)
        self.classifier_dot = Classifier()

    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        updated_dict = {}
        updated_dict["A"], updated_dict["B2"] = graph_data.x_dict["A"], graph_data.x_dict["B2"]
        updated_dict["B1"] = b1_nodes["B1"]
        a_nodes = self.second_layer_model(updated_dict , graph_data.edge_index_dict)
        dot_product = self.classifier_dot(a_nodes ,b1_nodes , graph_data)

        return dot_product


In [None]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

'''train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)
'''

In [None]:
# *****************************************************************************
# Training :

def train(model, data, optimizer, criterion, edge_type):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out = model(data)
    edge_labels = data[edge_type].edge_label
    loss = criterion(out, edge_labels)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, criterion, edge_type):
    model.eval()
    data = data.to(device)
    out = model(data)
    edge_labels = data[edge_type].edge_label
    val_loss = criterion(out, edge_labels)
    probs = torch.sigmoid(out)
    pred_class = probs.round()
    all_preds = pred_class.cpu().numpy()
    all_labels = edge_labels.cpu().numpy()
    all_probs = probs.cpu().numpy()
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')  # Calculate recall
    mcc = matthews_corrcoef(all_labels, all_preds)  # Calculate MCC
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    return val_loss.item(), f1, precision, recall, mcc, accuracy, auc  # Include recall and MCC in return values

def main():
    hidden_channels = 1000 # 1280
    model = Model(hidden_channels).to(device)
    model(train_data)
    criterion = torch.nn.BCEWithLogitsLoss()
    # Add weigth decay to prevent overfitting ! 
    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001) #weight_decay=0.001
    edge_type = ("B1", "infects", "A")
    for epoch in range(3000):
        train_loss = train(model, train_data, optimizer, criterion, edge_type)
        if epoch % 10 == 0:
            # Get all metrics including recall and MCC from evaluate function
            test_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, test_data, criterion, edge_type)
            print(f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}')
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/GATConv.model.single_batch.1807.pt")
    # The final eval : 
    print("Final evaluation ...")
    val_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, val_data, criterion, edge_type)
    print(f'F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}')

if __name__ == "__main__":
    main()



> Make predictions 

In [None]:
def make_predictions(model, data):
    model.eval()  # Set the model to evaluation mode
    data = data.to(device)  # Transfer the data to the device
    with torch.no_grad():  # No need to track gradients for prediction
        output = model(data)
    probabilities = torch.sigmoid(output)  # Convert output to probabilities
    predictions = probabilities.round()  # Convert probabilities to class labels
    return predictions.cpu().numpy(), probabilities.cpu().numpy()

# Load the saved model
hidden_channels = 1000  # this should be the same value you used during training #used to be 1000
model = Model(hidden_channels)
model.load_state_dict(torch.load(f"{path_work}/GATConv.model.single_batch.1807.pt"))
model = model.to(device)

# Predict using the new data
new_data = ...  # load or create new data here. It should have the same format as your training/test data
predictions, probabilities = make_predictions(model, new_data)

print(f"Predictions: {predictions}")
print(f"Probabilities: {probabilities}")