In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, HeteroConv , GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import f1_score, precision_score, recall_score, matthews_corrcoef, accuracy_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore")


In [2]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/graph_file.1107.pt')


In [None]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
graph_data = torch.load(f'{path_work}/train_nn/graph_file.1107.pt')


In [None]:
# The model : dot product
class GNN(torch.nn.Module):
    def __init__(self, edge_type , hidden_channels, conv=GATConv): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = 3, dropout = 0.1)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)  
        return x

# Dot product :
class Classifier(torch.nn.Module):
    def forward(self, x_dict_A , x_dict_B1, edge_index):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][edge_index[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][edge_index[edge_type].edge_label_index[0]]
        return (edge_feat_A * edge_feat_B1).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, out_channels , conv=GATConv):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") , out_channels) 
        self.second_layer_model = GNN(("B1", "infects", "A") , out_channels)
        self.classifier_dot = Classifier()
        
    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        updated_dict = {}
        updated_dict["A"], updated_dict["B2"] = graph_data.x_dict["A"], graph_data.x_dict["B2"]
        updated_dict["B1"] = b1_nodes["B1"]
        a_nodes = self.second_layer_model(updated_dict , graph_data.edge_index_dict)
        dot_product = self.classifier_dot(a_nodes ,b1_nodes , graph_data)
        
        return dot_product


In [4]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)


In [5]:
sampled_data = next(iter(train_loader))

sampled_data

HeteroData(
  [1mA[0m={
    x=[157, 127],
    n_id=[157]
  },
  [1mB1[0m={
    x=[270, 0],
    n_id=[270]
  },
  [1mB2[0m={
    x=[139, 1280],
    n_id=[139]
  },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 250],
    y=[250],
    edge_label=[128],
    edge_label_index=[2, 128],
    e_id=[250],
    input_id=[128]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 176],
    y=[176],
    e_id=[176]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 125],
    y=[125],
    e_id=[125]
  }
)

In [6]:
model = Model(20)
val = model(sampled_data)

In [7]:
val

tensor([-0.0097,  0.0120,  0.0120,  0.0120, -0.0154,  0.0123,  0.0123,  0.0123,
         0.0123,  0.0123,  0.0123,  0.0123,  0.0123,  0.0119,  0.0119,  0.0119,
         0.0119, -0.0036,  0.0119,  0.0119,  0.0119,  0.0117,  0.0117,  0.0117,
         0.0034,  0.0117,  0.0123,  0.0123,  0.0123,  0.0123,  0.0123,  0.0123,
         0.0123,  0.0120,  0.0120,  0.0120,  0.0120,  0.0120,  0.0120, -0.0057,
         0.0120,  0.0120, -0.0291,  0.0122,  0.0122,  0.0122, -0.0085,  0.0122,
         0.0122,  0.0122, -0.0081, -0.0130, -0.0130,  0.0116,  0.0116,  0.0123,
         0.0123,  0.0115,  0.0115,  0.0115,  0.1085, -0.0099,  0.0113, -0.0099,
         0.0119,  0.0119, -0.0002,  0.0119,  0.0119,  0.0119,  0.0119,  0.0119,
         0.0123,  0.0123,  0.0792,  0.0117,  0.0792,  0.0117,  0.0117,  0.0104,
         0.0104, -0.0090, -0.0090,  0.0115,  0.0545,  0.0121,  0.0121,  0.0121,
         0.0121,  0.0121,  0.0121,  0.0121,  0.0121,  0.0121,  0.0121,  0.0121,
         0.0121,  0.0121, -0.0200, -0.02

In [None]:
# *****************************************************************************
# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

def train(model, loader, optimizer, criterion, edge_type):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        edge_labels = data[edge_type].edge_label
        loss = criterion(out, edge_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader, criterion, edge_type):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []  # Collect output probabilities for AUC
    for data in loader:
        data = data.to(device)
        out = model(data)
        edge_labels = data[edge_type].edge_label
        val_loss = criterion(out, edge_labels)
        total_loss += val_loss.item()
        probs = torch.sigmoid(out)  # Convert output to probabilities
        pred_class = probs.round()  # Round to nearest integer to get class predictions
        all_preds.extend(pred_class.cpu().numpy())
        all_labels.extend(edge_labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())  # Collect output probabilities
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')  # Calculate recall
    mcc = matthews_corrcoef(all_labels, all_preds)  # Calculate MCC
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    return total_loss / len(loader), f1, precision, recall, mcc, accuracy, auc  # Include recall and MCC in return values

def main():
    hidden_channels = 1000
    model = Model(hidden_channels).to(device)
    # Due to lazy initialization, we need to run one model step so the number
    # of parameters can be inferred:
    eg_gratia_data = next(iter(val_loader))
    edge_type = ("B1", "infects", "A")
    with torch.no_grad():
        model(eg_gratia_data.to(device))
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    for epoch in range(1000):
        train_loss = train(model, train_loader, optimizer, criterion, edge_type)
        if epoch % 10 == 0:
            val_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, test_loader, criterion, edge_type)
            print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}, F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}')
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/SAGEConv.dot.model.2007.pt")
    print("Final evaluation ...")
    val_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, val_loader, criterion, edge_type)
    print(f'F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}')

if __name__ == "__main__":
    main()
