In [None]:
# =============================
# 📦 Install Required Libraries
# =============================
!pip install torch torchvision torchaudio --quiet
!pip install torch-geometric --quiet
!pip install transformers --quiet

# =============================
# 📚 Imports
# =============================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt

# =============================
# 📥 Load the Cora Dataset
# =============================
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

# =============================
# 🧠 Simulate Text Input per Node
# =============================
texts = [f"paper about topic {int(label)}" for label in data.y.tolist()]

# =============================
# 🔡 Encode Texts with BERT
# =============================
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
bert.eval()

def embed_texts(texts, batch_size=16):
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=32)
        with torch.no_grad():
            outputs = bert(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
        embeddings.append(cls_embeddings)
    return torch.cat(embeddings, dim=0)

bert_embeds = embed_texts(texts)
bert_embeds = torch.tensor(normalize(bert_embeds), dtype=torch.float)

# =============================
# 🔧 Define GCN Model
# =============================
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# =============================
# 🧠 Define Fusion Model
# =============================
class BERT_GCN_Fusion(nn.Module):
    def __init__(self, gnn_in, gnn_hidden, gnn_out, bert_dim, num_classes):
        super().__init__()
        self.gcn = GCN(gnn_in, gnn_hidden, gnn_out)
        self.mlp = nn.Sequential(
            nn.Linear(bert_dim + gnn_out, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, data, bert_embed):
        gnn_out = self.gcn(data.x, data.edge_index)
        combined = torch.cat([bert_embed, gnn_out], dim=1)
        return self.mlp(combined)

# =============================
# 🏋️ Training
# =============================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BERT_GCN_Fusion(gnn_in=data.x.shape[1], gnn_hidden=64, gnn_out=64,
                        bert_dim=bert_embeds.shape[1], num_classes=dataset.num_classes).to(device)
data = data.to(device)
bert_embeds = bert_embeds.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()

losses = []
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(data, bert_embeds)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# =============================
# 📈 Plot Loss
# =============================
plt.plot(losses)
plt.title("Training Loss (BERT + GCN Fusion)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# =============================
# ✅ Evaluation
# =============================
model.eval()
pred = model(data, bert_embeds).argmax(dim=1)
correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
acc = correct / int(data.test_mask.sum())
print(f"Test Accuracy: {acc:.4f}")
