In [1]:
#IMPORTS

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as tt
import matplotlib.pyplot as plt

from fastai.vision.all import *


In [2]:
#DATASET + TRANSFORMS

stats = ((0.4914, 0.4822, 0.4465),
         (0.2023, 0.1994, 0.2010))

train_tfms = tt.Compose([
    tt.RandomCrop(32, padding=4, padding_mode='reflect'),
    tt.RandomHorizontalFlip(),
    tt.ToTensor(),
    tt.Normalize(*stats)
])

valid_tfms = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(*stats)
])

trainset = torchvision.datasets.CIFAR10(
    root="data", train=True, download=True, transform=train_tfms
)

validset = torchvision.datasets.CIFAR10(
    root="data", train=False, download=True, transform=valid_tfms
)

dls = DataLoaders.from_dsets(trainset, validset, bs=256, num_workers=2).cuda()


100%|██████████| 170M/170M [00:07<00:00, 24.2MB/s] 
  entry = pickle.load(f, encoding="latin1")


In [4]:
#MODEL

def conv2d(ni, nf, stride=1, ks=3):
    return nn.Conv2d(ni, nf, ks, stride, ks//2, bias=False)

def batchnrelu(ni, nf):
    return nn.Sequential(
        nn.BatchNorm2d(ni),
        nn.ReLU(inplace=True),
        conv2d(ni, nf)
    )

class ResidualBlock(nn.Module):
    def __init__(self, ni, nf, stride=1):
        super().__init__()
        self.bn = nn.BatchNorm2d(ni)
        self.conv1 = conv2d(ni, nf, stride)
        self.conv2 = batchnrelu(nf, nf)

        self.shortcut = nn.Identity()
        if ni != nf or stride != 1:
            self.shortcut = conv2d(ni, nf, stride, ks=1)

    def forward(self, x):
        x_in = F.relu(self.bn(x), inplace=True)
        r = self.shortcut(x_in)
        x = self.conv1(x_in)
        x = self.conv2(x) * 0.2
        return x + r

def make_group(N, ni, nf, stride):
    layers = [ResidualBlock(ni, nf, stride)]
    for _ in range(1, N):
        layers.append(ResidualBlock(nf, nf))
    return layers

class WideResNet(nn.Module):
    def __init__(self, ngroups=3, N=3, nclasses=10, k=6, nstart=16):
        super().__init__()
        layers = [conv2d(3, nstart)]
        nch = [nstart]

        for i in range(ngroups):
            nch.append(nstart * (2**i) * k)
            stride = 2 if i > 0 else 1
            layers += make_group(N, nch[i], nch[i+1], stride)

        layers += [
            nn.BatchNorm2d(nch[-1]),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(nch[-1], nclasses)
        ]

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

model = WideResNet().cuda()


In [5]:
#LEARNER + TRAINING

learn = Learner(
    dls,
    model,
    loss_func=F.cross_entropy,
    metrics=accuracy
)

learn.clip = 0.1


In [None]:
#LEARNING RATE FINDER

learn.lr_find()


In [None]:
#TRAINING

learn.fit_one_cycle(10, 5e-3, wd=1e-4)
learn.fit_one_cycle(30, 5e-3, wd=1e-4)
