# CIFAR-10 Image Classification
This notebook implements a training and testing pipeline for an image classification task on the CIFAR-10 dataset. CIFAR-10 contains 60,000 32x32 RGB images distributed evenly across 10 image classes (6,000 images per class). The provided dataset splits consists of a train set with 50,000 images and a test set with 10,000 images. Here, the train set is further split into a train set with 45,000 images and a validation set with 5,000 images to allow for model evaluation throughout the training process. The model currently used is a barebones CNN architecture (model architecture will be periodically updated with progress).
## Milestones
- 10/23: init + setup basic training pipeline
- 10/24: achieve ~50% test accuracy with a barebones CNN
- 10/27: achieve ~91% test accuracy with resnet architecture
- 11/7: implement basic vision transformer architecture
- 11/8: achieve ~97.5% test accuracy with fine-tuned vision transformer (pretrained on ImageNet)

## Setup
Import essential libraries (PyTorch, numpy, wandb) and declare hyperparameters.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts, CosineAnnealingLR
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torchvision
from einops.layers.torch import Rearrange
from typing import Tuple
import numpy as np
import math
import random
import os
import wandb
from tqdm import tqdm

# commented are resnet hyperparameters
BATCH_SIZE = 32 # 128
EPOCHS = 100
LEARNING_RATE = 1e-4 # 0.1
N_CLASSES = 10
CHECKPOINT_FOLDER = 'checkpoint'
DEVICE = torch.device('mps')
LOG_INTERVAL = 16 # 4
SEED = 2023

random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

wandb.init(project="cifar")

## Data
Load the CIFAR-10 dataset using torchvision's dataset utilities. Apply normalization based on computed mean and std of the training set (see the commented out cell below). Initialize the train, validation, and test dataloaders.

In [51]:
augmentation_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.RandomRotation(10),
    # torchvision.transforms.RandomResizedCrop((32, 32), (0.8, 1)),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

train_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=augmentation_transforms)
val_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)

indices = list(range(len(train_iter)))
np.random.shuffle(indices)
train_idxs, val_idxs = indices[:45000], indices[-5000:]

train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(train_idxs))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(val_idxs))

test_iter = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [52]:
# print(cifar_iter.data.shape)
# data = torch.tensor(cifar_iter.data, dtype=torch.float32)
# mean = torch.mean(data, dim=(0, 1, 2)) / 255
# std = torch.std(data, dim=(0, 1, 2)) / 255
# print(mean, std)

## Classification Model
Implements a Vision Transformer architecture, which consists of multihead self-attention modules that help the model find relationships between image patches.

In [53]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, n_patches: int):
        super(PositionalEncoding, self).__init__()

        position = torch.arange(n_patches).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, n_patches, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe
    

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_head: int, n_heads: int):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_head = d_head
        d_embed = d_head * n_heads
        
        self.layer_norm = nn.LayerNorm(d_model)
        self.qkv_proj = nn.Linear(d_model, 3 * d_embed)

        self.out_proj = nn.Linear(d_embed, d_model)

    def forward(self, x):
        b, p, d = x.shape

        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(b, p, self.n_heads, 3 * self.d_head).permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        logits = q @ k.transpose(-2, -1)
        logits /= self.d_head ** -0.5

        out = F.softmax(logits, dim=-1) @ v
        out = out.transpose(1, 2).reshape(b, p, -1)

        return self.out_proj(out)


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, d_head: int, n_heads: int, d_feedforward: int, dropout: float = 0.0):
        super(EncoderLayer, self).__init__()
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, d_head, n_heads)

        self.layer_norm2 = nn.LayerNorm(d_model)
        self.fcn = nn.Sequential(
            nn.Linear(d_model, d_feedforward),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(d_feedforward, d_model)
        )

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attention_out = self.attention(self.layer_norm1(x))
        x = x + self.dropout(attention_out)
        
        fcn_out = self.fcn(self.layer_norm2(x))
        x = x + self.dropout(fcn_out)
        return x


class ImageClassifier(nn.Module):
    def __init__(self, n_patches: int, patch_size: Tuple[int, int], 
                 d_model: int, d_head: int, n_heads: int, d_feedforward: int, n_layers: int):
        super(ImageClassifier, self).__init__()
        patch_area = 3 * patch_size[0] * patch_size[1]

        self.patchify = nn.Sequential(
            # rearrange operation from einops
            Rearrange('b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_size[0], pw=patch_size[1]),
            nn.LayerNorm(patch_area),
            nn.Linear(patch_area, d_model)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_encoding = PositionalEncoding(d_model, n_patches + 1)

        self.encoder = nn.ModuleList([EncoderLayer(d_model, d_head, n_heads, d_feedforward) for _ in range(n_layers)])

        self.head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, N_CLASSES)
        )

    
    def forward(self, x):
        b, c, h, w = x.shape
        patches = self.patchify(x)
        embeddings = self.pos_encoding(torch.cat((self.cls_token.repeat(b, 1, 1), patches), dim=1))

        for layer in self.encoder:
            embeddings = layer(embeddings)

        out = self.head(embeddings[:, 0])
        return out

# Train, Evaluate, Score
- train_epoch: implements the training loop, which inputs images into the model, computes the cross entropy loss between the model outputs and labels, and backpropagates
- evaluate: implements the validation loop, which evaluates the model performance with the cross entropy loss as the metric (temporarily)
- score: implements the testing loop, which computes the trained model's performance on the unseen test data

In [54]:
def train_epoch(epoch: int, train_dataloader: DataLoader, model: ImageClassifier, optimizer, scheduler):
    model.train()
    losses = 0

    for i, (images, labels) in enumerate(tqdm(train_dataloader)):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses += loss.item()
        
        if ((i + 1) % LOG_INTERVAL == 0) or (i + 1 == len(train_dataloader)):
            wandb.log({ "loss": loss.item(), "lr": optimizer.param_groups[0]['lr'] })
        
        scheduler.step(epoch + i / len(train_dataloader))

    return losses / len(train_dataloader)

In [55]:
def evaluate(val_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0

    for images, labels in tqdm(val_dataloader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        losses += loss.item()

    return losses / len(val_dataloader)

In [58]:
def score(test_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0
    acc = 0

    for i, (images, labels) in enumerate(tqdm(test_dataloader)):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        losses += loss.item()
        _, max = torch.max(logits, dim=-1)
        acc += torch.sum(max == labels).item()
        
    return losses / len(test_dataloader), acc / len(test_iter)

In [None]:
# TODO: pretrain from scratch? instead of using pytorch checkpoint
# model = ImageClassifier()
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.to(DEVICE)
wandb.watch(model, log_freq=LOG_INTERVAL)

# optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0005, nesterov=True)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-7)
# scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, threshold=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

for epoch in range(EPOCHS):
    train_loss = train_epoch(epoch, train_dataloader, model, optimizer, scheduler)
    val_loss = evaluate(val_dataloader, model)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")
    wandb.log({ "lr": optimizer.param_groups[0]['lr'] })
    if epoch % 5 == 0:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }, CHECKPOINT_FOLDER + f'/cifar_epoch{epoch}.pt')

In [None]:
# model = ImageClassifier()
model = vit_b_16(weights=None)
model.load_state_dict(torch.load('checkpoint/cifar_epoch3.pt', map_location='mps')['model_state_dict'])
model.to(DEVICE)

test_loss, test_acc = score(test_dataloader, model)
print(test_loss, test_acc)