In [None]:
# ======================================
# 🧠 LLM + GNN (with Edge Type Awareness)
# ======================================

# 📦 Install Dependencies
!pip install torch torchvision torchaudio --quiet
!pip install torch-geometric --quiet
!pip install transformers --quiet

# =======================
# 📚 Import Libraries
# =======================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import RGCNConv
from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
import networkx as nx

# =======================
# 📥 Load Cora Dataset
# =======================
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges // 2}, Classes: {dataset.num_classes}")

# =======================
# 🏷️ Simulate Edge Types
# =======================
# Let's simulate edge types (e.g., 0 = 'cites', 1 = 'extends', 2 = 'contradicts')
import random
num_edge_types = 3
edge_type = torch.tensor([random.randint(0, num_edge_types - 1) for _ in range(data.edge_index.size(1))])

# =======================
# 🤖 Encode Node Texts with BERT
# =======================
texts = [f"paper about topic {int(label)}" for label in data.y.tolist()]

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
bert.eval()

def embed_texts(texts, batch_size=16):
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=32)
        with torch.no_grad():
            outputs = bert(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings.append(cls_embeddings)
    return torch.cat(embeddings, dim=0)

bert_embeds = embed_texts(texts)
bert_embeds = torch.tensor(normalize(bert_embeds), dtype=torch.float)

# =======================
# 🔧 Define R-GCN Model
# =======================
class RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations)

    def forward(self, x, edge_index, edge_type):
        x = F.relu(self.conv1(x, edge_index, edge_type))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

# =======================
# 🏋️ Train R-GCN
# =======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RGCN(in_channels=bert_embeds.size(1), hidden_channels=64, out_channels=dataset.num_classes, num_relations=num_edge_types).to(device)
data = data.to(device)
bert_embeds = bert_embeds.to(device)
edge_type = edge_type.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(bert_embeds, data.edge_index, edge_type)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# =======================
# 🎯 Evaluate
# =======================
model.eval()
pred = model(bert_embeds, data.edge_index, edge_type).argmax(dim=1)
correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
acc = correct / int(data.test_mask.sum())
print(f"Test Accuracy: {acc:.4f}")

# =======================
# 🧠 Simulate LLM-Inferred Node
# =======================
llm_text = "Paper on transformer architectures applied to citation graphs"
inputs = tokenizer(llm_text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
    llm_embed = bert(**inputs).last_hidden_state[:, 0, :]
llm_embed = F.normalize(llm_embed, p=2, dim=1).to(device)

# =======================
# 🔎 Link Prediction via Similarity
# =======================
with torch.no_grad():
    out_embed = model(bert_embeds, data.edge_index, edge_type)
    out_embed = F.normalize(out_embed, p=2, dim=1)
    sim = torch.matmul(out_embed, llm_embed.T).squeeze()
    topk = sim.topk(5).indices

# =======================
# 🔗 Add Node and Edges
# =======================
extended_x = torch.cat([bert_embeds, llm_embed], dim=0)
new_node_idx = extended_x.size(0) - 1
new_edges = torch.stack([
    torch.full((5,), new_node_idx), topk
], dim=0)
rev_edges = torch.stack([
    topk, torch.full((5,), new_node_idx)
], dim=0)

extended_edge_index = torch.cat([data.edge_index, new_edges, rev_edges], dim=1)
new_edge_types = torch.tensor([0]*5 + [0]*5).to(device)  # assume 'cites'
extended_edge_type = torch.cat([edge_type, new_edge_types], dim=0)

# =======================
# 🕸️ Visualize Subgraph
# =======================
import networkx as nx
sub_nodes = topk.tolist() + [new_node_idx]
src, tgt = extended_edge_index
mask = [(int(s) in sub_nodes and int(t) in sub_nodes) for s, t in zip(src, tgt)]
edge_sub = extended_edge_index[:, mask]

G = nx.Graph()
for i in sub_nodes:
    G.add_node(i, label="LLM" if i == new_node_idx else f"Node {i}")
for i in range(edge_sub.size(1)):
    u, v = int(edge_sub[0, i]), int(edge_sub[1, i])
    G.add_edge(u, v)

plt.figure(figsize=(8, 6))
pos = nx.spring_layout(G, seed=42)
nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, 'label'),
        node_color=["red" if n == new_node_idx else "skyblue" for n in G.nodes()], node_size=700)
plt.title("Subgraph Around LLM-Generated Node")
plt.show()