In [3]:
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import random_split, DataLoader

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = torchvision.datasets.ImageFolder(
    root='../images',
    transform=preprocess
)
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

resnet50_model = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
resnet50_model.fc = nn.Identity()
resnet50_model.eval()

# From last class: run the backbone to get 2048-d embeddings
for X, y in train_dataloader:
    print(resnet50_model(X).shape)
    print(resnet50_model(X))
    break

# Freeze all ResNet-50 parameters (transfer learning: reuse pretrained features).
for param in resnet50_model.parameters():
    param.requires_grad = False
resnet50_model.eval()

# Build a small classification head (e.g., Linear(2048->1024) + ReLU + Linear(1024->1)).
fc_model = nn.Sequential(
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)

# Chain the frozen backbone and the new head with nn.Sequential(backbone, head).
model = nn.Sequential(
    resnet50_model,
    fc_model
)

# Create an optimizer that updates ONLY the new headâ€™s parameters (not the backbone).
optimizer = torch.optim.Adam(fc_model.parameters(), lr=0.001)

# Print the full model to verify the architecture.
print(model)

# Take one batch from train_dataloader, run it through the combined model, and print out.shape.
for X, y in train_dataloader:
    out = model(X)
    print(out.shape)
    break

torch.Size([32, 2048])
tensor([[0.2105, 0.1632, 0.8361,  ..., 0.7645, 0.4049, 0.4315],
        [0.0963, 0.0567, 0.0736,  ..., 0.9347, 0.2326, 0.4182],
        [0.0402, 0.0584, 0.5249,  ..., 0.1643, 0.0000, 0.6844],
        ...,
        [0.0016, 0.3008, 0.0000,  ..., 0.0104, 0.0488, 0.0315],
        [0.0483, 0.1513, 0.0421,  ..., 0.2123, 0.1874, 0.4891],
        [0.2730, 0.4396, 0.0441,  ..., 0.4272, 0.0016, 0.4955]],
       grad_fn=<ViewBackward0>)
Sequential(
  (0): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T