# Optimizers

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_08 import *

## Imagenette data

In [3]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

In [5]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

In [6]:
bs = 128

In [7]:
il = ImageList.from_files(path, tfms=tfms)

In [8]:
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))

In [9]:
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())

In [10]:
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=4)

In [11]:
nfs = [32, 64, 128, 256]

In [12]:
cbfs = [partial(AvgStatsCallback, accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, norm_imagenette)]

In [17]:
learn, run = get_learn_run(nfs, data, 0.4, conv_layer, cbs=cbfs)

In [18]:
run.fit(1, learn)

train: [1.7471634699278735, tensor(0.3892, device='cuda:0')]
valid: [1.483586669921875, tensor(0.4980, device='cuda:0')]


## Refining the optimizer

In PyTorch `torch.optim` is simply is a dictionary that stores the hyper-parameters and references to the parameters of the model we want to train in parameter groups.

Our optimizer needs a `step` method and a `zero_grad` method. Ours will be more generic as the actual work is done by stepper functions. Want a different optimizer? Just write a new stepper function.

In [19]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):

        self.param_groups = list(params)
        # ensure that this is a list of lists of tensors
        if not isinstance(self.param_groups[0], list):
            self.param_groups = [self.param_groups]
        
        # a dict containing e.g. lr or mom for every param group
        self.hypers = [{**defaults} for p in self.param_groups]
        # unpack so that every group has their own copy of the dict
        # otherwise would have ref to the same
        
        self.steppers = listify(steppers)
        
    def grad_params(self):
        return [(p, hyper) for pg, hyper in zip(self.param_groups, self.hypers) for p in pg if p.grad is not None]
        
    def zero_grad(self):
        for p, hyper in self.grad_params():
            p.grad.detach()
            p.grad.zero_()
            
    def step(self):
        for p, hyper in self.grad_params():
            compose(p, self.steppers, **hyper)