### Pip install operations


In [1]:
! pip install spektral
! pip install ogb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting spektral
  Downloading spektral-1.3.0-py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: spektral
Successfully installed spektral-1.3.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py)

### Loading the dataset

In [3]:
# import ogbn ogbn-arxiv
from spektral.datasets.ogb import OGB
from spektral.transforms import AdjToSpTensor, GCNFilter
from ogb.nodeproppred import Evaluator, NodePropPredDataset

dataset_name = "ogbn-arxiv"
ogb_dataset = NodePropPredDataset(dataset_name)
dataset = OGB(ogb_dataset, transforms=[GCNFilter(), AdjToSpTensor()])

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:08<00:00,  9.13it/s]


Extracting dataset/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 11214.72it/s]

Saving...





In [None]:
import tensorflow as tf

edges = dataset[0].a

# Access the properties of the SparseTensor
indices = edges.indices
values = edges.values
dense_shape = edges.dense_shape

# As we are dealing with undirected edges, we consider both columns of the indices tensor
nodes = tf.concat([indices[:, 0], indices[:, 1]], axis=0)

# Count the number of edges per node
unique_nodes, _, counts = tf.unique_with_counts(nodes)

num_edges = {int(node): int(count) for node, count in zip(unique_nodes.numpy(), counts.numpy())}

In [None]:
! pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# convert tf dataset to torch tensor

import torch
from torch_geometric.data import Data

# Get the node features, edge indices, and labels
features = dataset[0].x
edge_indices = dataset[0].a.indices
labels = dataset[0].y

# Convert TensorFlow tensors to PyTorch Tensors
features_torch = torch.from_numpy(features)
edge_indices_torch = torch.from_numpy(edge_indices.numpy().T).long()  # Transpose to fit PyG's edge_index format and convert to long
labels_torch = torch.from_numpy(labels)

# Create a PyTorch Geometric Data object
data = Data(x=features_torch, edge_index=edge_indices_torch, y=labels_torch)



In [None]:
print("type:", type(data))
print("features: ", data.x, " \n feat len: ", len(data.x))
print()
print("edge indices: ", data.edge_index, " \n edge idx len: ", len(data.edge_index))
print()
print("labels: ", data.y)
print()

type: <class 'torch_geometric.data.data.Data'>
features:  tensor([[-0.0579, -0.0525, -0.0726,  ...,  0.1734, -0.1728, -0.1401],
        [-0.1245, -0.0707, -0.3252,  ...,  0.0685, -0.3721, -0.3010],
        [-0.0802, -0.0233, -0.1838,  ...,  0.1099,  0.1176, -0.1399],
        ...,
        [-0.2205, -0.0366, -0.4022,  ...,  0.1134, -0.1614, -0.1452],
        [-0.1382,  0.0409, -0.2518,  ..., -0.0893, -0.0413, -0.3761],
        [-0.0299,  0.2684, -0.1611,  ...,  0.1208,  0.0776, -0.0910]])  
 feat len:  169343

edge indices:  tensor([[     0,      0,      0,  ..., 169342, 169342, 169342],
        [     0,  52893,  93487,  ...,  27824, 158981, 169342]])  
 edge idx len:  2

labels:  tensor([[ 4],
        [ 5],
        [28],
        ...,
        [10],
        [ 4],
        [ 1]])



In [None]:
# split the edges into train/val w/ 50% ratio
# ratio: train ratio

import numpy as np

def split_edges_torch(data, ratio=0.5):
    edges = data.edge_index

    # Randomly shuffle the edge indices
    perm = np.random.permutation(edges.shape[1])

    # Calculate split index
    split_idx = int(edges.shape[1] * ratio)

    # Split into training and validation edges
    train_edges = edges[:, perm[:split_idx]]
    val_edges = edges[:, perm[split_idx:]]

    return train_edges, val_edges



train_edges, val_edges = split_edges_torch(data, ratio=0.5)



In [None]:
print(len(train_edges[0]))
print(len(val_edges[0]))

667793
667793


In [None]:
train_edges

tensor([[ 78962,   5182,   9224,  ...,    401,   3142,   6416],
        [135860,  95915,  40010,  ...,  31399,  49490,  78341]])

## Default GCN training

In [None]:
! pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=7ea5d1829cb7417c52c8fd65d76da63a964a0551a9c255197eb4703389779d9f
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geomet

In [None]:
import torch
from torch_geometric.nn import GCNConv
from torch.nn import BCEWithLogitsLoss
import torch.nn.functional as F

# Define a simple GNN model
class GNN(torch.nn.Module):
    def __init__(self, num_features):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)
        self.scoring = torch.nn.Sequential(
            torch.nn.Linear(2 * 256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
        )

    def forward(self, data, edge_index):
        x = self.conv1(data.x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv3(x, edge_index)
        return x

    def decode(self, z, indices):
        start, end = indices
        edge_features = torch.cat([z[start], z[end]], dim=1)
        return self.scoring(edge_features).squeeze(-1)


def bpr_loss(pos_logit, neg_logit):
    return -F.logsigmoid(pos_logit - neg_logit).mean()



In [None]:
from sklearn.metrics import recall_score
from itertools import combinations
import itertools

# recall_at_k_per_node(model, z, val_edges, k, unique_nodes, data.edge_index)
def recall_at_k_per_node(model, z, val_edges, k, unique_nodes, edge_index):

  # get val nodes
  val_nodes = torch.unique(val_edges)

  # get all the real edges from the val nodes: pos_v
  mask = torch.isin(edge_index, val_nodes).any(dim=0)
  positive_edges = edge_index[:, mask]

  # get all possible edges: all_v
  all_edges_val = list(itertools.product(val_nodes.tolist(), unique_nodes.tolist()))
  all_edges_val = torch.tensor(all_edges_val, dtype=torch.long).t().contiguous()

  # get scores for all possible edges
  scores = model.decode(z, all_edges_val)

  # Get top k scores and their corresponding edges
  _, top_k_indices = scores.topk(k, largest=True)
  top_k_edges = all_edges_val[:, top_k_indices.cpu()]

  # check how many of them in top50
  top_k_edges = set( tuple( sorted((int(n1), int(n2))) ) for n1, n2 in zip(top_k_edges[0], top_k_edges[1]))
  positive_edges = set(tuple(sorted((int(n1), int(n2)))) for n1, n2 in zip(positive_edges[0], positive_edges[1]))

  print("top_k_edges: \n", top_k_edges)
  print("positive_edges \n", positive_edges)

  # calculate recall@k
  num_hits = len(top_k_edges & graph2_edges)
  recall_at_K = num_hits / len(positive_edges)

  return recall_at_K




In [None]:
def validation(model, data, val_edges, k=50):
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        z_val = model(data, val_edges)  # Get the embeddings for validation edges
        pos_edge_index_val = val_edges  # Positive examples for validation
        neg_edge_index_val = negative_sampling(edge_index=pos_edge_index_val, num_nodes=z_val.size(0))  # Negative examples for validation
        pos_logit_val = model.decode(z_val, pos_edge_index_val)
        neg_logit_val = model.decode(z_val, neg_edge_index_val)
        val_loss = bpr_loss(pos_logit_val, neg_logit_val)

        unique_nodes = torch.unique(data.edge_index)

        recall = recall_at_k_per_node(model, z_val, val_edges, k, unique_nodes, data.edge_index)

    return val_loss.item(), recall


In [None]:
from torch_geometric.utils import negative_sampling
import torch

def train(model, data, train_edges, val_edges, optimizer, patience=10):

  # Define some initial best validation loss as infinity
  best_val_loss = float('inf')
  epochs_no_improve = 0

  # Training loop
  data, train_edges, val_edges = data.to(device), train_edges.to(device), val_edges.to(device)
  for epoch in range(1000):  # 1000 epochs
      print("epoch ", epoch)
      model.train()
      optimizer.zero_grad()
      z_train = model(data, train_edges)  # embeddings for training edges
      pos_edge_index = train_edges  # positive examples
      neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=z_train.size(0))  # negative examples
      #print("pos_edge_index.shape: ", pos_edge_index.shape)
      pos_logit = model.decode(z_train, pos_edge_index)
      neg_logit = model.decode(z_train, neg_edge_index)
      loss = bpr_loss(pos_logit, neg_logit)
      loss.backward()
      optimizer.step()
      print("train loss: ", loss.item())

      # Validation:
      if (epoch +1) % 5 == 0:
        # validation function calls model.eval(), calculating both val loss & recall@50
        val_loss, recall_50 = validation(model, data, val_edges, 50)
        print(f'Validation Loss: {val_loss}, Recall@50: {recall_50}')

        # Check if early stopping conditions are met
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping triggered after {epoch+1} epochs.')
                break


In [None]:
data

Data(x=[169343, 128], edge_index=[2, 1335586], y=[169343, 1])

In [None]:
data.x

tensor([[-0.0579, -0.0525, -0.0726,  ...,  0.1734, -0.1728, -0.1401],
        [-0.1245, -0.0707, -0.3252,  ...,  0.0685, -0.3721, -0.3010],
        [-0.0802, -0.0233, -0.1838,  ...,  0.1099,  0.1176, -0.1399],
        ...,
        [-0.2205, -0.0366, -0.4022,  ...,  0.1134, -0.1614, -0.1452],
        [-0.1382,  0.0409, -0.2518,  ..., -0.0893, -0.0413, -0.3761],
        [-0.0299,  0.2684, -0.1611,  ...,  0.1208,  0.0776, -0.0910]])

In [None]:
val_nodes = data.edge_index[:,(1,2)]

# get val nodes
val_nodes = torch.unique(val_edges)
print("val_nodes \n", val_nodes)
print()

# get all the real edges from the val nodes: pos_v
mask = torch.isin(data.edge_index, val_nodes).any(dim=0)
positive_edges = data.edge_index[:, mask]
print("positive_edges: \n", positive_edges)
print()


graph2_edges = set( tuple((int(n1), int(n2))) for n1, n2 in zip(positive_edges[0], positive_edges[1]) )

print("positive_edges: \n", graph2_edges)



val_nodes 
 tensor([    0, 52893, 93487], device='cuda:0')

positive_edges: 
 tensor([[     0,      0,      0,    411,    640,    640,   1162,   1162,   1897,
           3396,   3787,   4383,   4586,   4692,   4692,   4851,   5037,   5190,
           5537,   5537,   5611,   5803,   5950,   7026,   7043,   7223,   7558,
           8205,   8481,   9646,  10098,  10313,  10839,  10839,  10839,  12093,
          12110,  12939,  13760,  14291,  14291,  14982,  15134,  15577,  15577,
          15736,  16415,  17685,  17685,  17790,  17846,  18523,  18959,  19117,
          19188,  20468,  20468,  20500,  21173,  21760,  21842,  22110,  22110,
          22110,  22898,  22898,  23786,  24324,  24986,  25103,  25131,  25782,
          25933,  27167,  27242,  27521,  27607,  28475,  28491,  28932,  28943,
          28943,  29442,  29958,  30038,  30038,  30293,  30293,  30832,  30984,
          30984,  32066,  32085,  32085,  32984,  33025,  34297,  34405,  35100,
          35458,  36054,  36054

In [None]:

import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_edges, val_edges = split_edges_torch(data, ratio=0.5)
val_edges = data.edge_index[:,(1,2)]

print(train_edges)
print(data)

# Initialize your model, loss, and optimizer
num_features = data.x.shape[1]
model = GNN(num_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)  # L2 regularization

train(model, data, train_edges, val_edges, optimizer, patience=10)




tensor([[ 60203, 120151, 160419,  ..., 137939,  41674, 159210],
        [157548, 120151,  62615,  ..., 104703, 155309, 159210]],
       device='cuda:0')
Data(x=[169343, 128], edge_index=[2, 1335586], y=[169343, 1])
epoch  0
train loss:  0.6925304532051086
epoch  1
train loss:  0.6427505612373352
epoch  2
train loss:  0.6052559018135071
epoch  3
train loss:  0.5761004686355591
epoch  4
train loss:  0.5519051551818848
top_k_edges: 
 {(49164, 93487), (8759, 52893), (0, 40539), (93487, 156724), (93487, 122176), (0, 123745), (45937, 52893), (93487, 123500), (52893, 156724), (8759, 93487), (93487, 120867), (14001, 52893), (52893, 122176), (93487, 121004), (93487, 123643), (52893, 123500), (11404, 52893), (52893, 120867), (93487, 158270), (93487, 150829), (45937, 93487), (52893, 121004), (52893, 123643), (0, 122176), (52893, 150829), (14001, 93487), (57051, 93487), (52893, 158270), (52893, 78548), (11404, 93487), (4092, 52893), (0, 39238), (0, 123643), (40539, 52893), (0, 158270), (93487, 945