In [None]:
import sys
sys.path.append("../../")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision import models, transforms
from torchvision.datasets import CIFAR10

from utils.train import train

In [None]:
# resnet is trained on imagenet, so we use imagenet stats
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

train_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=train_transforms
)
val_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=False, 
    download=True, 
    transform=val_transforms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## Transfer Learning

Transfer learning is a technique where a model trained on a large dataset (such as ImageNet) is reused for a different but related task.

Instead of training a network from scratch, we start from pretrained weights that have already learned general visual features like edges, textures, and shapes. These features transfer well to many image classification problems.

<p align="center">
    <img src="../../assets/img/accuracy/transfer_learning.png" width="400">
</p>

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 10 classes

for param in model.fc.parameters():
    param.requires_grad = True

In [None]:
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)

Expect a final accuracy of **~81%**.

That is slightly better than previous approaches in less than half the epochs!



In [None]:
train(model, 10, train_loader, val_loader, optimizer, loss_fn, device)