In [2]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv , HeteroConv , GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging

warnings.filterwarnings("ignore") 

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/graph_file.2407.pt')


In [3]:
graph_data

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 7731],
    y=[7731]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 7731],
    y=[7731]
  }
)

In [27]:
# Binary classification 
# *****************************************************************************
logging.basicConfig(filename = f"{path_work}/GATConv.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')

# The model : dot product
class GNN(torch.nn.Module):
    def __init__(self, edge_type , conv, hidden_channels, heads, dropout): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)  
        return x

# Classifier, Binary :
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.lin1 = torch.nn.Linear(heads*hidden_channels + 127, 512)
        self.lin2 = torch.nn.Linear(512, 1)
        
    def forward(self, x_dict_A , x_dict_B1, graph_data):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][graph_data[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][graph_data[edge_type].edge_label_index[0]]
        features_phage = torch.cat((edge_feat_A ,edge_feat_B1), dim=-1)
        x = self.lin1(features_phage).relu()
        x = self.lin2(x)
        return x.view(-1)

class Model(torch.nn.Module):
    def __init__(self, conv, hidden_channels, heads, dropout):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") ,conv, hidden_channels,heads,dropout) 
        self.EdgeDecoder = EdgeDecoder(hidden_channels,heads)
        
    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        a_nodes =  graph_data.x_dict
        out = self.EdgeDecoder(a_nodes ,b1_nodes , graph_data)
        return out


In [32]:
# Multiclass classification 
# *****************************************************************************
logging.basicConfig(filename = f"{path_work}/GATConv.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')

# The model : dot product
class GNN(torch.nn.Module):
    def __init__(self, edge_type , conv, hidden_channels, heads, dropout): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)  
        return x

# Classifier, multiclass :
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.lin1 = torch.nn.Linear(heads*hidden_channels, 512)
        self.lin2 = torch.nn.Linear(512, 127)
        
    def forward(self, x_dict_A , x_dict_B1, graph_data):
        edge_type = ("B1", "infects", "A")
        edge_feat_B1 = x_dict_B1["B1"][graph_data[edge_type].edge_label_index[0]]
        x = self.lin1(edge_feat_B1).relu()
        x = self.lin2(x)
        return x.view(-1)

class Model(torch.nn.Module):
    def __init__(self, conv, hidden_channels, heads, dropout):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") ,conv, hidden_channels,heads,dropout) 
        self.EdgeDecoder = EdgeDecoder(hidden_channels,heads)
        
    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        a_nodes =  graph_data.x_dict
        out = self.EdgeDecoder(a_nodes ,b1_nodes , graph_data)
        return out

In [33]:
model_test = Model(GATConv , 1000, 3, 0.1)
out = model_test(sampled_data)


In [36]:
out

tensor([-0.0277,  0.0309, -0.0310,  ..., -0.0512, -0.0195,  0.0744],
       grad_fn=<ViewBackward0>)

In [4]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)


In [12]:
sampled_data = next(iter(train_loader))

sampled_data.edge_index_dict[("B1", "infects", "A")]
sampled_data[("B1", "infects", "A")]

{'edge_index': tensor([[128,   0, 129, 130, 131, 132, 133, 134, 135, 136, 137,  78,  90, 138,
         139, 140, 141, 142, 143, 144, 145,   2,   3, 146, 147, 148,   6, 149,
           4, 150,   5, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161,
         162,  13, 163, 164, 165,  15, 166, 167, 168,   9, 169, 170, 171, 172,
         173, 174, 175, 176, 177, 178, 179,  21, 180,  59, 181, 182, 183,  30,
         184, 185, 186, 187, 188,  68, 189,  11, 190, 191,  17, 192,  18, 193,
         194,  40, 195, 196, 197, 198, 199, 200, 201, 202, 203,  67,  66,  23,
         204, 205,  25, 206, 207,  27, 208,  29, 209, 210, 211, 212,  31,  32,
         213, 214,  34, 215, 216,  35, 217, 218, 219, 220, 221, 222, 223, 224,
         225, 226, 227, 228, 229, 230,  42, 231,  57, 232, 233, 234, 235, 236,
         237, 238, 239, 240, 241,  46, 242,  48, 243, 244, 245, 246,  53,  54,
         247, 248,  72,  55, 249, 250,  56, 251,  58,  98, 252,  62,  63, 253,
          99, 254, 255, 256, 257, 258

In [None]:
# *****************************************************************************
# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

def train(model, data, optimizer, criterion, edge_type):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out = model(data)
    edge_labels = data[edge_type].edge_label
    loss = criterion(out, edge_labels)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, criterion, edge_type):
    model.eval()
    data = data.to(device)
    out = model(data)
    edge_labels = data[edge_type].edge_label
    val_loss = criterion(out, edge_labels)
    probs = torch.sigmoid(out)
    pred_class = probs.round()
    all_preds = pred_class.cpu().numpy()
    all_labels = edge_labels.cpu().numpy()
    all_probs = probs.cpu().numpy()
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')  # Calculate recall
    mcc = matthews_corrcoef(all_labels, all_preds)  # Calculate MCC
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    return val_loss.item(), f1, precision, recall, mcc, accuracy, auc  # Include recall and MCC in return values

def main():
    hidden_channels = 1000
    lr = 0.0001
    conv = GATConv
    heads = 3
    dropout = 0.1
    logging.info(f"Let's start the work with {conv}\t{hidden_channels}\t{dropout}\t{lr}\t{heads}")
    model = Model(conv,hidden_channels,heads,dropout).to(device)
    model(train_data)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) #lr=0.00001, weight_decay=0.001
    edge_type = ("B1", "infects", "A")
    for epoch in range(3000):
        train_loss = train(model, train_data, optimizer, criterion, edge_type)
        if epoch % 10 == 0:
            # Get all metrics including recall and MCC from evaluate function
            test_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, test_data, criterion, edge_type)
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}'
            logging.info(info_training)
            print(info_training)
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/{conv}.model.single_batch.2407.pt")
    # The final eval : 
    print("Final evaluation ...")
    val_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, val_data, criterion, edge_type)
    print(f'F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}')
    logging.info(f"Final evaluation ...\nF1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}")



In [None]:
if __name__ == "__main__":
    main()
