# CIFAR-10 Image Classification
This notebook implements a training and testing pipeline for an image classification task on the CIFAR-10 dataset. CIFAR-10 contains 60,000 32x32 RGB images distributed evenly across 10 image classes (6,000 images per class). The provided dataset splits consists of a train set with 50,000 images and a test set with 10,000 images. Here, the train set is further split into a train set with 45,000 images and a validation set with 5,000 images to allow for model evaluation throughout the training process. The model currently used is a barebones CNN architecture (model architecture will be periodically updated with progress).
## Milestones
- 10/23: init + setup basic training pipeline
- 10/24: achieve ~50% test accuracy with a barebones CNN
- 10/27: achieve ~91% test accuracy with resnet architecture

## Setup
Import essential libraries (PyTorch, numpy, wandb) and declare hyperparameters.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import numpy as np
import random
import os
import wandb
from tqdm import tqdm

BATCH_SIZE = 128
EPOCHS = 100
LEARNING_RATE = 0.1
N_CLASSES = 10
CHECKPOINT_FOLDER = 'checkpoint'
DEVICE = torch.device('mps')
LOG_INTERVAL = 4
SEED = 2023

random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

wandb.init(project="cifar")

## Data
Load the CIFAR-10 dataset using torchvision's dataset utilities. Apply normalization based on computed mean and std of the training set (see the commented out cell below). Initialize the train, validation, and test dataloaders.

In [10]:
augmentation_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.RandomResizedCrop((32, 32), (0.8, 1)),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

train_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=augmentation_transforms)
val_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)

indices = list(range(len(train_iter)))
np.random.shuffle(indices)
train_idxs, val_idxs = indices[:45000], indices[-5000:]

train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(train_idxs))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(val_idxs))

test_iter = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [11]:
# print(cifar_iter.data.shape)
# data = torch.tensor(cifar_iter.data, dtype=torch.float32)
# mean = torch.mean(data, dim=(0, 1, 2)) / 255
# std = torch.std(data, dim=(0, 1, 2)) / 255
# print(mean, std)

## Classification Model
Implements a ResNet architecture, which allows for deeper neural networks and gradient flow through residual connections.

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1 if in_channels == out_channels else 2, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        skip = self.skip(x)
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = self.batch_norm2(self.conv2(x))
        return F.relu(skip + x)


class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.batch_norm = nn.BatchNorm2d(64)

        self.block1 = nn.Sequential(
            ResidualBlock(in_channels=64, out_channels=64),
            ResidualBlock(in_channels=64, out_channels=64),
        )
        self.block2 = nn.Sequential(
            ResidualBlock(in_channels=64, out_channels=128),
            ResidualBlock(in_channels=128, out_channels=128),
        )
        self.block3 = nn.Sequential(
            ResidualBlock(in_channels=128, out_channels=256),
            ResidualBlock(in_channels=256, out_channels=256),
        )
        self.block4 = nn.Sequential(
            ResidualBlock(in_channels=256, out_channels=512),
            ResidualBlock(in_channels=512, out_channels=512),
        )

        self.avg_pool = nn.AvgPool2d(4)
        
        self.linear = nn.Linear(512, N_CLASSES)
    
    def forward(self, x):
        x = F.relu(self.batch_norm(self.conv(x)))
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

# Train, Evaluate, Score
- train_epoch: implements the training loop, which inputs images into the model, computes the cross entropy loss between the model outputs and labels, and backpropagates
- evaluate: implements the validation loop, which evaluates the model performance with the cross entropy loss as the metric (temporarily)
- score: implements the testing loop, which computes the trained model's performance on the unseen test data

In [13]:
def train_epoch(train_dataloader: DataLoader, model: ImageClassifier, optimizer):
    model.train()
    losses = 0

    for i, (images, labels) in enumerate(tqdm(train_dataloader)):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses += loss.item()
        
        if ((i + 1) % LOG_INTERVAL == 0) or (i + 1 == len(train_dataloader)):
            wandb.log({ "loss": loss.item() })

    return losses / len(train_dataloader)

In [14]:
def evaluate(val_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0

    for images, labels in tqdm(val_dataloader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        losses += loss.item()

    return losses / len(val_dataloader)

In [15]:
def score(test_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0
    acc = 0

    for images, labels in test_dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        losses += loss.item()
        _, max = torch.max(logits, dim=-1)
        acc += torch.sum(max == labels).item()
        
    return losses / len(test_dataloader), acc / len(test_iter)

In [None]:
model = ImageClassifier()
model.to(DEVICE)
wandb.watch(model, log_freq=LOG_INTERVAL)

optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0005)
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3, threshold=0.001)

for epoch in range(EPOCHS):
    train_loss = train_epoch(train_dataloader, model, optimizer)
    val_loss = evaluate(val_dataloader, model)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")
    scheduler.step(val_loss)
    wandb.log({ "lr": optimizer.param_groups[0]['lr'] })
    if epoch % 5 == 0:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }, CHECKPOINT_FOLDER + f'/cifar_epoch{epoch}.pt')

In [29]:
model = ImageClassifier()
model.load_state_dict(torch.load('checkpoint/cifar_epoch100.pt')['model_state_dict'])
model.to(DEVICE)

test_loss, test_acc = score(test_dataloader, model)
print(test_loss, test_acc)

0.2601817424727392 0.9174
