In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import PillowWriter
import matplotlib.animation as animation

# Define the Rosenbrock function
# General form: (a-x)**2 + b(y-x**2)**2
# Minimum at (a, a**2)
# Here: a = 1, b = 100
def L(theta):
    x, y = theta
    b = 10
    return (1 - x)**2 + b * (y - x**2)**2

def grad_L(theta):
    x, y = theta
    b = 10
    grad_x = -2 * (1 - x) - 4 * b * x * (y - x**2)
    grad_y = 200 * (y - x**2)
    return np.array([grad_x, grad_y])

# Gradient descent parameters
alpha = 0.003
theta = np.array([-1.3, 0.5])  # initial point
iterations = 200

# Save theta values and losses
theta_values = [theta.copy()]
loss_values = [L(theta)]

for _ in range(iterations):
    gradient = grad_L(theta)
    theta -= alpha * gradient
    theta_values.append(theta.copy())
    loss_values.append(L(theta))

# Convert to NumPy arrays for consistent slicing
theta_values = np.array(theta_values)
loss_values = np.array(loss_values)

# Prepare data for 3D plot
x = np.linspace(-2, 2, 400)
y = np.linspace(-1, 3, 400)
X, Y = np.meshgrid(x, y)
Z = L([X, Y])

# Create 3D visualization
fig = plt.figure(figsize=(10, 7))
fig.subplots_adjust(left=0.25, bottom=0, right=0.75, top=1, wspace=None, hspace=None)
ax = fig.add_subplot(111, projection='3d')


# Plot the surface
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)

# Plot the gradient descent path
path, = ax.plot([], [], [], 'r-', lw=2)
scatter, = ax.plot([], [], [], 'ro')

# Animation update function
def update(frame):
    # Get the path and current point
    current_values = theta_values[:frame+1]
    current_losses = loss_values[:frame+1]
    
    # Set the path (x, y, z)
    path.set_data(current_values[:, 0], current_values[:, 1])
    path.set_3d_properties(current_losses)
    
    # Set the scatter point (latest point in the path)
    scatter.set_data([current_values[-1, 0]], [current_values[-1, 1]])
    scatter.set_3d_properties([current_losses[-1]])
    
    return path, scatter

# Set plot limits and labels
ax.set_xlim([-2, 2])
ax.set_ylim([-1, 2.5])
ax.set_zlim([0, 250])
ax.set_xlabel(r"$\theta_1$")
ax.set_ylabel(r"$\theta_2$")
ax.set_zlabel("L")

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=len(theta_values), blit=False)

# Save as a GIF

ani.save('gradient_descent_3d_rosenbrock.gif', writer=PillowWriter(fps=50))
plt.show()
