In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import HeteroConv , GCNConv , SAGEConv , GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 


In [3]:
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/graph_file.1107.pt')

graph_data

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 9677],
    y=[9677]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 9677],
    y=[9677]
  }
)

In [4]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

In [6]:
sampled_data = next(iter(train_loader))

sampled_data

HeteroData(
  [1mA[0m={
    x=[155, 127],
    n_id=[155]
  },
  [1mB1[0m={
    x=[289, 0],
    n_id=[289]
  },
  [1mB2[0m={
    x=[127, 1280],
    n_id=[127]
  },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 280],
    y=[280],
    edge_label=[128],
    edge_label_index=[2, 128],
    e_id=[280],
    input_id=[128]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 173],
    y=[173],
    e_id=[173]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 127],
    y=[127],
    e_id=[127]
  }
)

> That one seems to work :

In [None]:
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, features_A, features_B1, graph_data):
        index_B1 , index_A = graph_data["B1", "infects", "A"].edge_label_index
        z = torch.cat([features_B1[index_B1] ,features_A[index_A]], dim=-1)  # Can you explain why this line gives me an error 
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

In [10]:
# *****************************************************************************
# The model : GAT dot product
class GNN(torch.nn.Module):
    def __init__(self, edge_type , hidden_channels, conv=GATConv): # GCNConv(-1, 64) , SAGEConv((-1, -1), 64), GATConv((-1, -1), 64)
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)
        return x

# FNN layers product :
class Classifier_linear(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, x_dict_A, x_dict_B1, graph):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][graph[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][graph[edge_type].edge_label_index[0]]
        z = torch.cat([edge_feat_A ,edge_feat_B1], dim=-1)  # Can you explain why this line gives me an error 
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)
        
class Model(torch.nn.Module):
    def __init__(self, out_channels , conv=GATConv):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") , out_channels)
        self.second_layer_model = GNN(("B1", "infects", "A") , out_channels)
        self.classifier_linear = Classifier_linear(out_channels)

    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        updated_dict = {}
        updated_dict["A"], updated_dict["B2"] = graph_data.x_dict["A"], graph_data.x_dict["B2"]
        updated_dict["B1"] = b1_nodes["B1"]
        a_nodes = self.second_layer_model(updated_dict , graph_data.edge_index_dict)
        value = self.classifier_linear(a_nodes ,b1_nodes, graph_data)
        return value



In [11]:
model = Model(20)
val = model(sampled_data)

In [12]:
val

tensor([-0.1699, -0.1771, -0.1870, -0.1753, -0.1675, -0.1704, -0.1726, -0.1553,
        -0.1563, -0.1624, -0.1549, -0.1738, -0.1643, -0.1543, -0.1602, -0.1659,
        -0.1673, -0.1663, -0.1608, -0.1782, -0.1558, -0.1776, -0.1550, -0.1614,
        -0.1861, -0.1661, -0.1629, -0.1564, -0.1638, -0.1656, -0.1661, -0.1682,
        -0.1608, -0.1864, -0.1556, -0.1641, -0.1560, -0.1517, -0.1758, -0.1606,
        -0.1685, -0.1617, -0.1836, -0.1733, -0.1818, -0.1670, -0.1661, -0.1735,
        -0.1670, -0.1558, -0.1640, -0.1889, -0.1633, -0.1717, -0.1604, -0.1662,
        -0.1878, -0.1793, -0.1678, -0.1549, -0.1615, -0.1626, -0.1665, -0.1676,
        -0.1561, -0.1645, -0.1776, -0.1732, -0.1665, -0.1591, -0.1796, -0.1862,
        -0.1661, -0.1853, -0.1509, -0.1726, -0.1576, -0.1576, -0.1645, -0.1712,
        -0.1652, -0.1823, -0.1661, -0.1693, -0.1645, -0.1670, -0.1656, -0.1674,
        -0.1608, -0.1625, -0.1661, -0.1659, -0.1551, -0.1648, -0.1655, -0.1684,
        -0.1722, -0.1604, -0.1695, -0.18