In [1]:
import torch
from pathlib import Path
import os
import numpy as np
import torch.nn as nn
from datetime import datetime

import torch.nn.functional as F

from fp16util import *
from resnet import *

In [2]:
from fastai.conv_learner import Learner, TrainingPhase, ModelData, accuracy, DecayType
from functools import partial
from PIL import Image

In [3]:
import argparse, os, shutil, time, warnings

In [4]:
# from fastai.models.cifar10.wideresnet import wrn_22_cat, wrn_22, WideResNetConcat
torch.backends.cudnn.benchmark = True
PATH = Path.home()/'data/cifar10/'
os.makedirs(PATH,exist_ok=True)

In [5]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
def pad(img, p=4, padding_mode='reflect'):
    return Image.fromarray(np.pad(np.asarray(img), ((p, p), (p, p), (0, 0)), padding_mode))

In [6]:
workers = 7

## Model

In [7]:
# --
# Model definition
# Derived from models in `https://github.com/kuangliu/pytorch-cifar`

class PreActBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        return out + shortcut


class ResNet18(nn.Module):
    def __init__(self, block=PreActBlock, num_blocks=[2, 2, 2, 2], num_classes=10):
        super().__init__()
        
        self.in_channels = 64
        
        self.prep = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.layers = nn.Sequential(
            self._make_layer(64, 64, num_blocks[0], stride=1),
            self._make_layer(64, 128, num_blocks[1], stride=2),
            self._make_layer(128, 256, num_blocks[2], stride=2),
            self._make_layer(256, 256, num_blocks[3], stride=2),
        )
        
        self.classifier = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(in_channels=in_channels, out_channels=out_channels, stride=stride))
            in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.prep(x)
        
        x = self.layers(x)
        
        x_avg = F.adaptive_avg_pool2d(x, (1, 1))
        x_avg = x_avg.view(x_avg.size(0), -1)
        
        x_max = F.adaptive_max_pool2d(x, (1, 1))
        x_max = x_max.view(x_max.size(0), -1)
        
        x = torch.cat([x_avg, x_max], dim=-1)
        
        x = self.classifier(x)
        
        return x

In [8]:

def torch_loader(data_path, size, bs, val_bs=None, prefetcher=False):

    val_bs = val_bs or bs
    # Data loading code
    tfms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.24703,0.24349,0.26159))]

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(),
    ] + tfms)
    val_tfms = transforms.Compose(tfms)

    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_tfms)
    val_dataset  = datasets.CIFAR10(root=data_path, train=False, download=True, transform=val_tfms)

    train_loader = DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_loader = DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=workers, pin_memory=True)
    
    if prefetcher:
        train_loader = DataPrefetcher(train_loader)
        val_loader = DataPrefetcher(val_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    data.sz = size
    return data

# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [9]:
model = ResNet18()
model = model.cuda()

model = network_to_half(model)

# AS: todo: don't copy over weights as it seems to help performance

wd=5e-4
lr=1e-1
momentum = 0.9
# learn.clip = 1e-1
bs = 256
lrs = (0, 2e-1, 1e-2, 0)
sz=32


data = torch_loader(PATH, sz, bs, bs*2)
    
learn = Learner.from_model_data(model, data)
# learn.half()
learn.crit = F.cross_entropy
learn.metrics = [accuracy]
learn.opt_fn = partial(torch.optim.SGD, nesterov=True, momentum=0.9)
def_phase = {'opt_fn':learn.opt_fn, 'wds':wd, 'momentum':0.9}

phases = [
    TrainingPhase(**def_phase, epochs=15, lr=lrs[:2], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=15, lr=lrs[1:3], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=5, lr=lrs[-2:], lr_decay=DecayType.LINEAR),
]

learn.fit_opt_sched(phases)

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, description='Epoch', max=35), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.432283   1.46675    0.4937    
    1      0.952438   1.132214   0.6185                      
    2      0.7389     1.031066   0.6632                      
    3      0.63537    0.85653    0.7174                      
    4      0.564924   0.634587   0.7829                      
    5      0.506821   0.730412   0.7593                      
    6      0.471517   0.842502   0.7353                      
    7      0.434544   0.76498    0.7479                      
    8      0.415716   0.629073   0.7945                      
    9      0.403843   0.677437   0.7815                      
    10     0.396373   0.609323   0.7963                      
    11     0.384319   0.922041   0.7339                      
    12     0.382333   0.644001   0.7885                      
    13     0.387898   0.577145   0.8122                      
    14     0.385755   0.863038   0.7432                      
    15     0.364287   0.544

[0.1986947265625, 0.9416]