
## ResNet18 & CIFAR-100

Demo File to test fast training one ResNet18 Model takes on CIFAR100, on a subset of 10 classes.

Source: Most of this code is taken from FFCV CIFAR10 Example!
https://docs.ffcv.io/ffcv_examples/cifar10.html

### Downloading and storing CIFAR-100

In [1]:
from argparse import ArgumentParser
from typing import List
import time
import numpy as np
from tqdm import tqdm

import torch as ch
import torch    
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision

from fastargs import get_current_config, Param, Section
from fastargs.decorators import param
from fastargs.validation import And, OneOf

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device {device}')

Using device cuda:0


Helper function that loads the cifar100 datasets and stores it as a .beton (ffcv) file


Source: https://docs.ffcv.io/writing_datasets.html

In [3]:

def load_cifar100(train_dataset="./data/cifar_train.beton", val_dataset="./data/cifar_test.beton"):
    datasets = {
        'train': torchvision.datasets.CIFAR100('./data', train=True, download=True),
        'test': torchvision.datasets.CIFAR100('./data', train=False, download=True)
        }

    for (name, ds) in datasets.items():
        path = train_dataset if name == 'train' else val_dataset
        writer = DatasetWriter(path, {
            'image': RGBImageField(),
            'label': IntField()
        })
        writer.from_indexed_dataset(ds)

In [4]:
load_cifar100()

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 50000/50000 [00:00<00:00, 122966.22it/s]
100%|██████████| 10000/10000 [00:00<00:00, 33110.80it/s]


### Data Loader

In [5]:
from argparse import ArgumentParser
from typing import List
import time
import numpy as np
from tqdm import tqdm

import torch as ch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision

from fastargs import get_current_config, Param, Section
from fastargs.decorators import param
from fastargs.validation import And, OneOf

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

In [6]:
def make_dataloaders(train_dataset="./data/cifar_train.beton", val_dataset="./data/cifar_test.beton", batch_size=256, num_workers=12):
    paths = {
        'train': train_dataset,
        'test': val_dataset

    }

    start_time = time.time()
    # https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151#file-cifar100_mean_std-py
    # this source details how and what mean and std of the datasets are
    # took the values from the source above and multiplied by 255
    # not sure this properly translates to std dev of the dataset TODO: check this	
    CIFAR_MEAN = [129.310, 124.108, 112.404]
    CIFAR_STD = [68.2125, 65.4075, 70.4055]
    loaders = {}

    for name in ['train', 'test']:
        label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(ch.device(device)), Squeeze()]
        image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]
        if name == 'train':
            image_pipeline.extend([
                RandomHorizontalFlip(),
                RandomTranslate(padding=2, fill=tuple(map(int, CIFAR_MEAN))),
                Cutout(4, tuple(map(int, CIFAR_MEAN))),
            ])
        image_pipeline.extend([
            ToTensor(),
            ToDevice(ch.device(device), non_blocking=True),
            ToTorchImage(),
            Convert(ch.float32), # TODO check what the impact for float16 is (it was the initial value, and why it crashes with float16)	
            torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])
        
        ordering = OrderOption.RANDOM if name == 'train' else OrderOption.SEQUENTIAL

        loaders[name] = Loader(paths[name], batch_size=batch_size, num_workers=num_workers,
                               order=ordering, drop_last=(name == 'train'),
                               pipelines={'image': image_pipeline, 'label': label_pipeline})

    return loaders, start_time

### ResNet18

In [7]:
def generate_model(output_dim:int = 100):
    
    #ResNet general source: https://pytorch.org/vision/master/models/resnet.html
    
    model = torchvision.models.resnet18(pretrained=False)
    # make fc a sequential layer
    model.fc = ch.nn.Sequential(ch.nn.Linear(model.fc.in_features, output_dim), ch.nn.Softmax(dim=1))
    model = model.to(device=device)
    return model

In [8]:
def train(model, loaders, lr=0.1, epochs=50, momentum=0.9, weight_decay=0.0001, lr_peak_epoch=5):
    
    opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    iters_per_epoch = len(loaders['train'])
    # Cyclic LR with single triangle
    lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                            [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                            [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()

    for _ in range(epochs):
        for ims, labs in tqdm(loaders['train']):
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            scheduler.step()
            

In [9]:
def evaluate(model, loaders, lr_tta=False):
    # lr_tta: whether to use test-time augmentation by flipping images horizontally
    model.eval()
    with ch.no_grad():
        for name in ['train', 'test']:
            total_correct, total_num = 0., 0.
            for ims, labs in tqdm(loaders[name]):
                with autocast():
                    out = model(ims)
                    if lr_tta:
                        out += model(ims.flip(-1))
                    total_correct += out.argmax(1).eq(labs).sum().cpu().item()
                    total_num += ims.shape[0]
            print(f'{name} accuracy: {total_correct / total_num * 100:.1f}%')

## Running the code

In [10]:
loaders, start_time = make_dataloaders()
model = generate_model()
train(model, loaders)
print(f'Total time: {time.time() - start_time:.5f}')
evaluate(model, loaders)

100%|██████████| 195/195 [00:13<00:00, 14.66it/s]
100%|██████████| 195/195 [00:02<00:00, 68.88it/s]
100%|██████████| 195/195 [00:02<00:00, 72.98it/s]
100%|██████████| 195/195 [00:02<00:00, 67.73it/s]
100%|██████████| 195/195 [00:02<00:00, 66.83it/s]
100%|██████████| 195/195 [00:02<00:00, 65.25it/s]
100%|██████████| 195/195 [00:02<00:00, 71.55it/s]
100%|██████████| 195/195 [00:02<00:00, 71.03it/s]
100%|██████████| 195/195 [00:02<00:00, 96.30it/s] 
100%|██████████| 195/195 [00:02<00:00, 68.21it/s]
100%|██████████| 195/195 [00:02<00:00, 68.91it/s]
100%|██████████| 195/195 [00:02<00:00, 90.12it/s]
100%|██████████| 195/195 [00:02<00:00, 97.10it/s] 
100%|██████████| 195/195 [00:02<00:00, 68.95it/s]
100%|██████████| 195/195 [00:02<00:00, 72.97it/s]
100%|██████████| 195/195 [00:02<00:00, 74.84it/s]
100%|██████████| 195/195 [00:02<00:00, 92.67it/s]
100%|██████████| 195/195 [00:02<00:00, 76.57it/s]
100%|██████████| 195/195 [00:02<00:00, 72.81it/s]
100%|██████████| 195/195 [00:03<00:00, 60.68it/s

Total time: 146.40217


100%|██████████| 195/195 [00:00<00:00, 199.56it/s]


train accuracy: 20.6%


100%|██████████| 40/40 [00:00<00:00, 56.57it/s]

test accuracy: 17.8%



