## Import Libs

In [None]:
import os 
import torch 
import einops
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from mkl_random import geometric

from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import Linear
from triton.ops import attention
import torch_geometric

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Import Data

In [None]:
dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]

# normalize data
data = T.ToUndirected()(data)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)

In [None]:
print(data)

In [None]:
data.metadata()

## Design Model Architecture

In [10]:
class HeteroGATLayer(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 metadata,
                 n_heads=4,
                 dropout=0.5,
                 device=None):

        """
        Custom GAT Layer for heterogeneous data.

        Args:
        - in_features: input dimensions for each node type, 
                       assuming that each not type has the same dimension
        - out_features: Output dimension per head attention.
        - metadata: Tuple (node_types, edge_types) for the heterogeneous graph.
        - n_heads: Number of attention heads.
        - dropout: Dropout rate for attention coefficients.
        """
        super().__init__()

        self.node_types, self.edge_types = metadata
        self.n_heads = n_heads

        # Learnable weight matrices for each edge type
        self.edge_transforms = nn.ParameterDict({
            repr(edge_type): nn.Parameter(torch.randn(size=(in_features, n_heads, out_features), device=device))
            for edge_type in self.edge_types
        })
        for edge_type in self.edge_transforms:
            nn.init.xavier_normal_(self.edge_transforms[edge_type])

        # Learnable weight matrices for attention mechanism for each edge type
        self.attention_weights = nn.ParameterDict({
            repr(edge_type): nn.Parameter(torch.randn(size=(n_heads, 2 * out_features, 1), device=device))
            for edge_type in self.edge_types
        })
        for edge_type in self.attention_weights:
            nn.init.xavier_normal_(self.attention_weights[edge_type])

        # Learnable weight matrices for each node types 
        self.node_transforms = nn.ModuleDict({
            node_type: Linear(in_channels=-1, out_channels=out_features).to(device)
            for node_type in self.node_types
        })

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.dropout = dropout

        self.device = device

    def _get_attention_scores(self, h_src, h_dst, edge_index, edge_type):
        """
        Compute attention scores of source nodes for each destination nodes
        """

        # Calculate attention mechanism for message passing -> dst nodes receive message from src nodes
        # Target: attention scores: [n_heads, n_edge] 
        
        src_edge, dst_edge = edge_index

        src_features = h_src[src_edge]                                                      # [n_edge, n_heads, dim]
        
        # prepare input features for attention 
        a_input = torch.cat([h_src[src_edge], h_dst[dst_edge]], dim=-1)                     # [n_edge, n_heads, 2*dim]

        # calculate attention scores base on edge_type weight matrix 
        attention_scores = torch.matmul(a_input.permute(1, 0, 2),                           # [num_edges, n_heads, 1]
                                        self.attention_weights[repr(edge_type)])  
        
        attention_scores = self.leaky_relu(attention_scores).permute(1, 0, 2).squeeze(-1)   # [num_edges, n_heads]

        # compute softmax by index cluster (softmax for each destination node) 
        attention_scores = torch_geometric.utils.softmax(attention_scores, index=dst_edge)  # [num_edges, n_heads]
        
        return attention_scores
        

    def forward(self, x_dict, edge_index_dict):

        out_dict = {}

        # Iterate over edge types
        for edge_type, edge_index in edge_index_dict.items():
            src_type, _, dst_type = edge_type
            src_edge, dst_edge = edge_index
        
            if edge_index.size()[1] == 0 or src_type not in x_dict or dst_type not in x_dict:
                continue

            # step 1. apply linear transformation to src and dst features base on edge type 

            if isinstance(x_dict[src_type], list):
                print(src_type, edge_type)
            
            h_src = torch.matmul(x_dict[src_type],                                          # [n_heads, num_src, out_dim]
                                 self.edge_transforms[repr(edge_type)].permute(1, 0, 2))
            h_src = h_src.permute(1, 0, 2)                                                  # [num_src, n_heads, out_dim]

            h_dst = torch.matmul(x_dict[dst_type],                                          # [n_heads, num_dst, out_dim]
                                 self.edge_transforms[repr(edge_type)].permute(1, 0, 2))
            h_dst = h_dst.permute(1, 0, 2)                                                  # [num_dst, n_heads, out_dim]

            h_src = F.dropout(h_src, self.dropout, training=self.training)
            h_dst = F.dropout(h_dst, self.dropout, training=self.training)

            # step 2. compute message passing using attention mechanism 
            src_features = h_src[src_edge]
            attention_scores = self._get_attention_scores(h_src, 
                                                          h_dst, 
                                                          edge_index, 
                                                          edge_type)
            
            messages = torch.einsum('ij,ijl -> ijl', attention_scores, src_features)        # [num_edges, n_heads, dim]

            
            # step 3. aggregate message of neighbors by average messages 
            
            aggregated_features = torch.zeros((h_dst.size(0), 
                                               self.n_heads, messages.size(-1)), device=self.device)
            aggregated_features = aggregated_features.index_add_(0, dst_edge, messages)
    
            edge_count = torch.zeros((h_dst.size(0),), device=self.device)
            edge_count = edge_count.index_add_(0, dst_edge, torch.ones_like(dst_edge, dtype=torch.float))
            edge_count = edge_count.clamp(min=1)
    
            aggregated_features = aggregated_features / edge_count.view(-1, 1, 1)
            
            # aggregate multi-head attention
            aggregated_features = torch.mean(aggregated_features, dim=1)  

            # step 4. append to the output for this node type, be aware that one edge type might be involved in multiple types of edge 
            if dst_type not in out_dict:
                out_dict[dst_type] = [aggregated_features]
            else:
                out_dict[dst_type].append(aggregated_features)

        # step 5. combine all the edge type per node type and transform one more time. 
        for node_type in out_dict:
            # aggregate by concatenating/sum/average
            out_dict[node_type] = torch.cat(out_dict[node_type], dim=1)
            # out_dict[node_type] = torch.stack(out_dict[node_type], dim=1)
            # out_dict[node_type] = torch.sum(out_dict[node_type], dim=1)

            out_dict[node_type] = self.node_transforms[node_type](out_dict[node_type])

        return out_dict                  

In [11]:
class HeteroGAT(nn.Module):
    def __init__(self, 
                 in_features, 
                 hidden_dim, 
                 num_classes, 
                 n_heads=4,
                 metadata,
                 device):
        super().__init__()
        
        self.gat_layer1 = HeteroGATLayer(in_features=in_features,
                                         out_features=hidden_dim,
                                         n_heads=n_heads,
                                         dropout=0.5,
                                         metadata=metadata,
                                         device=device)
        
        self.gat_layer2 = HeteroGATLayer(in_features=hidden_dim,
                                         out_features=num_classes,
                                         n_heads=n_heads,
                                         dropout=0.5,
                                         metadata=metadata,
                                         device=device)
        
        node_types, _ = metadata
        self.fc = {
            node_type: Linear(-1, hidden_dim).to(device)
            for node_type in node_types
        }
        
        self.device = device
        
    def forward(self, x_dict, edge_index_dict):
        
        out_dict = self.gat_layer1(x_dict, edge_index_dict)        
        
        for node_type in out_dict:
            out_dict[node_type] = out_dict[node_type] + self.fc[node_type](out_dict[node_type])
            out_dict[node_type] = F.elu(out_dict[node_type])

        out_dict = self.gat_layer2(out_dict, edge_index_dict)
        
        return out_dict 
        

In [14]:
train_loader = NeighborLoader(data,
                              num_neighbors=[50,50],
                              batch_size=32,
                              input_nodes=('paper', data['paper'].train_mask))

In [None]:
from tqdm import tqdm

in_features = 128
hidden_dim = 128
num_classes = dataset.num_classes
metadata = data.metadata()

hetero_gat = HeteroGAT(in_features=in_features,
                       hidden_dim=hidden_dim,
                       num_classes=num_classes,
                       metadata=metadata,
                       device=device).to(device)

optim = torch.optim.Adam(hetero_gat.parameters(), lr=0.001, weight_decay=5e-4)

for epoch in range(1):
    epoch_loss = []
    with tqdm(train_loader, desc=f'Train. Epoch {epoch}', unit='batch') as t:
        for batch in t:
            optim.zero_grad()
            batch = batch.to(device)
            batch_size = batch['paper']['batch_size']
        
            output = hetero_gat(batch.x_dict, batch.edge_index_dict)
            loss = F.cross_entropy(output['paper'][:batch_size],
                                   batch['paper']['y'][:batch_size])
            
            epoch_loss.append(loss.item())
            loss.backward()
            optim.step()
            
            t.set_postfix(loss=sum(epoch_loss)/len(epoch_loss))

Train. Epoch 0:  49%|████▉     | 9702/19675 [22:28<33:09,  5.01batch/s, loss=4.77]

In [None]:
num_src = 10
num_dst = 6
num_edge = 8

dim = 3 
n_heads = 2

src_nodes = torch.randn(num_src, n_heads, dim)
dst_nodes = torch.randn(num_dst, n_heads, dim)

Wa = nn.Parameter(torch.randn(n_heads, 2*dim, 1))



In [None]:
src_edge = torch.randint(0, num_src, (num_edge,))
dst_edge = torch.randint(0, num_dst, (num_edge,))

edge_index = torch.cat([src_edge.unsqueeze(0), dst_edge.unsqueeze(0)], dim=0)
edge_index

In [None]:
# calculate attention 

src_features = src_nodes[src_edge]
dst_features = dst_nodes[dst_edge]


In [None]:
print(src_features.shape)
print(dst_features.shape)

In [None]:
a_input = torch.cat([src_features, dst_features], dim=-1)
a_input.shape

In [None]:
attention_scores = torch.matmul(a_input.permute(1,0,2), Wa).squeeze(-1)  # [n_heads, num_edge]
# calculate attention score base on dst edge cluster 
attention_scores = torch_geometric.utils.softmax(attention_scores, index=dst_edge, dim=-1)
attention_scores.shape

In [None]:
messages = torch.einsum('ij,jil->ijl', attention_scores, src_features).permute(1,0,2)
print(messages.shape)

aggregated_features = torch.zeros((num_dst, n_heads, dim))
print(aggregated_features.shape)

aggregated_features = aggregated_features.index_add_(0, dst_edge, messages)

edge_count = torch.zeros((num_dst,))
edge_count = edge_count.index_add_(0, dst_edge, torch.ones_like(dst_edge, dtype=torch.float))
edge_count = edge_count.clamp(min=1)

aggregated_features/edge_count.view(-1,1,1)