# Heterogeneous Graph Neural Network for Link Prediction
This notebook demonstrates the implementation of a Graph Neural Network (GNN) for link prediction in heterogeneous networks. We use PyTorch and PyTorch Geometric for building and training the model. The notebook is structured to provide clarity and a step-by-step approach to the process.


In [3]:
import argparse
import copy
import time

import pickle
import numpy as np
import networkx as nx 
import sklearn.metrics as metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T
import torch_geometric.nn as pyg_nn

from deepsnap.hetero_graph import HeteroGraph
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch
from deepsnap.hetero_gnn import (
    HeteroSAGEConv,
    HeteroConv,
    forward_op
)

ModuleNotFoundError: No module named 'torch_sparse'

In [4]:
!nvidia-smi

Tue Dec 12 23:51:57 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.41                 Driver Version: 531.41       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3050 T...  WDDM | 00000000:F3:00.0 Off |                  N/A |
| N/A   39C    P8                6W /  N/A|      0MiB /  4096MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Cell 3: Argument Parsing Function

In [None]:
def arg_parse():
    parser = argparse.ArgumentParser(description='Link pred arguments.')
    parser.add_argument('--device', type=str,
                        help='CPU / GPU device.')
    parser.add_argument('--data_path', type=str,
                        help='Path to wordnet nx gpickle file.')
    parser.add_argument('--epochs', type=int,
                        help='Number of epochs to train.')
    parser.add_argument('--mode', type=str,
                        help='Link prediction mode. Disjoint or all.')
    parser.add_argument('--edge_message_ratio', type=float,
                        help='Ratio of edges used for message-passing (only in disjoint mode).')
    parser.add_argument('--hidden_dim', type=int,
                        help='Hidden dimension of GNN.')
    parser.add_argument('--lr', type=float,
                        help='The learning rate.')
    parser.add_argument('--weight_decay', type=float,
                        help='Weight decay.')

    parser.set_defaults(
            device='cuda:0',
            data_path='data/oncourse.pkl',
            epochs=200, # originally 200
            mode='disjoint',
            edge_message_ratio=0.8,
            hidden_dim=32,
            lr=0.01,
            weight_decay=1e-4,
    )
    return parser.parse_args()


### Cell 4: Data Transformation Function

In [None]:
def WN_transform(G, num_edge_types, input_dim=5):
    # Function to transform the data
        H = nx.MultiDiGraph()
    for node in G.nodes():
        #node_label=G.nodes[node]['node_type']
        H.add_node(node, node_type='n1', node_feature=torch.ones(input_dim))
    for u, v, edge_key in G.edges:
        l = G[u][v][edge_key]['e_label']
        e_feat = torch.zeros(num_edge_types)
        e_feat[l] = 1.
        H.add_edge(u, v, edge_feature=e_feat, edge_type=str(l.item()))
    return H