In [1]:
import os, time
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms, models

In [2]:
from torchvision.datasets import LFWPeople

In [3]:
test_set = LFWPeople(root='./data', download=True)

URLError: <urlopen error [Errno 8] nodename nor servname provided, or not known>

## Fine-tuning (training loop)

In [1]:
# finetune_resnet18_cifar10_mps_noscaler.py
import os, time
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms, models

# -----------------------------
# Device (Apple Silicon MPS)
# -----------------------------
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("Using device:", device)

# -----------------------------
# Hyperparams
# -----------------------------
batch_size   = 128
epochs       = 10
lr           = 5e-4
weight_decay = 1e-4
val_size     = 5000  # from 50k train

# -----------------------------
# Weights & Normalization
# -----------------------------
weights = models.ResNet18_Weights.DEFAULT  # ImageNet-pretrained
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# Train-time augmentation (to 224x224) + ImageNet normalization
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Eval-time transforms (resize/center-crop) + ImageNet normalization
val_tf = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# -----------------------------
# Datasets & Splits
# -----------------------------
root = "./data"
train_full_aug  = torchvision.datasets.CIFAR10(root=root, train=True,  download=True, transform=train_tf)
train_full_eval = torchvision.datasets.CIFAR10(root=root, train=True,  download=False, transform=val_tf)
N = len(train_full_aug)  # 50_000

g = torch.Generator().manual_seed(1337)
perm = torch.randperm(N, generator=g)
val_idx   = perm[:val_size]
train_idx = perm[val_size:]

train_set = Subset(train_full_aug,  train_idx)
val_set   = Subset(train_full_eval, val_idx)

test_set  = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=val_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=4, persistent_workers=True)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)

# -----------------------------
# Model
# -----------------------------
model = models.resnet18(weights=weights)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # CIFAR-10 classes
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

# -----------------------------
# Warm-up (prime MPS/allocations)
# -----------------------------
model.eval()
with torch.no_grad():
    dummy = torch.randn(1, 3, 224, 224, device=device)
    _ = model(dummy)

# -----------------------------
# Eval helper (FP32)
# -----------------------------
@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        loss_sum += loss.item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return (loss_sum / total), (100.0 * correct / total)

# -----------------------------
# Training loop (FP32, no scaler)
# -----------------------------
best_val_acc = 0.0
ckpt_path = "resnet18_cifar10_mps_best.pt"

for epoch in range(3):
    model.train()
    start = time.time()
    running_loss, running_correct, seen = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss   += loss.item() * y.size(0)
        running_correct += (logits.argmax(1) == y).sum().item()
        seen           += y.size(0)

    scheduler.step()

    if device.type == "mps":
        torch.mps.synchronize()

    train_loss = running_loss / seen
    train_acc  = 100.0 * running_correct / seen

    val_loss, val_acc = evaluate(val_loader)

    elapsed = time.time() - start
    print(f"Epoch {epoch:02d}/{epochs} | "
          f"train loss {train_loss:.4f} acc {train_acc:.2f}% | "
          f"val loss {val_loss:.4f} acc {val_acc:.2f}% | {elapsed:.1f}s")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({"model": model.state_dict(),
                    "best_val_acc": best_val_acc,
                    "epoch": epoch}, ckpt_path)
        print(f"  ✅ Saved new best to {ckpt_path} (val acc {best_val_acc:.2f}%)")


Using device: mps
Files already downloaded and verified
Files already downloaded and verified
Epoch 00/10 | train loss 0.4924 acc 83.02% | val loss 0.3228 acc 89.00% | 179.3s
  ✅ Saved new best to resnet18_cifar10_mps_best.pt (val acc 89.00%)
Epoch 01/10 | train loss 0.3075 acc 89.37% | val loss 0.2795 acc 90.78% | 167.8s
  ✅ Saved new best to resnet18_cifar10_mps_best.pt (val acc 90.78%)
Epoch 02/10 | train loss 0.2387 acc 91.73% | val loss 0.2056 acc 92.82% | 168.2s
  ✅ Saved new best to resnet18_cifar10_mps_best.pt (val acc 92.82%)


## Inference Loop

In [2]:

# -----------------------------
# Test evaluation (best ckpt)
# -----------------------------
if os.path.exists(ckpt_path):
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model"])
    print(f"Loaded best checkpoint from epoch {state['epoch']} (val acc {state['best_val_acc']:.2f}%).")

test_loss, test_acc = evaluate(test_loader)
print(f"Test: loss {test_loss:.4f} | acc {test_acc:.2f}%")

Loaded best checkpoint from epoch 2 (val acc 92.82%).


  state = torch.load(ckpt_path, map_location=device)


Test: loss 0.2148 | acc 92.73%


In [None]:
coins = {1,3,5}
def dp(target):
    if target in coins:
        return 1
    if target > 0:
        return min(dp(target-5), dp(target-3), dp(target-1)) + 1
    else:
        return 0