In [None]:
# imports
import os
import torch, torchvision
import numpy as np
from importlib import reload
import resnets
reload(resnets)
from resnets import train, evaluate, ResNet18

In [None]:
# Batch size - number of images within a training batch of one training iteration i.e. 64
N_BATCH = 25

# Training epoch - number of passes through the full training dataset i.e. 20
N_EPOCH = 100

# Learning rate - step size to update parameters i.e. 1e-1
LEARNING_RATE = 1e-3

# Learning rate decay - scaling factor to decrease learning rate at the end of each decay period i.e. 0.10
LEARNING_RATE_DECAY = 0.05

# Learning rate decay period - number of epochs before reducing/decaying learning rate i.e. 5
LEARNING_RATE_DECAY_PERIOD = 10

In [None]:
# Set up dataloading
# Create transformations to apply to data during training
# https://pytorch.org/docs/stable/torchvision/transforms.html
transforms_train = torchvision.transforms.Compose([
    # Include random brightness, contrast, saturation between [0.8, 1.2] and
    # horizontal flip augmentations
    torchvision.transforms.ColorJitter(brightness=(0.8, 1.2),
                                       contrast=(0.8, 1.2),
                                       saturation=(0.8, 1.2)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
])

# Download and setup CIFAR10 training set using preconfigured torchvision.datasets.CIFAR10
cifar10_train = torchvision.datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=transforms_train)

# torch.utils.data.DataLoader and set shuffle=True, drop_last=True, num_workers=2
dataloader_train = torch.utils.data.DataLoader(cifar10_train,
                                               batch_size=N_BATCH,
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=2)

# Define the possible classes in CIFAR10
class_names = [
    'plane',
    'car',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]

# CIFAR10 has 10 classes
n_class = len(class_names)

# 3 input channels for RGB
n_input_channels = 3

net = ResNet18(n_input_channels, n_class)

# Setup learning rate SGD optimizer
# https://pytorch.org/docs/stable/optim.html?#torch.optim.SGD
optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)

# Train network
net.train()

net = train(
    model=net,
    dataloader=dataloader_train,
    n_epochs=N_EPOCH,
    optimizer=optimizer,
    lr_decay=LEARNING_RATE_DECAY,
    lr_decay_period=LEARNING_RATE_DECAY_PERIOD,
    device='cpu'
)

# Save weights
checkpoint_path = 'checkpoints/resnet.pth'
os.makedirs('checkpoints', exist_ok=True)
torch.save(net.state_dict(), checkpoint_path)

In [None]:
# Set up test dataloading
transforms_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

cifar10_test = torchvision.datasets.CIFAR10(
    root='data',
    train=False,
    download=True,
    transform=transforms_test
)

dataloader_test = torch.utils.data.DataLoader(cifar10_test,
                                              batch_size=N_BATCH,
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=2)

# Load model and evaluate
net = ResNet18(n_input_channels, n_class)
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint)
net.eval()

accuracy = evaluate(
    model=net,
    dataloader=dataloader_test,
    class_names=class_names,
    device='cuda'
)