## Import Libs

In [1]:
import os 
import torch 
import einops
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import Linear
from triton.ops import attention

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Import Data

In [2]:
dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]

# normalize data
data = T.ToUndirected()(data)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)

In [3]:
print(data)

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389],
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 11529061] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)


In [4]:
data.metadata()

(['paper', 'author', 'institution', 'field_of_study'],
 [('author', 'affiliated_with', 'institution'),
  ('author', 'writes', 'paper'),
  ('paper', 'cites', 'paper'),
  ('paper', 'has_topic', 'field_of_study'),
  ('institution', 'rev_affiliated_with', 'author'),
  ('paper', 'rev_writes', 'author'),
  ('field_of_study', 'rev_has_topic', 'paper')])

## Design Model Architecture

In [5]:
class HeteroGATLayer(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 metadata,
                 n_heads=4,
                 dropout=0.5,
                 device=None):

        """
        Custom GAT Layer for heterogeneous data.

        Args:
        - in_features: input dimensions for each node type, 
                       assuming that each not type has the same dimension
        - out_features: Output dimension per head attention.
        - metadata: Tuple (node_types, edge_types) for the heterogeneous graph.
        - n_heads: Number of attention heads.
        - dropout: Dropout rate for attention coefficients.
        """
        super().__init__()

        self.node_types, self.edge_types = metadata
        self.n_heads = n_heads

        # Learnable weight matrices for each edge type
        self.edge_transforms = nn.ParameterDict({
            repr(edge_type): nn.Parameter(torch.randn(size=(in_features, n_heads, out_features), device=device))
            for edge_type in self.edge_types
        })
        for edge_type in self.edge_transforms:
            nn.init.xavier_normal_(self.edge_transforms[edge_type])

        # Learnable weight matrices for attention mechanism for each edge type
        self.attention_weights = nn.ParameterDict({
            repr(edge_type): nn.Parameter(torch.randn(size=(n_heads, 2 * out_features, 1), device=device))
            for edge_type in self.edge_types
        })
        for edge_type in self.attention_weights:
            nn.init.xavier_normal_(self.attention_weights[edge_type])

        # Learnable weight matrices for each node types 
        self.node_transforms = nn.ModuleDict({
            node_type: Linear(in_channels=-1, out_channels=out_features).to(device)
            for node_type in self.node_types
        })

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.dropout = dropout

        self.device = device

    def _get_attention_scores(self, h_src, h_dst, edge_index, edge_type):
        """
        Compute attention scores of source nodes for each destination nodes
        """

        # Calculate attention mechanism for message passing 
        # Target: attention scores: [n_heads, num_src, num_dst] 
        # -> dst nodes receive message from src nodes

        # step 1. broadcast src nodes and dst nodes and concat them to create attention features:
        num_dst = h_dst.size(0)
        num_src = h_src.size(0)
                
        h_src_repeat = h_src.repeat(1, num_dst, 1).view(num_src * num_dst, -1,
                                                        h_src.size(-1))                                 # [num_src*num_dst, n_heads, out_dim]

        h_dst_repeat = h_dst.repeat(num_src, 1, 1)                                                      # [num_src*num_dst, n_heads, out_dim]

        a_input = (torch.cat([h_src_repeat, h_dst_repeat], dim=-1)                                      # [num_dst, n_heads, num_src, 2*out_dim]
                   .view(num_dst, -1, num_src, h_src.size(-1) * 2))

        # step 2. calculate attention scores base on edge_type

        attention_scores = (self.leaky_relu(torch.matmul(a_input,                                       # [num_dst, n_heads, num_src]
                                                         self.attention_weights[repr(edge_type)]))
                            .squeeze(-1))

        # step 3. mask the neighbors using adjacency matrix 

        adjacency_matrix = torch.zeros((h_dst.size(0), h_src.size(0)),                                  # [num_dst, num_src] 
                                       device=device)
        
        adjacency_matrix[edge_index[1], edge_index[0]] = 1
        adjacency_matrix = einops.repeat(adjacency_matrix, "m n -> m k n",
                                         k=self.n_heads)                                                # [num_dst, n_heads, num_src]
        
        attention_scores = attention_scores.masked_fill(adjacency_matrix == 0, float('-inf'))
        
        # step 4. compute softmax on the source nodes dimension 
        attention_scores = F.softmax(attention_scores, dim=-1)                                          # [num_dst, n_heads, num_src] 
           
        # print(attention_scores.shape)
        # print(torch.sum(attention_scores, dim=-1)[0])
    
        return attention_scores

    def forward(self, x_dict, edge_index_dict):

        out_dict = {}
        
        # Iterate over edge types
        for edge_type, edge_index in edge_index_dict.items():
            if edge_index.size()[1] == 0:
                continue 
            
            src_type, _, dst_type = edge_type
            if src_type not in x_dict or dst_type not in x_dict:
                continue 
            
            # step 1. apply linear transformation to src and dst features base on edge type 
            
            if isinstance(x_dict[src_type], list):
                print(src_type, edge_type)
            
            h_src = torch.matmul(x_dict[src_type],  # [n_heads, num_src, out_dim]
                                 self.edge_transforms[repr(edge_type)].permute(1, 0, 2))
            # print('h_src', h_src[0][:10])
            h_src = h_src.permute(1, 0, 2)                                                              # [num_src, n_heads, out_dim]

            h_dst = torch.matmul(x_dict[dst_type],                                                      # [n_heads, num_dst, out_dim]
                                 self.edge_transforms[repr(edge_type)].permute(1, 0, 2))
            h_dst = h_dst.permute(1, 0, 2)                                                              # [num_dst, n_heads, out_dim]

            h_src = F.dropout(h_src, self.dropout, training=self.training)
            h_dst = F.dropout(h_dst, self.dropout, training=self.training)
            
            # step 2. compute message passing using attention mechanism 
            
            attention_scores = self._get_attention_scores(h_src, h_dst, edge_index, edge_type)

            # step 3. update dst node embeddings base on the message from src nodes 
            # step 3.a compute dst node embeddings base on attention scores
            updated_h_dst = torch.matmul(attention_scores.permute(1, 0, 2),
                                             h_src.permute(1, 0, 2))                                    # [n_heads, num_dst, out_dim]             
                        
            
            # step 3.b aggregate the features from multi head attentions. 
            updated_h_dst = torch.mean(updated_h_dst, dim=0).squeeze(0)                                  # [num_dst, out_dim]  
            # print(edge_type)
            # print('updated_h_dst', updated_h_dst.shape)
            # print(updated_h_dst)
            
            # step 3.c check nan in h_dst, nan means isolated nodes and replace the nan row by ....
            nan_mask = torch.isnan(updated_h_dst).all(dim=1)                                            # [num_dist,]
            # print(h_dst.shape)
            # h_dst = torch.sum(h_dst, dim=1)                                                           # [num_dist, out_dim]
            # print('h_dst', h_dst.shape)
            updated_h_dst[nan_mask] = torch.zeros_like(updated_h_dst)[nan_mask]
            # print(updated_h_dst)

            # step 4. append to the output for this node type. 
            # be aware that one edge type might be involved in multiple types of edge 
            if dst_type not in out_dict:
                out_dict[dst_type] = [updated_h_dst]
            else:
                out_dict[dst_type].append(updated_h_dst)

        # step 5. combine all the edge type per node type and transform one more time. 
        for node_type in out_dict:            
            # aggregate by concatenating/summing  
            # out_dict[node_type] = torch.cat(out_dict[node_type], dim=1)
            out_dict[node_type] = torch.stack(out_dict[node_type], dim=1)
            # print(out_dict[node_type].shape)
            out_dict[node_type] = torch.sum(out_dict[node_type], dim=1)
            # print(out_dict[node_type].shape)
            out_dict[node_type] = self.node_transforms[node_type](out_dict[node_type])

        return out_dict                  

In [6]:
class HeteroGAT(nn.Module):
    def __init__(self, in_features, hidden_dim, num_classes, metadata, device):
        super().__init__()
        
        self.gat_layer1 = HeteroGATLayer(in_features=in_features,
                                         out_features=num_classes,
                                         n_heads=1,
                                         dropout=0.5,
                                         metadata=metadata,
                                         device=device)
        
        self.gat_layer2 = HeteroGATLayer(in_features=hidden_dim,
                                         out_features=num_classes,
                                         n_heads=1,
                                         dropout=0.5,
                                         metadata=metadata,
                                         device=device)
        
        node_types, _ = metadata
        self.fc = {
            node_type: Linear(-1, hidden_dim).to(device)
            for node_type in node_types
        }
        
        self.device = device
        
    def forward(self, x_dict, edge_index_dict):
        
        out_dict = self.gat_layer1(x_dict, edge_index_dict)        
        
        for node_type in out_dict:
            # if not isinstance(out_dict[node_type], list):
            out_dict[node_type] = out_dict[node_type] + self.fc[node_type](out_dict[node_type])
            out_dict[node_type] = F.elu(out_dict[node_type])
        
        out_dict = self.gat_layer2(out_dict, edge_index_dict)
        
        return out_dict 
        

In [7]:
train_loader = NeighborLoader(data,
                              num_neighbors=[30,30],
                              batch_size=128,
                              input_nodes=('paper', data['paper'].train_mask))

In [8]:
from tqdm import tqdm

in_features = 128
hidden_dim = 64
num_classes = dataset.num_classes
metadata = data.metadata()

hetero_gat = HeteroGAT(in_features=in_features,
                       hidden_dim=hidden_dim,
                       num_classes=num_classes,
                       metadata=metadata,
                       device=device).to(device)

optim = torch.optim.Adam(hetero_gat.parameters(), lr=0.005, weight_decay=5e-4)

for epoch in range(10):
    epoch_loss = []
    with tqdm(train_loader, desc=f'Train. Epoch {epoch}', unit='batch') as t:
        for batch in t:
            optim.zero_grad()
            batch = batch.to(device)
            batch_size = batch['paper']['batch_size']
        
            output = hetero_gat(batch.x_dict, batch.edge_index_dict)
            loss = F.cross_entropy(output['paper'][:batch_size],
                                   batch['paper']['y'][:batch_size])
            
            epoch_loss.append(loss.item())
            loss.backward()
            optim.step()
            
            t.set_postfix(loss=sum(epoch_loss)/len(epoch_loss))


Train. Epoch 0: 100%|██████████| 4919/4919 [07:46<00:00, 10.54batch/s, loss=5.02]
Train. Epoch 1:   6%|▌         | 280/4919 [00:29<08:02,  9.61batch/s, loss=5]   


KeyboardInterrupt: 