In [None]:
# =============================
# 📦 Install Libraries
# =============================
!pip install torch torchvision torchaudio --quiet
!pip install torch-geometric --quiet
!pip install transformers --quiet

# =============================
# 📚 Imports
# =============================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt

# =============================
# 📥 Load Cora Dataset
# =============================
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

# =============================
# 🧠 Generate BERT Texts (Simulated Abstracts)
# =============================
node_ids = [f"Paper about topic {int(label)}" for label in data.y.tolist()]

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
bert.eval()

def embed_texts(texts, batch_size=16):
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=32)
        with torch.no_grad():
            outputs = bert(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings.append(cls_embeddings)
    return torch.cat(embeddings, dim=0)

x = embed_texts(node_ids)
x = torch.tensor(normalize(x), dtype=torch.float)
data.x = x  # Replace node features

# =============================
# 🧠 Define GAT Model
# =============================
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.gat2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return x

# =============================
# 🏋️ Train Model
# =============================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels=x.shape[1], hidden_channels=32, out_channels=7).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()

losses = []
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# =============================
# 📈 Plot Loss
# =============================
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("CrossEntropy Loss")
plt.grid(True)
plt.show()

# =============================
# ✅ Evaluate
# =============================
model.eval()
_, pred = model(data).max(dim=1)
correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
acc = correct / int(data.test_mask.sum())
print(f"Test Accuracy: {acc:.4f}")
