# Newton method

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import jax.numpy as jnp
import jax

# We enable double precision in JAX
from jax.config import config
config.update("jax_enable_x64", True)

We consider a random matrix $A \in \mathbb{R}^{n\times n}$, with $n = 100$ and a random vector $\mathbf{x}_{\text{ex}} \in \mathbb{R}^n$.
We define then $\mathbf{b} = A \, \mathbf{x}_{\text{ex}}$.

In [2]:
n = 100

np.random.seed(0)
A = np.random.randn(n,n)
x_ex = np.random.randn(n)
b = A @ x_ex

Define the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^2
$$

In [3]:
def loss(x):
    return jnp.sum(jnp.square(A @ x - b))

loss(x_ex)

Array(0., dtype=float64)

By using the `jax` library, implement and compile functins returning the gradient ($\nabla \mathcal{J}(\mathbf{x})$) and the hessian ($\nabla^2 \mathcal{J}(\mathbf{x})$) of the loss function (*Hint*: use the `jacrev` or the `jacfwd`) function.

In [7]:
grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

Check that the results are correct (up to machine precision).

In [30]:
np.random.seed(0)
x_guess = np.random.randn(n)

G_ad = grad_jit(x_guess)
G_ex = 2 * A.T @ (A @ x_guess - b)
print(np.linalg.norm(G_ad - G_ex))

H_ad = hess_jit(x_guess)
H_ex = 2 * A.T @ A
print(np.linalg.norm(H_ad - H_ex))

2.1550089998180016e-12
4.829664679334261e-13


Exploit the formula
$$
\nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v} = \nabla_{\mathbf{x}} \phi(\mathbf{x}, \mathbf{v})
$$
where 
$$
\phi(\mathbf{x}, \mathbf{v}) := \nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}
$$
to write an optimized function returning the hessian-vector-product
$$
(\mathbf{x}, \mathbf{v}) \mapsto \nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v}.
$$
Compare the computational performance w.r.t. the full hessian computation.

In [14]:
np.random.seed(1)
v = np.random.randn(n)

hvp_basic = lambda x, v: hess(x) @ v
gvp = lambda x, v: jax.jvp(loss, [x], [v])[1] #jax.dot(grad(x),v)
hvp = lambda x, v: jax.grad(gvp, argnums=0)(x,v)

hvp_basic_jit = jax.jit(hvp_basic)
hvp_jit = jax.jit(hvp)

Hv_ad = hvp_jit(x_guess, v)
Hv_ex = H_ex @ v
print(np.linalg.norm(Hv_ad - Hv_ex))

1.2744887647117243e-12


In [15]:
%timeit hvp_basic_jit(x_guess, v)
%timeit hvp_jit(x_guess, v)

224 µs ± 4.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
16.4 µs ± 54.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Implement the Newton method for the minimization of the loss function $\mathcal{L}$. Set a maximim number of 100 iterations and a tolerance on the increment norm of $\epsilon = 10^{-8}$.

In [31]:
x = x_guess.copy()
n_epochs = 100
tolerance = 1e-8

for epoch in range(n_epochs):
    loss_val = loss_jit(x)
    grad_val = grad_jit(x)
    hess_val = hess_jit(x)
    
    increment = jnp.linalg.solve(hess_val, -grad_val)
    
    x += increment
    norm_increment = jnp.linalg.norm(increment)
    print(f'======Epoch {epoch}========')
    print(f'Loss = {loss_val:3e}')
    print(f'Increment = {norm_increment:3e}')
    
    if jnp.linalg.norm(increment) < tolerance:
        break
    

Loss = 3.738843e+04
Increment = 1.547522e+01
Loss = 1.434584e-22
Increment = 1.252722e-09


Repeat the optimization loop for the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_4^4
$$

In [32]:
def loss_4(x):
    return jnp.sum(jnp.power(A @ x - b, 4))

loss_4(x_ex)

Array(0., dtype=float64)

In [33]:
grad_4 = jax.grad(loss_4)
hess_4 = jax.jacfwd(jax.jacrev(loss_4))

loss_4_jit = jax.jit(loss_4)
grad_4_jit = jax.jit(grad_4)
hess_4_jit = jax.jit(hess_4)

In [34]:
x = x_guess.copy()
n_epochs = 100
tolerance = 1e-8

for epoch in range(n_epochs):
    loss_val = loss_4_jit(x)
    grad_val = grad_4_jit(x)
    hess_val = hess_4_jit(x)
    
    increment = jnp.linalg.solve(hess_val, -grad_val)
    
    x += increment
    norm_increment = jnp.linalg.norm(increment)
    print(f'======Epoch {epoch}========')
    print(f'Loss = {loss_val:3e}')
    print(f'Increment = {norm_increment:3e}')
    
    if jnp.linalg.norm(increment) < tolerance:
        break
    

Loss = 2.432931e+08
Increment = 5.158408e+00
Loss = 4.805790e+07
Increment = 3.438938e+00
Loss = 9.492918e+06
Increment = 2.292626e+00
Loss = 1.875144e+06
Increment = 1.528416e+00
Loss = 3.703989e+05
Increment = 1.018945e+00
Loss = 7.316521e+04
Increment = 6.792965e-01
Loss = 1.445239e+04
Increment = 4.528642e-01
Loss = 2.854793e+03
Increment = 3.019095e-01
Loss = 5.639097e+02
Increment = 2.012729e-01
Loss = 1.113896e+02
Increment = 1.341820e-01
Loss = 2.200288e+01
Increment = 8.945465e-02
Loss = 4.346247e+00
Increment = 5.963644e-02
Loss = 8.585180e-01
Increment = 3.975763e-02
Loss = 1.695838e-01
Increment = 2.650508e-02
Loss = 3.349803e-02
Increment = 1.767006e-02
Loss = 6.616896e-03
Increment = 1.178005e-02
Loss = 1.307041e-03
Increment = 7.853358e-03
Loss = 2.581810e-04
Increment = 5.235572e-03
Loss = 5.099871e-05
Increment = 3.490383e-03
Loss = 1.007382e-05
Increment = 2.326920e-03
Loss = 1.989890e-06
Increment = 1.551280e-03
Loss = 3.930647e-07
Increment = 1.034187e-03
Loss = 7.7