# Training CIFAR-10 in 75 seconds on a single GPU

In [None]:
# Run the following in command line to create a conda environment with FFCV installed:

# conda create -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo \
# opencv pytorch torchvision cudatoolkit=11.3 numba -c conda-forge -c pytorch \
# && conda activate ffcv && conda update ffmpeg && pip install ffcv

In [None]:
from typing import List

import torch
from torch import nn
import torchvision

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

### Step 1: Create an FFCV-compatible CIFAR-10 dataset

In [None]:
cifar_dir = '/home/jovyan/work/DataLocal-w/ffcv-cifar10/'

datasets = {
    'train': torchvision.datasets.CIFAR10(cifar_dir, train=True, download=True),
    'test': torchvision.datasets.CIFAR10(cifar_dir, train=False, download=True)
}


In [None]:
for (name, ds) in datasets.items():
    writer = DatasetWriter(f'{cifar_dir}/cifar_{name}.beton', {
        'image': RGBImageField(),
        'label': IntField()
    })
    writer.from_indexed_dataset(ds)

### Step 2: Create data loaders

In [None]:
# Note that statistics are wrt to uin8 range, [0,255].
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
device = 'cuda:1'
num_workers = 128

BATCH_SIZE = 400

loaders = {}
for name in ['train', 'test']:
    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    # Add image transforms and normalization
    if name == 'train':
        image_pipeline.extend([
            RandomHorizontalFlip(),
            RandomTranslate(padding=2),
            Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization.
        ])
    image_pipeline.extend([
        ToTensor(),
        ToDevice(device, non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    # Create loaders
    loaders[name] = Loader(f'{cifar_dir}/cifar_{name}.beton',
                            batch_size=BATCH_SIZE,
                            num_workers=num_workers,
                            order=OrderOption.RANDOM,
                            drop_last=(name == 'train'),
                            pipelines={'image': image_pipeline,
                                       'label': label_pipeline})

### Step 3: Setup model architecture and optimization parameters

In [None]:
class Mul(nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight

class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class Residual(nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)

def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
    return nn.Sequential(
            nn.Conv2d(channels_in, channels_out,
                         kernel_size=kernel_size, stride=stride, padding=padding,
                         groups=groups, bias=False),
            nn.BatchNorm2d(channels_out),
            nn.ReLU(inplace=True)
    )

NUM_CLASSES = 10
model = nn.Sequential(
    conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
    conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
    Residual(nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
    conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(2),
    Residual(nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
    conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
    nn.AdaptiveMaxPool2d((1, 1)),
    Flatten(),
    nn.Linear(128, NUM_CLASSES, bias=False),
    Mul(0.2)
)
model = model.to(device,memory_format=torch.channels_last)

### Step 4: Train and evaluate the model

In [None]:
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler

EPOCHS = 10

opt = SGD(model.parameters(), lr=.5, momentum=0.9, weight_decay=5e-4)
iters_per_epoch = 50000 // BATCH_SIZE
lr_schedule = np.interp(np.arange((EPOCHS+1) * iters_per_epoch),
                        [0, 5 * iters_per_epoch, EPOCHS * iters_per_epoch],
                        [0, 1, 0])
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
scaler = GradScaler()
loss_fn = CrossEntropyLoss(label_smoothing=0.1)

In [None]:
from fastprogress import progress_bar
from tqdm import tqdm
import time

start = time.time()

for ep in progress_bar(range(EPOCHS)):
    model.train()
    for ims, labs in loaders['train']:
        opt.zero_grad(set_to_none=True)
        with autocast():
            out = model(ims)
            loss = loss_fn(out, labs)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        scheduler.step()

    model.eval()
    with torch.no_grad():
        total_correct, total_num = 0., 0.
        for ims, labs in loaders['test']:
            with autocast():
                out = (model(ims) + model(torch.fliplr(ims))) / 2. # Test-time augmentation
                total_correct += out.argmax(1).eq(labs).sum().cpu().item()
                total_num += ims.shape[0]

        print(f'Epoch {ep} – Accuracy: {total_correct / total_num * 100:.1f}%. Time elapsed: {round(time.time() - start,4)} sec')
        