In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
===========================================================
 Gradient Descent Surface GUI
===========================================================

An interactive visualization of gradient descent on common 
2D objective functions including Rosenbrock, Quadratic, 
Himmelblau, and a multi-minima quadratic surface. 

Users can:
- Visualize both 3D surface and 2D contour plots.
- Adjust learning rate, gradient noise, and step speed.
- Interactively run, pause, or step through optimization.
- Observe convergence behavior in real time.

-----------------------------------------------------------
 Author:  Burak Demirel
 Title:   MRI Clinical Scientist
 Company: Philips Healthcare
 Contact: burak.demirel@philips.com
-----------------------------------------------------------

Run with:
    python gradient_descent_gui.py

Requirements:
    - numpy
    - matplotlib
    - tkinter (built-in)

Tested on:
    Python 3.9+, Windows/Linux/macOS

-----------------------------------------------------------
"""

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import tkinter as tk
from tkinter import ttk

# ==========================
# Objective functions & grads
# ==========================

def rosenbrock(xy, a=1.0, b=100.0):
    x, y = xy[..., 0], xy[..., 1]
    f = (a - x)**2 + b * (y - x**2)**2
    dfdx = -2*(a - x) - 4*b*x*(y - x**2)
    dfdy = 2*b*(y - x**2)
    return f, np.stack([dfdx, dfdy], axis=-1)

def quadratic(xy, A=None, bvec=None, c=0.0):
    # f(x) = 0.5 (x^T A x) + b^T x + c
    if A is None:
        A = np.array([[3.0, 1.0],[1.0, 2.0]])
    if bvec is None:
        bvec = np.array([-1.0, 0.5])
    x = xy[..., 0]
    y = xy[..., 1]
    X = np.stack([x, y], axis=-1)
    AX = X @ A
    f = 0.5*np.sum(X*AX, axis=-1) + np.sum(bvec*X, axis=-1) + c
    grad = AX + bvec
    return f, grad

def himmelblau(xy):
    x, y = xy[..., 0], xy[..., 1]
    f = (x**2 + y - 11)**2 + (x + y**2 - 7)**2
    dfdx = 4*x*(x**2 + y - 11) + 2*(x + y**2 - 7)
    dfdy = 2*(x**2 + y - 11) + 4*y*(x + y**2 - 7)
    return f, np.stack([dfdx, dfdy], axis=-1)

def multi_minima_quadratic(xy):
    """Quadratic with sinusoidal bumps to create multiple minima."""
    x, y = xy[..., 0], xy[..., 1]
    f = 0.5 * (x**2 + y**2) + 0.3 * np.sin(3*x) * np.sin(3*y)
    dfdx = x + 0.9 * np.cos(3*x) * np.sin(3*y)
    dfdy = y + 0.9 * np.sin(3*x) * np.cos(3*y)
    return f, np.stack([dfdx, dfdy], axis=-1)

FUNCTIONS = {
    "Rosenbrock": rosenbrock,
    "Quadratic": quadratic,
    "Multi-minima Quadratic": multi_minima_quadratic,
    "Himmelblau": himmelblau,
}

# ==========================
# Gradient Descent Engine
# ==========================
class GDEngine:
    def __init__(self, func_name="Rosenbrock", lr=0.001, noise=0.0):
        self.func_name = func_name
        self.f = FUNCTIONS[func_name]
        self.lr = lr
        self.noise = noise
        self.reset_state()

    def reset_state(self, x0=None):
        self.t = 0
        self.x = np.array([-1.5, 1.8]) if x0 is None else np.array(x0, dtype=float)
        self.history = [self.x.copy()]

    def step(self):
        _, grad = self.f(self.x[None, :])
        g = grad[0]
        if self.noise > 0:
            g = g + np.random.randn(*g.shape) * self.noise
        self.x = self.x - self.lr * g
        self.t += 1
        self.history.append(self.x.copy())
        return self.x.copy(), g.copy()

# ==========================
# GUI
# ==========================
class GDGUI:
    def __init__(self, root):
        self.root = root
        root.title("Gradient Descent on a Surface")

        # State
        self.running = False
        self.engine = GDEngine()

        # Layout frames
        self.ctrl_frame = ttk.Frame(root, padding=8)
        self.ctrl_frame.pack(side=tk.LEFT, fill=tk.Y)

        self.fig = plt.Figure(figsize=(8.8, 5.6), dpi=100)
        self.ax3d = self.fig.add_subplot(1, 2, 1, projection='3d')
        self.ax2d = self.fig.add_subplot(1, 2, 2)
        self.canvas = FigureCanvasTkAgg(self.fig, master=root)
        self.canvas.get_tk_widget().pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        self._build_controls()
        self._build_surface()
        self._redraw_all()

    # ---------- Controls ----------
    def _build_controls(self):
        # Function selector
        ttk.Label(self.ctrl_frame, text="Function", font=("Segoe UI", 10, "bold")).pack(anchor=tk.W)
        self.func_var = tk.StringVar(value="Rosenbrock")
        func_menu = ttk.Combobox(
            self.ctrl_frame,
            textvariable=self.func_var,
            values=list(FUNCTIONS.keys()),
            state="readonly"
        )
        func_menu.pack(fill=tk.X)
        func_menu.bind("<<ComboboxSelected>>", lambda e: self.on_change_function())

        # Learning rate
        ttk.Label(self.ctrl_frame, text="Learning rate").pack(anchor=tk.W, pady=(8, 0))
        self.lr_var = tk.DoubleVar(value=0.001)
        ttk.Scale(
            self.ctrl_frame,
            from_=1e-5, to=1e-0,
            orient=tk.HORIZONTAL,
            variable=self.lr_var,
            command=lambda v: self.on_change_lr()
        ).pack(fill=tk.X)
        # Typing in Entry doesn't trigger on_change_lr, so we sync on actions too.
        self.lr_entry = ttk.Entry(self.ctrl_frame, textvariable=self.lr_var)
        self.lr_entry.pack(fill=tk.X, pady=(2, 0))

        # Gradient noise
        ttk.Label(self.ctrl_frame, text="Gradient noise (std)").pack(anchor=tk.W, pady=(8, 0))
        self.noise_var = tk.DoubleVar(value=0.0)
        ttk.Scale(
            self.ctrl_frame,
            from_=0.0, to=1.0,
            orient=tk.HORIZONTAL,
            variable=self.noise_var,
            command=lambda v: self.on_change_noise()
        ).pack(fill=tk.X)
        self.noise_entry = ttk.Entry(self.ctrl_frame, textvariable=self.noise_var)
        self.noise_entry.pack(fill=tk.X, pady=(2, 0))

        # Start point
        ttk.Label(self.ctrl_frame, text="Start x, y").pack(anchor=tk.W, pady=(8, 0))
        sp_frame = ttk.Frame(self.ctrl_frame)
        sp_frame.pack(fill=tk.X)
        self.x0_var = tk.DoubleVar(value=-1.5)
        self.y0_var = tk.DoubleVar(value=1.8)
        ttk.Entry(sp_frame, textvariable=self.x0_var, width=8).pack(side=tk.LEFT, padx=(0, 4))
        ttk.Entry(sp_frame, textvariable=self.y0_var, width=8).pack(side=tk.LEFT)

        # Buttons
        btn_frame = ttk.Frame(self.ctrl_frame)
        btn_frame.pack(fill=tk.X, pady=(10, 0))
        self.start_btn = ttk.Button(btn_frame, text="Start", command=self.on_start)
        self.start_btn.pack(side=tk.LEFT, expand=True, fill=tk.X, padx=(0, 4))
        ttk.Button(btn_frame, text="Step", command=self.on_step).pack(side=tk.LEFT, expand=True, fill=tk.X, padx=4)
        self.pause_btn = ttk.Button(btn_frame, text="Pause", command=self.on_pause, state=tk.DISABLED)
        self.pause_btn.pack(side=tk.LEFT, expand=True, fill=tk.X, padx=(4, 0))

        ttk.Button(self.ctrl_frame, text="Reset", command=self.on_reset).pack(fill=tk.X, pady=(6, 0))

        # Status labels
        self.status = tk.StringVar(value="iter: 0 | x: [-1.50, 1.80] | f(x): ...")
        ttk.Separator(self.ctrl_frame, orient=tk.HORIZONTAL).pack(fill=tk.X, pady=8)
        ttk.Label(self.ctrl_frame, textvariable=self.status, wraplength=260).pack(anchor=tk.W)

        # Speed
        ttk.Label(self.ctrl_frame, text="Speed (ms per step)").pack(anchor=tk.W, pady=(8, 0))
        self.speed_var = tk.IntVar(value=60)
        ttk.Scale(self.ctrl_frame, from_=5, to=500, orient=tk.HORIZONTAL, variable=self.speed_var).pack(fill=tk.X)

    # ---------- Helpers ----------
    def sync_hyperparams(self):
        """Pull current UI values into the engine (handles Entry edits)."""
        self.engine.lr = float(self.lr_var.get())
        self.engine.noise = float(self.noise_var.get())

    # ---------- Surface ----------
    def _build_surface(self):
        ranges = {
            "Rosenbrock": (-4, 4, -4, 10),
            "Quadratic": (-5, 5, -5, 5),
            "Multi-minima Quadratic": (-6, 6, -6, 6),
            "Himmelblau": (-12, 12, -12, 12),
        }

        fx = self.func_var.get()
        xmin, xmax, ymin, ymax = ranges.get(fx, (-4, 4, -4, 4))
        self.grid_x = np.linspace(xmin, xmax, 200)
        self.grid_y = np.linspace(ymin, ymax, 200)
        X, Y = np.meshgrid(self.grid_x, self.grid_y)
        XY = np.stack([X, Y], axis=-1)
        Z, _ = FUNCTIONS[fx](XY)

        # clear axes
        self.ax3d.cla(); self.ax2d.cla()

        # 3D surface
        self.ax3d.plot_surface(X, Y, Z, rstride=4, cstride=4, linewidth=0.0, alpha=0.8)
        self.ax3d.set_xlabel('x'); self.ax3d.set_ylabel('y'); self.ax3d.set_zlabel('f(x,y)')
        self.ax3d.set_title(f"{fx} Surface")
        self.ax3d.view_init(elev=35, azim=-60)

        # 2D contour
        self.ax2d.contour(X, Y, Z, levels=30)
        self.ax2d.set_xlabel('x'); self.ax2d.set_ylabel('y')
        self.ax2d.set_title(f"{fx} Contours & Path")

        # Initialize path lines
        self.path3d, = self.ax3d.plot([], [], [], marker='o', linestyle='-', linewidth=1.5)
        self.path2d, = self.ax2d.plot([], [], marker='o', linestyle='-', linewidth=1.5)

        # Gradient quiver (2D)
        _, G = FUNCTIONS[fx](XY)
        skip = (slice(None, None, 12), slice(None, None, 12))
        U = -G[..., 0][skip]
        V = -G[..., 1][skip]
        self.ax2d.quiver(X[skip], Y[skip], U, V, angles='xy')

    # ---------- Drawing ----------
    def _redraw_all(self):
        hist = np.array(self.engine.history)
        if len(hist) > 0:
            self.path2d.set_data(hist[:, 0], hist[:, 1])
            XY = hist
            Z, _ = FUNCTIONS[self.func_var.get()](XY)
            self.path3d.set_data(hist[:, 0], hist[:, 1])
            self.path3d.set_3d_properties(Z)

            fx_val, _ = FUNCTIONS[self.func_var.get()](self.engine.x[None, :])
            self.status.set(f"iter: {self.engine.t} | x: [{self.engine.x[0]:.4f}, {self.engine.x[1]:.4f}] | f(x): {fx_val[0]:.6f}")

        self.canvas.draw_idle()

    # ---------- Events ----------
    def on_change_function(self):
        # Keep current hyperparams in sync when swapping functions
        self.sync_hyperparams()
        self.engine = GDEngine(func_name=self.func_var.get(), lr=self.engine.lr, noise=self.engine.noise)
        self.engine.reset_state([float(self.x0_var.get()), float(self.y0_var.get())])
        self._build_surface()
        self._redraw_all()

    def on_change_lr(self):
        # Slider changes come here; entry typing handled by sync on actions
        self.engine.lr = float(self.lr_var.get())

    def on_change_noise(self):
        self.engine.noise = float(self.noise_var.get())

    def on_start(self):
        if not self.running:
            # ensure hyperparams reflect what’s in the entries
            self.sync_hyperparams()
            # reset to starting point typed by user if we've never run
            if self.engine.t == 0:
                self.engine.reset_state([float(self.x0_var.get()), float(self.y0_var.get())])
                self._redraw_all()
            self.running = True
            self.start_btn.configure(state=tk.DISABLED)
            self.pause_btn.configure(state=tk.NORMAL)
            self._run_loop()

    def on_pause(self):
        self.running = False
        self.start_btn.configure(state=tk.NORMAL)
        self.pause_btn.configure(state=tk.DISABLED)

    def on_step(self):
        # make sure we pick up any edits before stepping
        self.sync_hyperparams()
        if self.engine.t == 0 and not self.running:
            self.engine.reset_state([float(self.x0_var.get()), float(self.y0_var.get())])
        self.engine.step()
        self._redraw_all()

    def on_reset(self):
        self.running = False
        self.start_btn.configure(state=tk.NORMAL)
        self.pause_btn.configure(state=tk.DISABLED)
        # also sync here so resets use current LR/noise on the next start
        self.sync_hyperparams()
        self.engine.reset_state([float(self.x0_var.get()), float(self.y0_var.get())])
        self._build_surface()
        self._redraw_all()

    def _run_loop(self):
        if not self.running:
            return
        self.engine.step()
        self._redraw_all()
        self.root.after(int(self.speed_var.get()), self._run_loop)

if __name__ == "__main__":
    root = tk.Tk()
    app = GDGUI(root)
    root.mainloop()
