# Newton method


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax

# We enable double precision in JAX
jax.config.update("jax_enable_x64", True)

We consider a random matrix $A \in \mathbb{R}^{n\times n}$, with $n = 100$ and a random vector $\mathbf{x}_{\text{ex}} \in \mathbb{R}^n$.
We define then $\mathbf{b} = A \, \mathbf{x}_{\text{ex}}$.


In [None]:
n = 100

np.random.seed(0)
A = np.random.randn(n, n)
x_ex = np.random.randn(n)
b = A @ x_ex

Define the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^2
$$


In [None]:
def loss(x):
    return jnp.sum(jnp.square(A @ x - b))


print(loss(x_ex))

By using the `jax` library, implement and compile functins returning the gradient ($\nabla \mathcal{J}(\mathbf{x})$) and the hessian ($\nabla^2 \mathcal{J}(\mathbf{x})$) of the loss function.


In [None]:
grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

Check that the results are correct (up to machine precision).


In [None]:
np.random.seed(0)
x_guess = np.random.randn(n)

G_ad = grad_jit(x_guess)
G_ex = 2 * A.T @ (A @ x_guess - b)
print(np.linalg.norm(G_ad - G_ex))

H_ad = hess_jit(x_guess)
H_ex = 2 * A.T @ A
print(np.linalg.norm(H_ad - H_ex))


* JAX enables us to compute the Hessian $\nabla^2 \mathcal{J}(\mathbf{x})$. This is a **function**.
* Then, we use it by evaluating the Hessian at one point $\nabla^2 \mathcal{J}(\mathbf{x})|_{\mathbf{x}_0}$, for instance when we apply the Newton's method. This is a **matrix**.
* However, when we use this matrix (the evaluation of the Hessian at a point), it rarely happens that we need the whole structure of the matrix. In most case it is sufficient to know the action that the matrix has on a vector $\mathbf{v}_0$. **Avoiding storing a full matrix greatly diminishes the cost in memory (which is quadratic).**

In other words, do we need to compute $\nabla^2 \mathcal{J}(\mathbf{x})|_{\mathbf{x}_0}$ in order to compute $\nabla^2 \mathcal{J}(\mathbf{x})|_{\mathbf{x}_0} \mathbf{v}_0$? 

From calcolus we know that

$$
\nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v} = \nabla_{\mathbf{x}} \phi(\mathbf{x}, \mathbf{v})
$$

where

$$
\phi(\mathbf{x}, \mathbf{v}) := \nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}
$$

because

$$
\left( \nabla_{\mathbf x}\phi(\mathbf x,\mathbf v) \right)_i
= \frac{\partial}{\partial x_i}
\left( \sum_{j=1}^n \frac{\partial \mathcal J}{\partial x_j}(\mathbf x)\, v_j \right)
= \sum_{j=1}^n \frac{\partial^2 \mathcal J}{\partial x_i \partial x_j}(\mathbf x)\, v_j .
$$

Then, in JAX we can compute $\nabla^2 \mathcal{J}(\mathbf{x})|_{\mathbf{x}_0} \mathbf{v}_0$ by evaluating at $(\mathbf{x}_0, \mathbf{v}_0)$ the function
$$
(\mathbf{x}, \mathbf{v}) \mapsto \nabla_{\mathbf{x}} (\nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}).
$$

Notice that in this implementation there is never a matrix since $\mathcal{J}$ is a scalar field.

Implement this version of the Hessian and compare the computational performance w.r.t. the full hessian computation.

In [None]:
np.random.seed(1)
v = np.random.randn(n)

hvp_basic = lambda x, v: hess(x) @ v

# SOLUTION-BEGIN
gvp = lambda x, v: jnp.dot(grad(x), v)
hvp = lambda x, v: jax.grad(gvp, argnums=0)(x, v)

hvp_basic_jit = jax.jit(hvp_basic)
hvp_jit = jax.jit(hvp)
# SOLUTION-END

Hv_ad = hvp_jit(x_guess, v)
Hv_ex = H_ex @ v
print(np.linalg.norm(Hv_ad - Hv_ex))

In [None]:
%timeit hvp_basic_jit(x_guess, v)
%timeit hvp_jit(x_guess, v)

Implement the Newton method for the minimization of the loss function $\mathcal{L}$. Set a maximim number of 100 iterations and a tolerance on the increment norm of $\epsilon = 10^{-8}$.


In [None]:
x = x_guess.copy()
num_epochs = 100
eps = 1e-8

for epoch in range(num_epochs):
    # SOLUTION-BEGIN
    H = hess_jit(x)
    G = grad_jit(x)
    incr = np.linalg.solve(H, -G)
    x += incr
    # SOLUTION-END

    print("============ epoch %d" % epoch)
    print("loss: %1.3e" % loss_jit(x))
    print("grad: %1.3e" % np.linalg.norm(G))
    print("incr: %1.3e" % np.linalg.norm(incr))

    if np.linalg.norm(incr) < eps:
        break

rel_err = np.linalg.norm(x - x_ex) / np.linalg.norm(x_ex)
print(f"Relative error: {rel_err:1.3e}")

Solve the system with `jax.scipy.sparse.linalg.cg` and use the "matrix-free" version of the Hessia

In [None]:
x = x_guess.copy()
num_epochs = 100
eps = 1e-8

for epoch in range(num_epochs):
    # SOLUTION-BEGIN
    G = grad_jit(x)
    incr, info = jax.scipy.sparse.linalg.cg(lambda y: hvp_jit(x, y), -G, tol=eps)
    x += incr
    # SOLUTION-END

    print("============ epoch %d" % epoch)
    print("loss: %1.3e" % loss_jit(x))
    print("grad: %1.3e" % np.linalg.norm(G))
    print("incr: %1.3e" % np.linalg.norm(incr))

    if np.linalg.norm(incr) < eps:
        break

rel_err = np.linalg.norm(x - x_ex) / np.linalg.norm(x_ex)
print(f"Relative error: {rel_err:1.3e}")

Repeat the optimization loop for the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^4
$$


In [None]:
def loss(x):
    return jnp.sum((A @ x - b) ** 4)


grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

In [None]:
x = x_guess.copy()
num_epochs = 100
eps = 1e-8

hist = [loss_jit(x)]
for epoch in range(num_epochs):
    hist.append(loss_jit(x))

    H = hess_jit(x)
    G = grad_jit(x)
    incr = np.linalg.solve(H, -G)
    x += incr
    
    if np.linalg.norm(incr) < eps:
        print("convergence reached!")
        break

plt.semilogy(hist, "o-")
print("epochs: %d" % epoch)
print("relative error: %1.3e" % (np.linalg.norm(x - x_ex) / np.linalg.norm(x_ex)))

# Quasi-Newton Optimization with BFGS Update

This algorithm minimizes an objective function $ f(\mathbf{x}) $ using a quasi-Newton method with the BFGS update. The goal is to iteratively update the solution $ \mathbf{x} $ and approximate the inverse Hessian until convergence criteria are met.

## Algorithm

1. **Initialization**:

   - Set the initial guess $ \mathbf{x} = \mathbf{x}\_{\text{guess}} $.
   - Let $ \mathbf{I} $ be the identity matrix, and initialize $ B^{-1} = \mathbf{I} $.
   - Compute the initial gradient $ \nabla f = \nabla f(\mathbf{x}\_{\text{guess}}) $.
   - Initialize the loss history: $ \text{history} = [f(\mathbf{x}_{\text{guess}})] $.
   - Set $ \text{epoch} = 0 $.

   $$
   B^{-1} = \mathbf{I}, \quad \nabla f = \nabla f(\mathbf{x}_{\text{guess}}), \quad \text{history} = [f(\mathbf{x}_{\text{guess}})]
   $$

2. **Iterative Updates**:

   - While $ \|\nabla f\| > \text{tol} $ and epoch $<$ max epoch:

     - Increment the epoch counter:

       $$
       \text{epoch} \leftarrow \text{epoch} + 1
       $$

     - Compute the search direction:

       $$
       \mathbf{p} = -B^{-1} \nabla f
       $$

     - Perform a line search to find the step size $ \alpha $ using `sp.optimize.line_search`:

       $$
       \alpha \leftarrow \text{line\_search}(f, \nabla f, \mathbf{x}, \mathbf{p})
       $$

       If $ \alpha $ is not found (scipy returns `None`), set $ \alpha = 10^{-8} $.

     - Update the solution vector:

       $$
       \mathbf{x}_{\text{new}} = \mathbf{x} + \alpha \mathbf{p}
       $$

     - Compute the displacement:

       $$
       \mathbf{s} = \mathbf{x}_{\text{new}} - \mathbf{x}
       $$

     - Update $ \mathbf{x} $:

       $$
       \mathbf{x} \leftarrow \mathbf{x}_{\text{new}}
       $$

     - Compute the new gradient and gradient difference:

       $$
       \nabla f_{\text{new}} = \nabla f(\mathbf{x}), \quad \mathbf{y} = \nabla f_{\text{new}} - \nabla f
       $$

       Update $ \nabla f $:

       $$
       \nabla f \leftarrow \nabla f_{\text{new}}
       $$

     - Compute the scalar $ \rho $:

       $$
       \rho = \frac{1}{\mathbf{y}^\top \mathbf{s}}
       $$

     - Update the inverse Hessian approximation $ B^{-1} $ using the Sherman–Morrison formula:

       $$
       \mathbf{E} = \mathbf{I} - \rho \mathbf{y} \mathbf{s}^\top
       $$

       $$
       B^{-1} \leftarrow \mathbf{E}^\top B^{-1} \mathbf{E} + \rho \mathbf{s} \mathbf{s}^\top
       $$

     - Append the current loss value to history:
       $$
       \text{history.append}(f(\mathbf{x}))
       $$


In [None]:
max_epochs = 1000
tol = 1e-8

import scipy as sp

np.random.seed(0)

epoch = 0
x = x_guess.copy()
I = np.eye(x.size)
Binv = I
grad = grad_jit(x_guess)
history = [loss_jit(x_guess)]

while np.linalg.norm(grad) > tol and epoch < max_epochs:
    epoch += 1

    # SOLUTION-BEGIN
    # search direction
    p = -Binv @ grad

    # line search
    alpha = sp.optimize.line_search(loss_jit, grad_jit, x, p)[0]
    alpha = 1e-8 if alpha is None else alpha
    x_new = x + alpha * p

    # computing y and s
    s = x_new - x
    x = x_new
    grad_new = grad_jit(x_new)
    y = grad_new - grad
    grad = grad_new

    # Sherman–Morrison update
    rho = 1.0 / (np.dot(y, s))
    E = I - rho * np.outer(y, s)
    Binv = E.T @ Binv @ E + rho * np.outer(s, s)
    # SOLUTION-END

    history.append(loss_jit(x))

plt.semilogy(history, "o-")