[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bsc-life/ml4br-ml-course/blob/main/nbs/day_3/02_gradient_descent_demo.ipynb)


# Interactive demo for Gradient Descent

In [None]:
!pip install -q ipywidgets plotly
from google.colab import output
output.enable_custom_widget_manager()

This demo explores how different combinations of surfaces, starting points, and learning rates affect the behavior of gradient descent. We examine three types of functions commonly used in optimization research:

- **Quadratic Surface:** A simple convex function with smooth, symmetric gradients.
- **Rosenbrock Function:** A non-convex function with a narrow, curved valley—commonly used to test optimization algorithms.
- **Saddle Surface:** A function with a saddle point that illustrates challenges in escaping flat or unstable regions.

For each function, we simulate gradient descent from different initial positions and learning rates to highlight behaviors such as fast convergence, overshooting, slow progress, and oscillation. These examples demonstrate the importance of surface geometry and hyperparameter choice in optimization.


In [1]:
import numpy as np
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, VBox, Dropdown, HBox, BoundedFloatText
from IPython.display import display, clear_output

In [None]:
# Define surfaces
def quadratic(x, y):
    return x**2 + y**2

def rosenbrock(x, y):
    return (1 - x)**2 + 100*(y - x**2)**2

def saddle(x, y):
    return x**2 - y**2

# Gradient and surface definitions
surfaces = {
    "Quadratic": {
        "func": quadratic,
        "grad": lambda x, y: (2*x, 2*y),
        "range": (-2, 2),
        "lr": (0.001, 0.2)
    },
    "Rosenbrock": {
        "func": rosenbrock,
        "grad": lambda x, y: (
            -2*(1 - x) - 400*x*(y - x**2),
            200*(y - x**2)
        ),
        "range": (-1.5, 1.5),
        "lr": (0.0001, 0.05)
    },
    "Saddle": {
        "func": saddle,
        "grad": lambda x, y: (2*x, -2*y),
        "range": (-2, 2),
        "lr": (0.001, 0.1)
    }
}

# Optimization algorithms
def gradient_descent_step(x, y, grad_func, lr, **kwargs):
    dx, dy = grad_func(x, y)
    return x - lr * dx, y - lr * dy

def momentum_step(x, y, grad_func, lr, v_prev, beta=0.9, **kwargs):
    dx, dy = grad_func(x, y)
    v_x = beta * v_prev[0] + (1 - beta) * dx
    v_y = beta * v_prev[1] + (1 - beta) * dy
    x_new = x - lr * v_x
    y_new = y - lr * v_y
    return x_new, y_new, (v_x, v_y)

optimizers = {
    "Gradient Descent": gradient_descent_step,
    "Momentum": momentum_step
}

# Plot generation
def plot_surface_with_path(func, grad, optimizer_name, x0, y0, lr, surface_range, lr_range, max_steps=1000, tol=1e-6):
    x = np.linspace(surface_range[0], surface_range[1], 100)
    y = np.linspace(surface_range[0], surface_range[1], 100)
    X, Y = np.meshgrid(x, y)
    Z = func(X, Y)

    path_x, path_y, path_z = [x0], [y0], [func(x0, y0)]
    x_curr, y_curr = x0, y0
    optimizer = optimizers[optimizer_name]
    v = (0, 0)

    for _ in range(max_steps):
        if optimizer_name == "Momentum":
            x_new, y_new, v = optimizer(x_curr, y_curr, grad, lr, v_prev=v)
        else:
            x_new, y_new = optimizer(x_curr, y_curr, grad, lr)
        z_new = func(x_new, y_new)

        path_x.append(x_new)
        path_y.append(y_new)
        path_z.append(z_new)

        if abs(z_new - path_z[-2]) < tol:
            break

        x_curr, y_curr = x_new, y_new

        # Break on divergence
        if abs(z_new) > 1e6:
            break

    surface = go.Surface(x=X, y=Y, z=Z, colorscale='Viridis', opacity=0.7)
    path = go.Scatter3d(x=path_x, y=path_y, z=path_z,
                        mode='lines+markers',
                        marker=dict(size=3, color='red'),
                        line=dict(color='red', width=4))
    layout = go.Layout(
        scene=dict(
            xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.2))
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )
    fig = go.Figure(data=[surface, path], layout=layout)
    return fig

# UI widgets
surface_dropdown = Dropdown(options=list(surfaces.keys()), value='Quadratic', description='Surface')
optimizer_dropdown = Dropdown(options=list(optimizers.keys()), value='Gradient Descent', description='Optimizer')

x0_slider = FloatSlider(description='x₀', min=-2, max=2, step=0.001, readout_format='.4f')
y0_slider = FloatSlider(description='y₀', min=-2, max=2, step=0.001, readout_format='.4f')
lr_slider = FloatSlider(description='lr', min=0.001, max=0.2, step=0.0001, readout_format='.4f')

x0_text = BoundedFloatText(description='x₀', min=-2, max=2, step=0.001)
y0_text = BoundedFloatText(description='y₀', min=-2, max=2, step=0.001)
lr_text = BoundedFloatText(description='lr', min=0.001, max=0.2, step=0.0001)


# Sync sliders and text inputs
def link_slider_text(slider, text):
    slider.observe(lambda c: setattr(text, 'value', c['new']), names='value')
    text.observe(lambda c: setattr(slider, 'value', c['new']), names='value')

link_slider_text(x0_slider, x0_text)
link_slider_text(y0_slider, y0_text)
link_slider_text(lr_slider, lr_text)

# Output area
from IPython.display import display
import ipywidgets as widgets
plot_output = widgets.Output()

def update_plot(*args):
    surface_info = surfaces[surface_dropdown.value]
    func = surface_info["func"]
    grad = surface_info["grad"]
    rng = surface_info["range"]
    lr_rng = surface_info["lr"]

    x0_value = x0_text.value
    y0_value = y0_text.value
    lr_value = lr_text.value

    fig = plot_surface_with_path(
        func=func,
        grad=grad,
        optimizer_name=optimizer_dropdown.value,
        x0=x0_value,
        y0=y0_value,
        lr=lr_value,
        surface_range=rng,
        lr_range=lr_rng
    )

    plot_output.clear_output(wait=True)
    with plot_output:
        display(fig)


# Reactivity
x0_slider.observe(update_plot, names='value')
y0_slider.observe(update_plot, names='value')
lr_slider.observe(update_plot, names='value')
optimizer_dropdown.observe(update_plot, names='value')

# Reset on surface change
def get_surface_params(surface_name):
    if surface_name == "Quadratic":
        return -2, 2, -2, 2, 0.001, 0.2
    elif surface_name == "Saddle":
        return -2, 2, -2, 2, 0.001, 0.1
    elif surface_name == "Rosenbrock":
        return -2, 2, -1, 3, 0.0001, 0.01
    elif surface_name == "Rastrigin":
        return -5.12, 5.12, -5.12, 5.12, 0.0005, 0.05
    else:
        # default fallback
        return -2, 2, -2, 2, 0.001, 0.1

# Reset sliders and update plot when surface changes
def on_surface_change(change):
    surface_name = change['new']
    x_min, x_max, y_min, y_max, lr_min, lr_max = get_surface_params(surface_name)

    # Update slider ranges
    x0_slider.min, x0_slider.max = x_min, x_max
    y0_slider.min, y0_slider.max = y_min, y_max
    lr_slider.min, lr_slider.max = lr_min, lr_max

    # Set default values depending on surface
    if surface_name == "Rosenbrock":
        x0_slider.value = -1.2
        y0_slider.value = 1.0
        lr_slider.value = 0.002
    elif surface_name == "Quadratic":
        x0_slider.value = 1.0
        y0_slider.value = -1.0
        lr_slider.value = 0.1
    elif surface_name == "Saddle":
        x0_slider.value = -1.0
        y0_slider.value = 1.0
        lr_slider.value = 0.05
    elif surface_name == "Rastrigin":
        x0_slider.value = 0.5
        y0_slider.value = 0.5
        lr_slider.value = 0.01
    else:
        x0_slider.value = (x_min + x_max) / 2
        y0_slider.value = (y_min + y_max) / 2
        lr_slider.value = (lr_min + lr_max) / 2

    update_plot()  # Redraw the plot with new values

surface_dropdown.observe(on_surface_change, names='value')

# Layout
ui = VBox([
    HBox([surface_dropdown, optimizer_dropdown]),
    HBox([x0_slider, x0_text]),
    HBox([y0_slider, y0_text]),
    HBox([lr_slider, lr_text]),
    plot_output
])

display(ui)
update_plot()


VBox(children=(HBox(children=(Dropdown(description='Surface', options=('Quadratic', 'Rosenbrock', 'Saddle'), v…

## ✅ Quadratic Surface (`f(x, y) = x² + y²`)

| Starting Point (x₀, y₀) | Learning Rate | Behavior                 | Explanation                                                                 |
|-------------------------|---------------|--------------------------|-----------------------------------------------------------------------------|
| (1.0, -1.0)             | 0.1           | Fast convergence         | Standard convex bowl; large step size moves directly downhill and converges quickly. |
| (1.5, 1.5)              | 0.2           | Overshoot + fast convergence | At the edge of a stable LR; shows slight overshooting but still converges fast due to symmetry and curvature. |

---

## 🌀 Rosenbrock Function (`f(x, y) = (1 - x)² + 100(y - x²)²`)

| Starting Point (x₀, y₀) | Learning Rate | Behavior                   | Explanation                                                                 |
|-------------------------|---------------|----------------------------|-----------------------------------------------------------------------------|
| (-1.2, 1.0)             | 0.002         | Slow convergence along valley | Classical starting point. Learns slowly due to narrow, curved valley of the function. Good for visualizing zigzag motion. |
| (-1.2, 1.0)             | 0.0005        | Even slower, more stable   | Tiny steps prevent divergence, but the optimizer takes a long time. Highlights need for adaptive or momentum-based methods. |

---

## ♾️ Saddle Surface (`f(x, y) = x² - y²`)

| Starting Point (x₀, y₀) | Learning Rate | Behavior             | Explanation                                                                 |
|-------------------------|---------------|----------------------|-----------------------------------------------------------------------------|
| (-1.0, 1.0)             | 0.05          | Saddle oscillation   | Starts near the saddle; optimization moves in x-direction but oscillates or diverges in y-direction unless LR is low. |
| (0.5, 0.5)              | 0.01          | Drifts toward x-axis | Lower learning rate avoids divergence. Demonstrates how gradients behave near saddle points — attracts along one axis, repels along the other. |
