

* Node-Level Attention: Learns the importance between a node and its meta-path-based neighbors.

*  Semantic-Level Attention: Learns the importance of different meta-paths.
*HAN Model: Combines both attention mechanisms to generate node embeddings.
*Training Loop: Uses a semi-supervised approach for node classification.




In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class NodeLevelAttention(nn.Module):
    def __init__(self, in_dim, out_dim, heads=8, dropout=0.6):
        super(NodeLevelAttention, self).__init__()
        self.gat = GATConv(in_dim, out_dim // heads, heads=heads, dropout=dropout)

    def forward(self, x, edge_index):
        return F.elu(self.gat(x, edge_index))

class SemanticLevelAttention(nn.Module):
    def __init__(self, in_dim):
        super(SemanticLevelAttention, self).__init__()
        self.attn_vector = nn.Parameter(torch.Tensor(in_dim, 1))
        nn.init.xavier_uniform_(self.attn_vector.data, gain=1.414)

    def forward(self, semantic_embeddings):
        scores = torch.matmul(semantic_embeddings, self.attn_vector).squeeze()
        attention_weights = F.softmax(scores, dim=0)
        return torch.sum(attention_weights.unsqueeze(-1) * semantic_embeddings, dim=0)

class HAN(nn.Module):
    def __init__(self, meta_paths, in_dim, hidden_dim, out_dim, heads=8, dropout=0.6):
        super(HAN, self).__init__()
        self.meta_paths = meta_paths
        self.node_attention = nn.ModuleList([NodeLevelAttention(in_dim, hidden_dim, heads, dropout) for _ in meta_paths])
        self.semantic_attention = SemanticLevelAttention(hidden_dim)
        self.classifier = nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_indices):
        semantic_embeddings = []
        for i, edge_index in enumerate(edge_indices):
            semantic_embeddings.append(self.node_attention[i](x, edge_index))
        semantic_embeddings = torch.stack(semantic_embeddings, dim=0)
        final_embedding = self.semantic_attention(semantic_embeddings)
        return F.log_softmax(self.classifier(final_embedding), dim=1)

def train_han(model, data, optimizer, epochs=100):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(data.x, data.edge_indices)
        loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

class ExampleData:
    def __init__(self):
        self.x = torch.rand((10, 16))
        self.edge_indices = [torch.randint(0, 10, (2, 20)) for _ in range(2)]
        self.y = torch.randint(0, 3, (10,))
        self.train_mask = torch.tensor([True] * 7 + [False] * 3)

data = ExampleData()
model = HAN(meta_paths=[0, 1], in_dim=16, hidden_dim=32, out_dim=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

train_han(model, data, optimizer, epochs=10)



Epoch 1, Loss: 1.0966864824295044
Epoch 2, Loss: 1.0392301082611084
Epoch 3, Loss: 0.9214492440223694
Epoch 4, Loss: 0.9888289570808411
Epoch 5, Loss: 0.9607367515563965
Epoch 6, Loss: 0.8389010429382324
Epoch 7, Loss: 0.9290308356285095
Epoch 8, Loss: 0.8736441731452942
Epoch 9, Loss: 1.0188474655151367
Epoch 10, Loss: 0.8146460652351379
