In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D


# Two-hump camel function definition
def camel(x, y):
    return 2 * x**2 - 1.05 * x**4 + (x**6) / 6 + x * y + y**2


# Gradient of the two-hump camel function
def gradient_camel(params):
    x, y = params
    dx = 4 * x - 4.2 * x**3 + x**5 + y
    dy = x + 2 * y
    return np.array([dx, dy])


# Optimizer functions
def gradient_descent(params, grads, lr):
    return params - lr * grads


def rmsprop(params, grads, cache, lr=0.01, beta=0.9, epsilon=1e-8):
    cache = beta * cache + (1 - beta) * (grads**2)
    params = params - lr * grads / (np.sqrt(cache) + epsilon)
    return params, cache


def adam(params, grads, m, v, t, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    m = beta1 * m + (1 - beta1) * grads
    v = beta2 * v + (1 - beta2) * (grads**2)
    m_hat = m / (1 - beta1**t)
    v_hat = v / (1 - beta2**t)
    params = params - lr * m_hat / (np.sqrt(v_hat) + epsilon)
    return params, m, v


# Function to plot the loss surface
def plot_surface(ax):
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    X, Y = np.meshgrid(x, y)
    Z = camel(X, Y)
    ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.6, edgecolor="none")
    return ax


# Function to run the optimizer and animate
def animate_optimizers(
    optimizer="Gradient Descent",
    learning_rate=0.1,
    decay_rate=0.9,
    params_X=1.5,
    params_Y=1.5,
):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax = plot_surface(ax)

    # Initial parameters
    params = np.array([params_X, params_Y])

    # Path lists for storing points to animate
    path_x, path_y, path_z = [], [], []

    # Initial values for RMSProp and Adam
    cache_rmsprop = np.zeros_like(params)
    m_adam = np.zeros_like(params)
    v_adam = np.zeros_like(params)

    # Number of iterations
    num_iterations = 50

    for t in range(1, num_iterations + 1):
        # Compute the gradient
        grads = gradient_camel(params)

        # Update parameters based on optimizer choice
        if optimizer == "Gradient Descent":
            params = gradient_descent(params, grads, learning_rate)
        elif optimizer == "RMSProp":
            params, cache_rmsprop = rmsprop(
                params, grads, cache_rmsprop, lr=learning_rate, beta=decay_rate
            )
        elif optimizer == "Adam":
            params, m_adam, v_adam = adam(
                params,
                grads,
                m_adam,
                v_adam,
                t,
                lr=learning_rate,
                beta1=decay_rate,
                beta2=0.999,
            )

        # Store the path
        path_x.append(params[0])
        path_y.append(params[1])
        path_z.append(camel(params[0], params[1]))

        # Update the plot
        ax.plot(path_x, path_y, path_z, "ro-", markersize=5)

    plt.show()


# Interactive controls with ipywidgets
def run_interactive():
    optimizer_widget = widgets.Dropdown(
        options=["Gradient Descent", "RMSProp", "Adam"],
        value="Gradient Descent",
        description="Optimizer:",
    )

    learning_rate_widget = widgets.FloatSlider(
        value=0.1, min=0.001, max=1.0, step=0.01, description="Learning Rate:"
    )

    decay_rate_widget = widgets.FloatSlider(
        value=0.9, min=0.5, max=1.0, step=0.01, description="Decay Rate:"
    )
    init_param_X_widget = widgets.FloatSlider(
        value=-2, min=-2, max=2, step=0.01, description="Initial X"
    )
    init_param_Y_widget = widgets.FloatSlider(
        value=-2, min=-2, max=2, step=0.01, description="Initial Y"
    )

    ui = widgets.VBox(
        [
            optimizer_widget,
            learning_rate_widget,
            decay_rate_widget,
            init_param_X_widget,
            init_param_Y_widget,
        ]
    )

    out = widgets.interactive_output(
        animate_optimizers,
        {
            "optimizer": optimizer_widget,
            "learning_rate": learning_rate_widget,
            "decay_rate": decay_rate_widget,
            "params_X": init_param_X_widget,
            "params_Y": init_param_Y_widget,
        },
    )

    display(ui, out)


# Run the interactive plot
run_interactive()