In [None]:
# default_exp optimizer

In [None]:
#export
from local.torch_basics import *
from local.test import *

In [None]:
from local.notebook.showdoc import *

# Optimizer

> Define the general fastai optimizer and the variants

## Optimizer -

In [None]:
#export
class _BaseOptimizer():
    "Common functionality between `Optimizer` and `OptimWrapper`"
    def all_params(self, n=slice(None), with_grad=False):
        res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_groups[n],self.hypers[n]) for p in pg)
        return L(o for o in res if o[0].grad is not None) if with_grad else res

    def _set_require_grad(self, rg, p,pg,state,h): p.requires_grad_(rg or state.get('force_train', False))
    def freeze_to(self, n):
        self.frozen_idx = n if n >= 0 else len(self.param_groups) + n
        if self.frozen_idx >= len(self.param_groups):
            warn(f"Freezing {self.frozen_idx} groups; model has {len(self.param_groups)}; whole model is frozen.")
        for o in self.all_params(slice(n, None)): self._set_require_grad(True,  *o)
        for o in self.all_params(slice(None, n)): self._set_require_grad(False, *o)

    def freeze(self):
        assert(len(self.param_groups)>1)
        self.freeze_to(-1)

    def unfreeze(self): self.freeze_to(0)
    def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)
    def _set_hyper(self, k, v):
        for v_,h in zip(v, self.hypers): h[k] = v_

    def set_hyper(self, k, v):
        if isinstance(v, slice):
            if v.start: v = even_mults(v.start, v.stop, len(self.param_groups))
            else: v = [v.stop/10]*(len(self.param_groups)-1) + [v.stop]
        v = L(v, use_list=None)
        if len(v)==1: v = v*len(self.param_groups)
        assert len(v) == len(self.hypers), f"Trying to set {len(v)} values for {k} but there are {len(self.param_groups)} parameter groups."
        self._set_hyper(k, v)

In [None]:
add_docs(_BaseOptimizer, 
         all_params="List of param_groups, parameters, and hypers",
         freeze_to="Freeze parameter groups up to `n`",
         freeze="Freeze up to last parameter group",
         unfreeze="Unfreeze the entire model",
         set_hypers="`set_hyper` for all `kwargs`",
         set_hyper="Set the value(s) in `v` for hyper-parameter `k`")

In [None]:
# export
class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `steppers`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self, params, steppers, stats=None, train_bn=True, **defaults):
        params = L(params)
        self.steppers,self.stats,self.state,self.train_bn = L(steppers),L(stats),defaultdict(dict),train_bn
        defaults = merge(*self.stats.attrgot('defaults'), *self.steppers.attrgot('defaults'), defaults)
        self.param_groups = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        #self.step_func = compose(*steppers)
        self.hypers = L({} for _ in range_of(self.param_groups))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

    def zero_grad(self):
        for p,*_ in self.all_params(with_grad=True):
            p.grad.detach_()
            p.grad.zero_()

    def step(self):
        for p,pg,state,hyper in self.all_params(with_grad=True):
            for stat in self.stats:    state = stat(state, p, **hyper)
            for step in self.steppers: step(p, **{**state, **hyper})
            self.state[p] = state

    def clear_state(self):
        for p,pg,state,hyper in self.all_params():
            self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}

    def state_dict(self):
        state = [self.state[p] for p,*_ in self.all_params()]
        return {'state': state, 'hypers': self.hypers}

    def load_state_dict(self, sd):
        assert len(sd["hypers"]) == len(self.param_groups)
        assert len(sd["state"])  == sum([len(pg) for pg in self.param_groups])
        self.hypers = sd['hypers']
        self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}

In [None]:
add_docs(Optimizer, 
         zero_grad="Standard PyTorch API: Zero all the grad attributes of the parameters",
         step="Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad",
         state_dict="Return the state of the optimizer in a dictionary",
         load_state_dict="Load the content of `sd`",
         clear_state="Reset the state of the optimizer")

### Initializing an Optimizer

`params` will be used to create the `param_groups` of the optimizer. If it's a collection (or a generator) of parameters, it will be a `L` containing one `L` with all the parameters. To define multiple parameter groups `params` should be passed as a collection (or a generator) of `L`s.

> Note: In PyTorch, `model.parameters()` returns a generator with all the parameters, that you can directly pass to `Optimizer`.

In [None]:
opt = Optimizer([1,2,3], noop)
test_eq(opt.param_groups, [[1,2,3]])
opt = Optimizer(range(3), noop)
test_eq(opt.param_groups, [[0,1,2]])
opt = Optimizer([[1,2],[3]], noop)
test_eq(opt.param_groups, [[1,2],[3]])
opt = Optimizer(([o,o+1] for o in range(0,4,2)), noop)
test_eq(opt.param_groups, [[0,1],[2,3]])

`steppers` is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay. Additionally, each `stepper` can have a `defaults` attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the `defaults` kwargs. The steppers will be called by `Optimizer.step` (which is the standard PyTorch name), and gradients can be cleared with `Optimizer.zero_grad` (also a standard PyTorch name).

Once the defaults have all been pulled off, they are copied as many times as there are `param_groups` and stored in `hypers`. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjsut those values after the init. 

In [None]:
def tst_arg(p, lr=0, **kwargs): return p
tst_arg.defaults = dict(lr=1e-2)

def tst_arg2(p, lr2=0, **kwargs): return p
tst_arg2.defaults = dict(lr2=1e-3)

def tst_arg3(p, mom=0, **kwargs): return p
tst_arg3.defaults = dict(mom=0.9)

def tst_arg4(p, **kwargs): return p

opt = Optimizer([1,2,3], [tst_arg,tst_arg2], tst_arg3)
test_eq(opt.hypers, [{'lr2': 1e-3, 'mom': 0.9, 'lr': 1e-2}])
opt = Optimizer([1,2,3], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}])
opt = Optimizer([[1,2],[3]], tst_arg)
test_eq(opt.hypers, [{'lr': 1e-2}, {'lr': 1e-2}])
opt = Optimizer([[1,2],[3]], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.1}])

For each hyper-parameter, you can pass a slice or a collection to set them, if there are multiple parameter groups. A slice will be converted to a log-uniform collection from its beginning to its end, or if it only has an end `e`, to a collection of as many values as there are parameter groups that are `...,e/10,e/10,e`.

Setting an yper-paramter with a collection that has a different number of elements than the optimizer has paramter groups will raise an error.

In [None]:
opt = Optimizer([[1,2],[3]], tst_arg, lr=[0.1,0.2])
test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.2}])
opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-2))
test_eq(opt.hypers, [{'lr': 1e-3}, {'lr': 1e-3}, {'lr': 1e-2}])
opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-4,1e-2))
test_eq(opt.hypers, [{'lr': 1e-4}, {'lr': 1e-3}, {'lr': 1e-2}])
test_fail(lambda: Optimizer([[1,2],[3],[4]], tst_arg, lr=np.array([0.1,0.2])))

### Basic steppers

To be able to give examples of optimizer steps, we will need some steppers, like the following:

In [None]:
#export
def sgd_step(p, lr, **kwargs):
    p.data.add_(-lr, p.grad.data)
    return p

In [None]:
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

In [None]:
p = tst_param(1., 0.1)
p = sgd_step(p, 1.)
test_eq(p, tensor([0.9]))
test_eq(p.grad, tensor([0.1]))

In [None]:
#export
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd and wd!=0: p.data.mul_(1 - lr*wd)
    return p
weight_decay.defaults = dict(wd=0.)

In [None]:
p = tst_param(1., 0.1)
p = weight_decay(p, 1., 0.1)
test_eq(p, tensor([0.9]))
test_eq(p.grad, tensor([0.1]))

In [None]:
#export
def l2_reg(p, lr, wd, do_wd=True, **kwargs):
    "L2 regularization as adding `wd*p` to `p.grad`"
    if do_wd and wd!=0: p.grad.data.add_(wd, p.data)
    return p
l2_reg.defaults = dict(wd=0.)

In [None]:
p = tst_param(1., 0.1)
p = l2_reg(p, 1., 0.1)
test_eq(p, tensor([1.]))
test_eq(p.grad, tensor([0.2]))

> Warning: Weight decay and L2 regularization is the same thing for basic SGD, but for more complex optimizers, they are very different. See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) for more information.

### Making the step

In [None]:
show_doc(Optimizer.step)

<h4 id="Optimizer.step" class="doc_header"><code>Optimizer.step</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L20" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.step</code>()

Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad

This method will loop over all param groups, then all parameters for which `grad` is not None and call each function in `stepper`, passing it the parameter `p` with the hyper-parameters in the corresponding dict in `hypers`.

In [None]:
#test basic step
r = L.range(4)
def tst_params(): return r.map(tst_param)

params = tst_params()
opt = Optimizer(params, sgd_step, lr=0.1)
opt.step()
test_close([p.item() for p in params], r.map(mul(0.99)))

In [None]:
#test two steps
params = tst_params()
opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt.step()
test_close([p.item() for p in params], r.map(mul(0.98)))

In [None]:
#test None gradients are ignored
params = tst_params()
opt = Optimizer(params, sgd_step, lr=0.1)
params[-1].grad = None
opt.step()
test_close([p.item() for p in params], [0., 0.99, 1.98, 3.])

In [None]:
#test discriminative lrs
params = tst_params()
opt = Optimizer([params[:2], params[2:]], sgd_step, lr=0.1)
opt.hypers[0]['lr'] = 0.01
opt.step()
test_close([p.item() for p in params], [0., 0.999, 1.98, 2.97])

In [None]:
show_doc(Optimizer.zero_grad)

<h4 id="Optimizer.zero_grad" class="doc_header"><code>Optimizer.zero_grad</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L15" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.zero_grad</code>()

Standard PyTorch API: Zero all the grad attributes of the parameters

In [None]:
params = tst_params()
opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt.zero_grad()
[test_eq(p.grad, tensor([0.])) for p in params];

`Optimizer` has `stats` which are functions taking the state associated with a parameter.  `stats` use that parameter, plus the optimizer hyper-parameters, to update the state. 
That state can then be used by any stepper.  The best example is a momentum calculation. 
`stats` are initialized to an empty dictionary the first time we try to access it, and after that the `stat` function will have to be properly initialized.

In [None]:
def tst_stat(state, p, **kwargs): 
    state['sum'] = state.get('sum', torch.zeros_like(p)) + p.data
    return state
tst_stat.defaults = {'mom': 0.9}

#Test Optimizer init
opt = Optimizer([1,2,3], noop, stats=tst_stat)
test_eq(opt.hypers, [{'mom': 0.9}])
opt = Optimizer([1,2,3], noop, stats=tst_stat, mom=0.99)
test_eq(opt.hypers, [{'mom': 0.99}])

#Test stat
x = torch.randn(4,5)
state = tst_stat({}, x)
assert 'sum' in state
test_eq(state['sum'], x)
state = tst_stat(state, x)
test_eq(state['sum'], 2*x)

## Statistics

In [None]:
# export
def average_grad(state, p, mom, dampening=False, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if 'grad_avg' not in state: state['grad_avg'] = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    state['grad_avg'].mul_(mom).add_(damp, p.grad.data)
    return state

average_grad.defaults = dict(mom=0.9)

`dampening=False` gives the classical formula for momentum in SGD: 
```
new_val = old_val * mom + grad
```
whereas `dampening=True` makes it an exponential moving average:
```
new_val = old_val * mom + grad * (1-mom)
```

In [None]:
p = tst_param([1,2,3], [4,5,6])
state = {}
state = average_grad(state, p, mom=0.9)
test_eq(state['grad_avg'], p.grad)
state = average_grad(state, p, mom=0.9)
test_eq(state['grad_avg'], p.grad * 1.9)
#Test dampening
state = {}
state = average_grad(state, p,  mom=0.9, dampening=True)
test_eq(state['grad_avg'], 0.1*p.grad)
state = average_grad(state, p, mom=0.9, dampening=True)
test_eq(state['grad_avg'], (0.1*0.9+0.1)*p.grad)

In [None]:
# export
def average_sqr_grad(state, p, sqr_mom, dampening=True, **kwargs):
    if 'sqr_avg' not in state: state['sqr_avg'] = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    state['sqr_avg'].mul_(sqr_mom).addcmul_(damp, p.grad.data, p.grad.data)
    return state

average_sqr_grad.defaults = dict(sqr_mom=0.99)

`dampening=False` gives the classical formula for momentum in SGD: 
```
new_val = old_val * mom + grad**2
```
whereas `dampening=True` makes it an exponential moving average:
```
new_val = old_val * mom + (grad**2) * (1-mom)
```

In [None]:
p = tst_param([1,2,3], [4,5,6])
state = {}
state = average_sqr_grad(state, p, sqr_mom=0.99, dampening=False)
test_eq(state['sqr_avg'], p.grad.pow(2))
state = average_sqr_grad(state, p, sqr_mom=0.99, dampening=False)
test_eq(state['sqr_avg'], p.grad.pow(2) * 1.99)
#Test dampening
state = {}
state = average_sqr_grad(state, p,  sqr_mom=0.99)
test_close(state['sqr_avg'], 0.01*p.grad.pow(2))
state = average_sqr_grad(state, p, sqr_mom=0.99)
test_close(state['sqr_avg'], (0.01*0.99+0.01)*p.grad.pow(2))

### Freezing part of the model

In [None]:
show_doc(Optimizer.freeze)

<h4 id="_BaseOptimizer.freeze" class="doc_header"><code>_BaseOptimizer.freeze</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L16" class="source_link" style="float:right">[source]</a></h4>

> <code>_BaseOptimizer.freeze</code>()

Freeze up to last parameter group

In [None]:
show_doc(Optimizer.freeze_to)

<h4 id="_BaseOptimizer.freeze_to" class="doc_header"><code>_BaseOptimizer.freeze_to</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L9" class="source_link" style="float:right">[source]</a></h4>

> <code>_BaseOptimizer.freeze_to</code>(**`n`**)

Freeze parameter groups up to `n`

In [None]:
show_doc(Optimizer.unfreeze)

<h4 id="_BaseOptimizer.unfreeze" class="doc_header"><code>_BaseOptimizer.unfreeze</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L20" class="source_link" style="float:right">[source]</a></h4>

> <code>_BaseOptimizer.unfreeze</code>()

Unfreeze the entire model

In [None]:
#Freezing the first layer
params = [tst_params(), tst_params(), tst_params()]
opt = Optimizer(params, sgd_step, lr=0.1)
opt.freeze_to(1)
req_grad = Self.requires_grad()
test_eq(L(params[0]).map(req_grad), [False]*4)
for i in {1,2}: test_eq(L(params[i]).map(req_grad), [True]*4)
    
#Unfreezing
opt.unfreeze()
for i in range(2): test_eq(L(params[i]).map(req_grad), [True]*4)

#TODO: test warning
# opt.freeze_to(3)

Parameters such as batchnorm weights/bias can be marked to always be in training mode, just put `force_train=true` in their state.

In [None]:
params = [tst_params(), tst_params(), tst_params()]
opt = Optimizer(params, sgd_step, lr=0.1)
for p in L(params[1])[[1,3]]: opt.state[p] = {'force_train': True}
opt.freeze()
test_eq(L(params[0]).map(req_grad), [False]*4)
test_eq(L(params[1]).map(req_grad), [False, True, False, True])
test_eq(L(params[2]).map(req_grad), [True]*4)

### Serializing

In [None]:
show_doc(Optimizer.state_dict)

<h4 id="Optimizer.state_dict" class="doc_header"><code>Optimizer.state_dict</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L30" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.state_dict</code>()

Return the state of the optimizer in a dictionary

In [None]:
show_doc(Optimizer.load_state_dict)

<h4 id="Optimizer.load_state_dict" class="doc_header"><code>Optimizer.load_state_dict</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.load_state_dict</code>(**`sd`**)

Load the content of `sd`

In [None]:
p = tst_param([1,2,3], [4,5,6])
opt = Optimizer(p, noop, stats=average_grad)
opt.step()
test_eq(opt.state[p]['grad_avg'], tensor([[4., 5., 6.]]))

sd = opt.state_dict()
p1 = tst_param([10,20,30], [40,50,60])
opt = Optimizer(p1, noop, stats=average_grad, mom=0.99)
test_eq(opt.hypers[0]['mom'], 0.99)
test_eq(opt.state, {})

opt.load_state_dict(sd)
test_eq(opt.hypers[0]['mom'], 0.9)
test_eq(opt.state[p1]['grad_avg'], tensor([[4., 5., 6.]]))

In [None]:
show_doc(Optimizer.clear_state)

<h4 id="Optimizer.clear_state" class="doc_header"><code>Optimizer.clear_state</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.clear_state</code>()

Reset the state of the optimizer

In [None]:
p = tst_param([1,2,3], [4,5,6])
opt = Optimizer(p, noop, stats=average_grad)
opt.state[p] = {'force_train': True}
opt.step()
test_eq(opt.state[p]['grad_avg'], tensor([[4., 5., 6.]]))

opt.clear_state()
test_eq(opt.state[p], {'force_train': True})

## Optimizers

### SGD with momentum

In [None]:
#export
def momentum_step(p, lr, grad_avg, **kwargs):
    "Step for SGD with momentum with `lr`"
    p.data.add_(-lr, grad_avg)
    return p

In [None]:
#export
def SGD(params, lr, mom=0., wd=0., decouple_wd=True):
    "A `Optimizer` for SGD with `lr` and `mom` and `params`"
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(sgd_step if mom==0 else momentum_step)
    if mom == 0.: return Optimizer(params, steppers, lr=lr, wd=wd)
    else: return Optimizer(params, steppers, stats=average_grad, lr=lr, mom=mom, wd=wd)

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#Vanilla SGD
params = tst_params()
opt = SGD(params, lr=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
[p.item() for p in params]
test_close([p.item() for p in params], [i*0.98 for i in range(4)])

In [None]:
#SGD with momentum
params = tst_params()
opt = SGD(params, lr=0.1, mom=0.9)
assert isinstance(opt, Optimizer)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
[p.item() for p in params]
test_close([p.item() for p in params], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])
for i,p in enumerate(params): test_close(opt.state[p]['grad_avg'].item(), i*0.19)

Test weight decay, notice how we can see that L2 regularization is different from weight decay even for simple SGD with momentum.

In [None]:
params = tst_params()
#Weight decay
opt = SGD(params, lr=0.1, mom=0.9, wd=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.98 for i in range(4)])
#L2 reg
opt = SGD(params, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False)
opt.step()
test_close([p.item() for p in params], [i*0.97 for i in range(4)])

### RMSProp

In [None]:
#export
def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):
    "Step for SGD with momentum with `lr`"
    denom = sqr_avg.sqrt().add_(eps)
    p.data.addcdiv_(-lr, (grad_avg if grad_avg is not None else p.grad), denom)
    return p

rms_prop_step.defaults = dict(eps=1e-8)

In [None]:
#export
def RMSProp(params, lr, sqr_mom=0.99, mom=0., wd=0., decouple_wd=True):
    "A `Optimizer` for RMSProp with `lr`, `sqr_mom`, `mom` and `params`"
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(rms_prop_step)
    stats = [average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd)

RMSProp was introduced by Geoffrey Hinton in his [course](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). What is named `sqr_mom` here is the `alpha` in the course. Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#Without momentum
import math
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = RMSProp(params, lr=0.1)
opt.step()
test_close(params[0], tensor([0.,1.,2.]))
opt.step()
step = - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params[0], tensor([step, 1+step, 2+step]))

In [None]:
#With momentum
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = RMSProp(params, lr=0.1, mom=0.9)
opt.step()
test_close(params[0], tensor([0.,1.,2.]))
opt.step()
step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params[0], tensor([step, 1+step, 2+step]))

### Adam

In [None]:
#export
def step_stat(state, p, **kwargs):
    "Register the number of steps done in `state` for `p`"
    if 'step' not in state: state['step'] = 0
    state['step'] += 1
    return state

In [None]:
p = tst_param(1,0.1)
state = {}
state = step_stat(state, p)
test_eq(state['step'], 1)
for _ in range(5): state = step_stat(state, p)
test_eq(state['step'], 6)

In [None]:
#export
def debias(mom, damp, step): return damp * (1 - mom**step) / (1-mom)

In [None]:
#export
def adam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
    "Step for Adam with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2).sqrt() + eps)
    return p

adam_step._defaults = dict(eps=1e-5)

In [None]:
#export
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(adam_step)
    stats = [partial(average_grad, dampening=True), average_sqr_grad, step_stat]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

Adam was introduced by Diederik P. Kingma and Jimmy Ba in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). For consistency accross optimizers, we renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

> Note: Don't forget that `eps` is an hyper-parameter you can change. Some models won't train without a very high `eps` like 0.1 (intuitively, the higher `eps` is, the closer we are to normal SGD). The usual default of 1e-8 is often too extreme in the sense we don't manage to get as good results as with SGD. 

In [None]:
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = Adam(params, lr=0.1)
opt.step()
step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params[0], tensor([1+step, 2+step, 3+step]))
opt.step()
test_close(params[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)

### LARS/LARC

In [None]:
#export
def larc_layer_lr(state, p, lr, trust_coeff, wd, eps, clip=True, **kwargs):
    "Computes the local lr before weight decay is applied"
    p_norm,g_norm = torch.norm(p.data),torch.norm(p.grad.data)
    local_lr = lr*trust_coeff * (p_norm) / (g_norm + p_norm * wd + eps)
    state['local_lr'] = min(lr, local_lr) if clip else local_lr
    return state
larc_layer_lr.defaults = dict(trust_coeff=0.02, wd=0., eps=1e-8)

In [None]:
#export
def larc_step(p, local_lr, grad_avg=None, **kwargs):
    "Step for LARC `local_lr` on `p`"
    p.data.add_(-local_lr, p.grad.data if grad_avg is None else grad_avg)
    return p

In [None]:
#export
def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., decouple_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(larc_step)
    stats = [] if mom==0. else [average_grad]
    stats.append(partial(larc_layer_lr, clip=clip))
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

The LARS optimizer was first introduced in [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888) then refined in its LARC variant (original LARS is with `clip=False`). A learning rate is computed for each individual layer with a certain `trust_coefficient`, then clipped to be always less than `lr`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt = Larc(params, lr=0.1)
opt.step()
#First param local lr is 0.02 < lr so it's not clipped
test_close(opt.state[params[0]]['local_lr'], 0.02)
#Second param local lr is 0.2 > lr so it's clipped
test_eq(opt.state[params[1]]['local_lr'], 0.1)
test_close(params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.999,1.998,2.997]))

In [None]:
params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt = Larc(params, lr=0.1, clip=False)
opt.step()
#No clipping
test_close(opt.state[params[0]]['local_lr'], 0.02)
test_close(opt.state[params[1]]['local_lr'], 0.2)
test_close(params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.998,1.996,2.994]))

### LAMB

In [None]:
#export
def lamb_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
    "Step for LAMB with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    r1 = p.data.pow(2).mean().sqrt()
    step = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)
    r2 = step.pow(2).mean().sqrt()
    q = 1 if r1 == 0 or r2 == 0 else min(r1/r2,10)
    p.data.add_(-lr * q, step)
    return p
lamb_step._defaults = dict(eps=1e-6, wd=0.)

In [None]:
#export
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [weight_decay] if decouple_wd else [l2_reg]
    steppers.append(lamb_step)
    stats = [partial(average_grad, dampening=True), average_sqr_grad, step_stat]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

LAMB was introduced in [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962). Intuitively, it's LARC applied to Adam. As in `Adam`, we renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = Lamb(params, lr=0.1)
opt.step()
test_close(params[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)

## OptimWrapper -

In [None]:
#export
def detuplify_pg(d):
    res = {}
    for k,v in d.items():
        if k == 'params': continue
        if is_listy(v): res.update(**{f'{k}__{i}': v_ for i,v_ in enumerate(v)})
        else: res[k] = v
    return res

In [None]:
tst = {'lr': 1e-2, 'mom': 0.9, 'params':[0,1,2]}
test_eq(detuplify_pg(tst), {'lr': 1e-2, 'mom': 0.9})
tst = {'lr': 1e-2, 'betas': (0.9,0.999), 'params':[0,1,2]}
test_eq(detuplify_pg(tst), {'lr': 1e-2, 'betas__0': 0.9, 'betas__1': 0.999})

In [None]:
#export
def set_item_pg(pg, k, v):
    if '__' not in k: pg[k] = v
    else:
        name,idx = k.split('__')
        pg[name] = tuple(v if i==int(idx) else pg[name][i] for i in range_of(pg[name]))
    return pg

In [None]:
tst = {'lr': 1e-2, 'mom': 0.9, 'params':[0,1,2]}
test_eq(set_item_pg(tst, 'lr', 1e-3), {'lr': 1e-3, 'mom': 0.9, 'params':[0,1,2]})
tst = {'lr': 1e-2, 'betas': (0.9,0.999), 'params':[0,1,2]}
test_eq(set_item_pg(tst, 'betas__0', 0.95), {'lr': 1e-2, 'betas': (0.95,0.999), 'params':[0,1,2]})

In [None]:
#export
pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom', 'betas__1': 'sqr_mom'}

In [None]:
#export
class OptimWrapper(_BaseOptimizer, GetAttr):
    _xtra=['zero_grad', 'step', 'state_dict', 'load_state_dict']
    _default='opt'
    def __init__(self, opt, hp_map=None):
        self.opt = opt
        if hp_map is None: hp_map = pytorch_hp_map
        self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in detuplify_pg(opt.param_groups[0]).keys()}
        self.bwd_map = {v:k for k,v in self.fwd_map.items()}
        self.state = defaultdict(dict, {})
        self.frozen_idx = 0

    @property
    def param_groups(self): return [pg['params'] for pg in self.opt.param_groups]
    @param_groups.setter
    def param_groups(self, v):
        for pg,v_ in zip(self.opt.param_groups,v): pg['params'] = v_

    @property
    def hypers(self):
        return [{self.fwd_map[k]:v for k,v in detuplify_pg(pg).items() if k != 'params'} for pg in self.opt.param_groups]

    def _set_hyper(self, k, v):
        for pg,v_ in zip(self.opt.param_groups,v): pg = set_item_pg(pg, self.bwd_map[k], v_)

    def clear_state(self): self.opt.state = defaultdict(dict, {})

In [None]:
sgd = SGD([tensor([1,2,3])], lr=1e-3, mom=0.9, wd=1e-2)
tst_sgd = OptimWrapper(torch.optim.SGD([tensor([1,2,3])], lr=1e-3, momentum=0.9, weight_decay=1e-2))
#Access to param_groups
test_eq(tst_sgd.param_groups, sgd.param_groups)
#Set param_groups
tst_sgd.param_groups = [[tensor([4,5,6])]]
test_eq(tst_sgd.opt.param_groups[0]['params'], [tensor(4,5,6)])
#Access to hypers
test_eq(tst_sgd.hypers, [{**sgd.hypers[0], 'dampening': 0., 'nesterov': False}])
#Set hypers
tst_sgd.set_hyper('mom', 0.95)
test_eq(tst_sgd.opt.param_groups[0]['momentum'], 0.95)

In [None]:
tst_sgd = OptimWrapper(torch.optim.SGD([{'params': [tensor([1,2,3])], 'lr': 1e-3}, 
                                        {'params': [tensor([4,5,6])], 'lr': 1e-2}], momentum=0.9, weight_decay=1e-2))
sgd = SGD([[tensor([1,2,3])], [tensor([4,5,6])]], lr=[1e-3, 1e-2], mom=0.9, wd=1e-2)
#Access to param_groups
test_eq(tst_sgd.param_groups, sgd.param_groups)
#Set param_groups
tst_sgd.param_groups = [[tensor([4,5,6])], [tensor([1,2,3])]]
test_eq(tst_sgd.opt.param_groups[0]['params'], [tensor(4,5,6)])
test_eq(tst_sgd.opt.param_groups[1]['params'], [tensor(1,2,3)])
#Access to hypers
test_eq(tst_sgd.hypers, [{**sgd.hypers[i], 'dampening': 0., 'nesterov': False} for i in range(2)])
#Set hypers
tst_sgd.set_hyper('mom', 0.95)
test_eq([pg['momentum'] for pg in tst_sgd.opt.param_groups], [0.95,0.95])
tst_sgd.set_hyper('lr', [1e-4,1e-3])
test_eq([pg['lr'] for pg in tst_sgd.opt.param_groups], [1e-4,1e-3])

In [None]:
#hide
#check it works with tuply hp names like in Adam
tst_adam = OptimWrapper(torch.optim.Adam([tensor([1,2,3])], lr=1e-2, betas=(0.9, 0.99)))
test_eq(tst_adam.hypers, [{'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, 'amsgrad': False}])
tst_adam.set_hyper('mom', 0.95)
test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.99))
tst_adam.set_hyper('sqr_mom', 0.9)
test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.9))

In [None]:
def _mock_train(m, x, y, opt):
    m.train()
    for i in range(0, 100, 25):
        z = m(x[i:i+25])
        loss = F.mse_loss(z, y[i:i+25])
        loss.backward()
        opt.step()
        opt.zero_grad()

In [None]:
m = nn.Linear(4,5)
x = torch.randn(100, 3, 4)
y = torch.randn(100, 3, 5)
try:
    torch.save(m.state_dict(), 'tmp.pth')
    wgt,bias = m.weight.data.clone(),m.bias.data.clone()

    m.load_state_dict(torch.load('tmp.pth'))
    opt1 = OptimWrapper(torch.optim.AdamW(m.parameters(), betas=(0.9, 0.99), eps=1e-5, weight_decay=1e-2))
    _mock_train(m, x.clone(), y.clone(), opt1)
    wgt1,bias1 = m.weight.data.clone(),m.bias.data.clone()

    m.load_state_dict(torch.load('tmp.pth'))
    opt2 = Adam(m.parameters(), 1e-3, wd=1e-2)
    _mock_train(m, x.clone(), y.clone(), opt2)
    wgt2,bias2 = m.weight.data.clone(),m.bias.data.clone()
    
    test_close(wgt1,wgt2,eps=1e-3)
    test_close(bias1,bias2,eps=1e-3)
finally: os.remove('tmp.pth')

In [None]:
m = nn.Linear(4,5)
x = torch.randn(100, 3, 4)
y = torch.randn(100, 3, 5)
try:
    torch.save(m.state_dict(), 'tmp.pth')
    wgt,bias = m.weight.data.clone(),m.bias.data.clone()

    m.load_state_dict(torch.load('tmp.pth'))
    opt1 = OptimWrapper(torch.optim.Adam(m.parameters(), betas=(0.9, 0.99), eps=1e-5, weight_decay=1e-2))
    _mock_train(m, x.clone(), y.clone(), opt1)
    wgt1,bias1 = m.weight.data.clone(),m.bias.data.clone()

    m.load_state_dict(torch.load('tmp.pth'))
    opt2 = Adam(m.parameters(), 1e-3, wd=1e-2, decouple_wd=False)
    _mock_train(m, x.clone(), y.clone(), opt2)
    wgt2,bias2 = m.weight.data.clone(),m.bias.data.clone()
    
    test_close(wgt1,wgt2,eps=1e-3)
    test_close(bias1,bias2,eps=1e-3)
finally: os.remove('tmp.pth')

## Export -

In [None]:
#hide
from local.notebook.export import *
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_utils.ipynb.
Converted 01b_dispatch.ipynb.
Converted 01c_transform.ipynb.
Converted 02_script.ipynb.
Converted 03_torch_core.ipynb.
Converted 03a_layers.ipynb.
Converted 04_dataloader.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_data_block.ipynb.
Converted 08_vision_core.ipynb.
Converted 09_vision_augment.ipynb.
Converted 10_pets_tutorial.ipynb.
Converted 11_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 13a_metrics.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 21_vision_learner.ipynb.
Converted 22_tutorial_imagenette.ipynb.
Converted 23_tutorial_