# Fine‑Tuning for Graph Neural Networks (DGL)
Pretrain a GCN on Cora, then fine‑tune (freeze MP, train readout) on Citeseer.

In [None]:
!pip -q install -U dgl torch


In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

class GCN(nn.Module):
    def __init__(self, in_feats, hid, out_feats):
        super().__init__()
        self.conv1 = dgl.nn.GraphConv(in_feats, hid, activation=F.relu)
        self.conv2 = dgl.nn.GraphConv(hid, out_feats)
        self.readout = nn.Identity()  # placeholder
    def forward(self, g, x):
        h = self.conv1(g, x)
        h = self.conv2(g, h)
        return h


In [None]:
# Pretrain on Cora (node classification)
cora = CoraGraphDataset()
g = cora[0].to(device)
feat = g.ndata["feat"].to(device)
labels = g.ndata["label"].to(device)
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

model = GCN(feat.shape[1], 64, cora.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

for epoch in range(100):
    model.train()
    logits = model(g, feat)
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    opt.zero_grad(); loss.backward(); opt.step()
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            val_acc = (logits[val_mask].argmax(1) == labels[val_mask]).float().mean().item()
        print(f"Epoch {epoch} | loss {loss.item():.3f} | val acc {val_acc:.3f}")

torch.save(model.state_dict(), "gcn_cora.pt")


In [None]:
# Fine‑tune on Citeseer (freeze convs, train only a new linear readout)
cit = CiteseerGraphDataset()
g2 = cit[0].to(device)
feat2 = g2.ndata["feat"].to(device)
labels2 = g2.ndata["label"].to(device)
train2 = g2.ndata["train_mask"]
val2 = g2.ndata["val_mask"]
test2 = g2.ndata["test_mask"]

ft_model = GCN(feat2.shape[1], 64, cit.num_classes).to(device)
ft_model.load_state_dict(torch.load("gcn_cora.pt"), strict=False)

for p in [ft_model.conv1, ft_model.conv2]:
    for param in p.parameters():
        param.requires_grad = False

# New readout: map hidden representation to Citeseer classes
ft_model.conv2 = dgl.nn.GraphConv(64, cit.num_classes).to(device)

opt = torch.optim.Adam([p for p in ft_model.parameters() if p.requires_grad], lr=5e-3)

for epoch in range(60):
    ft_model.train()
    logits = ft_model(g2, feat2)
    loss = F.cross_entropy(logits[train2], labels2[train2])
    opt.zero_grad(); loss.backward(); opt.step()
    if epoch % 10 == 0:
        ft_model.eval()
        with torch.no_grad():
            val_acc = (logits[val2].argmax(1) == labels2[val2]).float().mean().item()
        print(f"[FT] Epoch {epoch} | loss {loss.item():.3f} | val acc {val_acc:.3f}")
