# Gradient Descent and Convexity: Advanced Visual Storytelling

This notebook upgrades the visuals with interactive Plotly charts, animated descent paths, and a 3D loss surface. The goal is to make convexity, local minima, and learning-rate stability feel tangible.


## Setup
We use Plotly for interactive, high-resolution visuals. If you open this in Jupyter, the figures are fully interactive (pan, zoom, rotate).


In [10]:
import numpy as np
import plotly.graph_objects as go
import plotly.figure_factory as ff
import plotly.io as pio
from plotly.subplots import make_subplots

pio.templates.default = "plotly_white"

from IPython.display import HTML, display

def show_plot(fig, use_cdn=True):
    """Render Plotly figures reliably in classic Jupyter Notebook."""
    html = fig.to_html(include_plotlyjs="cdn" if use_cdn else True, full_html=False)
    display(HTML(html))

PALETTE = {
    "convex": "#1f77b4",
    "nonconvex": "#d62728",
    "path_left": "#ff7f0e",
    "path_right": "#2ca02c",
    "grid": "#7f7f7f",
}

def f_convex(x):
    return x**2

def f_non_convex(x):
    return x**4 - 2 * x**2 + 0.2 * x

def d_non_convex(x):
    return 4 * x**3 - 4 * x + 0.2

def gradient_descent_1d(start_x, learning_rate, n_steps, grad_func):
    x = start_x
    history = [x]
    for _ in range(n_steps):
        x = x - learning_rate * grad_func(x)
        history.append(x)
    return np.array(history)

def f_2d(x, y):
    return x**2 + 1.5 * y**2

def grad_f_2d(x, y):
    return 2 * x, 3 * y

def gradient_descent_2d(start_x, start_y, lr, steps, max_radius=6.0):
    path = [(start_x, start_y)]
    x, y = start_x, start_y
    for _ in range(steps):
        gx, gy = grad_f_2d(x, y)
        x -= lr * gx
        y -= lr * gy
        path.append((x, y))
        if x**2 + y**2 > max_radius**2:
            break
    return np.array(path)


## 1. Convex vs. Non-Convex Landscapes (1D)
Convex functions have a single basin; non-convex functions have multiple basins and local minima.


In [11]:
x = np.linspace(-2.2, 2.2, 400)
y_convex = f_convex(x)
y_non_convex = f_non_convex(x)

coeffs = [4, 0, -4, 0.2]
roots = np.roots(coeffs)
real_roots = roots[np.isclose(roots.imag, 0)].real
mins = []
for r in real_roots:
    if 12 * r**2 - 4 > 0:
        mins.append(r)
mins = sorted(mins)
min_values = [(r, f_non_convex(r)) for r in mins]
global_min = min(min_values, key=lambda item: item[1]) if min_values else None

fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=("Convex: single basin", "Non-Convex: multiple basins")
)

fig.add_trace(
    go.Scatter(x=x, y=y_convex, mode="lines", line=dict(color=PALETTE["convex"], width=3), name="f(x)=x^2"),
    row=1, col=1,
)
fig.add_trace(
    go.Scatter(x=x, y=y_non_convex, mode="lines", line=dict(color=PALETTE["nonconvex"], width=3), name="f(x)=x^4-2x^2+0.2x"),
    row=1, col=2,
)

fig.add_trace(
    go.Scatter(x=[0], y=[f_convex(0)], mode="markers+text", text=["Global min"], textposition="top center",
               marker=dict(size=10, color=PALETTE["convex"])) ,
    row=1, col=1,
)

for r, y in min_values:
    label = "Global min" if global_min and np.isclose(r, global_min[0]) else "Local min"
    fig.add_trace(
        go.Scatter(x=[r], y=[y], mode="markers+text", text=[label], textposition="top center",
                   marker=dict(size=10, color=PALETTE["nonconvex"])) ,
        row=1, col=2,
    )

fig.update_xaxes(title_text="Parameter theta")
fig.update_yaxes(title_text="Loss J(theta)")
fig.update_layout(height=420, width=1000, showlegend=False)
show_plot(fig)


## 2. Animated Gradient Descent on a Non-Convex Curve
Two hikers start from different points. Each step is a local slope decision, which can lock the path into different valleys.


In [12]:
learning_rate = 0.05
steps = 10
frame_duration_ms = 500

path_left = gradient_descent_1d(-1.8, learning_rate, steps, d_non_convex)
path_right = gradient_descent_1d(1.8, learning_rate, steps, d_non_convex)

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=x, y=y_non_convex, mode="lines", line=dict(color=PALETTE["grid"], width=2), name="Loss surface")
)
fig.add_trace(
    go.Scatter(x=path_left, y=f_non_convex(path_left), mode="lines", line=dict(color=PALETTE["path_left"], width=2, dash="dot"), name="Left path")
)
fig.add_trace(
    go.Scatter(x=path_right, y=f_non_convex(path_right), mode="lines", line=dict(color=PALETTE["path_right"], width=2, dash="dot"), name="Right path")
)

fig.add_trace(
    go.Scatter(x=[path_left[0]], y=[f_non_convex(path_left[0])], mode="markers", marker=dict(size=10, color=PALETTE["path_left"]), name="Hiker left")
)
fig.add_trace(
    go.Scatter(x=[path_right[0]], y=[f_non_convex(path_right[0])], mode="markers", marker=dict(size=10, color=PALETTE["path_right"]), name="Hiker right")
)

frames = []
for t in range(steps + 1):
    frames.append(
        go.Frame(
            data=[
                go.Scatter(x=[path_left[t]], y=[f_non_convex(path_left[t])]),
                go.Scatter(x=[path_right[t]], y=[f_non_convex(path_right[t])]),
            ],
            traces=[3, 4],
            name=str(t),
        )
    )

fig.frames = frames

fig.update_layout(
    title="Gradient Descent: two starting points, two outcomes",
    xaxis_title="Parameter theta",
    yaxis_title="Loss J(theta)",
    height=500,
    updatemenus=[
        {
            "type": "buttons",
            "showactive": False,
            "x": 0.05,
            "y": 1.15,
            "buttons": [
                {
                    "label": "Play",
                    "method": "animate",
                    "args": [None, {"frame": {"duration": frame_duration_ms, "redraw": True}, "fromcurrent": True}],
                },
                {
                    "label": "Pause",
                    "method": "animate",
                    "args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}],
                },
            ],
        }
    ],
    sliders=[
        {
            "currentvalue": {"prefix": "Step: "},
            "steps": [
                {
                    "args": [[str(t)], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                    "label": str(t),
                    "method": "animate",
                }
                for t in range(steps + 1)
            ],
        }
    ],
)
show_plot(fig)


## 3. 2D Contours + Gradient Field
We view a 2D convex function from above. Contours are level sets; arrows show the descent direction (negative gradient).


In [13]:
grid = np.linspace(-3, 3, 140)
X, Y = np.meshgrid(grid, grid)
Z = f_2d(X, Y)

path = gradient_descent_2d(-2.5, 2.0, lr=0.12, steps=25)

q = np.linspace(-3, 3, 15)
Xq, Yq = np.meshgrid(q, q)
U, V = grad_f_2d(Xq, Yq)

quiver_fig = ff.create_quiver(
    Xq, Yq, -U, -V,
    scale=0.25, arrow_scale=0.3,
    line=dict(color="rgba(50,50,50,0.45)")
)

fig = go.Figure()
fig.add_trace(
    go.Contour(x=grid, y=grid, z=Z, colorscale="Greys", contours=dict(showlabels=False), opacity=0.55, showscale=False)
)
for tr in quiver_fig.data:
    fig.add_trace(tr)

fig.add_trace(
    go.Scatter(x=path[:, 0], y=path[:, 1], mode="lines+markers",
               line=dict(color=PALETTE["path_right"], width=3),
               marker=dict(size=6, color=PALETTE["path_right"]), name="Descent path")
)

fig.update_layout(
    title="Contours with Gradient Field",
    xaxis_title="Parameter 1",
    yaxis_title="Parameter 2",
    height=520,
    xaxis=dict(scaleanchor="y", scaleratio=1),
)
show_plot(fig)


## 4. 3D Loss Surface + Descent Path
This reveals how a 2D parameter update becomes a 3D trajectory on the loss surface.


In [15]:
path_z = f_2d(path[:, 0], path[:, 1])

fig = go.Figure()
fig.add_trace(
    go.Surface(x=grid, y=grid, z=Z, colorscale="Viridis", opacity=0.9, showscale=False)
)
fig.add_trace(
    go.Scatter3d(x=path[:, 0], y=path[:, 1], z=path_z, mode="lines+markers",
                 line=dict(color="#ffffff", width=6),
                 marker=dict(size=4, color="#ffffff"), name="Descent path")
)

fig.update_layout(
    title="3D Loss Surface",
    scene=dict(
        xaxis_title="Parameter 1",
        yaxis_title="Parameter 2",
        zaxis_title="Loss",
        camera=dict(eye=dict(x=1.4, y=1.3, z=0.8)),
    ),
    height=560,
)
show_plot(fig)


## 5. Learning Rate Stability
Small learning rates move safely but slowly. Oversized learning rates can overshoot and oscillate.


In [16]:
paths = {
    "Small LR (0.02)": gradient_descent_2d(-2.5, 2.0, lr=0.02, steps=40),
    "Balanced LR (0.12)": gradient_descent_2d(-2.5, 2.0, lr=0.12, steps=40),
    "Large LR (0.6)": gradient_descent_2d(-2.5, 2.0, lr=0.6, steps=15),
}
colors = {
    "Small LR (0.02)": "#1f77b4",
    "Balanced LR (0.12)": "#2ca02c",
    "Large LR (0.6)": "#d62728",
}

fig = go.Figure()
fig.add_trace(
    go.Contour(x=grid, y=grid, z=Z, colorscale="Greys", contours=dict(showlabels=False), opacity=0.5, showscale=False)
)

for name, p in paths.items():
    fig.add_trace(
        go.Scatter(x=p[:, 0], y=p[:, 1], mode="lines+markers",
                   line=dict(color=colors[name], width=3),
                   marker=dict(size=5, color=colors[name]), name=name)
    )

fig.add_trace(
    go.Scatter(x=[0], y=[0], mode="markers", marker=dict(size=10, color="black", symbol="x"), name="Global min")
)

fig.update_layout(
    title="Learning Rate Comparison",
    xaxis_title="Parameter 1",
    yaxis_title="Parameter 2",
    height=520,
    xaxis=dict(scaleanchor="y", scaleratio=1),
)
show_plot(fig)
