# optim

> Fill in a module description here

In [None]:
#| default_exp optim

In [None]:
"""Optimization module"""
import minima as mi
from minima.nn import Parameter
from minima.autograd import Tensor
from minima import init
import numpy as np

In [None]:
class Optimizer:
    def __init__(self, params):
        self.params = params

    def step(self):
        raise NotImplementedError()

    def zero_grad(self):
        for p in self.params:
            p.grad = None

In [None]:
class SGD(Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.0, wd=0.0):
        super().__init__(params)

        self.lr = lr
        self.momentum = momentum
        self.u = {}
        self.wd = wd

    def step(self):
        for self.idx, p in enumerate(self.params):
            self._reg_step(p)
            self._opt_step(p)
            
                
    def _opt_step(self, p):
        if self.idx not in self.u:
            self.u[self.idx] = init.zeros(*p.shape)
        self.u[self.idx] = self.momentum * self.u[self.idx] + (1 - self.momentum) * p.grad.data
        p.data = p.data - self.lr * self.u[self.idx]

    def _reg_step(self, p):
        if self.wd != 0:
            p.data *= (1 - self.lr * self.wd)
        # all same :3
        # p.data *= (1 - self.lr * self.weight_decay)
        # p.data = p.data - self.lr * self.weight_decay * p.data
        # p.data -= self.lr * self.weight_decay * p.data
    
    def zero_grad(self):
        for p in self.params:
            p.grad = None

## Adam Optimizer

This is a [PyTorch](https://pytorch.org) implementation of popular optimizer *Adam* from paper
 [Adam: A Method for Stochastic Optimization](https://papers.labml.ai/paper/1412.6980).

*Adam* update is,
$$
\begin{align}
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\
\hat{v}_t &\leftarrow \frac{v_t}{1-\beta_2^t} \\
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
\end{align}
$$
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters.
$m_t$ and $v_t$ are first and second order moments.
$\hat{m}_t$  and $\hat{v}_t$ are biased corrected moments.
$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter
that acts against variance in gradients.

Effective step taken assuming $\epsilon = 0$ is,
$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$
This is bounded by,
$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$
when $1-\beta_1 \gt \sqrt{1-\beta_2}$
and
$$\vert \Delta t\vert  \le \alpha$$
otherwise.
And in most common scenarios,
$$\vert \Delta t \vert \approx \alpha$$

In [None]:
class Adam(Optimizer):
    def __init__(
        self,
        params, # `params` is the list of parameters
        lr=0.01, # `lr` is the learning rate $\alpha$
        beta1=0.9, #
        beta2=0.999, #
        eps=1e-8, # `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
        weight_decay=0.0, # is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
    ):
        super().__init__(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.wd = weight_decay
        self.t = 0

        self.exp_avg = {}
        self.exp_avg_sq = {}

    def step(self):
        for self.idx, p in enumerate(self.params):
            self._reg_step(p)
            self._opt_step(p)
            
                
    def _opt_step(self, p):
        if self.idx not in self.exp_avg:
            self.exp_avg[self.idx] = init.zeros(*p.shape)
            self.exp_avg_sq[self.idx] = init.zeros(*p.shape)
        
        # Update biased first and second moment estimates
        self.exp_avg[self.idx] = self.beta1 * self.exp_avg[self.idx] + (1 - self.beta1) * p.grad.data
        self.exp_avg_sq[self.idx] = self.beta2 * self.exp_avg_sq[self.idx] + (1 - self.beta2) * p.grad.data**2
        
        # Compute bias-corrected first and second moment estimates
        exp_avg_hat = self.exp_avg[self.idx] / (1 - self.beta1 ** (self.idx + 1))
        exp_avg_sq_hat = self.exp_avg_sq[self.idx] / (1 - self.beta2 ** (self.idx + 1))
        p.data = p.data - self.lr * exp_avg_hat / (exp_avg_sq_hat ** 0.5 + self.eps)

    def _reg_step(self, p):
        if self.wd != 0:
            p.data *= (1 - self.lr * self.wd)
        # all same :3
        # p.data *= (1 - self.lr * self.weight_decay)
        # p.data = p.data - self.lr * self.weight_decay * p.data
        # p.data -= self.lr * self.weight_decay * p.data

## Export

## Export

In [None]:
import nbdev; nbdev.nbdev_export()