In [1]:
%matplotlib inline

In [2]:
import jax
import numpy as np
import seaborn as sns
import jax.numpy as jnp

import matplotlib.pyplot as plt
import ipywidgets as widgets

from IPython.display import display
from ipywidgets import VBox, HBox, interactive_output

from jax import random, grad

In [3]:
def f_xy_selection(name):
    if name == "sqrt(abs(sin(x) + cos(y)))":
        return lambda x, y: jnp.sqrt(jnp.abs(jnp.sin(x) + jnp.cos(y)))
    elif name == "x^2 + y^2":
        return lambda x, y: x**2 + y**2
    elif name == "x^2 + y^2 + 4":
        return lambda x, y: x**2 + y**2 + 4
    elif name == "Three-Hump Camel Function":
        return lambda x, y: 2 * x**2 - 1.05 * x**4 + x**6 / 6 + x * y + y**2
    elif name == "Matyas Function":
        return lambda x, y: 0.26 * (x**2 + y**2) - 0.48 * x * y
    elif name == "Bukin Function No. 6":
        return lambda x, y: 100 * jnp.sqrt(jnp.abs(y - 0.01 * x**2)) + 0.01 * jnp.abs(x + 10)
    elif name == "Ackley's Function":
        return lambda x, y: -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2))) - jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y))) + jnp.e + 20
    elif name == "Easom Function":
        return lambda x, y: -jnp.cos(x) * jnp.sin(y) * jnp.exp(-((x-jnp.pi)**2 + (y - jnp.pi)**2))
    elif name == "McCormick Function":
        return lambda x, y: jnp.sin(x + y) + (x - y)**2 - 1.5*x + 2.5*y + 1
    elif name == "Drop-Wave Function":
        return lambda x, y: -1 * (1 + jnp.cos(12 * jnp.sqrt(x**2 + y**2))) / (0.5 * (x**2 + y**2) + 2)
    else:
        raise ValueError("Function not found")


def update(lr, x, y, f_fn):
    f_dx = grad(f_fn, argnums=(0, 1))
    
    dx, dy = f_dx(x, y)
    
    x = x - lr * dx
    y = y - lr * dy
    
    return x, y


def show_gradient_desc_process(
        lr, 
        iterations,
        x_range: tuple,
        y_range: tuple,
        number_of_points: int = 500,
        seed: int = 69,
        function_name: str = "sqrt(abs(sin(x) + cos(y)))"
):
    x = np.linspace(x_range[0], x_range[1], number_of_points)
    y = np.linspace(y_range[0], y_range[1], number_of_points)
    
    selected_fn = f_xy_selection(function_name)
    
    X, Y = np.meshgrid(x, y)
    Z = selected_fn(X, Y)

    key = random.key(seed)
    key, subkey = random.split(key)
    
    x_rnd = random.uniform(key, minval=x_range[0], maxval=x_range[1])
    y_rnd = random.uniform(subkey, minval=y_range[0], maxval=y_range[1])
        
    learing_points = [(x_rnd, y_rnd), ]
    learning_points_f = [selected_fn(x_rnd, y_rnd), ]

    for i in range(iterations):
        x_rnd, y_rnd = update(lr, x_rnd, y_rnd, selected_fn)
        
        x_rnd = jnp.clip(x_rnd, x_range[0], x_range[1])
        y_rnd = jnp.clip(y_rnd, y_range[0], y_range[1])
        
        learing_points.append((x_rnd, y_rnd))
        learning_points_f.append(selected_fn(x_rnd, y_rnd))
    
    lp_x, lp_y = [], []
    for point_idx, (learning_process_x, learning_process_y) in enumerate(learing_points):
        lp_x.append(learning_process_x)
        lp_y.append(learning_process_y)

    animation_fig = plt.figure(figsize=(16, 6))
    surface_ax = animation_fig.add_subplot(121, projection='3d')
    contour_ax = animation_fig.add_subplot(122)
    
    _countour = surface_ax.plot_surface(X, Y, Z, linewidth=0, cmap="viridis", rstride=1, cstride=1)    
    _countour = contour_ax.contourf(X, Y, Z, cmap="viridis")
    
    line, = surface_ax.plot(lp_x, lp_y, learning_points_f, color='orange', marker='x')
    
    line_start, = contour_ax.plot([lp_x[0]], [lp_y[0]], color='cyan', marker='o')
    line_process, = contour_ax.plot(lp_x[1:-1], lp_y[1:-1], color='orange', marker='x')
    line_end, = contour_ax.plot([lp_x[-1]], [lp_y[-1]], color='red', marker='o')
    
    line_surface_start, = surface_ax.plot(
        [lp_x[0]],
        [lp_y[0]],
        [learning_points_f[0]],
        color="cyan",
        marker="o",
        label="start point"
    )
    line_surface_process, = surface_ax.plot(
        lp_x[1:-1],
        lp_y[1:-1],
        learning_points_f[1:-1],
        color="orange",
        marker="x",
        label="gradient point"
    )
    
    line_surface_end, = surface_ax.plot(
        [lp_x[-1]],
        [lp_y[-1]],
        [learning_points_f[-1]],
        color="red",
        marker="o",
        label="end point"
    )
    
    surface_ax.legend()
    plt.show()

In [5]:
CHOOSE_FUNCTION_INPUT = widgets.Dropdown(
    options=[
        "sqrt(abs(sin(x) + cos(y)))",
        "x^2 + y^2",
        "x^2 + y^2 + 4",
        "Three-Hump Camel Function",
        "Matyas Function",
        "Bukin Function No. 6",
        "Ackley's Function",
        "Drop-Wave Function"
    ],
    value="Ackley's Function",
    description='Function:',
    disabled=False    
)


LEARNING_RATE_INPUT = widgets.BoundedFloatText(
    value=1e-3,
    min=1e-6,
    max=100.0,
    step=1e-3,
    description='Learning Rate:',
    disabled=False,
    style={'description_width': 'initial'}
)

ITERATIONS_SLIDER = widgets.IntSlider(
    value=500,
    min=10,
    max=5000,
    step=10,
    description='Iterations:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style={'description_width': 'initial'}
)

X_MIN_MAX_RANGE_SLIDER = widgets.FloatRangeSlider(
    value=[-10, 10],
    min=-100,
    max=100,
    step=1,
    description='X Range:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style={'description_width': 'initial'}
)

Y_MIN_MAX_RANGE_SLIDER = widgets.FloatRangeSlider(
    value=[-10, 10],
    min=-100,
    max=100,
    step=1,
    description='Y Range:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style={'description_width': 'initial'}
)

NUMBER_OF_POINT_TO_SAMPLE_SLIDER = widgets.IntSlider(
    value=500,
    min=100,
    max=1000,
    step=10,
    description="Sample points:",
    style={'description_width': 'initial'}
)

SEED_SLIDER = widgets.IntSlider(
    value=42,
    min=1,
    max=10000,
    step=1,
    description="Seed:",
    style={'description_width': 'initial'}    
)

gradient_display_out = interactive_output(
    show_gradient_desc_process,
    {
        "lr": LEARNING_RATE_INPUT,
        "iterations": ITERATIONS_SLIDER,
        "x_range": X_MIN_MAX_RANGE_SLIDER,
        "y_range": Y_MIN_MAX_RANGE_SLIDER,
        "number_of_points": NUMBER_OF_POINT_TO_SAMPLE_SLIDER,
        "seed": SEED_SLIDER,
        "function_name": CHOOSE_FUNCTION_INPUT
    }
)

display(
    HBox([
        VBox([
            CHOOSE_FUNCTION_INPUT,        
            LEARNING_RATE_INPUT,
            ITERATIONS_SLIDER], layout=widgets.Layout(width='100%')),
        VBox([NUMBER_OF_POINT_TO_SAMPLE_SLIDER,
            X_MIN_MAX_RANGE_SLIDER,
            Y_MIN_MAX_RANGE_SLIDER,
            SEED_SLIDER  
        ], layout=widgets.Layout(width='100%'))
    ]),     gradient_display_out
)

HBox(children=(VBox(children=(Dropdown(description='Function:', index=6, options=('sqrt(abs(sin(x) + cos(y)))'…

Output()