## CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
from fastai.models.cifar10.wideresnet import wrn_22
torch.backends.cudnn.benchmark = True
PATH = Path.home()/'data/cifar10/'
os.makedirs(PATH,exist_ok=True)

In [3]:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

bs=128
sz=32

# Mixup

In [5]:
class MixUpDataLoader(object):
    """
    Creates a new data loader with mixup from a given dataloader.
    
    Mixup is applied between a batch and a shuffled version of itself. 
    If we use a regular beta distribution, this can create near duplicates as some lines might be 
    1 * original + 0 * shuffled while others could be 0 * original + 1 * shuffled, this is why
    there is a trick where we take the maximum of lambda and 1-lambda.
    
    Arguments:
    dl (DataLoader): the data loader to mix up
    alpha (float): value of the parameter to use in the beta distribution.
    """
    def __init__(self, dl, alpha):
        self.dl, self.alpha = dl, alpha
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for (x, y) in iter(self.dl):
            #Taking one different lambda per image speeds up training 
            lambd = np.random.beta(self.alpha, self.alpha, y.size(0))
            #Trick to avoid near duplicates
            lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
            lambd = to_gpu(VV(lambd))
            shuffle = torch.randperm(y.size(0))
            x1, y1 = x[shuffle], y[shuffle]
            yield (x * lambd.view(lambd.size(0),1,1,1) + x1 * (1-lambd).view(lambd.size(0),1,1,1), [y, y1, lambd])

In [6]:
class MixUpLoss(nn.Module):
    """
    Adapts the loss function to go with mixup.
    
    Since the targets aren't one-hot encoded, we use the linearity of the loss function with
    regards to the target to mix up the loss instead of one-hot encoded targets.
    
    Argument:
    crit: a loss function. It must have the parameter reduced=False to have the loss per element.
    """
    def __init__(self, crit):
        super().__init__()
        self.crit = crit()
        
    def forward(self, output, target):
        if not isinstance(target, list): return self.crit(output, target).mean()
        loss1, loss2 = self.crit(output,target[0]), self.crit(output,target[1])
        return (loss1 * target[2] + loss2 * (1-target[2])).mean()

In [7]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, val_name='test')

In [8]:
mixup_dl = MixUpDataLoader(data.trn_dl, 0.6)

In [9]:
m = wrn_22()
opt_fn = partial(optim.Adam, betas=(0.95,0.99))
learn = ConvLearner.from_model_data(m, data)
learn.metrics = [accuracy]
wd=1e-4
learn.opt_fn = opt_fn
learn.data.trn_dl = mixup_dl
learn.crit = MixUpLoss(partial(nn.CrossEntropyLoss, reduce=False))

In [10]:
%time learn.fit(3e-3, 1, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85), wds=0.1, use_wd_sched=True)

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.567097   1.294256   0.553     
 48%|████▊     | 186/391 [00:22<00:24,  8.36it/s, loss=1.46]

KeyboardInterrupt: 

Training is as fast as without mixup (23 minutes).

### Mixup + Cutout

In [7]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip(), Cutout(1,16)], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, val_name='test')

In [8]:
mixup_dl = MixUpDataLoader(data.trn_dl, 0.6)

In [14]:
m = wrn_22()
opt_fn = partial(optim.Adam, betas=(0.95,0.99))
learn = ConvLearner.from_model_data(m, data)
learn.metrics = [accuracy]
wd=0.03
learn.opt_fn = opt_fn
learn.data.trn_dl = mixup_dl
learn.crit = MixUpLoss(partial(nn.CrossEntropyLoss, reduce=False))

In [15]:
%time learn.fit(3e-3, 1, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85), wds=0.1, use_wd_sched=True)

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.629132   1.317984   0.5234    
    1      1.465925   0.968519   0.6787                     
    2      1.367382   1.053036   0.6409                     
    3      1.301879   0.83975    0.7229                     
    4      1.258994   0.779681   0.745                      
    5      1.232516   0.832038   0.7198                     
    6      1.202696   0.673424   0.7904                     
    7      1.192234   0.544152   0.8312                     
    8      1.156781   0.649854   0.7988                     
    9      1.149639   0.599015   0.8301                     
    10     1.129813   0.502171   0.8594                     
    11     1.127827   0.539411   0.8464                     
    12     1.131102   0.535542   0.8322                     
    13     1.109017   0.474734   0.8585                     
    14     1.10992    0.471096   0.864                      
    15     1.068623   0.46983    0.861   

[0.21446209721565246, 0.9473]