## CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
PATH = "../fp16/data/cifar10/"
os.makedirs(PATH,exist_ok=True)

In [3]:
from fp16util import *

### Load Data

In [4]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

In [5]:
def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)

In [6]:
bs=128

## Initial model

In [16]:
from fastai.models.cifar10.resnext import resnext29_8_64

m = resnext29_8_64()
if True:
    m = FP16(m)
bm = BasicModel(m.cuda(), name='cifar10_rn29_8_64')

In [28]:
m2 = resnext29_8_64()

In [30]:
m2.type(torch.cuda.HalfTensor)

CifarResNeXt(
  (conv_1_3x3): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (stage_1): Sequential(
    (0): ResNeXtBottleneck(
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      )
      (conv_reduce): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    )
    (1): ResNeXtBottleneck(
      (conv_reduce): Conv2d(256, 512, kernel_size=(1, 1), stride=

In [27]:
m.type

<bound method Module.type of FP16(
  (module): CifarResNeXt(
    (conv_1_3x3): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (stage_1): Sequential(
      (0): ResNeXtBottleneck(
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
        )
        (conv_reduce): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      )
      (1): ResNe

In [26]:
m.type(torch.cuda.HalfTensor)

FP16(
  (module): CifarResNeXt(
    (conv_1_3x3): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (stage_1): Sequential(
      (0): ResNeXtBottleneck(
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
        )
        (conv_reduce): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      )
      (1): ResNeXtBottleneck(
        (conv_r

In [17]:
data = get_data(8,bs*4)

In [18]:
learn = ConvLearner(data, bm)
learn.unfreeze()

In [25]:
hasattr(m, 'reset')

False

In [19]:
class StepperFP16():
    def __init__(self, m, opt, crit, clip=0, reg_fn=None, loss_scale=1):
#         super().__init__(m,opt,crit,clip,reg_fn)

        self.m,self.opt,self.crit,self.clip,self.reg_fn = m,opt,crit,clip,reg_fn
        self.reset(True)
        
        self.loss_scale=1
        self.fp16 = True
        
        if self.fp16:
            self.param_copy = copy_params(m, opt)
        
        self.reset(True)

    def reset(self, train=True):
        if train: apply_leaf(self.m, set_train_mode)
        else: self.m.eval()
        if hasattr(self.m, 'reset'): 
            self.m.reset()
            self.param_copy = copy_params(self.m, self.opt)

#     def step(self, xs, y, epoch):
#         if self.fp16:
#             return self.step16(xs, y, epoch)
#         xtra = []
#         output = self.m(*xs)
#         if isinstance(output,tuple): output,*xtra = output
#         self.opt.zero_grad()
#         loss = raw_loss = self.crit(output, y)
#         if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
#         loss.backward()
#         if self.clip:   # Gradient clipping
#             nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip)
#         self.opt.step()
#         return raw_loss.data[0]
    
    
    def step(self, xs, y, epoch):
        xtra = []
        output = self.m(*xs)
        if isinstance(output,tuple): output,*xtra = output
        self.m.zero_grad()
        loss = raw_loss = self.crit(output, y)
        loss = loss*self.loss_scale
        if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
        loss.backward()
        
        set_grad(self.param_copy, list(self.m.parameters()))
        
        if self.loss_scale != 1:
            for param in self.param_copy:
                param.grad.data = param.grad.data/self.loss_scale
        
        if self.clip:   # Gradient clipping
            nn.utils.clip_grad_norm(trainable_params_(self.param_copy), self.clip)
        self.opt.step()
        copy_in_params(self.m, self.param_copy)
        return raw_loss.data[0]

    def evaluate(self, xs, y):
        preds = self.m(*xs)
        if isinstance(preds,tuple): preds=preds[0]
        return preds, self.crit(preds, y)

### Let's try to copy params into fp16

In [20]:
# %pdb on

In [21]:
def copy_params(model, optim):
    param_copy = [param.clone().type(torch.cuda.FloatTensor).detach() for param in m.parameters()]
    model_group = list(filter(lambda x: len(x['params']) == len(param_copy), optim.param_groups))[0]
    for i,param in enumerate(param_copy):
        param.requires_grad = param_copy()
        # make sure this works with frozen layers
        assert(model_group['params'][i].shape == param.shape)
        model_group['params'][i] = param
        
#         print('Iteration:', i)
#         print('Optim:', model_group['params'][i].shape)
#         print('Copy param:', param.shape)
       
    
    ## Sanity Check
    m_0 = next(model.parameters())
    print('Model first layer:', type(m_0.data), m_0.shape)
    opt_0 = list(filter(lambda x: len(x['params']) == len(param_copy), optim.param_groups))[0]['params'][0]
    print('Optim first layer:', type(opt_0.data), opt_0.shape)
    
    
    return param_copy

## END TEST

In [22]:
lr=1e-2; wd=5e-4

In [24]:
%time learn.fit(lr, 1, stepper=StepperFP16, cycle_len=0.2, loss_scale=128)

Model first layer: <class 'torch.cuda.HalfTensor'> torch.Size([64, 3, 3, 3])
Optim first layer: <class 'torch.cuda.FloatTensor'> torch.Size([64, 3, 3, 3])


HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

20it [00:14,  1.41it/s, loss=2.04]                        epoch      trn_loss   val_loss   accuracy   
    0      2.038359   1.962012   0.283973  

CPU times: user 22 s, sys: 10.2 s, total: 32.2 s
Wall time: 18.1 s


[1.96201171875, 0.28397288620471955]

In [27]:
learn.fit(lr, 2, cycle_len=1)

A Jupyter Widget

[ 0.       1.43256  1.39725  0.49746]                       
[ 1.       1.35969  1.3411   0.51602]                       



In [28]:
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

A Jupyter Widget

[ 0.       1.31041  1.29344  0.53174]                       
[ 1.       1.30313  1.31292  0.53418]                       
[ 2.       1.15682  1.22019  0.5668 ]                       
[ 3.       1.26632  1.34606  0.54121]                       
[ 4.       1.14698  1.18958  0.57598]                       
[ 5.       1.02205  1.13905  0.60254]                       
[ 6.       0.93291  1.13761  0.60596]                        



In [30]:
learn.save('8x8_8')