In [None]:
# Run this cell to hide it.
# It defines plot_progress() for plotting an optimization trace.

import matplotlib
from matplotlib import pyplot as plt

def plot_progress(bowl, track, losses):
    # Draw the contours of the objective function, and x, and y
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12, 5))
    for size in torch.linspace(0.1, 1.0, 10):
        angle = torch.linspace(0, 6.3, 100)
        circle = torch.stack([angle.sin(), angle.cos()])
        ellipse = torch.mm(torch.inverse(skew), circle) * size
        ax1.plot(ellipse[0,:], ellipse[1,:], color='skyblue')
    track = torch.stack(track).t()
    ax1.set_title('progress of x')
    ax1.plot(track[0,:], track[1,:], marker='o')
    ax1.set_ylim(-1, 1)
    ax1.set_xlim(-1.6, 1.6)
    ax1.set_ylabel('x[1]')
    ax1.set_xlabel('x[0]')
    ax2.set_title('progress of y')
    ax2.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
    ax2.plot(range(len(losses)), losses, marker='o')
    ax2.set_ylabel('objective')
    ax2.set_xlabel('iteration')
    fig.show()

from IPython.display import HTML
HTML('''<script>function toggle_code(){$('.rendered.selected div.input').toggle().find('textarea').focus();} $(toggle_code)</script>
<a href="javascript:toggle_code()">Toggle</a> the code for plot_progress.''')

Pytorch Optimizers
==================

Pytorch includes several optimization algorithms.

Optimizers have a simple job: given gradients of an objective with respect to a set of input parameters, adjust the parameters reduce the objective.  They do this by modifying each parameter by a small amount in the direction given by the gradient.

The actual optmization algorithms employ a number of techniques to make the process faster and more robust as repeated steps are taken, by trying to adapt to the shape of the objective surface as it is explored.  The simplest method is SGD-with-momentum, which is implemented in pytorch as `pytorch.optim.SGD`.

Using SGD
---------

To use SGD, you need to calculate your objective and fill in gradients on all the parameters before it can take a step.

  1. Set your parameters (x in this case) to `x.requires_grad = True` so autograd tracks them (line 5).
  2. Create the optimizer and tell it about the parameters to adjust (`[x]` here) (line 6).
  3. In a loop, compute your objective, then call `objective.backward()` to fill in `x.grad` and then `optimizer.step()` to adjust `x` accordingly (lines 12-15).
  
**Remember to zero gradient.** Notice that we use `optimizer.zero_grad()` each time to set x.grad to zero before recomputing gradients; if we do not do this, then the new gradient will be added to the old one.


In [None]:
import torch

x_init = torch.randn(2)
x = x_init.clone()
x.requires_grad = True
optimizer = torch.optim.SGD([x], lr=0.1, momentum=0.5)

bowl = torch.tensor([[ 0.4410, -1.0317], [-0.2844, -0.1035]])
track, losses = [], []

for iter in range(21):
    objective = torch.mm(bowl, x[:,None]).norm()
    optimizer.zero_grad()
    objective.backward()
    optimizer.step()
    track.append(x.detach().clone())
    losses.append(objective.detach())
    
plot_progress(bowl, track, losses)

Using other optimizers
----------------------

Other optimizers are similar.  Adam is a popular adaptive method that does well without much tuning, and can be dropped in to replace plain SGD.

Some other fancy optimizers, such as LBFGS, need to be given an objective function that they can call repeatedly to probe gradients themselves.  Examples can be found elsewhere.

In [None]:
# The code below uses Adam
x = x_init.clone()
x.requires_grad = True
optimizer = torch.optim.Adam([x], lr=0.1)

track, losses = [], []

for iter in range(21):
    objective = torch.mm(bowl, x[:,None]).norm()
    optimizer.zero_grad()
    objective.backward()
    optimizer.step()
    track.append(x.detach().clone())
    losses.append(objective.detach())
    
plot_progress(bowl, track, losses)

Adjusting learning rate schedules
---------------------------------

One of the easiest ways to improve convergence is to reduce the learning rate as more iterations are done.
