## Train a Network

We are now going to train a network using the complete data set / data loader that we covered to date.
This will run on CIFAR10 which is harder to correctly classify than MNIST.
The images are still very small so the downloaded data will be small.

In [None]:
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim

import numpy as np
from torchvision import datasets, transforms

from tqdm import tqdm

import PIL

In [None]:
CUDA_AVAILABLE = torch.cuda.is_available()

if not CUDA_AVAILABLE:
    print("If you are running this on Google Colab then")
    print("Menu -> Runtime -> Change runtime type -> Hardware Accelerator -> GPU")
    print("Then try this again...")

In [None]:
def to_image(image: torch.Tensor) -> PIL.Image:
    # the rescaling also reverses the normalization (close enough)
    image -= image.min()
    image /= image.max()
    return transforms.functional.to_pil_image(image.cpu(), 'RGB')

---

Because the images are so very small (32x32) we can't use any augmentation that would rotate or shift the image, as that could well make the image impossible to recognize.
We can use color jittering as that alters the entire image consistently, and we can flip horizontally.

It's generally a bad idea to flip photos vertically, as you don't usually take a photo upside down.
There are certain situations were that could be appropriate - satellite images or individual cell images are good examples.
Good augmentation relies on your knowledge of the problem space.

In [None]:
# 341 mb dataset

train_ds = datasets.CIFAR10(
    'data',
    download=True,
    train=True,
    transform=transforms.Compose([
        # This lets you randomly apply all the transformations in this list.
        # The test is not once per transform, it either skips all transforms or applies all transforms.
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
        ]),
        
        # This is a combination of RandomApply and HorizontalFlip, by default has a 50% chance of flipping the image
        transforms.RandomHorizontalFlip(),

        # We can only train with tensors, so we convert the image to a tensor.
        transforms.ToTensor(),

        # A very good thing to do is to normalize the tensors.
        # This ensures the resulting tensors have a mean of 0 and a standard deviation of 1.
        # For pre-existing datasets you can look up the normalization values, or you can calculate them like I did above.
        transforms.Normalize(mean=(0.49139968, 0.48215841, 0.44653091), std=(0.24703223, 0.24348513, 0.26158784)),
    ]),
)

In [None]:
train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=128, shuffle=True, num_workers=4
)

In [None]:
valid_dl = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        'data',
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # Must apply the same normalization!
            transforms.Normalize(mean=(0.49139968, 0.48215841, 0.44653091), std=(0.24703223, 0.24348513, 0.26158784)),
        ]),
    ),
    batch_size=128,
    shuffle=True,
    num_workers=4,
)

In [None]:
# Same resnet model as before, this shows how to use torch.hub
model = torch.hub.load(
    github='pytorch/vision:v0.6.0',
    model='resnet18',
    pretrained=True,
)

In [None]:
if CUDA_AVAILABLE:
    model = model.to('cuda')

In [None]:
def train(
    model: nn.Module,
    train: torch.utils.data.DataLoader,
    valid: torch.utils.data.DataLoader,
    epochs: int = 4,
    lr: float = 0.001
) -> None:
    optimizer = optim.AdamW(params=model.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()

    train_batches = len(train)
    train_loss = 0.

    valid_batches = len(valid)
    valid_loss = 0.

    for epoch in range(epochs):
        train_loss = eval_loss = 0.

        for inputs, targets in tqdm(train, desc=f"train {epoch}"):
            if CUDA_AVAILABLE:
                inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss_value = loss(outputs, targets)
            loss_value.backward()
            optimizer.step()

            train_loss += loss_value.item()

        with torch.no_grad():
            for inputs, targets in tqdm(valid, desc=f"valid {epoch}"):
                if CUDA_AVAILABLE:
                    inputs, targets = inputs.cuda(), targets.cuda()

                outputs = model(inputs)
                loss_value = loss(outputs, targets)
                valid_loss += loss_value.item()
        
        # remember tensorboardx makes pretty graphs
        train_loss /= train_batches
        valid_loss /= valid_batches
        print(f"\rtrain loss: {train_loss:.2e}")
        print(f"valid loss: {valid_loss:.2e}")

In [None]:
train(model, train_dl, valid_dl)

In [None]:
def score(
    model: nn.Module,
    valid: torch.utils.data.DataLoader,
) -> float:
    correct, incorrect = 0, 0

    with torch.no_grad():
        for inputs, targets in tqdm(valid, desc=f"validation"):
            if CUDA_AVAILABLE:
                inputs, targets = inputs.cuda(), targets.cuda()

            outputs = model(inputs)
            matching = torch.eq(targets, outputs.argmax(dim=1)).sum().item()

            correct += matching
            incorrect += inputs.shape[0] - matching
    
    return correct / (correct + incorrect)

In [None]:
score(model, valid_dl)

This isn't a great result.
The CIFAR-10 stats on Wikipedia suggest this is state of the art for 2010.

The Resnet 18 architecture is very basic though.
What we need to do now is improve it!