# Chapter 4 - Optimization
## Deep Learning Curriculum - Jacob Hilton

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

* Set up a testbed using the setup from the NQM paper, where the covariance matrix of the gradient and the Hessian are both diagonal. You can use the same defaults for these matrices as in the paper, i.e., diagonal entries of 1, 1/2, 1/3, ... for both (in the paper they go up to 10^4, you can reduce this to 10^3 if experiments are taking too long to run). Implement both SGD with momentum and Adam.

#### Notes:

Consider that the parameters $\theta_i$ are initialized such that $\theta_i \sim N(0, 1)$. Consider that $E(\theta_i^{1}) = E(\theta_i^0 - \alpha g(\theta_i)) = E(\theta_i^0) - \alpha E(h_i \theta_i^0 + \epsilon_i) = (1 - \alpha h_i)E(\theta_i^0) = 0$ thus gradient descent will not alter the mean of the parameter distribution. Thus, the only thing that will be ordered over the course of the updates will be the covariance matrix, due to the introduction of noise in the parameters stemming from the stochasticity of SGD. Thus in order to follow the evolution of the parameters we only need to look at the evolution of the variance of the parameters.

$$Var(\theta_i^t) = v_i^t \\ v_i^t = (1 - \alpha h_i)^2 v_{i}^{t-1} + \alpha^2 \sigma_i^2$$

where $\sigma_i$ is the noise variance. Through recursion:

$$0: v_i^0 \\
1: v_i^1 = (1 - \alpha h_i)^2 v_0 + \alpha^2 \sigma_i^2 \\
2: v_i^2 = (1 - \alpha h_i)^4 v_0 + (1 - \alpha h_i)^2 \alpha^2 \sigma_i^2 + \alpha^2 \sigma_i^2 \\
\vdots \\
t: v_i^t = (1 - \alpha h_i)^{2t} v_0 + \sum_{k=0}^{t-1} \alpha^2 \sigma_i^2 (1 - \alpha h_i)^{2 k }$$

considering the sum of the geometric series $\sum_{k=0}^{t-1} \alpha^2 \sigma_i^2 (1 - \alpha h_i)^{2 k } = \frac{\alpha^2 \sigma_i^2(1 - (1 - \alpha h_i)^{2t})}{1 - (1 - \alpha h_i)^2}$ thus our final equation for t-step-ahead parameters is:

$$v_i^t = (1 - \alpha h_i)^{2t} (v_0 - \frac{\alpha^2 \sigma_i^2}{1 - (1 - \alpha h_i)^2}) + \frac{\alpha^2 \sigma_i^2}{1 - (1 - \alpha h_i)^2} \\ v_i^t = \beta^t (v_0 - v^*) + v^*$$ where $$\beta = (1 - \alpha h_i)^2 \\
v^* =  \frac{\alpha^2 \sigma_i^2}{1 - (1 - \alpha h_i)^2}$$

Given the above formula, we want to know how many steps it will take to reach a desired loss, since we want to compare the impact of batch size on this convergence for various optimization algorithms.

In [None]:
def make_gradients(size=1e3):

  h_vec = 1 / np.arange(1, size+1, 1)
  h = np.array(h_vec) * np.eye(len(h_vec))

  return torch.from_numpy(h)

In [None]:
loss = lambda h, v: (1/2) * torch.sum(h * v)

In [None]:
def t_step_ahead_var(alpha, sigma_sq, hi, t, v0):

  beta = (1 - alpha*hi)**2
  v_star = ((alpha**2)*sigma_sq) / (1 - beta)

  return (beta**t) * (v0 - v_star) + v_star

In [None]:
def steps_to_loss(alpha, sigma_sq, hi, v0, loss):

  high = 1e20
  low = 0

  while low < high:

    mid = round((low + high) / 2)
    v_mid = t_step_ahead_var(alpha, sigma_sq, hi, mid, v0)
    l = loss(hi, v_mid)

    # if we have reduced loss more than we needed
    if l<loss:
      high=mid
    # if we need more steps
    else:
      low=mid

  return low

In [None]:
def loss_curve(alpha, sigma_sq, hi, max_t, v0):
  losses = []
  for t in range(1, max_t+1):
    vt = t_step_ahead_var(alpha, sigma_sq, hi, t, v0)
    l = loss(hi, vt)
    losses.append(l)
  return losses

In [None]:
def grid_search_alphas(sigma_sq, hi, v0, loss, low_a = 1e-10, high_a = 0.75):

  alphas = np.arange(low_a, high_a, 1000)
  best_steps = 1e50
  best_alpha = None
  for a in alphas:
    steps = steps_to_loss(a, sigma_sq, hi, v0, loss)
    if steps < best_steps:
      best_steps = steps
      best_alpha = a

  return best_alpha, best_steps

### Momentum

Following Appendix B, and the implementation: https://github.com/gd-zhang/noisy-quadratic-model/blob/master/nqm.ipynb

In [None]:
import math

In [None]:
def simulate_momentum_dynamics(h, sigma_sq, s_init, lrate, k, momentum=0.0):
    ndim = h.shape[0]

    init_p = torch.zeros((ndim, 3))
    init_p[:, 0] = s_init

    x = lrate*h
    noise = ((lrate)**2 * sigma_sq).unsqueeze(-1).expand(ndim, 3)

    # defining the transition matrix (see eqn.(21) in the paper)
    trans_mat = torch.zeros((ndim, 3, 3))
    trans_mat[:, :, 1] = momentum ** 2
    trans_mat[:, 0, 0] = (1 - x) ** 2
    trans_mat[:, 1, 0] = x ** 2
    trans_mat[:, 2, 0] = - (1 - x) * x
    trans_mat[:, 0, 2] = 2 * (1 - x) * momentum
    trans_mat[:, 1, 2] = -2 * x * momentum
    trans_mat[:, 2, 2] = (1 - 2 * x) * momentum

    # eigen-decomposition
    e, v = torch.linalg.eig(trans_mat)
    v_inv = torch.linalg.inv(v)

    trans_init_p = torch.matmul(v_inv, init_p.unsqueeze(-1).type(torch.complex64)).squeeze()
    trans_noise = torch.matmul(v_inv, noise.unsqueeze(-1).type(torch.complex64)).squeeze()

    p_star = trans_noise / (1 - e)

    trans_final_p = e ** k * (trans_init_p - p_star) + p_star
    final_p = torch.matmul(v, trans_final_p.unsqueeze(-1)).squeeze()
    return final_p[:, 0].real

In [None]:
def time_to_loss_wm(h, sigma_sq, s_init, lrate, momentum, C, counts=None, prec=None):

  high = 10000000
  low = 0

  while low < high - 1:
    mid = int(math.floor((low + high) / 2))

    s = simulate_momentum_dynamics(h, sigma_sq, s_init, lrate, mid, momentum=momentum, prec=prec)
    L = loss(h, s, counts)
    if L < C:
      high = mid
    else:
      low = mid

  return low

In [None]:
def grid_search_alphas_and_momentum(h, sigma_sq, s_init, C, low_a = 1e-10, high_a = 0.75, max_momentum = 0.95):
  alphas = np.arange(low_a, high_a, 100)
  momentums = np.arange(0, max_momentum, 20)

  best_steps = 1e50
  best_alpha = None
  best_moment = None

  for a in alphas:
    for m in momentums:

      steps = time_to_loss_wm(h, sigma_sq, s_init, a, m, C)

      if steps < best_steps:

        best_steps = steps
        best_alpha = a
        best_moment = m

  return best_alpha, best_moment, best_steps

### Adam

As in section 3.3, if we use $P = H^p$ then the preconditioned case works exactly like the non preconditioned case, but with hessian $\tilde{H} = P^{-\frac{1}{2}} H P^{-\frac{1}{2}}$ and gradient covariance $\tilde{C} = P^{-\frac{1}{2}} C P^{-\frac{1}{2}}$.

In [None]:
def new_hessian(h, p=0.5):

  P = h**p
  P_exp = P ** (-.5)
  new_h = torch.matmul(torch.matmul(P_exp, h), P_exp)

  return new_h

* Create a method for optimizing learning rate schedules. You can either use dynamic programming using equation (3) as in the paper (see footnote on page 7), or a simpler empirical method such as black-box optimization (perhaps with simpler schedule).

Notes: We want to find the optimal piecewise constant learning rate schedule, i.e. $lr_1$ for the first 20 epochs, then $lr_2$ for the next 30 etc. In the paper we see that this is done first by creating an optimization algorithm that minimizes loss by finding the best learning rate schedule given the number of steps. Then binary search is conducted on the number of steps to find the schedule that reaches the target loss in the least amount of steps.

The algorithm used to find the optimal learning rate schedule is the BFGS algorithm, which is a second-order optimization algorithm. It does not calculate the hessian directly, instead calculating approximations via the gradient.

In [None]:
from scipy.optimize import Bounds, minimize
from jax import grad
import jax.numpy as jnp

In [None]:
loss_jax = lambda h, v: (1/2) * jnp.sum(h * v)

In [None]:
def make_gradients_jax(size=1e3):

  h_vec = 1 / np.arange(1, size+1, 1)
  h = np.array(h_vec) * jnp.eye(len(h_vec))

  return h

In [None]:
def t_step_ahead_var_jax(alpha, sigma_sq, hi, t, v0):

  beta = (1 - alpha*hi)**2
  v_star = ((alpha**2)*sigma_sq) / (1 - beta)

  return (beta**t) * (v0 - v_star) + v_star

In [None]:
def t_ahead_var_with_lrs(alphas, sigma_sq, hi, t_intervals, v0):
  v = v0
  alphas = jnp.exp(alphas)
  combs = zip(alphas, t_intervals)
  for a, t in combs:
    v = t_step_ahead_var_jax(a, sigma_sq, hi, t, v)
  return loss_jax(hi, v)

In [None]:
def optimal_alphas(sigma_sq, hi, t_intervals, v0, max_bound):

  init_alphas = -1 * jnp.ones(len(t_intervals))
  loss_wrap = lambda alphas: t_ahead_var_with_lrs(alphas, sigma_sq, hi, t_intervals, v0)
  gradient = grad(t_ahead_var_with_lrs)
  gradient_wrap = lambda alphas: gradient(alphas, sigma_sq, hi, t_intervals, v0)
  bounds = Bounds(-np.inf, max_bound)
  result = minimize(loss_wrap, init_alphas, jac=gradient_wrap, bounds=bounds)

  return result.x

In [None]:
def steps_to_loss_pwc(sigma_sq, hi, v0, loss, max_bound, pieces=50):

  high = 1e20
  low = 0
  best_alphas = None
  min_steps = None

  while low < high:

    mid = round((low + high) / 2)
    mid = int(min(mid, 2**31 - 1))
    t_intervals = mid * jnp.ones(pieces)
    alphas = optimal_alphas(sigma_sq, hi, t_intervals, v0, max_bound)
    l = t_ahead_var_with_lrs(alphas, sigma_sq, hi, t_intervals, v0)

    # if we have reduced loss more than we needed
    if l<loss:
      high=mid
      best_alphas = alphas
      min_steps = mid

    # if we need more steps
    else:
      low=mid

  return min_steps, best_alphas

In [None]:
hi = make_gradients_jax()

In [1]:
batch_sizes = 2 ** jnp.arange(4, 12)
s_init = jnp.ones(10**3)
target_loss = 0.01

for i, bs in enumerate(batch_sizes):
  sigma_sq = hi / bs
  _, lrates = steps_to_loss_pwc(sigma_sq, hi, s_init, target_loss, 2)
  #if bs.item() != 64: # TODO: figure out what's wrong with bs=64
  plt.semilogy(lrates.detach().numpy(), label="BS " + str(bs), basey=2)
plt.xlabel('Pieces')
plt.ylabel('Learning Rate')
plt.ylim(ymax=4)
plt.tight_layout()
plt.legend(loc=3);

* Check that at very small batch sizes, the optimal learning rate scales with batch size as expected: proportional to the batch size for SGD, proportional to the square root of the batch size for Adam.

In [None]:
batch_sizes = [2 ** i for i in range(2, 22)]
hi = make_gradients()
v0 = torch.ones(10**3)
target_loss = 0.1

num_steps = []
optimal_lrs = []

for bs in batch_sizes:
    sigma_sq = hi / bs
    optimal_alpha, steps = grid_search_alphas(sigma_sq, hi, v0, target_loss)
    num_steps.append(steps)
    optimal_lrs.append(optimal_alpha)

In [None]:
plt.loglog(batch_sizes, optimal_lrs, basex=2, basey=2)
#plt.ylim(2 ** (-14), 2 ** 3)
plt.xlabel("Batch size")
plt.ylabel("Optimal Learning Rate")
plt.tight_layout()

In [None]:
batch_sizes = [2 ** i for i in range(2, 22)]
hi = make_gradients()
v0 = torch.ones(10**3)
target_loss = 0.1

num_steps_mom = []
optimal_lrs_mom = []

for bs in batch_sizes:
    sigma_sq = hi / bs
    optimal_alpha, optimal_moment, steps_to_target = grid_search_alphas_and_momentum(hi, sigma_sq, v0, target_loss)
    num_steps_mom.append(steps_to_target)
    optimal_lrs_mom.append(optimal_alpha)

In [None]:
plt.loglog(batch_sizes, optimal_lrs_mom, basex=2, basey=2)
plt.ylim(2 ** (-14), 2 ** 3)
plt.xlabel("Batch size")
plt.ylabel("Optimal Learning Rate")
plt.tight_layout()

* Look at the relationship between the batch size and the number of steps to reach a target loss. Study the effects of momentum and using Adam on this relationship.

In [None]:
plt.loglog(batch_sizes, num_steps, basex=2, basey=2)
plt.xlabel("Batch size")
plt.ylabel("Steps until target loss")
plt.tight_layout()

In [None]:
plt.loglog(batch_sizes, num_steps_mom, basex=2, basey=2)
plt.xlabel("Batch size")
plt.ylabel("Steps until target loss")
plt.tight_layout()