# 02 — Train ResNet (Transfer Learning)

Fine-tune a ResNet-34 or ResNet-50 backbone using **dual learning rates** (smaller LR on pretrained backbone,
larger LR on the new classification head), with **StepLR** schedule, **early stopping**, and **best-model saving**.
We log **top-1**, **top-5**, **F1 (macro)**, and **mAP**.

In [None]:
# %pip install torch torchvision torchaudio
# %pip install numpy pandas scikit-learn matplotlib tqdm

import os, json, time, numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import models, datasets, transforms
from sklearn.metrics import f1_score, average_precision_score, top_k_accuracy_score

DATA_ROOT = Path("/path/to/Places2_simp")  # <-- EDIT
SAVE_DIR = Path("../checkpoints"); SAVE_DIR.mkdir(parents=True, exist_ok=True)
NUM_WORKERS = 4
BATCH_TRAIN, BATCH_VAL = 256, 1024

CFG = {
    "arch": "resnet34",      # 'resnet34' or 'resnet50'
    "epochs": 30,
    "patience": 10,
    "lr_backbone": 1e-4,
    "lr_head": 1e-3,
    "weight_decay": 8e-4,
    "step_size": 6,
    "gamma": 0.6,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}
print(CFG)

In [None]:
train_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=20, fill=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

bare = datasets.ImageFolder(root=str(DATA_ROOT))
targets = np.array([y for _, y in bare.imgs])
num_classes = len(bare.classes)

train_idx, val_idx = [], []
for cls, idxs in pd.Series(np.arange(len(targets))).groupby(targets).groups.items():
    idxs = np.array(list(idxs)); np.random.shuffle(idxs); n_val = int(0.2*len(idxs))
    val_idx.extend(idxs[:n_val]); train_idx.extend(idxs[n_val:])

train_ds = Subset(datasets.ImageFolder(root=str(DATA_ROOT), transform=train_tf), train_idx)
val_ds   = Subset(datasets.ImageFolder(root=str(DATA_ROOT), transform=val_tf), val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_VAL,   shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

num_classes

In [None]:
def build_model(arch, num_classes):
    if arch == "resnet34":
        m = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
    elif arch == "resnet50":
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    else:
        raise ValueError("arch must be resnet34 or resnet50")
    in_features = m.fc.in_features
    m.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(in_features, num_classes))
    return m

model = build_model(CFG["arch"], num_classes).to(CFG["device"])
backbone_params, head_params = [], []
for name, p in model.named_parameters():
    (head_params if name.startswith("fc.") else backbone_params).append(p)

optimizer = torch.optim.AdamW([
    {"params": backbone_params, "lr": CFG["lr_backbone"]},
    {"params": head_params,     "lr": CFG["lr_head"]}
], weight_decay=CFG["weight_decay"])

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=CFG["step_size"], gamma=CFG["gamma"])
criterion = nn.CrossEntropyLoss()

In [None]:
def epoch_run(loader, train=True):
    model.train(train)
    loss_sum, y_true, y_prob = 0.0, [], []
    for x,y in loader:
        x,y = x.to(CFG["device"]), y.to(CFG["device"])
        with torch.set_grad_enabled(train):
            logits = model(x); loss = criterion(logits, y)
            if train: optimizer.zero_grad(); loss.backward(); optimizer.step()
        loss_sum += float(loss.detach().cpu()) * y.size(0)
        y_true.append(y.detach().cpu().numpy())
        y_prob.append(torch.softmax(logits, dim=1).detach().cpu().numpy())
    y_true = np.concatenate(y_true); y_prob = np.concatenate(y_prob)
    y_pred = y_prob.argmax(1)
    top1 = (y_pred == y_true).mean()
    top5 = top_k_accuracy_score(y_true, y_prob, k=5, labels=np.arange(num_classes))
    f1   = f1_score(y_true, y_pred, average="macro")
    y_true_ovr = np.eye(num_classes)[y_true]
    mAP  = average_precision_score(y_true_ovr, y_prob, average="macro")
    return loss_sum/len(loader.dataset), top1, top5, f1, mAP

best_val_loss = float("inf"); patience = 0; history=[]
for epoch in range(1, CFG["epochs"]+1):
    tr = epoch_run(train_loader, True)
    va = epoch_run(val_loader, False)
    scheduler.step()
    history.append({"epoch":epoch, "train_loss":tr[0], "val_loss":va[0],
                    "train_top1":tr[1], "val_top1":va[1],
                    "train_top5":tr[2], "val_top5":va[2],
                    "train_f1":tr[3],   "val_f1":va[3],
                    "train_mAP":tr[4],  "val_mAP":va[4]})
    print(f"Epoch {epoch:02d} | tr loss {tr[0]:.4f} top1 {tr[1]*100:.2f} top5 {tr[2]*100:.2f} | "
          f"va loss {va[0]:.4f} top1 {va[1]*100:.2f} top5 {va[2]*100:.2f} f1 {va[3]:.3f} mAP {va[4]:.3f}")
    if va[0] < best_val_loss:
        best_val_loss = va[0]; patience = 0
        torch.save(model.state_dict(), SAVE_DIR / f"best_{CFG['arch']}.pth")
    else:
        patience += 1
        if patience >= CFG["patience"]:
            print("Early stopping."); break

hist_df = pd.DataFrame(history)
hist_df.to_csv(SAVE_DIR / f"history_{CFG['arch']}.csv", index=False)
hist_df.tail()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax[0].plot(hist_df['epoch'], hist_df['train_loss'], label='train'); ax[0].plot(hist_df['epoch'], hist_df['val_loss'], label='val'); ax[0].set_title('Loss'); ax[0].legend()
ax[1].plot(hist_df['epoch'], hist_df['val_top1']*100, label='Top-1'); ax[1].plot(hist_df['epoch'], hist_df['val_top5']*100, label='Top-5'); ax[1].set_title('Accuracies'); ax[1].legend()
plt.tight_layout(); plt.show()