In [None]:
import torch
from pathlib import Path
import os
import numpy as np
import torch.nn as nn
from datetime import datetime

import torch.nn.functional as F

from fp16util import *
from resnet import *

In [None]:

def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--save-dir', type=str, default=Path.cwd(), help='Directory to save logs and models.')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('-b-size', '--batch-size', default=256, type=int,
                        metavar='N', help='mini-batch size (default: 256)')
#     parser.add_argument('--init-bn0', action='store_true', help='Intialize running batch norm mean to 0')
    parser.add_argument('--print-freq', '-p', default=200, type=int,
                        metavar='N', help='print every this many steps (default: 5)')
#     parser.add_argument('--no-bn-wd', action='store_true', help='Remove batch norm from weight decay')
    parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
    parser.add_argument('--loss-scale', type=float, default=1,
                        help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
    parser.add_argument('--distributed', action='store_true', help='Run distributed training')
    parser.add_argument('--world-size', default=-1, type=int, 
                        help='total number of processes (machines*gpus)')
    parser.add_argument('--dist-url', default='env://', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int,
                        help='Used for multi-process training. Can either be manually set ' +
                        'or automatically set by using \'python -m multiproc\'.')
    return parser


In [None]:
import argparse, os, shutil, time, warnings

In [28]:
args_input = [
    '/home/paperspace/data/cifar10', 
    '--save-dir', '/home/paperspace/data/cifar_training/preact_test',
    '-b', '256', 
    '--loss-scale', '512',
    '--fp16',
    '--wd', '5e-4',
    '--momentum', '0.9',
    '--phases', '[(0,2e-1,15),(2e-1,1e-2,15),(1e-2,0,5)]'
#     '--train-half' # With fp16, iterations are so fast this doesn't matter
]

In [29]:
global args
args = get_parser().parse_args(args_input)

In [30]:
# from fastai.models.cifar10.wideresnet import wrn_22_cat, wrn_22, WideResNetConcat
torch.backends.cudnn.benchmark = True
PATH = Path.home()/'data/cifar10/'
os.makedirs(PATH,exist_ok=True)

In [31]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
def pad(img, p=4, padding_mode='reflect'):
    return Image.fromarray(np.pad(np.asarray(img), ((p, p), (p, p), (0, 0)), padding_mode))

## Model

In [32]:
# --
# Model definition
# Derived from models in `https://github.com/kuangliu/pytorch-cifar`

class PreActBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        return out + shortcut


class ResNet18(nn.Module):
    def __init__(self, num_blocks=[2, 2, 2, 2], num_classes=10):
        super().__init__()
        
        self.in_channels = 64
        
        self.prep = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.layers = nn.Sequential(
            self._make_layer(64, 64, num_blocks[0], stride=1),
            self._make_layer(64, 128, num_blocks[1], stride=2),
            self._make_layer(128, 256, num_blocks[2], stride=2),
            self._make_layer(256, 256, num_blocks[3], stride=2),
        )
        
        self.classifier = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(PreActBlock(in_channels=in_channels, out_channels=out_channels, stride=stride))
            in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.prep(x)
        
        x = self.layers(x)
        
        x_avg = F.adaptive_avg_pool2d(x, (1, 1))
        x_avg = x_avg.view(x_avg.size(0), -1)
        
        x_max = F.adaptive_max_pool2d(x, (1, 1))
        x_max = x_max.view(x_max.size(0), -1)
        
        x = torch.cat([x_avg, x_max], dim=-1)
        
        x = self.classifier(x)
        
        return x

### Torch loader

In [33]:

def fast_collate(batch):
    if not batch: return torch.tensor([]), torch.tensor([])
    imgs = [img[0] for img in batch]
    targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
    w = imgs[0].size[0]
    h = imgs[0].size[1]
    tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
    for i, img in enumerate(imgs):
        nump_array = np.asarray(img, dtype=np.uint8)
        tens = torch.from_numpy(nump_array)
        if(nump_array.ndim < 3):
            nump_array = np.expand_dims(nump_array, axis=-1)
        nump_array = np.rollaxis(nump_array, 2)
        tensor[i] += torch.from_numpy(nump_array)

        # Seems to be slower for our pipeline. Need to ask Sylvain
        # tensor[i] += pil2tensor(img)
        
    return tensor, targets


def torch_loader(data_path, size, bs, val_bs=None):

    val_bs = val_bs or bs
    # Data loading code

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(),
    ])

    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_tfms)
    val_dataset  = datasets.CIFAR10(root=data_path, train=False, download=True)
    
    train_sampler = (torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None)
#     val_sampler = (torch.utils.data.distributed.DistributedSampler(val_dataset) if args.distributed else None)
    val_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=fast_collate)
    
#     if prefetcher:
#     train_loader = DataPrefetcher(train_loader, fp16=True)
#     val_loader = DataPrefetcher(val_loader, fp16=True)
    
    return train_loader, val_loader


# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, prefetch=True, fp16=False):
        self.loader = loader
        self.prefetch = prefetch
        self.mean = torch.tensor([0.4914 * 255, 0.4822 * 255, 0.4465 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.24703 * 255, 0.24349 * 255, 0.26159 * 255]).cuda().view(1,3,1,1)
        self.fp16 = fp16
        self.loaditer = iter(self.loader)
        if self.fp16:
            self.mean = self.mean.half()
            self.std = self.std.half()
        if self.prefetch:
            self.stream = torch.cuda.Stream()
            self.next_input = None
            self.next_target = None
            self.preload()

    def __len__(self): return len(self.loader)

    def preload(self):
        self.next_input, self.next_target = next(self.loaditer)
        with torch.cuda.stream(self.stream):
            self.next_input = self.process_input(self.next_input)
            self.next_target = self.next_target.cuda(non_blocking=True)
    
    def process_input(self, input, non_blocking=True):
        input = input.cuda(non_blocking=non_blocking)
        if self.fp16: input = input.half()
        else: input = input.float()
        if len(input.shape) < 3: return input
        return input.sub_(self.mean).div_(self.std)
            
    def __iter__(self):
        if not self.prefetch:
            for input, target in self.loaditer:
                yield self.process_input(input), target.cuda()
            return
        while True:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            try: self.preload() # 0.5 fix
            except Exception as e:
                yield input, target
                break
            yield input, target

In [36]:
# orig submission params
# bs = 128
# lrs = (0, 1e-1, 5e-3, 0)

# higher batch size - able to converge around epoch 31 ~ 3:48
# bs = 256
# lrs = (0, 2e-1, 1e-2, 0)

# wd=5e-4
# # lr=1e-1
# momentum = 0.9

sz = 32
trn_loader, val_loader = torch_loader(PATH, sz, args.batch_size, args.batch_size*2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
eval('[(0,2e-1,15),(2e-1,1e-2,15),(1e-2,0,5)]')

In [None]:
sum([1,2,3])

In [37]:

# lrs = (0, 2e-1, 1e-2, 0)
# start_lr, end_lr, num_epochs
# [(0,2e-1,15),(2e-1,1e-2,15),(1e-2,0,5)]


class Scheduler():
    def __init__(self, optimizer, phases=[(0,2e-1,15),(2e-1,1e-2,15),(1e-2,0,5)]):
        self.optimizer = optimizer
        self.current_lr = None
        self.phases = phases
        self.tot_epochs = sum([p[2] for p in phases])

    def linear_lr(self, start_lr, end_lr, epoch_curr, batch_curr, epoch_tot, batch_tot):
        step_tot = epoch_tot * batch_tot
        step_curr = epoch_curr * batch_tot + batch_curr
        step_size = (end_lr - start_lr)/step_tot
        return start_lr + step_curr * step_size
    
    def get_current_phase(self, epoch):
        epoch_accum = 0
        for phase in self.phases:
            start_lr,end_lr,num_epochs = phase
            if epoch <= epoch_accum+num_epochs: return start_lr, end_lr, num_epochs, epoch - epoch_accum
            epoch_accum += num_epochs
        raise Exception('Epoch out of range')
            
    def get_lr(self, epoch, batch_curr, batch_tot):
        start_lr, end_lr, num_epochs, relative_epoch = self.get_current_phase(epoch)
        return self.linear_lr(start_lr, end_lr, relative_epoch, batch_curr, num_epochs, batch_tot)

    def update_lr(self, epoch, batch_num, batch_tot):
        lr = self.get_lr(epoch, batch_num, batch_tot)
        if (self.current_lr != lr) and ((batch_num == 1) or (batch_num == batch_tot)): 
            print(f'Changing LR from {self.current_lr} to {lr}')

        self.current_lr = lr

        for param_group in self.optimizer.param_groups:
            lr_old = param_group['lr'] or lr
            param_group['lr'] = lr

            # Trick 4: apply momentum correction when lr is updated
            # https://github.com/pytorch/examples/pull/262
#             if lr > lr_old: param_group['momentum'] = lr / lr_old * args.momentum
#             else: param_group['momentum'] = args.momentum


In [38]:

def str_to_num_array(argstr, num_type=int):
    return [num_type(s) for s in argstr.split(',')]

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return float(t[0])
#         return t[0]

def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()

    st = time.time()
    trn_len = len(trn_loader)

    # print('Begin training loop:', st)
    for i,(input,target) in enumerate(DataPrefetcher(trn_loader)):
        batch_size = input.size(0)
        batch_num = i+1
        # if i == 0: print('Received input:', time.time()-st)

        # measure data loading time
        data_time.update(time.time() - end)
        scheduler.update_lr(epoch, i+1, trn_len)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        if args.distributed:
            # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
            corr1 = correct(output.data, target)
            metrics = torch.tensor([batch_size, loss, corr1]).float().cuda()
            batch_total, reduced_loss, corr1 = sum_tensor(metrics)
            reduced_loss = reduced_loss/dist.get_world_size()
            prec1 = corr1*(100.0/batch_total)
        else:
            reduced_loss = loss.data
            batch_total = input.size(0)
            prec1 = accuracy(output.data, target) # measure accuracy and record loss

        losses.update(to_python_float(reduced_loss), batch_total)
        top1.update(to_python_float(prec1), batch_total)

        loss = loss*args.loss_scale
        # compute gradient and do SGD step
        # if i == 0: print('Evaluate and loss:', time.time()-st)
        if args.fp16:
            model.zero_grad()
            loss.backward()
            model_grads_to_master_grads(model_params, master_params)
            for param in master_params:
                param.grad.data = param.grad.data/args.loss_scale
            optimizer.step()
            master_params_to_model_params(model_params, master_params)
            torch.cuda.synchronize()
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # if i == 0: print('Backward step:', time.time()-st)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        should_print = (batch_num%args.print_freq == 0) or (batch_num==trn_len)
        if args.local_rank == 0 and should_print:
            output = ('Epoch: [{0}][{1}/{2}]\t' \
                    + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                    + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                    + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                    + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})').format(
                    epoch, batch_num, trn_len, batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1)
            print(output)
            with open(f'{args.save_dir}/full.log', 'a') as f:
                f.write(output + '\n')
    
def validate(val_loader, model, criterion, epoch, start_time):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()
    end = time.time()
    val_len = len(val_loader)

    for i,(input,target) in enumerate(DataPrefetcher(val_loader)):
        batch_num = i+1
        if args.distributed and False: # (AS) Remove this later
            prec1, loss, batch_total = distributed_predict(input, target, model, criterion)
        else:
            with torch.no_grad():
                output = model(input)
                loss = criterion(output, target).data
            batch_total = input.size(0)
            prec1 = accuracy(output.data, target)
            
        losses.update(to_python_float(loss), batch_total)
        top1.update(to_python_float(prec1), batch_total)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        should_print = (batch_num%args.print_freq == 0) or (batch_num==val_len)
        if args.local_rank == 0 and should_print:
#             output = (batch_num, val_len,batch_time,losses,top1)
            output = ('Test: [{0}/{1}]\t' \
                    + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                    + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                    + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})').format(
                    batch_num, val_len, batch_time=batch_time, loss=losses,
                    top1=top1)
            print(output)
            with open(f'{args.save_dir}/full.log', 'a') as f:
                f.write(output + '\n')


def distributed_predict(input, target, model, criterion):
    batch_size = input.size(0)
    output = loss = corr1 = valid_batches = 0
    
    if batch_size:
        # compute output
        with torch.no_grad():
            # using module instead of model because DistributedDataParallel forward function has a sync point.
            # with distributed validation sampler, we don't always have data for each gpu
            assert(is_distributed_model(model))
            output = model.module(input)
            loss = criterion(output, target).data
        # measure accuracy and record loss
        valid_batches = 1
        corr1 = correct(output.data, target)

    metrics = torch.tensor([batch_size, valid_batches, loss, corr1]).float().cuda()
    batch_total, valid_batches, reduced_loss, corr1 = sum_tensor(metrics)
    reduced_loss = reduced_loss/valid_batches

    prec1 = corr1*(100.0/batch_total)
    return prec1, reduced_loss, batch_total

def is_distributed_model(model):
    return isinstance(model, nn.parallel.DistributedDataParallel)# or (args.c10d and isinstance(model, distributed_c10d._DistributedDataParallelC10d))

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = self.avg = self.sum = self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    corrrect_ks = correct(output, target, topk)
    batch_size = target.size(0)
    return [correct_k.float().mul_(100.0 / batch_size) for correct_k in corrrect_ks]

def correct(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).sum(0, keepdim=True)
        res.append(correct_k)
    return res


def sum_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    return rt

def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt


In [41]:
if args.distributed:
    print('Distributed: initializing process group')
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size)
    assert(args.world_size == dist.get_world_size())
    print("Distributed: success (%d/%d)"%(args.local_rank, args.world_size))

model = ResNet18()
model = model.cuda()

# AS: todo: don't copy over weights as it seems to help performance

if args.fp16: model = network_to_half(model)
elif args.distributed: model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)


global model_params, master_params
if args.fp16: model_params, master_params = prep_param_lists(model)
else: master_params = list(model.parameters())

# define loss function (criterion) and optimizer
# criterion = nn.CrossEntropyLoss().cuda()
criterion = F.cross_entropy
optimizer = torch.optim.SGD(master_params, 0, nesterov=True, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = Scheduler(optimizer, phases=eval(args.phases))


print(args)
print("~~epoch\thours\ttop1Accuracy\n")

start_time = datetime.now() # Loading start to after everything is loaded
print("Begin training")
for epoch in range(scheduler.tot_epochs):
    train(trn_loader, model, criterion, optimizer, scheduler, epoch)
    validate(val_loader, model, criterion, epoch, start_time)


Namespace(batch_size=256, data='/home/paperspace/data/cifar10', dist_backend='nccl', dist_url='env://', distributed=False, fp16=True, local_rank=0, loss_scale=512.0, momentum=0.9, phases='[(0,2e-1,15),(2e-1,1e-2,15),(1e-2,0,5)]', print_freq=200, save_dir='/home/paperspace/data/cifar_training/preact_test', weight_decay=0.0005, workers=8, world_size=-1)
~~epoch	hours	top1Accuracy

Begin training
Changing LR from None to 0.0
Epoch: [0][196/196]	Time 0.022 (0.036)	Data 0.001 (0.003)	Loss 2.3730 (2.3325)	Prec@1 7.500 (10.116)
Test: [20/20]	Time 0.011 (0.042)	Loss 2.3555 (2.3343)	Prec@1 11.397 (10.320)
Changing LR from 0.0 to 0.013401360544217686
Epoch: [1][196/196]	Time 0.037 (0.037)	Data 0.001 (0.004)	Loss 0.9814 (1.4291)	Prec@1 66.250 (47.498)
Test: [20/20]	Time 0.045 (0.041)	Loss 1.4111 (1.3104)	Prec@1 46.691 (53.550)
Changing LR from 0.013401360544217686 to 0.02680272108843537
Epoch: [2][196/196]	Time 0.022 (0.037)	Data 0.000 (0.004)	Loss 0.7490 (0.9464)	Prec@1 68.750 (66.256)
Test: [20

Changing LR from 0.005979591836734694 to 0.003969387755102041
Epoch: [33][196/196]	Time 0.058 (0.037)	Data 0.003 (0.004)	Loss 0.0168 (0.0292)	Prec@1 100.000 (99.136)
Test: [20/20]	Time 0.013 (0.038)	Loss 0.1936 (0.2049)	Prec@1 94.853 (93.820)
Changing LR from 0.003969387755102041 to 0.0019591836734693877
Epoch: [34][196/196]	Time 0.032 (0.036)	Data 0.002 (0.003)	Loss 0.0118 (0.0267)	Prec@1 100.000 (99.236)
Test: [20/20]	Time 0.010 (0.041)	Loss 0.1993 (0.2041)	Prec@1 94.118 (93.940)


In [14]:
from fastai.conv_learner import Learner, TrainingPhase, ModelData, accuracy, DecayType
from functools import partial
from PIL import Image

In [15]:

def old_torch_loader(data_path, size, bs, val_bs=None, prefetcher=False):

    val_bs = val_bs or bs
    # Data loading code
    tfms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.24703,0.24349,0.26159))]

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(),
    ] + tfms)
    val_tfms = transforms.Compose(tfms)

    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_tfms)
    val_dataset  = datasets.CIFAR10(root=data_path, train=False, download=True, transform=val_tfms)
    aug_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=train_tfms)

    train_loader = DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    aug_loader = DataLoader(
        aug_dataset,
        batch_size=bs, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if prefetcher:
        train_loader = OldDataPrefetcher(train_loader)
        val_loader = OldDataPrefetcher(val_loader)
        aug_loader = OldDataPrefetcher(aug_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    data.sz = size
    data.aug_dl = aug_loader
    return data

# Seems to speed up training by ~2%
class OldDataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [16]:
# data.trn_ds[0]

In [17]:
# next(iter(data.trn_dl))

In [18]:
model = ResNet18()
model = model.cuda()
if args.fp16: model = network_to_half(model)

# AS: todo: don't copy over weights as it seems to help performance

wd=5e-4
lr=1e-1
momentum = 0.9
# learn.clip = 1e-1
bs = 256
lrs = (0, 2e-1, 1e-2, 0)
sz=32


data = old_torch_loader(PATH, sz, bs, bs*2)
    
learn = Learner.from_model_data(model, data)
# learn.half()
learn.crit = F.cross_entropy
learn.metrics = [accuracy]
learn.opt_fn = partial(torch.optim.SGD, nesterov=True, momentum=0.9)
def_phase = {'opt_fn':learn.opt_fn, 'wds':wd, 'momentum':0.9}

phases = [
    TrainingPhase(**def_phase, epochs=15, lr=lrs[:2], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=15, lr=lrs[1:3], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=5, lr=lrs[-2:], lr_decay=DecayType.LINEAR),
]

learn.fit_opt_sched(phases)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, description='Epoch', max=35), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.412447   2.159091   0.3951    
    1      0.970583   1.630464   0.5208                      
    2      0.749944   0.852949   0.7058                      
    3      0.652545   1.217766   0.6298                      
    4      0.556323   0.867725   0.7102                      
    5      0.495791   0.74362    0.7525                      
    6      0.46937    0.846975   0.7384                      
    7      0.434401   0.677632   0.7743                      
    8      0.417817   0.830285   0.7379                      
    9      0.40224    0.617939   0.8005                      
    10     0.392207   0.863409   0.732                       
    11     0.39357    0.874344   0.7123                      
    12     0.394713   0.571084   0.8028                      
    13     0.384292   0.734388   0.7597                      
    14     0.396824   0.559785   0.8139                      
    15     0.372283   0.756

[0.200433984375, 0.9422000000953674]

In [42]:
model = ResNet18()
model = model.cuda()

# AS: todo: don't copy over weights as it seems to help performance

wd=5e-4
lr=1e-1
momentum = 0.9
# learn.clip = 1e-1
bs = 256
lrs = (0, 2e-1, 1e-2, 0)
sz=32

if args.fp16: model = network_to_half(model)

trn_loader, val_loader = torch_loader(PATH, sz, bs, bs)
data = ModelData(PATH, trn_loader, val_loader)

learn = Learner.from_model_data(model, data)
# learn.half()
learn.crit = F.cross_entropy
learn.metrics = [accuracy]
learn.opt_fn = partial(torch.optim.SGD, nesterov=True, momentum=0.9)
def_phase = {'opt_fn':learn.opt_fn, 'wds':wd, 'momentum':0.9}

phases = [
    TrainingPhase(**def_phase, epochs=15, lr=lrs[:2], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=15, lr=lrs[1:3], lr_decay=DecayType.LINEAR),
    TrainingPhase(**def_phase, epochs=5, lr=lrs[-2:], lr_decay=DecayType.LINEAR),
]

learn.fit_opt_sched(phases)

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(IntProgress(value=0, description='Epoch', max=35), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.455398   1.86548    36.12     
    1      0.972145   2.053288   45.67                       
    2      0.769381   0.941118   66.78                       
    3      0.643727   1.122708   63.73                       
    4      0.580281   0.987419   67.95                       
    5      0.524518   0.812247   73.51                       
    6      0.472833   1.092777   69.43                       
    7      0.447318   0.613691   79.17                       
    8      0.435333   1.137866   66.48                       
    9      0.417069   1.191509   61.16                       
    10     0.402149   1.107984   66.65                       
    11     0.40548    0.664285   78.02                       
    12     0.396633   1.074928   71.39                       
    13     0.386369   0.764332   75.42                       
    14     0.394707   0.966597   71.23                       
    15     0.38009    1.273

[0.20989140625, array([93.78])]