# Homework 6

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D

## Problem 1.2

In [None]:
# function and gradient
def func(x):
    return (1 / 3) * x[0] ** 3 - 2 * x[0] * x[1] + (1 / 2) * x[1] ** 2 - 8 * x[0] + 4 * x[1]


def grad_func(x):
    return jax.grad(func)(x)

In [None]:
x1_range = np.linspace(-5.5, 5.5, 200)
x2_range = np.linspace(-5.5, 5.5, 200)
X1, X2 = np.meshgrid(x1_range, x2_range)
Z = np.array([[func(jnp.array([x1, x2])) for x1 in x1_range] for x2 in x2_range])

iterations = 100
alpha = 1e-1

# initial point
initial_point = jnp.array([1.0, 1.0])

x = initial_point
fig, ax = plt.subplots(figsize=(8, 6))
CS = ax.contour(X1, X2, Z, 10, cmap="plasma", linewidths=2, levels=30)
ax.clabel(CS, inline=1, fontsize=10)
ax.plot(x[0], x[1], "or", markersize=4)
ax.text(x[0] + 0.1, x[1] + 0.1, r"$x^0$")

for k in range(iterations):
    # Compute gradient and next point
    grad_x = grad_func(x)
    next_x = x - alpha * grad_x
    grad_norm = jnp.linalg.norm(grad_func(next_x))
    if k % 10 == 0:
        print(f"Iteration #{k}: x = {next_x}, L2 gradient norm = {grad_norm:.4f}")

    # Plot arrow
    arrow = FancyArrowPatch(
        (x[0], x[1]),
        (next_x[0], next_x[1]),
        arrowstyle="simple",
        color="k",
        mutation_scale=10,
    )
    ax.add_patch(arrow)

    ax.plot(next_x[0], next_x[1], "or", markersize=4)
    ax.text(next_x[0] + 0.1, next_x[1] + 0.1, f"$x^{k+1}$")

    x = next_x

# Set plot details
ax.set_xlabel("x1")
ax.set_ylabel("x2")
plt.grid()
plt.savefig("grad_desc_1.png", dpi=150)
plt.show()

In [None]:
x1_range = np.linspace(-5.5, 5.5, 200)
x2_range = np.linspace(-5.5, 5.5, 200)
X1, X2 = np.meshgrid(x1_range, x2_range)
Z = np.array([[func(jnp.array([x1, x2])) for x1 in x1_range] for x2 in x2_range])

fig, ax = plt.subplots(figsize=(8, 6))
contour = plt.contour(X1, X2, Z, levels=30, cmap="viridis")
ax.clabel(contour, inline=1, fontsize=10)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

stationary_points_numeric = [(0, -4), (4, 4)]
for point in stationary_points_numeric:
    plt.plot(point[0], point[1], "ro")  # Red dots for stationary points
plt.savefig("stat_point_1.png", dpi=150)
plt.show()

## Problem 1.3

In [None]:
x1, x2 = sp.symbols("x1 x2")
f = (1 / 3) * x1**3 - 2 * x1 * x2 + (1 / 2) * x2**2 - 8 * x1 + 4 * x2

Hessian = sp.hessian(f, (x1, x2))
Hessian_eval_1 = Hessian.subs({x1: 0, x2: -4})  # At (0, -4)
Hessian_eval_2 = Hessian.subs({x1: 4, x2: 4})  # At (4, 4)

# Check definiteness by eigenvalues
Hessian_eval_1_eigenvalues = [sp.N(eig) for eig in Hessian_eval_1.eigenvals()]
Hessian_eval_2_eigenvalues = [sp.N(eig) for eig in Hessian_eval_2.eigenvals()]

In [None]:
Hessian_eval_1_eigenvalues

In [None]:
Hessian_eval_2_eigenvalues

In [None]:
stationary_points = [(0, -4), (4, 4)]
stationary_values = [func(pt) for pt in stationary_points]

In [None]:
x1_vals = np.linspace(-5, 5, 100)
x2_vals = np.linspace(-5, 5, 100)
X1, X2 = np.meshgrid(x1_vals, x2_vals)
F = func((X1, X2))

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
surf = ax.plot_surface(X1, X2, F, cmap="viridis", edgecolor="none")

for i, (x, y) in enumerate(stationary_points):
    ax.scatter(x, y, stationary_values[i], color="darkred", s=50, label=f"Stationary Point {i+1}")

fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$f(x)$")

ax.legend(loc="upper left")
plt.savefig("3d_plot_1.png", dpi=150)
plt.show()

## Problem 2.2

In [None]:
def stationary_point(Q, g):
    Q_inv = np.linalg.inv(Q)
    x_stationary = -0.5 * Q_inv @ g
    return x_stationary


def is_invertible(Q):
    return np.linalg.det(Q) != 0

In [None]:
Q_q1 = np.array([[2, 2], [2, -1]])
g_q1 = np.array([0, 6])

In [None]:
is_invertible(Q_q1)

In [None]:
x_q1 = stationary_point(Q_q1, g_q1)
x_q1

In [None]:
Q_q2 = np.array([[2, 2, 0], [2, 4, 3], [0, 3, 1]])
g_q2 = np.array([4, -8, 0])

In [None]:
is_invertible(Q_q2)

In [None]:
x_q2 = stationary_point(Q_q2, g_q2)
x_q2

## Problem 2.3

In [None]:
eigenvalues_q1 = np.linalg.eigvals(Q_q1)
eigenvalues_q2 = np.linalg.eigvals(Q_q2)

In [None]:
eigenvalues_q1

In [None]:
eigenvalues_q2

## Problem 3

In [None]:
def f(x):
    x1, x2 = x
    return jnp.array([jnp.cos(x1) + x2, (x1 + x2) ** 2])


def g(y):
    y1, y2 = y
    return y1**2 + y1 * y2


def h(x):
    return g(f(x))

In [None]:
grad_h = jax.grad(h)

In [None]:
x_val = jnp.array([1.0, 2.0])  # Example input
gradient = grad_h(x_val)

print("Gradient of h(x) at x =", x_val, "is", gradient)

In [None]:
# Manual
def gradient_h_component1(x1, x2):
    term1 = -np.sin(x1) * (2 * (np.cos(x1) + x2) + (x1 + x2) ** 2)
    term2 = 2 * (x1 + x2) * (np.cos(x1) + x2)
    return term1 + term2


def gradient_h_component2(x1, x2):
    term1 = 2 * (np.cos(x1) + x2) + (x1 + x2) ** 2
    term2 = 2 * (x1 + x2) * (np.cos(x1) + x2)
    return term1 + term2


def gradient_h(x):
    x1, x2 = x
    grad1 = gradient_h_component1(x1, x2)
    grad2 = gradient_h_component2(x1, x2)
    return np.array([grad1, grad2])

In [None]:
x_val = np.array([1.0, 2.0])
gradient = gradient_h(x_val)
gradient