In [26]:
import torch
from pathlib import Path
import os
import numpy as np
import torch.nn as nn
from datetime import datetime

import torch.nn.functional as F

from fp16util import *
from resnet import *

In [27]:
from fastai.conv_learner import Learner, TrainingPhase, ModelData, accuracy, DecayType
from functools import partial
from PIL import Image

In [28]:
import argparse, os, shutil, time, warnings

In [29]:
# from fastai.models.cifar10.wideresnet import wrn_22_cat, wrn_22, WideResNetConcat
torch.backends.cudnn.benchmark = True
PATH = Path.home()/'data/cifar10/'
os.makedirs(PATH,exist_ok=True)

In [30]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
def pad(img, p=4, padding_mode='reflect'):
    return Image.fromarray(np.pad(np.asarray(img), ((p, p), (p, p), (0, 0)), padding_mode))

In [31]:
workers = 7

## Model

In [32]:
# --
# Model definition
# Derived from models in `https://github.com/kuangliu/pytorch-cifar`

class PreActBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        return out + shortcut


class ResNet18(nn.Module):
    def __init__(self, num_blocks=[2, 2, 2, 2], num_classes=10):
        super().__init__()
        
        self.in_channels = 64
        
        self.prep = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.layers = nn.Sequential(
            self._make_layer(64, 64, num_blocks[0], stride=1),
            self._make_layer(64, 128, num_blocks[1], stride=2),
            self._make_layer(128, 256, num_blocks[2], stride=2),
            self._make_layer(256, 256, num_blocks[3], stride=2),
        )
        
        self.classifier = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(PreActBlock(in_channels=in_channels, out_channels=out_channels, stride=stride))
            in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.prep(x)
        
        x = self.layers(x)
        
        x_avg = F.adaptive_avg_pool2d(x, (1, 1))
        x_avg = x_avg.view(x_avg.size(0), -1)
        
        x_max = F.adaptive_max_pool2d(x, (1, 1))
        x_max = x_max.view(x_max.size(0), -1)
        
        x = torch.cat([x_avg, x_max], dim=-1)
        
        x = self.classifier(x)
        
        return x

In [42]:
from autoaugment import CIFAR10Policy

In [45]:

def torch_loader(data_path, size, bs, val_bs=None):

    val_bs = val_bs or bs
    # Data loading code
    tfms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.24703,0.24349,0.26159))]

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(), CIFAR10Policy(), 
    ] + tfms)
    val_tfms = transforms.Compose(tfms)

    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_tfms)
    val_dataset  = datasets.CIFAR10(root=data_path, train=False, download=True, transform=val_tfms)

    train_loader = DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_loader = DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=workers, pin_memory=True)
    
    train_loader = DataPrefetcher(train_loader)
    val_loader = DataPrefetcher(val_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    data.sz = size
    return data

# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [46]:
model = ResNet18()
model = model.cuda()

model = network_to_half(model)

# AS: todo: don't copy over weights as it seems to help performance

wd=5e-4
lr=1e-1
momentum = 0.9
# learn.clip = 1e-1
bs = 256
lrs = (0, 2e-1, 1e-2, 0)
sz=32


data = torch_loader(PATH, sz, bs, bs*2)
    
learn = Learner.from_model_data(model, data)
# learn.half()
learn.crit = F.cross_entropy
learn.metrics = [accuracy]
learn.opt_fn = partial(torch.optim.SGD, nesterov=True, momentum=0.9)
def_phase = {'opt_fn':learn.opt_fn, 'wds':wd, 'momentum':0.9}

phases = [
    TrainingPhase(**def_phase, epochs=15, lr=lrs[:2], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=15, lr=lrs[1:3], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=5, lr=lrs[-2:], lr_decay=DecayType.LINEAR),
]

learn.fit_opt_sched(phases)

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, description='Epoch', max=35), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.762563   1.453778   0.466     
    1      1.273694   1.134766   0.591                      
    2      1.01924    1.052695   0.6355                     
    3      0.888116   1.402075   0.5848                      
    4      0.830706   1.238211   0.6275                      
    5      0.772058   0.711582   0.7628                      
    6      0.708016   0.863311   0.7174                      
    7      0.674272   0.820344   0.7429                      
    8      0.666573   0.580635   0.8102                      
    9      0.645891   0.761352   0.76                        
    10     0.628814   0.537986   0.81                        
    11     0.650979   0.78332    0.7417                      
    12     0.617117   0.575993   0.808                       
    13     0.624168   0.729434   0.7594                      
    14     0.625508   1.409106   0.6228                      
    15     0.613905   0.78321

[0.179841015625, 0.9385000004768371]

### BN0

In [38]:
from torch.nn.parameter import Parameter
def init_dist_weights(model):
    # https://arxiv.org/pdf/1706.02677.pdf
    # https://github.com/pytorch/examples/pull/262
    for m in model.modules():
        if isinstance(m, BasicBlock): m.bn2.weight = Parameter(torch.zeros_like(m.bn2.weight))
        if isinstance(m, Bottleneck): m.bn3.weight = Parameter(torch.zeros_like(m.bn3.weight))
        if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)


In [39]:
model = ResNet18()
model = model.cuda()

init_dist_weights(model)

# model = network_to_half(model)

# AS: todo: don't copy over weights as it seems to help performance

wd=5e-4
lr=1e-1
momentum = 0.9
# learn.clip = 1e-1
bs = 256
lrs = (0, 2e-1, 1e-2, 0)
sz=32


data = torch_loader(PATH, sz, bs, bs*2)
    
learn = Learner.from_model_data(model, data)

Files already downloaded and verified
Files already downloaded and verified


In [41]:
learn.lr_find2(wds=5e-4)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/196 [00:00<?, ?it/s]


RuntimeError: input and target shapes do not match: input [256 x 10], target [256] at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THCUNN/generic/MSECriterion.cu:15

In [None]:
# learn.half()
learn.crit = F.cross_entropy
learn.metrics = [accuracy]
learn.opt_fn = partial(torch.optim.SGD, nesterov=True, momentum=0.9)
def_phase = {'opt_fn':learn.opt_fn, 'wds':wd, 'momentum':0.9}

phases = [
    TrainingPhase(**def_phase, epochs=15, lr=lrs[:2], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=15, lr=lrs[1:3], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=5, lr=lrs[-2:], lr_decay=DecayType.LINEAR),
]

learn.fit_opt_sched(phases)