# Optimizers - When, Where and How to Tweak Them

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exp.nb_08 import *

## Getting Imagenette Data From the DataBlock NB

In [3]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

In [4]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, 
        to_float_tensor]
bs = 128

img_list = ImageList.from_files(path, tfms=tfms)
split_data = SplitData.split_by_func(img_list, 
                                     partial(grandparent_splitter,
                                            valid_name='val'))

labels = label_by_func(split_data, parent_labeler, 
                       proc_y=CategoryProcessor())
data = labels.to_databunch(bs, c_in=3, c_out=10, 
                           num_workers=4)

In [5]:
# Create a model
nfs = [32, 64, 128, 256]

In [6]:
callbacks = [partial(AvgStatsCallback, accuracy),
             CudaCallback,
             partial(BatchTransformXCallback, norm_imagenette)]

In [7]:
# Baseline training with vanilla SGD
learn, run = get_learn_run(nfs, data, lr=0.4,
                          layer=conv_layer, cbs=callbacks)

In [8]:
run.fit(1, learn)

train: [1.7948843347370367, tensor(0.3808, device='cuda:0')]
valid: [1.6955538415605096, tensor(0.4425, device='cuda:0')]


## Refining the Optimizer

**`NOTES`**

- The base PyTorch optimizer in `torch.optim` is a dictionary which stores the hyper-parameters and references to the parameters of the model we want to train in parameter groups.

- It contains `step` which updates our parameters with gradients and a method `zero_grad` to detach and zero the gradients of our parameters.

- We will build a more flexible equivalent from scratch. Here, the step function loops over all the parameters to execute the step using stepper functions, which we will provide when initializing the optimizer.

- This will end up giving us parameter groups / layer groups.

In [9]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):
        # Could be a generator
        self.param_groups = list(params)
        # Ensuring params is a list of lists of parameter tensors
        if not isinstance(self.param_groups[0], list):
            self.param_groups = [self.param_groups]
        # Creating a dictionary for individual param groups with
        # their own references
        self.hypers = [{**defaults} for p in self.param_groups]
        self.steppers = listify(steppers)
        
    def grad_params(self):
        return [(p, hyper) for pg, hyper in zip(self.param_groups, self.hypers)
                for p in pg if p.grad is not None]
    
    def zero_grad(self):
        for p, hyper in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()
            
    def step(self): 
        # This step function doesn't do anything except carry out a
        # composition on items we pass on, which in turn carry out
        # their own operations. For e.g. one cycle annealing, 
        # discriminative LRs etc.
        for p, hyper in self.grad_params():
            compose(p, self.steppers, **hyper)

In [10]:
# To carry out SGD, via the stepper
def sgd_step(param, lr, **kwargs):
    param.data.add_(-lr, param.grad.data)
    return param

In [12]:
opt_func = partial(Optimizer, steppers=[sgd_step])

**`NOTES`**

- After changing the optimizer, we will need to adjust the callbacks which used the properties from the PyTorch optimizer.

- Hyper-parameters are in the list of dictionaries `opt.hypers`.

In [13]:
# Updating the Recorder, ParamScheduler and LR_Find classes
class Recorder(Callback):
    def begin_fit(self): self.lrs, self.losses = [], []
        
    def after_batch(self):
        if not self.in_train: 
            return
        self.lrs.append(self.opt.hypers[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())
        
    def plot_lr(self): 
        plt.plot(self.lrs)
        
    def plot_loss(self):
        plt.plot(self.losses)
        
    def plot(self, skip_last=0):
        losses = [o.item() for o in self.losses]
        n = len(losses) - skip_last
        plt.xscale('log')
        plt.plot(self.lrs[:n], losses[:n])
        

class ParamScheduler(Callback):
    _order = 1
    
    def __init__(self, pname, sched_funcs):
        self.pname, self.sched_funcs = pname, listify(sched_funcs)
    
    def begin_batch(self):
        if not self.in_train:
            return
        fs = self.sched_funcs
        if len(fs)==1: 
            fs = fs*len(self.opt.param_groups)
        pos = self.n_epochs / self.epochs
        for f, h in zip(fs, self.opt.hypers):
            h[self.pname] = f(pos)
                

class LR_Find(Callback):
    _order = 1
    def __init(self, max_iter=100, min_lr=1e-6, max_lr=10):
        self.max_iter, self.min_lr, self.max_lr = max_iter, min_lr, max_lr
        self.best_loss = 1e9
        
    def begin_batch(self):
        if not self.in_train: 
            return
        pos = self.n_iter / self.max_iter
        lr = self.min_lr * (self.max_lr / self.min_lr) ** pos
        for pg in self.opt.hypers:
            pg['lr'] = lr
            
    def after_step(self):
        if self.n_iter >= self.max_iter or self.loss > self.best_loss*10:
            raise CancelTrainException()
        if self.loss < self.best_loss:
            self.best_loss = self.loss
        