In [None]:
# ==============================
# 📘 Minimal HGT Example in DGL
# ==============================

# 📦 Install DGL
!pip install dgl -f https://data.dgl.ai/wheels/repo.html --quiet
!pip install transformers --quiet

# ==============================
# 📚 Imports
# ==============================
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import HeteroGraphConv, HGTConv
from transformers import BertTokenizer, BertModel
import numpy as np

# ==============================
# 🔧 Define Toy Heterogeneous Graph
# ==============================
# 3 node types: paper, author, venue
# 3 edge types: writes, publishes_in, cites

num_papers = 5
num_authors = 3
num_venues = 2

graph_data = {
    ('author', 'writes', 'paper'): ([0, 1, 2], [0, 1, 2]),
    ('paper', 'published_in', 'venue'): ([0, 1, 2, 3], [0, 1, 0, 1]),
    ('paper', 'cites', 'paper'): ([0, 1, 2], [1, 2, 3])
}

hg = dgl.heterograph(graph_data)

# ==============================
# 📐 Node Feature Initialization
# ==============================
# We'll use BERT for paper abstracts, random for authors/venues

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
bert.eval()

paper_abstracts = [
    "Graph neural networks for citation prediction",
    "Attention mechanisms in transformers",
    "LLMs for graph completion tasks",
    "Knowledge graphs and neural architectures",
    "Survey on graph transformers"
]

def encode_bert(texts):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = bert(**inputs)
    return outputs.last_hidden_state[:, 0, :]  # CLS token

paper_feats = encode_bert(paper_abstracts)
author_feats = torch.randn(num_authors, 768)
venue_feats = torch.randn(num_venues, 768)

hg.nodes['paper'].data['h'] = paper_feats
hg.nodes['author'].data['h'] = author_feats
hg.nodes['venue'].data['h'] = venue_feats

# ==============================
# 🧠 HGT Model Definition
# ==============================
class HGTModel(nn.Module):
    def __init__(self, metadata, in_dim, hidden_dim, out_dim, num_heads, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(HGTConv(
                in_dim if i == 0 else hidden_dim,
                hidden_dim,
                metadata,
                num_heads=num_heads,
                dropout=0.2,
                use_norm=True
            ))
        self.linear = nn.Linear(hidden_dim, out_dim)

    def forward(self, g, inputs):
        h_dict = inputs
        for layer in self.layers:
            h_dict = layer(g, h_dict)
        return {k: self.linear(v) for k, v in h_dict.items()}

# ==============================
# 🔁 Training Loop (Example Task)
# ==============================
# Toy task: classify paper topics into 3 categories
labels = torch.tensor([0, 1, 2, 0, 1])
train_mask = torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HGTModel(hg.metadata(), 768, 256, 3, num_heads=4, num_layers=2).to(device)
hg = hg.to(device)
inputs = {k: hg.nodes[k].data['h'].to(device) for k in hg.ntypes}
labels = labels.to(device)
train_mask = train_mask.to(device)

opt = torch.optim.AdamW(model.parameters(), lr=0.005)

for epoch in range(50):
    model.train()
    out_dict = model(hg, inputs)
    out = out_dict['paper']
    loss = F.cross_entropy(out[train_mask], labels[train_mask])
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

# ==============================
# ✅ Inference on Paper Nodes
# ==============================
model.eval()
out = model(hg, inputs)['paper']
preds = out.argmax(dim=1)
print("Predictions:", preds.tolist())
