# Optimizers - When, Where and How to Tweak Them

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exp.nb_08 import *

## Getting Imagenette Data From the DataBlock NB

In [3]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

In [4]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, 
        to_float_tensor]
bs = 128

img_list = ImageList.from_files(path, tfms=tfms)
split_data = SplitData.split_by_func(img_list, 
                                     partial(grandparent_splitter,
                                            valid_name='val'))

labels = label_by_func(split_data, parent_labeler, 
                       proc_y=CategoryProcessor())
data = labels.to_databunch(bs, c_in=3, c_out=10, 
                           num_workers=4)

In [5]:
# Create a model
nfs = [32, 64, 128, 256]

In [6]:
callbacks = [partial(AvgStatsCallback, accuracy),
             CudaCallback,
             partial(BatchTransformXCallback, norm_imagenette)]

In [7]:
# Baseline training with vanilla SGD
learn, run = get_learn_run(nfs, data, lr=0.4,
                          layer=conv_layer, cbs=callbacks)

In [8]:
run.fit(1, learn)

train: [1.7948843347370367, tensor(0.3808, device='cuda:0')]
valid: [1.6955538415605096, tensor(0.4425, device='cuda:0')]


## Refining the Optimizer

**`NOTES`**

- The base PyTorch optimizer in `torch.optim` is a dictionary which stores the hyper-parameters and references to the parameters of the model we want to train in parameter groups.

- It contains `step` which updates our parameters with gradients and a method `zero_grad` to detach and zero the gradients of our parameters.

- We will build a more flexible equivalent from scratch. Here, the step function loops over all the parameters to execute the step using stepper functions, which we will provide when initializing the optimizer.

In [None]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):
        # Generators can also be added
        self.param_groups = list(params)
        # Ensuring params is a list of lists
        if not isinstance(self.param_groups[0], list):
            self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for p in self.param_groups]
        self.steppers = listify(steppers)
        
    def grad_params(self):
        return [(p, hyper) for pg, hyper in zip(self.param_groups, self.hypers)
                for p in pg if p.grad is not None]
    
    def zero_grad(self):
        for p, hyper in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()
            
    def step(self):
        for p, hyper in self.grad_params():
            compose()