## CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
PATH = "../fp16/data/cifar10/"
os.makedirs(PATH,exist_ok=True)

In [3]:
from fp16util import *

### Load Data

In [4]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

In [5]:
def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)

In [6]:
bs=128

## Initial model

In [7]:
from fastai.models.cifar10.resnext import resnext29_8_64

m = resnext29_8_64()
if True:
    m = FP16(m)
bm = BasicModel(m.cuda(), name='cifar10_rn29_8_64')

In [8]:
data = get_data(8,bs*4)

In [9]:
learn = ConvLearner(data, bm)
learn.unfreeze()

In [10]:
class StepperFP16():
    def __init__(self, m, opt, crit, clip=0, reg_fn=None, loss_scale=1, fp16=False):
        self.m,self.opt,self.crit,self.clip,self.reg_fn = m,opt,crit,clip,reg_fn
        self.reset(True)
        
        self.fp16 = fp16
        self.loss_scale = loss_scale if fp16 else 1
        if self.fp16: self.fp32_params = copy_model_to_fp32(m, opt)
        
    def reset(self, train=True):
        if train: apply_leaf(self.m, set_train_mode)
        else: self.m.eval()
        if hasattr(self.m, 'reset'): 
            self.m.reset()
            if self.fp16: self.fp32_params = copy_model_to_fp32(self.m, self.opt)

    def step(self, xs, y, epoch):
        if self.fp16: return self.step_fp16(xs, y, epoch)
        xtra = []
        output = self.m(*xs)
        if isinstance(output,tuple): output,*xtra = output
        self.opt.zero_grad()
        loss = raw_loss = self.crit(output, y)
        if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
        loss.backward()
        if self.clip:   # Gradient clipping
            nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip)
        self.opt.step()
        return raw_loss.data[0]
    
    
    def step_fp16(self, xs, y, epoch):
        xtra = []
        output = self.m(*xs)
        if isinstance(output,tuple): output,*xtra = output
        self.m.zero_grad()
        loss = raw_loss = self.crit(output, y)
        if loss_scale != 1: loss = loss*self.loss_scale
        if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
        loss.backward()
        update_fp32_grads(self.fp32_params, m)
        if self.loss_scale != 1:
            for param in self.fp32_params: param.grad.data.div_(self.loss_scale)
        if self.clip:   # Gradient clipping
            nn.utils.clip_grad_norm(trainable_params_(self.fp32_params), self.clip)
        self.opt.step()
        copy_fp32_to_model(self.m, self.fp32_params)
        return raw_loss.data[0]

    def evaluate(self, xs, y):
        preds = self.m(*xs)
        if isinstance(preds,tuple): preds=preds[0]
        return preds, self.crit(preds, y)

### Let's try to copy params into fp16

In [None]:
%pdb on

In [11]:
def copy_model_to_fp32(m, optim):
    fp32_params = [m_param.clone().type(torch.cuda.FloatTensor).detach() for m_param in m.parameters()]
    optimizer_groups = list(filter(lambda x: len(x['params']) == len(fp32_params), optim.param_groups))
    assert optimizer_groups == 1, 'Unable to locate matching optimizer and model parameters'
    optimizer_params = optimizer_groups[0]['params']
    for fp32_param in enumerate(fp32_params):
        fp32_param.requires_grad = optim_params[i]
        assert optim_params[i].shape == fp32_param.shape, f'fp32 param copy out of sync'
        optim_params[i] = fp32_param
        
#         print('Iteration:', i)
#         print('Optim:', model_group['params'][i].shape)
#         print('Copy param:', param.shape)
       
    ## Sanity Check
    m_0 = next(m.parameters())
    print('Model first layer:', type(m_0.data), m_0.shape)
    opt_0 = list(filter(lambda x: len(x['params']) == len(fp32_params), optim.param_groups))[0]['params'][0]
    print('Optim first layer:', type(opt_0.data), opt_0.shape)
    
    
    return fp32_params

In [None]:
new_optim = torch.optim.SGD(m.half().parameters(), lr=.05)

In [None]:
new_optim.param_groups[0].keys()

In [None]:
t1 = new_optim.param_groups[0]['params'][3]
type(t1.data), t1.shape

In [None]:
type(t1)

In [None]:
new_optim.param_groups[0]['params'][0] = next(m.parameters())

In [None]:
t2 = new_optim.param_groups[0]['params'][0][0]
type(t2.data), t2.shape

In [None]:
it = m.parameters()

In [None]:
a = next(it)

In [None]:
type(a.data)

## END TEST

In [None]:
# learn.fit_gen()

In [None]:
lr=1e-2; wd=5e-4

In [None]:
%time learn.fit(lr, 1, stepper=StepperFP16)

In [None]:
learn.fit(lr, 2, cycle_len=1)

In [None]:
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

In [None]:
learn.save('8x8_8')