# Image classification on CIFAR-10 with ResNet18


### Imports

In [None]:
import random
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from topography.training import train, evaluate, Writer
from topography.models import resnet

### Hyperparameters

In [1]:
seed = 0 # Random seed
root = './cifar10' # Output directory

epochs = 400 # Number of training epochs
batch_size = 256 # Batch size
lr = 0.1 # Base learning rate
weight_decay = 5e-3 # Weight decay

### Set random seed

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Data loading

In [None]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = torchvision.datasets.CIFAR10(
    root=f'{root}/data', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True,
    num_workers=2, pin_memory=True)
test_set = torchvision.datasets.CIFAR10(
    root=f'{root}/data', train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=batch_size, shuffle=False,
    num_workers=2, pin_memory=True)


### Main components

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss()
model = resnet(out_features=10, pretrained=False, num_layers=18).to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[200, 300], gamma=0.1)

### Training loop

In [None]:
writer = Writer(f'{root}/runs')
for epoch in range(1, epochs+1):
    train(model, train_loader, optimizer, criterion, device, writer, epoch)
    evaluate(model, test_loader, criterion, device, writer, 'test', epoch)
    writer.save(model, optimizer, 'acc')
    writer.step()
    scheduler.step()
writer.close()