# Newton method

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time
import jax.numpy as jnp
import jax

# We enable double precision in JAX
from jax import config
config.update("jax_enable_x64", True)

We consider a random matrix $A \in \mathbb{R}^{n\times n}$, with $n = 100$ and a random vector $\mathbf{x}_{\text{ex}} \in \mathbb{R}^n$.
We define then $\mathbf{b} = A \, \mathbf{x}_{\text{ex}}$.

In [3]:
n = 100

np.random.seed(0)
A = np.random.randn(n,n)
x_ex = np.random.randn(n)
b = A @ x_ex

Define the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^2
$$

In [4]:
def loss(x):
    return jnp.sum(jnp.square( A @ x - b))

By using the `jax` library, implement and compile functins returning the gradient ($\nabla \mathcal{J}(\mathbf{x})$) and the hessian ($\nabla^2 \mathcal{J}(\mathbf{x})$) of the loss function (*Hint*: use the `jacrev` or the `jacfwd`) function.

In [5]:
grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

Check that the results are correct (up to machine precision).

In [6]:
np.random.seed(0)
x_guess = np.random.randn(n)

G_ad = grad_jit(x_guess)
G_ex = 2 * A.T @ (A @ x_guess - b)
print(np.linalg.norm(G_ad - G_ex))

H_ad = hess_jit(x_guess)
H_ex = 2 * A.T @ A
print(np.linalg.norm(H_ad - H_ex))

2.1550089998180016e-12
4.829664679334261e-13


Exploit the formula
$$
\nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v} = \nabla_{\mathbf{x}} \phi(\mathbf{x}, \mathbf{v})
$$
where 
$$
\phi(\mathbf{x}, \mathbf{v}) := \nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}
$$
to write an optimized function returning the hessian-vector-product
$$
(\mathbf{x}, \mathbf{v}) \mapsto \nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v}.
$$
Compare the computational performance w.r.t. the full hessian computation.

### Why This Formula Works
This formula relies on the fact that differentiation is a linear operator. By chaining gradients, we compute exactly the derivative needed for the Hessian-vector product without ever constructing the Hessian itself.

In [8]:
np.random.seed(1)
v = np.random.randn(n)

hvp_basic = lambda x, v: hess(x) @ v
gvp = lambda x,v : jnp.dot(grad(x), v)
hvp = lambda x,v : jax.grad(gvp, argnums=0)(x,v)

hvp_basic_jit = jax.jit(hvp_basic)
hvp_jit = jax.jit(hvp)

Hv_ad = hvp_jit(x_guess, v)
Hv_ex = H_ex @ v
print(np.linalg.norm(Hv_ad - Hv_ex))

1.2744887647117243e-12


In [9]:
%timeit hvp_basic_jit(x_guess, v)
%timeit hvp_jit(x_guess, v)

172 µs ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.3 µs ± 4.96 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Implement the Newton method for the minimization of the loss function $\mathcal{L}$. Set a maximim number of 100 iterations and a tolerance on the increment norm of $\epsilon = 10^{-8}$.

In [None]:
x = x_guess.copy()
num_epochs = 100
eps = 1e-8

elapsed_time = 0
for epoch in range(num_epochs):
    t0 = time.time()
    # compute the gradient G and the increment incr by using the CONIUATE GRADIENT
    G = grad_jit(x)
    H = hess_jit(x)
    incr = np.linalg.solve(H, -G) # if you put -G you can su,the incr
    # YOU SHOULD NEVER COMPUTE EXPLICITLY THE HESSIAN
    elapsed_time += time.time() - t0
    x += incr

    print("========== epoch %d" % epoch)
    print("loss: %1.3e" % loss_jit(x))
    print("grad: %1.3e" % np.linalg.norm(G))
    print("loss: %1.3e" %  np.linalg.norm(incr))

    if np.linalg.norm(incr) < eps:
        break


print(f"Elapsed time: {elapsed_time:.4f} [s]")
real_err = np.linalg.norm(x -x_ex) / np.linalg.norm()

Repeat the optimization loop for the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_4^4
$$