### Pip Install Operations


In [1]:
! pip install spektral
! pip install ogb
! pip install torch_geometric



## Setup


In [2]:
import torch
from torch_geometric.data import Data
import random
import numpy as np
import math

from spektral.datasets.ogb import OGB
from spektral.transforms import AdjToSpTensor, GCNFilter
from ogb.nodeproppred import Evaluator, NodePropPredDataset

from torch_geometric.nn import GCNConv
from torch.nn import BCEWithLogitsLoss
import torch.nn.functional as F

from torch_geometric.utils import negative_sampling

from itertools import product, combinations

from sklearn.metrics import recall_score
import itertools

In [3]:
if torch.backends.mps.is_available():
    device_name=       'mps'
elif torch.cuda.is_available():
    device_name= 'cuda' 
else:
    device_name=  'cpu'
device = torch.device(device_name)
device = torch.device('cpu')
train_edge_percentage = 0.5
V_percentage = 0.95
SEED = 42

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.manual_seed(SEED)

<torch._C.Generator at 0x12a1ee370>

### Training Related Setup

In [4]:
device

device(type='cpu')

## Dataset Related


### Loading the dataset

In [5]:
# import ogbn ogbn-arxiv


dataset_name = "ogbn-arxiv"
ogb_dataset = NodePropPredDataset(dataset_name)
dataset = OGB(ogb_dataset, transforms=[GCNFilter(), AdjToSpTensor()])

### Converting dataset from TF to Torch

In [6]:
# convert tf dataset to torch tensor



# Get the node features, edge indices, and labels
features = dataset[0].x
edge_indices = dataset[0].a.indices
labels = dataset[0].y

# Convert TensorFlow tensors to PyTorch Tensors
features_torch = torch.from_numpy(features)
edge_indices_torch = torch.from_numpy(edge_indices.numpy().T).long()  # Transpose to fit PyG's edge_index format and convert to long
labels_torch = torch.from_numpy(labels)

# Create a PyTorch Geometric Data object
data = Data(x=features_torch, edge_index=edge_indices_torch, y=labels_torch)

### Applying dataset splits



In [7]:


# # # # # #
# Getting V and V_new
# # # # # #

# Assume that `data` is your PyTorch Geometric graph object.
# data = ...

# Get the number of nodes in your graph.
num_nodes = data.num_nodes

# Create a random permutation of indices [0, 1, 2, ..., num_nodes-1].
perm = torch.randperm(num_nodes)

# Calculate the index at which to split the permutation.
split_idx = int(num_nodes * V_percentage)

# Split the permutation into indices for V (95%) and V_new (5%).
V = perm[:split_idx].to(device)
V_new = perm[split_idx:].to(device)

# V and V_new are now the indices of the nodes in the 95% and 5% splits, respectively.

# ------> For node classification





In [8]:


# # # # # #
# Splitting edges to training and validation edges
# # # # # #



# Assuming your data is in this format
# data = Data(x=features_torch, edge_index=edge_indices_torch, y=labels_torch)

# Get the number of edges
num_edges = data.edge_index.size(1)

# Create a list of indices representing the edges
edge_indices = list(range(num_edges))

# Shuffle the indices randomly
random.shuffle(edge_indices)

# Define the percentage of edges to be used for training
num_train_edges = int(train_edge_percentage * num_edges)

# Split the indices into two sets: for training and validation
train_edge_indices = edge_indices[:num_train_edges]
val_edge_indices = edge_indices[num_train_edges:]

# Function to create a new edge_index tensor based on selected indices
def create_edge_index_subset(edge_index, selected_indices):
    return edge_index[:, selected_indices]

# Create new edge_index tensors for training and validation
E_train = create_edge_index_subset(data.edge_index, train_edge_indices)
E_val = create_edge_index_subset(data.edge_index, val_edge_indices)
E_all = data.edge_index.to(device)

# Now, 'edge_index_train' contains the edges for the training set,
# and 'edge_index_val' contains the edges for the validation set.


## Model Related


### Classical Backbone Model (GCN)

In [9]:


# Define a simple GNN model
class GCN(torch.nn.Module):
    def __init__(self, num_features):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)

        self.scoring = torch.nn.Sequential(
            torch.nn.Linear(2 * 256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1)
        )

    def forward(self, data, edge_index):
        x = self.conv1(data.x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv3(x, edge_index)
        return x

    def decode(self, z, indices):
        start, end = indices
        edge_features = torch.cat([z[start], z[end]], dim=1)
        return self.scoring(edge_features).squeeze(-1)


def bpr_loss(pos_logit, neg_logit):
    return torch.log(F.sigmoid(pos_logit - neg_logit)).sum()
    #return -F.logsigmoid(pos_logit - neg_logit).sum()
    # Log in torch



## Training, Validation Test


### Train

In [10]:


def train(model,V, data, train_edges, val_edges, optimizer, patience=10, epochs = 1000, test_active = True):

  # Define some initial best validation loss as infinity
  best_val_loss = float('inf')
  epochs_no_improve = 0

  # Training loop
  data, train_edges, val_edges = data.to(device), train_edges.to(device), val_edges.to(device)
  for epoch in range(epochs):  # 1000 epochs
      print("epoch ", epoch)

      model.train()
      optimizer.zero_grad()

      z_train = model(data, train_edges)  # embeddings for training edges
      pos_edge_index = train_edges  # positive examples
      neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=z_train.size(0))  # negative examples

      #print("pos_edge_index.shape: ", pos_edge_index.shape)
      pos_logit = model.decode(z_train, pos_edge_index)
      neg_logit = model.decode(z_train, neg_edge_index)

      loss = bpr_loss(pos_logit, neg_logit)

      loss.backward()
      optimizer.step()

      print("train loss: ", loss.item())

      if test_active:
        res = test(model, V, val_edges,z_train, 50)
        print("recall@50: ", res)

      # Validation:
      if (epoch +1) % 5 == 0:
        # validation function calls model.eval(), calculating both val loss & recall@50
        val_loss, recall_50 = validation(model, data, val_edges, 50)
        print(f'Validation Loss: {val_loss}, Recall@50: {recall_50}')

        # Check if early stopping conditions are met
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping triggered after {epoch+1} epochs.')
                break




### Validation


In [11]:


def validation(model, nodes, val_edges, z, k=50):
    #model.eval()  # Set the model to evaluation mode

    with torch.no_grad():


        # Convert V to a boolean tensor for faster lookup.
        v_mask = torch.zeros(num_nodes, dtype=torch.bool)
        v_mask[nodes] = True

        # Assume val_edges contains the validation edges (it should be a 2 x num_val_edges tensor)
        # val_edges = ...

        # Check if both nodes of each edge in val_edges are in V
        source_nodes = val_edges[0, :]
        target_nodes = val_edges[1, :]
        can_exist_in_V = v_mask[source_nodes] & v_mask[target_nodes]

        # Filter the edges that can exist in V
        valid_edges_in_V = val_edges[:, can_exist_in_V]
        positive_pairs = valid_edges_in_V

        # --- Generating negative pairs ---

        # Find the unique starting nodes in val_edges
        start_nodes = torch.unique(val_edges[0, :])

        # Generate all possible pairs from start_nodes to all nodes in V
        all_possible_pairs = torch.tensor(list(product(start_nodes.tolist(), V.tolist())))

        # Remove the existing edges in val_edges from all_possible_pairs to create the negative pairs
        existing_pairs = valid_edges_in_V.t()
        negative_pairs = []
        for pair in all_possible_pairs:
            if not any(torch.all(pair == existing_pair, dim=0) for existing_pair in existing_pairs):
                negative_pairs.append(pair)

        negative_pairs = torch.stack(negative_pairs).t()


          # Negative examples for validation

        pos_logit_val = model.decode(z, positive_pairs)
        neg_logit_val = model.decode(z, negative_pairs)

        val_loss = bpr_loss(pos_logit_val, neg_logit_val)

    return val_loss.item()

### Test

In [12]:
def test(model, nodes, val_edges, z, k=50):
    model.eval()  # Set the model to evaluation mode


    # Take 5 samples from val_edges as positive examples


    with torch.no_grad():

        # Convert V to a boolean tensor for faster lookup.
        v_mask = torch.zeros(num_nodes, dtype=torch.bool)
        v_mask[nodes] = True
        v_mask = v_mask.to(device)

        # Assume val_edges contains the validation edges (it should be a 2 x num_val_edges tensor)
        # val_edges = ...

        # Check if both nodes of each edge in val_edges are in V
        source_nodes = val_edges[0, :]
        target_nodes = val_edges[1, :]
        can_exist_in_V = v_mask[source_nodes] & v_mask[target_nodes]

        # Filter the edges that can exist in V
        valid_edges_in_V = val_edges[:, can_exist_in_V]
        positive_pairs = valid_edges_in_V


        # FOR MEMORY
        positive_pairs = positive_pairs[:, torch.randint(valid_edges_in_V.size(1), (1,))]


        # --- Generating negative pairs ---

        # Find the unique starting nodes in val_edges
        start_nodes = torch.unique(valid_edges_in_V[0, :]).to(device)

        # Generate all possible pairs from start_nodes to all nodes in V
        all_possible_pairs = torch.stack(torch.meshgrid(start_nodes, V), dim=-1).reshape(-1, 2).t().to(device)


        # Remove the existing edges in val_edges from all_possible_pairs to create the negative pairs
        existing_pairs = valid_edges_in_V.t()
        existing_pairs = existing_pairs.to(device)

        # Removing positive pairs that are generated accidentaly
        negative_pairs = remove_common_edges(E_all=valid_edges_in_V,B=all_possible_pairs) # B - (A INTERSECTION B)

        # Negative examples for validation
        positive_scores = model.decode(z, positive_pairs)
        negative_scores = model.decode(z, negative_pairs)

        # Combine positive edges and negative scores
        all_edges = torch.cat([valid_edges_in_V, negative_pairs], dim=1)
        all_scores = torch.cat([positive_scores, negative_scores])
        # Indicate which edges are positive (1 for positive, 0 for negative)
        positive_edge_indicator = torch.tensor([1]*valid_edges_in_V.size(1) + [0]*negative_pairs.size(1)).to(device)


        recall= calculate_recall_per_node(all_edges, all_scores, positive_edge_indicator, k)


        return recall


## Utils

### Recall calculation


In [13]:
import torch

def calculate_recall_per_node(all_edges, all_scores, positive_edge_indicator, K):
    """
    Calculate recall for each individual starting node using tensor operations on GPU.

    Parameters:
    - all_edges: Tensor of shape [2, num_edges], containing edges (source -> target).
    - all_scores: Tensor of shape [num_edges], containing scores for each edge.
    - positive_edge_indicator: Tensor of shape [num_edges], containing 1 for positive edges and 0 for negative edges.
    - K: The number of top edges to consider for calculating recall.

    Returns:
    - recall_per_node: Dictionary with nodes as keys and recall as values.
    """

    print("all_edges")
    print(all_edges.shape)

    print("all_scores")
    print(all_scores.shape)

    print("positive_edge_indicator")
    print(positive_edge_indicator.shape)



    # Get unique start nodes
    start_nodes = torch.unique(all_edges[0, :])

    # Sort scores in descending order
    sorted_indices = torch.argsort(all_scores, descending=True)

    all_edges = all_edges[:, sorted_indices]
    positive_edge_indicator = positive_edge_indicator[sorted_indices]


    recall = 0
    for start_node in start_nodes:
      # Find all edges related to start_node
      mask = all_edges[0,:] == start_node

      filtered_indicators = positive_edge_indicator[mask]

      positive_edge_indicator = torch.masked_select(positive_edge_indicator, torch.logical_not(mask))

      recall = recall + filtered_indicators[:K].sum() / filtered_indicators.sum()

    recall = recall / len(start_nodes)



    # Create bins for each unique start node
    #bins = torch.zeros_like(all_scores).scatter_(0, all_edges[0, :], 1).cumsum(0)

    # Create a mask for top K elements in each bin
    #top_k_mask = (bins <= K).gather(0, torch.argsort(bins.gather(0, sorted_indices)))

    # Compute recalls by start node
    #top_k_sorted_positive_indicators = positive_edge_indicator[sorted_indices][top_k_mask]
    #recall = (top_k_sorted_positive_indicators.view(len(start_nodes), -1).sum(1) / K).cpu().numpy()

    # Create recall_per_node dictionary
    #recall_per_node = {node.item(): recall for node, recall in zip(start_nodes, recalls)}

    return recall

# Example usage remains the same


### Find intersection, remove it


In [14]:
import torch

def remove_common_edges(E_all, B):
    return B
    # Compute the pairwise equality
    pairwise_equality = torch.eq(E_all.unsqueeze(2), B.unsqueeze(1))

    # Determine the columns where all rows are True (i.e., both elements in column are equal)
    column_equality = torch.all(pairwise_equality, dim=0)

    # Clear intermediate tensor
    del pairwise_equality

    # Use in-place operation to set intersection elements to 0
    intersection = B[:, column_equality.any(dim=0)]
    intersection[:] = 0

    # Create a new tensor without intersection elements
    B_without_intersection = B[:, ~column_equality.any(dim=0)].clone()

    return B_without_intersection

  # Display the intersection and B without intersection
  #print("Intersection:")
  #print(intersection.cpu())
  #print("B without intersection:")
  #print(B_without_intersection.cpu())


### TuneUP: Synthesizing tail nodes

In [15]:
from torch_geometric.utils import degree

def renormalize(edge_index, num_nodes):
    # Convert to PyTorch tensor for calculation
    edge_index = edge_index.clone().detach()

    # Calculate degree and create Degree Matrix D
    row, col = edge_index
    deg = degree(row, num_nodes, dtype=edge_index.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

    # Renormalize
    row, col = edge_index
    edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    return edge_index, edge_weight


def random_edge_sampler(data, percent):
    edge_index = data.edge_index
    num_nodes = data.num_nodes

    num_edges = edge_index.size(1)
    perm = torch.randperm(num_edges)
    preserve_nnz = int(num_edges * percent)

    # Indices for kept edges
    kept_indices = perm[:preserve_nnz]
    kept_edges = edge_index[:, kept_indices]
    kept_edges, kept_weights = renormalize(kept_edges, num_nodes)
    data_kept = Data(edge_index=kept_edges, edge_attr=kept_weights)

    # Indices for dropped edges
    dropped_indices = perm[preserve_nnz:]
    dropped_edges = edge_index[:, dropped_indices]
    dropped_edges, dropped_weights = renormalize(dropped_edges, num_nodes)
    data_dropped = Data(edge_index=dropped_edges, edge_attr=dropped_weights)

    return data_kept, data_dropped


"""
## USAGE:

# percent: rate of edges to keep
data_kept, data_dropped = random_edge_sampler(data, percent)

## INPUT:
Data(x=[169343, 128], edge_index=[2, 1335586], y=[169343, 1])

## RETURNS:
(Data(edge_index=[2, 468243]), Data(edge_index=[2, 962188]))

"""

'\n## USAGE:\n\n# percent: rate of edges to keep\ndata_kept, data_dropped = random_edge_sampler(data, percent)\n\n## INPUT:\nData(x=[169343, 128], edge_index=[2, 1335586], y=[169343, 1])\n\n## RETURNS:\n(Data(edge_index=[2, 468243]), Data(edge_index=[2, 962188]))\n\n'

# Execution



### To Device

In [16]:
device = torch.device("cpu")

In [17]:
model = GCN(128)

model = model.to(device)
data = data.to(device)
E_train = E_train.to(device)
E_val = E_val.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)  # L2 regularization

#train(model, V, data, train_edges=E_train, val_edges=E_val, optimizer=optimizer)


In [18]:
V.shape

torch.Size([160875])

In [19]:
v_mask = torch.zeros(num_nodes, dtype=torch.bool)
v_mask.size()

torch.Size([169343])

In [20]:
num_nodes

169343

In [21]:
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


data