In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv , HeteroConv , GATConv, GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging

warnings.filterwarnings("ignore") 

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/graph_file.2607.LE.pt')

In [25]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.2,
    add_negative_train_samples=False,
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"),
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors= [-1],
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors= [-1],
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors= [-1],
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)


In [27]:
sampled_data = next(iter(train_loader))
sampled_data

HeteroData(
  [1mA[0m={
    x=[127, 127],
    n_id=[127]
  },
  [1mB1[0m={
    x=[299, 0],
    n_id=[299]
  },
  [1mB2[0m={
    x=[111, 1280],
    n_id=[111]
  },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 299],
    y=[299],
    edge_label=[128],
    edge_label_index=[2, 128],
    e_id=[299],
    input_id=[128]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 155],
    y=[155],
    e_id=[155]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 128],
    y=[128],
    e_id=[128]
  }
)

In [28]:
val_data

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 5412],
    y=[5412],
    edge_label=[1546],
    edge_label_index=[2, 1546]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 5412],
    y=[5412]
  }
)

In [14]:
# ***************************************************************************
# The model : multi class classification 
class GNN(torch.nn.Module):
    def __init__(self, edge_type , conv, hidden_channels, heads, dropout): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)  
        return x

# Classifier, multiclass :
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.lin1 = torch.nn.Linear(heads*hidden_channels, 512)
        self.lin2 = torch.nn.Linear(512, 127)
        
    def forward(self , x_dict_B1, graph_data):
        edge_type = ("B1", "infects", "A")
        labels = graph_data.x_dict["A"][graph_data[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][graph_data[edge_type].edge_label_index[0]]
        x = self.lin1(edge_feat_B1).relu()
        x = self.lin2(x)
        return x , labels

class Model(torch.nn.Module):
    def __init__(self, conv, hidden_channels, heads, dropout):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") ,conv, hidden_channels,heads,dropout) 
        self.EdgeDecoder = EdgeDecoder(hidden_channels,heads)
        
    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        a_nodes =  graph_data.x_dict
        out ,labels = self.EdgeDecoder(b1_nodes , graph_data)
        return out , labels



In [24]:
sampled_data[("B1", "infects", "A")].edge_label

tensor([0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1.,
        0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1.,
        1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0.,
        1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 0.,
        0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 0.,
        1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0.,
        1., 1.])

In [31]:
model = Model(GATv2Conv,20,1,0.1)
out , labels = model(train_data)

In [45]:
train_data.x_dict["A"]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [33]:
len(out)

5412

In [39]:
labels.unique()

tensor([0., 1.])

In [43]:
clean_labels = set()
for label in labels:
    clean_labels.add(tuple(label.tolist())) 
len(clean_labels)

5

In [35]:
labels_back = torch.argmax(labels, dim=1)

In [38]:
labels_back.unique()

tensor([ 7,  8, 39, 74, 95])

In [None]:
# *****************************************************************************
# Training : Multiclass :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

def train(model, data, optimizer, criterion):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out, labels = model(data)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data, criterion, num_classes):
    model.eval()
    data = data.to(device)
    out, labels = model(data)
    labels = torch.argmax(labels, dim=1)  # one-hot to indices
    val_loss = criterion(out, labels)
    pred_class = out.argmax(dim=1)
    accuracy = (pred_class == labels).sum().item() / labels.size(0)
    per_class_accuracy = [(pred_class[labels == i] == labels[labels == i]).sum().item() / (labels == i).sum().item() for i in range(num_classes)]
    conf_mat = confusion_matrix(labels, pred_class)
    auc_roc = roc_auc_score(label_binarize(labels, classes=range(num_classes)), out.cpu().detach().numpy(), multi_class='ovr')
    return val_loss.item(), accuracy, per_class_accuracy, conf_mat, auc_roc

def main():
    hidden_channels = 1000
    lr = 0.0001
    conv = GATConv
    heads = 1
    dropout = 0.1
    decay = 5e-4
    num_classes = 127  # modify this to match your number of classes
    logging.info(f"Let's start the work with {conv}\t{hidden_channels}\t{dropout}\t{lr}\t{heads}")
    model = Model(conv,hidden_channels,heads,dropout).to(device)
    criterion = CrossEntropyLoss()  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay) 
    for epoch in range(3000):
        train_loss = train(model, train_data, optimizer, criterion)
        if epoch % 50 == 0:
            val_loss, accuracy, per_class_accuracy, conf_mat, auc_roc = evaluate(model, test_data, criterion, num_classes)
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {val_loss}, Accuracy: {accuracy}, Per Class Accuracy: {per_class_accuracy}, Confusion Matrix: {conf_mat}, AUC_ROC: {auc_roc}'
            logging.info(info_training)
            print(info_training)
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/GATv2Conv.model.{heads}.multiclass.single_batch.2407.pt")
    # The final eval : 
    print("Final evaluation ...")
    val_loss, accuracy, per_class_accuracy, conf_mat, auc_roc = evaluate(model, val_data, criterion, num_classes)
    print(f'Final Test Loss: {val_loss}, Accuracy: {accuracy}, Per Class Accuracy: {per_class_accuracy}, Confusion Matrix: {conf_mat}, AUC_ROC: {auc_roc}')
    logging.info(f"Final evaluation ...\nFinal Test Loss: {val_loss}, Accuracy: {accuracy}, Per Class Accuracy: {per_class_accuracy}, Confusion Matrix: {conf_mat}, AUC_ROC: {auc_roc}")


if __name__ == "__main__":
    main()
