In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import flax
from typing import Any
import numpy as np

# Second-Order Optimization

Let's discuss optimization in the context of **second-order gradient descent**. The idea here is to use second-order gradient information in combination with first-order information to make more accurate update steps. The basic algorithm is also known as **Newton's method**. Examining the difference between first and second-order gradient updates:

$$
\begin{align}
\theta & \leftarrow \theta - \alpha \nabla_\theta \; L(\theta)  & & \text{(First-order gradient descent)}\\
\theta & \leftarrow \theta - \alpha H(\theta)^{-1} \nabla_\theta \; L(\theta)  & & \text{(Second-order gradient descent)}\\
\end{align}
$$

is the presence of the $H(\theta)^{-1}$ term. This is a matrix so important we give it a name, the **Hessian**, and it is the matrix of all pairwise second-order derivatives. By scaling our gradients with the inverse Hessian, we get a number of nice properties (which we will examine shortly). The downside of course is the cost; calculating $H(\theta)$ itself is expensive, and inverting it even more so.

In the rest of this page, we'll look at:
- Interpretations of second-order descent
- Approximations to computing $H(\theta)^{-1}$ in classical optimization
- Approximations to computing $H(\theta)^{-1}$ in deep learning

## Second-order descent as preconditioning

A black-box way to view second-order descent is as a specific type of **preconditioning**. Recall that a preconditioner is a linear transformation of the gradient, often written as a matrix:

$$
\theta & \leftarrow \theta - P \nabla_\theta \; L(\theta).
$$

We often want to use preconditioning when we want to update certain parameters at different speeds. For example, a diagonal preconditioner can be used to specify per-parameter learning rates. In classical optimization problems, this can be helpful is we know that e.g. certain inputs have a much higher magnitude, etc. The inverse Hessian ends up being the 'correct' way to precondition a gradient update using second-order information, as shown next.

## Second-order descent as solving a quadratic approximation

We can approximate the true loss function using a second-order Taylor series expansion:

$$
\tilde{L}(\theta + \theta') = L(\theta) + \nabla L(\theta)^{T}\theta' + \dfrac{1}{2} \theta'^{T} \nabla^2 L(\theta) \theta'.
$$

Assuming $\nabla^2 L(\theta)$ is invertible, we can now solve for the $\theta'$ that minimizes this approximate loss:

$$
\begin{align}
& \nabla \tilde{L}(\theta + \theta') = 0 \\
& \nabla L(\theta) + \nabla^2 L(\theta) \theta' = 0 \\
& \theta' = (\nabla^2 L(\theta))^{-1} \nabla L(\theta) \\
& \theta' = H(\theta)^{-1} \nabla L(\theta) \\
\end{align}
$$

which gives us our original second-order descent method.

## Second-order descent as knowing how far to step

A key property of a proper inverse Hessian is that it is positive definite -- in other words, matrix multiplying will never flip the sign of a vector.

**kevin todo notes**.