# Understanding Adam

The purpose of this notebook will be to understand the Adam optimizer as it is implemented in PyTorch.
The original paper describing this algorithm can be found at: https://arxiv.org/pdf/1412.6980.pdf

First of all, here is the pseudocode for the algorithm, taken from the original paper. It is essentially just your vanilla backpropagation with some added weights and correction terms. We'll get into the intuition for this extra steps momentarily.
![Adam Pseudocode](pics/Adam.png)

So what is the idea here? Well, $g_t$ is just your normal computation of gradients using chain rule. The term $m_t$ is the mean error over $t$ steps, biased towards the initial vector. Since we chose the initial vector to be 0, this means that the algorithm is biased towards a zero gradient. This will get corrected later.

The term $v_t$ is similarly a zero-biased estimate of the uncentered variance in the network's error. This estimate uses the assumption that the expectation of the squared error is the same as the square of the expected error.

The bias-corrected estimates for $m_t$ and $v_t$ are $\hat{m}_t$ and $\hat{v}_t$. The paper claims that the ratio $\hat{m}_t/\sqrt{\hat{v}_t}$ can be thought of like a signal-to-noise ratio (SNR), so that the smaller the SNR the smaller the step size in parameter space will be. This acts somewhat like automatic annealing because the SNR decreases close to an optimal point in parameter space. 

## Implementation in PyTorch

The implementation of Adam in PyTorch makes use of two improvements proposed in https://arxiv.org/pdf/1711.05101.pdf

The first improvement is to use $$g_t 🠤 \nabla_\theta f_t(\theta_{t-1})+\lambda\theta_{t-1}$$ in the first update rule. This function of this term is to essentially apply $L_2$ regularization in concert with the Adam update rule.

The second improvement is to use $$\theta_t 🠤 \theta_{t-1}-\eta_t(\alpha \hat{m}_t/(\sqrt{\hat{v}_t}+\epsilon)+\lambda\theta_{t-1})$$ with scheduled annealing parameter $\eta_t$. This addition of $\lambda\theta_{t-1}$ provides weight decay separately from the update rule.

Below is the source code from PyTorch implementing this improved Adam algorithm called AdamW, the default implementation of Adam used when `Adam` is called from `Optimizer`.

I have removed all the safety checks and extra options from the code to make it as easy to read as possible. Please don't use this code outside this notebook!

In [2]:
import math
import torch
from torch.optim import Optimizer
from torch import Tensor
from typing import List

def adam(params: List[Tensor],           # list of parameters that contribute gradients
         grads: List[Tensor],            # g_t
         exp_avgs: List[Tensor],         # m_t
         exp_avg_sqs: List[Tensor],      # v_t
         max_exp_avg_sqs: List[Tensor],  # max(v_t)
         state_steps: List[int],
         amsgrad: bool,                  # I have removed this option, please ignore this parameter
         beta1: float,                   # from pseudocode
         beta2: float,                   # also from pseudocode
         lr: float,                      # alpha from pseudocode
         weight_decay: float,            # lambda from pseudocode
         eps: float):                    # from pseudocode
    r"""Functional API that performs Adam algorithm computation.
    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            # This is the first of the improvements in AdamW algorithm, the L2 regularization
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)                     # update rule for m_t
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)        # update rule for v_t
        denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) # all the weights and biases go here

        step_size = lr / bias_correction1                                   # add the learning rate alpha

        param.addcdiv_(exp_avg, denom, value=-step_size)  # update params according to p -= step_size*exp_avg/denom

With the algorithm implemented, we just need to use the class `Adam` which has all the methods PyTorch expects from an optimizer. 

In [3]:
class Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
        super(Adam, self).__init__(params, defaults)  # initialize using the inherited Optimizer class

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()  # Makes every computation have requires_grad=False, since we already have all the grads we need
    def step(self):
        """Performs a single optimization step."""
        loss = None

        for group in self.param_groups:
            params_with_grad = []  # list of parameters that contribute gradients
            grads = []             # g_t
            exp_avgs = []          # m_t
            exp_avg_sqs = []       # v_t
            state_sums = []        #
            max_exp_avg_sqs = []   #
            state_steps = []       #

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)  # initialize to zero

                    state = self.state[p] # parameters in PyTorch have an attribute `state` which is a dict
                    
                    # Lazy state initialization, used in first round of updates. Can ignore this if reading code to learn.
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])
                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            beta1, beta2 = group['betas'] # get the beta values for the algorithm from the class attribute 'betas'
            
            # call the adam update algorithm
            adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   group['amsgrad'],
                   beta1,
                   beta2,
                   group['lr'],
                   group['weight_decay'],
                   group['eps']
                   )
        return loss

## Quick example

Let's quickly check that this implementation is working. The following code should run without errors if our Adam code is working.

In [4]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use our implementation of Adam
learning_rate = 1e-4
optimizer = Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

99 63.04720687866211
199 1.6846511363983154
299 0.009795689955353737
399 0.00013946012768428773
499 1.3764803952653892e-05
