SeHGNN Implementation
This implementation includes:

1.Simplified Neighbor Aggregation → Precomputes neighbor aggregation.

2.Multi-layer Feature Projection → Uses MLPs to project features.

3.Transformer-based Semantic Fusion → Uses a Transformer module to combine metapath information.


In [2]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m82.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m875.6/875.6 kB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu11==11.8.87

In [3]:
!pip install torch-geometric


Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Using cached torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [4]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html


Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.2%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m51.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_sparse-0.6.18%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_cluster-1.6.3%2Bpt20cu118-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_s

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
import math

class NeighborAggregation(nn.Module):
    def __init__(self, metapaths, in_channels):
        super().__init__()
        self.metapaths = metapaths
        self.in_channels = in_channels

    def forward(self, x, edge_index_dict):
        """Precomputes neighbor aggregation for each metapath"""
        aggregated_features = {}
        for metapath in self.metapaths:
            edge_index = edge_index_dict[metapath]
            row, col = edge_index
            agg_feature = torch.zeros_like(x)
            agg_feature.index_add_(0, row, x[col])
            agg_feature /= torch.clamp(torch.bincount(row, minlength=x.size(0)).unsqueeze(1), min=1)
            aggregated_features[metapath] = agg_feature
        return aggregated_features

# ---- Step 2: Multi-layer Feature Projection (MLP) ----
class FeatureProjection(nn.Module):
    def __init__(self, in_channels, out_channels, metapaths):
        super().__init__()
        self.metapaths = metapaths
        self.mlps = nn.ModuleDict({
            metapath: nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.ReLU(),
                nn.Linear(out_channels, out_channels)
            ) for metapath in metapaths
        })

    def forward(self, aggregated_features):
        """Projects features to a uniform space using MLPs"""
        projected_features = {mp: self.mlps[mp](feat) for mp, feat in aggregated_features.items()}
        return projected_features

# ---- Step 3: Transformer-based Semantic Fusion ----
class TransformerSemanticFusion(nn.Module):
    def __init__(self, out_channels, num_metapaths, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        self.out_channels = out_channels

        # Transformer parameters
        self.query = nn.Linear(out_channels, out_channels)
        self.key = nn.Linear(out_channels, out_channels)
        self.value = nn.Linear(out_channels, out_channels)
        self.softmax = nn.Softmax(dim=-1)
        self.beta = nn.Parameter(torch.ones(1))

    def forward(self, projected_features):
        """Uses Transformer attention to merge metapath-based embeddings"""
        metapath_list = list(projected_features.keys())
        feature_matrix = torch.stack([projected_features[mp] for mp in metapath_list], dim=1)

        Q = self.query(feature_matrix)
        K = self.key(feature_matrix)
        V = self.value(feature_matrix)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.out_channels)
        attention_weights = self.softmax(attention_scores)
        fusion = self.beta * torch.matmul(attention_weights, V) + feature_matrix
        return fusion.mean(dim=1)

# ---- Putting it all together ----
class SeHGNN(nn.Module):
    def __init__(self, in_channels, out_channels, metapaths):
        super().__init__()
        self.aggregation = NeighborAggregation(metapaths, in_channels)
        self.projection = FeatureProjection(in_channels, out_channels, metapaths)
        self.fusion = TransformerSemanticFusion(out_channels, len(metapaths))

    def forward(self, x, edge_index_dict):
        aggregated_features = self.aggregation(x, edge_index_dict)
        projected_features = self.projection(aggregated_features)
        node_embeddings = self.fusion(projected_features)
        return node_embeddings

if __name__ == "__main__":
    num_nodes = 100
    in_channels = 64
    out_channels = 32
    metapaths = ["PA", "PS", "PAP", "PSP"]

    # Random input features
    x = torch.randn((num_nodes, in_channels))

    edge_index_dict = {
        "PA": torch.randint(0, num_nodes, (2, 300)),
        "PS": torch.randint(0, num_nodes, (2, 300)),
        "PAP": torch.randint(0, num_nodes, (2, 300)),
        "PSP": torch.randint(0, num_nodes, (2, 300)),
    }

    # Initialize and run SeHGNN
    model = SeHGNN(in_channels, out_channels, metapaths)
    node_embeddings = model(x, edge_index_dict)
    print("Node Embeddings Shape:", node_embeddings.shape)




Node Embeddings Shape: torch.Size([100, 32])
