# IEMS 351 Lab 4 Automatic Differentiation via Jax

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import grad, hessian, jacfwd, jacrev
from matplotlib.patches import FancyArrowPatch

## Example 1 
Let $f:\mathbb{R}^n \mapsto \mathbb{R}$:
$$
f(x) = \frac{1}{2} x^\top x,
$$
Let $y, d \in \mathbb{R}^n$ and $h: \mathbb{R} \mapsto \mathbb{R}^n$:
$$
h(t) = y + t \cdot d.
$$
Define a composition function $g(t) = f(h(t))$. 

In [None]:
# define function
n = 4


def f(x):
    return 0.5 * x @ x


#
def h(t):
    y = jnp.array([1.0, 1.0, 1.0, 1.0])
    d = jnp.array([1, 2, 3, 4])
    return x + t * d


#
def g(t):
    return f(h(t))

## Compute the gradient of $f(x)$ at x = (1,1,1,1)

In [None]:
x = jnp.array([1.0, 1.0, 1.0, 1.0])
grad_f = grad(f)(x)
print(grad_f)

## Compute the jacobian (transponse of gradient) of $g(t)$ at $t = 0$ 

In [None]:
t = 0.0
grad_h = jacfwd(h)(t)
print(grad_h)

## Compute the gradient of $g(t) = f(h(t))$ at $t=0$ using chain rule

In [None]:
grad_g_chain_rule = grad_h @ grad_f
print("Gradient of g(t) = f(h(t)) at t=0 using chain rule: {}".format(grad_g_chain_rule))

## Compute the gradient of $g(t) = f(h(t))$ at $t=0$ using JAX

In [None]:
grad_g_jax = grad(g)(t)
print("Gradient of g(t) = f(h(t)) at t=0 using JAX: {}".format(grad_g_jax))

## Rosenbrock Gradient

In [None]:
def rosenbrock(x, a=1, b=100):
    return (a - x[0]) ** 2 + b * (x[1] - x[0] ** 2) ** 2


rosenbrock_grad = grad(rosenbrock)

In [None]:
# Next, let us plot the progress of gradient method
x = np.arange(2, 15, 0.1)
y = np.arange(-5, 20, 0.1)

X, Y = np.meshgrid(x, y)
a = 1
b = 100
Z = (a - X) ** 2 + b * (Y - X**2) ** 2

x0 = np.array([1.5, 1.5])

fig, ax = plt.subplots(1, 1)
CS = ax.contour(X, Y, Z, 10, cmap="jet", linewidths=2)
ax.clabel(CS, inline=1, fontsize=10)

# max number of iterations
max_iterations = 100
# reset the stepszie alpha
alpha = 0.5
# initialize current x
cur_x = x0
ax.text(cur_x[0] + 0.1, cur_x[1] + 0.1, r"$x^0$")
ax.plot(cur_x[0], cur_x[1], "or", markersize=4)
for i in range(max_iterations):
    cur_grad = rosenbrock_grad(jnp.array(cur_x))  # Compute gradient
    # Update step
    next_x = cur_x - alpha * cur_grad

    # Plot the arrow showing the step
    arrow = FancyArrowPatch(
        tuple(cur_x),  # Convert points to tuples
        tuple(next_x),
        arrowstyle="simple",
        color="k",
        mutation_scale=5,
    )
    ax.add_patch(arrow)
    ax.plot(next_x[0], next_x[1], "or", markersize=4)
    ax.text(next_x[0] + 0.1, next_x[1] + 0.1, r"$x^{}$".format(i + 1))

    # Update current point
    cur_x = next_x

ax.set_xlabel("x")
ax.set_ylabel("y")
plt.show()