## CIFAR 10

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from fastai.conv_learner import *
from fastai.models.cifar10.wideresnet import wrn_22
torch.backends.cudnn.benchmark = True
PATH = Path.home()/'data/cifar10/'
os.makedirs(PATH,exist_ok=True)

In [4]:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

bs=128
sz=32

# Cutout

Cutout is already implemented in the fastai library. Args are n_holes (1 in the paper), size of the hole, probability of applying (default: 0.5). Applying Cutout after normalization or before (which means blanking with 0s or the means) doesn't change the results.

In [4]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip(), Cutout(1,16)], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, val_name='test')

In [9]:
m = wrn_22()
opt_fn = partial(optim.Adam, betas=(0.95,0.99))
learn = ConvLearner.from_model_data(m, data, opt_fn=opt_fn)
learn.crit = nn.CrossEntropyLoss()
learn.metrics = [accuracy]
wd=0.03

In [10]:
%time learn.fit(3e-3, 1, cycle_len=34, use_clr_beta=(10,7.5,0.95,0.85), wds=wd, use_wd_sched=True)

HBox(children=(IntProgress(value=0, description='Epoch', max=34), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.265698   1.131748   0.5859    
    1      1.011879   0.925887   0.6743                     
    2      0.874358   0.83785    0.7146                      
    3      0.756501   0.661968   0.7674                      
    4      0.702329   0.576785   0.8056                      
    5      0.679879   0.568449   0.8077                      
    6      0.624821   0.590153   0.7986                      
    7      0.584256   0.636968   0.7912                      
    8      0.555692   0.508859   0.8294                      
    9      0.521765   0.60031    0.8008                      
    10     0.516555   0.498995   0.835                       
    11     0.494465   0.593156   0.8126                      
    12     0.478316   0.486834   0.8367                      
    13     0.453475   0.383599   0.8663                      
    14     0.429461   0.397569   0.8682                      
    15     0.413949   0.4585

[0.18503666452169418, 0.9481]

I did obtain 95% in 34 epochs, but not too sure with which hyper-parameters. Maybe a bit more wd. Anyway, you get the idea.

# Mixup

In [5]:
class MixUpDataLoader(object):
    """
    Creates a new data loader with mixup from a given dataloader.
    
    Mixup is applied between a batch and a shuffled version of itself. 
    If we use a regular beta distribution, this can create near duplicates as some lines might be 
    1 * original + 0 * shuffled while others could be 0 * original + 1 * shuffled, this is why
    there is a trick where we take the maximum of lambda and 1-lambda.
    
    Arguments:
    dl (DataLoader): the data loader to mix up
    alpha (float): value of the parameter to use in the beta distribution.
    """
    def __init__(self, dl, alpha):
        self.dl, self.alpha = dl, alpha
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for (x, y) in iter(self.dl):
            #Taking one different lambda per image speeds up training 
            lambd = np.random.beta(self.alpha, self.alpha, y.size(0))
            #Trick to avoid near duplicates
            lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
            lambd = to_gpu(VV(lambd))
            shuffle = torch.randperm(y.size(0))
            x1, y1 = x[shuffle], y[shuffle]
            yield (x * lambd.view(lambd.size(0),1,1,1) + x1 * (1-lambd).view(lambd.size(0),1,1,1), [y, y1, lambd])

In [6]:
class MixUpLoss(nn.Module):
    """
    Adapts the loss function to go with mixup.
    
    Since the targets aren't one-hot encoded, we use the linearity of the loss function with
    regards to the target to mix up the loss instead of one-hot encoded targets.
    
    Argument:
    crit: a loss function. It must have the parameter reduced=False to have the loss per element.
    """
    def __init__(self, crit):
        super().__init__()
        self.crit = crit()
        
    def forward(self, output, target):
        if not isinstance(target, list): return self.crit(output, target).mean()
        loss1, loss2 = self.crit(output,target[0]), self.crit(output,target[1])
        return (loss1 * target[2] + loss2 * (1-target[2])).mean()

In [7]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, val_name='test')

In [8]:
mixup_dl = MixUpDataLoader(data.trn_dl, 0.6)

In [9]:
m = wrn_22()
opt_fn = partial(optim.Adam, betas=(0.95,0.99))
learn = ConvLearner.from_model_data(m, data)
learn.metrics = [accuracy]
wd=1e-4
learn.opt_fn = opt_fn
learn.data.trn_dl = mixup_dl
learn.crit = MixUpLoss(partial(nn.CrossEntropyLoss, reduce=False))

In [10]:
%time learn.fit(3e-3, 1, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85), wds=0.1, use_wd_sched=True)

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.567097   1.294256   0.553     
 48%|████▊     | 186/391 [00:22<00:24,  8.36it/s, loss=1.46]

KeyboardInterrupt: 

Training is as fast as without mixup (23 minutes).

This second one is with a different alpha: here mixup_dl = MixUpDataLoader(data.trn_dl, 0.4)

In [22]:
%time learn.fit(3e-3, 1, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85), wds=0.1, use_wd_sched=True)

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

 20%|██        | 79/391 [00:08<00:33,  9.32it/s, loss=1.88]
epoch      trn_loss   val_loss   accuracy                   
    0      1.472878   1.345358   0.5291    
    1      1.269906   0.84233    0.7231                     
    2      1.179544   0.839368   0.7125                     
    3      1.112833   0.738385   0.7596                     
    4      1.055841   0.98404    0.6708                     
    5      1.06348    0.654288   0.7998                     
    6      1.015946   0.515297   0.8349                     
    7      0.996191   0.493962   0.8435                      
    8      0.979429   0.525252   0.8424                      
    9      0.966817   0.509232   0.8455                      
    10     0.957474   0.571616   0.8228                      
    11     0.926887   0.495118   0.8441                      
    12     0.938132   0.487569   0.8488                      
    13     0.930551   0.472688   0.8567                      
    14     0.912599   0.442748   0.

[0.1936487452030182, 0.9493]

# Logs

The regular traning with AdamW and using of LogResults callback.

In [27]:
class LogResults(Callback):
    """
    Callback to log all the results of the training:
    - at the end of each epoch: training loss, validation loss and metrics
    - at the end of the first batches then every epoch: deciles of the params and their gradients
    """
    
    def __init__(self, learn, fname, init_text=''):
        super().__init__()
        self.learn, self.fname, self.init_text = learn, fname, init_text
        self.pcts = [0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999]
        self.pnames = {p:n for n,p in learn.model.named_parameters()}
        self.module_names = get_module_names(learn.model)
        
    def on_train_begin(self):
        self.logs, self.epoch, self.n = self.init_text + "\n", 0, 0
        self.deciles = {}
        for name in self.pnames.values(): 
            self.deciles[name] = collections.defaultdict(list)
            self.deciles[name + '.grad'] = collections.defaultdict(list)
        for name in self.module_names.values(): self.deciles[name] = collections.defaultdict(list)
        names = ["epoch", "trn_loss", "val_loss", "metric"]
        layout = "{!s:10} " * len(names)
        self.logs += layout.format(*names) + "\n"
    
    def on_batch_begin(self):
        if self.n == 0 or (self.epoch == 0 and is_power_of_two(self.n+1)):
            self.hooks = []
            self.learn.model.apply(self.register_hook)
    
    def on_batch_end(self, metrics):
        self.loss = metrics
        if self.n == 0 or (self.epoch == 0 and is_power_of_two(self.n+1)):
            self.save_deciles()
        if len(self.hooks) != 0:
            for h in self.hooks: h.remove()
            self.hooks=[]
        self.n += 1
    
    def on_epoch_end(self, metrics):
        self.save_stats(self.epoch, [self.loss] + metrics)
        self.epoch += 1
        self.n=0
        
    def save_stats(self, epoch, values, decimals=6):
        layout = "{!s:^10}" + " {!s:10}" * len(values)
        values = [epoch] + list(np.round(values, decimals))
        self.logs += layout.format(*values) + "\n"
    
    def save_deciles(self):
        for group_param in self.learn.sched.layer_opt.opt_params():
            for param in group_param['params']:
                self.add_deciles(self.pnames[param], to_np(param))
                self.add_deciles(self.pnames[param] + '.grad', to_np(param.grad))
    
    def separate_pcts(self,arr):
        n = len(arr.reshape(-1))
        pos, neg = arr[arr > 0], arr[arr < 0]
        pos_pcts = np.percentile(pos, self.pcts) if len(pos) > 0 else np.array([])
        neg_pcts = np.percentile(neg, self.pcts) if len(neg) > 0 else np.array([])
        return len(pos)/n, len(neg)/n, pos_pcts, neg_pcts
    
    def add_deciles(self, name, arr):
        pos, neg, pct_pos, pct_neg = self.separate_pcts(arr)
        self.deciles[name]['sgn'].append([pos, neg])
        self.deciles[name]['pos'].append(pct_pos)
        self.deciles[name]['neg'].append(pct_neg)
                                                        
    def on_train_end(self):
        with open(self.fname + '.txt', 'a') as f: f.write(self.logs)
        pickle.dump(self.deciles, open(self.fname + '.pkl', 'wb'))
        
    def register_hook(self, module):
        def hook_save_act(module, input, output):
            pos, neg, pct_pos, pct_neg = self.separate_pcts(to_np(output))
            m_name = self.module_names[module]
            self.deciles[m_name]['sgn'].append([pos, neg])
            self.deciles[m_name]['pos'].append(pct_pos)
            self.deciles[m_name]['neg'].append(pct_neg)
        if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == self.learn.model)):
            self.hooks.append(module.register_forward_hook(hook_save_act))

def get_module_names(model):
    def register_names(module):
        if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)):
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            m_name = f'{class_name}-{len(names)+1}'
            names[module] = m_name
    names = {}
    model.apply(register_names)
    return names

def is_power_of_two(n):
    while n>1:
        if n%2 != 0: return False
        n = n//2
    return True

In [31]:
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, val_name='test')

In [32]:
m = wrn_22()
opt_fn = partial(optim.Adam, betas=(0.95,0.99))
learn = ConvLearner.from_model_data(m, data, opt_fn=opt_fn)
learn.crit = nn.CrossEntropyLoss()
learn.metrics = [accuracy]
wd=0.1

In [33]:
log_cb = LogResults(learn, str(PATH/'cifar10'))

In [34]:
%time learn.fit(3e-3, 1, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85), wds=0.1, use_wd_sched=True, callbacks=[log_cb])

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.099058   1.092296   0.6064    
    1      0.840501   0.879823   0.6933                      
    2      0.691204   0.750623   0.7482                      
    3      0.617986   0.590799   0.8001                      
    4      0.553649   0.562221   0.808                       
    5      0.503341   0.857572   0.7261                      
    6      0.460281   0.526698   0.8185                      
    7      0.440226   0.519977   0.8228                      
    8      0.413289   0.56235    0.8171                      
    9      0.393606   0.447091   0.8463                      
    10     0.370249   0.484564   0.8476                      
    11     0.353899   0.442357   0.8554                      
    12     0.34476    0.385429   0.8705                      
    13     0.343156   0.456771   0.8485                      
    14     0.304945   0.429652   0.8579                      
    15     0.274984   0.331

[0.22669122726917268, 0.9409]

Loggging the results affects performance a bit (especially in the first epoch where we log results every 2^n batch), here we go from 23 minutes to 26min 36s.