# CS224W Project

In [6]:
import copy
import numpy as np
import math
import pandas as pd
from tqdm import trange
import torch_geometric 
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.data import DataLoader

from ogb.linkproppred import PygLinkPropPredDataset


# Load Data

In [7]:
dataset = PygLinkPropPredDataset(name = "ogbl-wikikg2") 

split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
graph = dataset[0] # pyg graph object containing only training edges


  self.data, self.slices = torch.load(self.processed_paths[0])
  train = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'train.pt')))
  valid = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'valid.pt')))
  test = replace_numpy_with_torchtensor(torch.load(osp.join(path, 'test.pt')))


# Create Model

In [8]:
class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args.model_type)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim, args.num_relations, args.text_dim, args.num_communities))
        assert (args.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(args.num_layers-1):
            self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim, args.num_relations, args.text_dim, args.num_communities))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout),
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args.dropout
        self.num_layers = args.num_layers

        self.emb = emb
    
    def build_conv_model(self, model_type):
        if model_type == "GAT":
            return GAT  

    def forward(self, data):
        x, edge_index, edge_reltype, batch = data.x, data.edge_index, data.edge_reltype, data.batch

        x = torch.randn((2, 20))
        text_emb = torch.randn(x.shape[0], 100)
        community_assign = torch.zeros(x.shape[0], dtype=torch.long)
        
        # Ensure edge_reltype values are within bounds
        edge_reltype = edge_reltype % self.convs[0].num_relations  # Modulo to keep within bounds
        edge_reltype = edge_reltype.long()

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, text_emb, edge_reltype, community_assign)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        if self.emb == True:
            return x
        
        return F.log_softmax(x, dim=1)
    
    def loss(self, pred, label):
        return F.nll_loss(pred, label)

    

# Enhancements to the Graph Neural Network Architecture:

## 1. Community-based Attention using Leiden Algorithm
- Detects densely connected communities in the graph
- Allows nodes to attend differently to nodes in same vs different communities  
- Leiden algorithm provides high-quality, hierarchical community structure

## 2. Hierarchical Attention Mechanism
- Local attention: Node-to-node interactions within neighborhoods
- Global attention: Node-to-community interactions across graph
- Combines both levels for richer graph representations

## 3. Text Embedding Integration  
- Processes text associated with nodes using embeddings
- Projects text features into same space as structural features
- Enables multi-modal learning from both graph and text

## 4. Relationship-specific Processing
- Different weight matrices for different edge types
- Allows model to learn relationship-specific transformations
- Better handles heterogeneous graph structures

In [9]:
class GAT(MessagePassing):
    def __init__(self, 
                 in_channels, 
                 out_channels,
                 num_relations,
                 text_dim,
                 num_communities,
                 heads=2,
                 concat=True,
                 negative_slope=0.2,
                 dropout=0.,
                 add_self_loops=True,
                 bias=True,
                 **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.num_relations = num_relations
        self.text_dim = text_dim
        self.num_communities = num_communities
        
        # Local attention components
        self.W_Q = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        self.W_K = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        self.W_V = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        
        # Global (community) attention components
        self.V_Q = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        self.V_K = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        self.V_V = nn.Parameter(torch.Tensor(heads, in_channels, out_channels))
        
        # Relation-specific components
        self.W_r = nn.ParameterList([
            nn.Parameter(torch.Tensor(in_channels, out_channels)) 
            for _ in range(num_relations)
        ])
        
        # Text embedding processing
        self.text_proj = nn.Linear(text_dim, out_channels)
        self.W_text = nn.ParameterList([
            nn.Parameter(torch.Tensor(text_dim, out_channels))
            for _ in range(num_relations)
        ])
        
        # Position encodings for communities
        self.P_vc = nn.Parameter(torch.Tensor(num_communities, out_channels))
        
        # Edge-type specific masks
        self.M_vu = nn.Parameter(torch.Tensor(num_relations, heads))
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(3 * out_channels, 4 * out_channels),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * out_channels, out_channels)
        )
        
        # Final aggregation MLP
        self.final_mlp = nn.Sequential(
            nn.Linear(out_channels * 4, out_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_channels * 2, out_channels)
        )
        
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        
        # Initialize attention components
        for param in [self.W_Q, self.W_K, self.W_V, self.V_Q, self.V_K, self.V_V]:
            nn.init.xavier_normal_(param, gain=gain)
            
        # Initialize relation weights
        for w_r in self.W_r:
            nn.init.xavier_normal_(w_r, gain=gain)
            
        # Initialize text projections
        for w_text in self.W_text:
            nn.init.xavier_normal_(w_text, gain=gain)
            
        # Initialize position encodings and masks
        nn.init.xavier_normal_(self.P_vc, gain=gain)
        nn.init.xavier_normal_(self.M_vu, gain=gain)

    def forward(self, x, edge_index, text_emb, rel_type, community_assign, size=None):
        # Local attention
        local_out = self._compute_local_attention(x, edge_index, rel_type)
        
        # Global community attention
        global_out = self._compute_global_attention(x, community_assign)
        
        # Combine local and global attention
        combined_struct = local_out + global_out
        combined_struct = self.ffn(combined_struct)
        
        # Process text embeddings
        text_out = self._process_text_embeddings(text_emb, edge_index, rel_type)
        
        # Final aggregation
        final_out = self.final_mlp(torch.cat([
            combined_struct,
            text_out,
            x  # Original features as residual
        ], dim=-1))
        
        return final_out

    def _compute_local_attention(self, x, edge_index, rel_type):
        H = self.heads
        
        # Compute Q, K, V projections
        q = torch.einsum('bhd,ni->bhid', self.W_Q, x)
        k = torch.einsum('bhd,ni->bhid', self.W_K, x)
        v = torch.einsum('bhd,ni->bhid', self.W_V, x)
        
        # Compute attention scores
        attn_score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.out_channels)
        attn_score = attn_score + self.M_vu[rel_type].unsqueeze(-1)
        
        # Apply attention
        attn_weights = F.softmax(attn_score, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
        
        return torch.matmul(attn_weights, v)

    def _compute_global_attention(self, x, community_assign):
        # Compute community embeddings
        community_emb = scatter_mean(x, community_assign, dim=0)
        
        # Global attention computation
        q_global = torch.einsum('bhd,ni->bhid', self.V_Q, x)
        k_global = torch.einsum('bhd,ci->bhid', self.V_K, community_emb)
        v_global = torch.einsum('bhd,ci->bhid', self.V_V, community_emb)
        
        # Add position encodings
        k_global = k_global + self.P_vc[community_assign]
        
        # Compute and apply attention
        attn_score = torch.matmul(q_global, k_global.transpose(-2, -1)) / math.sqrt(self.out_channels)
        attn_weights = F.softmax(attn_score, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
        
        return torch.matmul(attn_weights, v_global)

    def _process_text_embeddings(self, text_emb, edge_index, rel_type):
        # Project text embeddings
        text_proj = self.text_proj(text_emb)
        
        # Relation-specific text processing
        rel_text = torch.einsum('bd,rd->br', text_emb, self.W_text[rel_type])
        
        return text_proj + rel_text

In [None]:
import torch.optim as optim

def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    return scheduler, optimizer

def train(dataset, args):
    test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # build model
    model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes,
                            args)
    scheduler, opt = build_optimizer(args, model.parameters())

    # train
    losses = []
    test_accs = []
    best_acc = 0
    best_model = None
    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        losses.append(total_loss)

        if epoch % 10 == 0:
          test_acc = test(test_loader, model)
          test_accs.append(test_acc)
          if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(model)
        else:
          test_accs.append(test_accs[-1])

    return test_accs, losses, best_model, best_acc, test_loader


def test(loader, test_model, is_validation=False, save_model_preds=False, model_type=None):
    test_model.eval()

    correct = 0
    # Note that Cora is only one graph!
    for data in loader:
        with torch.no_grad():
            # max(dim=1) returns values, indices tuple; only need indices
            pred = test_model(data).max(dim=1)[1]
            label = data.y

        mask = data.val_mask if is_validation else data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = label[mask]

        if save_model_preds:
          print ("Saving Model Predictions for Model Type", model_type)

          data = {}
          data['pred'] = pred.view(-1).cpu().detach().numpy()
          data['label'] = label.view(-1).cpu().detach().numpy()

          df = pd.DataFrame(data=data)
          # Save locally as csv
          df.to_csv('CORA-Node-' + model_type + '.csv', sep=',', index=False)

        correct += pred.eq(label).sum().item()

    total = 0
    for data in loader.dataset:
        total += torch.sum(data.val_mask if is_validation else data.test_mask).item()

    return correct / total

for args_dict in [
    {'model_type': 'GAT', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01, 'num_relations': 3, 'text_dim': 100, 'num_communities': 10},
]:
    args = SimpleNamespace(**args_dict)
    train(dataset, args)