In [4]:
from ogb.linkproppred import PygLinkPropPredDataset

dataset = PygLinkPropPredDataset(name = "ogbl-wikikg2") 

split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
graph = dataset[0] # pyg graph object containing only training edges



Downloading http://snap.stanford.edu/ogb/data/linkproppred/wikikg-v2.zip


Downloaded 4.13 GB: 100%|██████████| 4232/4232 [07:49<00:00,  9.01it/s]


Extracting dataset/wikikg-v2.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 36792.14it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 313.38it/s]

Saving...



Done!
  self.data, self.slices = torch.load(self.processed_paths[0])
  train = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'train.pt')))
  valid = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'valid.pt')))
  test = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'test.pt')))


In [3]:
import torch_geometric 
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

In [5]:
class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args.model_type)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        assert (args.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(args.num_layers-1):
            self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout),
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args.dropout
        self.num_layers = args.num_layers

        self.emb = emb
    
    def build_conv_model(self, model_type):
        if model_type == "GAT":
            return GAT  

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout,training=self.training)

        x = self.post_mp(x)

        if self.emb == True:
            return x
        
        return F.log_softmax(x, dim=1)
    
    def loss(self, pred, label):
        return F.nll_loss(pred, label)

    

# What's new
## New Parameters
 - concat: Concatenate or average multi-head outputs
 - add_self_loops: Add self-loops to graph
 - bias: Use bias in linear layers  
 - residual: Add residual connections
 - share_weights: Share weights between source/target transformations

## Architectural Improvements 
 - Layer normalization for training stability
 - Feed-forward network after attention (Transformer-style)
 - Xavier normal weight initialization
 - Residual connections with proper dimensionality
 - GELU activation in feed-forward network

## Enhanced Attention
 - Optional weight sharing between transformations
 - Proper self-loop handling
 - Sophisticated aggregation with dimension handling

## Structural Improvements
 - Parameter initialization with gain calculation
 - Flexible multi-head concatenation
 - Improved dropout implementation

In [6]:
class GAT(MessagePassing):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 heads=2,
                 concat=True,
                 negative_slope=0.2, 
                 dropout=0., 
                 add_self_loops=True,
                 bias=True,
                 residual=False,
                 share_weights=False,
                 **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.residual = residual
        self.share_weights = share_weights

        # Linear transformations for source and target nodes
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)
        if share_weights:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # Attention mechanisms
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
        
        # Optional residual connection
        if self.residual:
            if self.concat:
                self.res_fc = Linear(in_channels, heads * out_channels, bias=bias)
            else:
                self.res_fc = Linear(in_channels, out_channels, bias=bias)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(
            heads * out_channels if concat else out_channels
        )
        
        # Feed-forward network after attention
        self.feed_forward = nn.Sequential(
            nn.Linear(heads * out_channels if concat else out_channels,
                     4 * (heads * out_channels if concat else out_channels)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * (heads * out_channels if concat else out_channels),
                     heads * out_channels if concat else out_channels)
        )
        
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.lin_l.weight, gain=gain)
        if not self.share_weights:
            nn.init.xavier_normal_(self.lin_r.weight, gain=gain)
        nn.init.xavier_normal_(self.att_l, gain=gain)
        nn.init.xavier_normal_(self.att_r, gain=gain)
        if self.residual:
            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
        
        # Initialize feed-forward layers
        for layer in self.feed_forward:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight, gain=gain)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self, x, edge_index, size=None):
        H, C = self.heads, self.out_channels

        # Add self-loops to edge_index
        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x.size(0)
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        # Linear transformations and multi-head splitting
        x_l = self.lin_l(x).view(-1, H, C)
        x_r = self.lin_r(x).view(-1, H, C)

        # Calculate attention coefficients
        alpha_l = (x_l * self.att_l).sum(dim=-1)
        alpha_r = (x_r * self.att_r).sum(dim=-1)

        # Propagate attention-weighted messages
        out = self.propagate(edge_index, 
                           x=(x_r, x_r),
                           alpha=(alpha_l, alpha_r), 
                           size=size)

        # Reshape output
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        # Residual connection
        if self.residual:
            res = self.res_fc(x)
            out = out + res

        # Layer normalization
        out = self.layer_norm(out)

        # Feed-forward network
        ff_out = self.feed_forward(out)
        
        # Final residual connection
        out = out + ff_out

        return out

    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        # Attention mechanism
        alpha = alpha_i + alpha_j
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        
        # Apply feature-wise attention
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        alpha = alpha.unsqueeze(-1)
        
        return x_j * alpha

    def aggregate(self, inputs, index, dim_size=None):
        # Aggregation with attention weights
        return torch_scatter.scatter(inputs, index, dim=0, 
                                  dim_size=dim_size, reduce='sum')