
# 10‑Minute Quickstart Challenge: Escaping Saddle Points with a Minimal PSD (Perturbed Saddle-escape Descent)

**Goal:** Build strong intuition about **saddle points** and a hands‑on **minimal PSD** (PSD‑Lite) that escapes them on a small 2D example.  
Complete the following challenges:
1. Visualize a classic saddle surface and analyze why gradient descent may stall.
2. Implement a minimal PSD-Lite with finite-difference curvature probes (no Hessians).
3. Observe PSD-Lite perturb and escape when the gradient is small.
4. (Appendix) Integrate the curvature probe into a deep-learning training loop.

> This is an **educational quickstart**: simple, readable code that matches the *spirit* of PSD. It uses finite differences and conservative defaults, not exact constants from any paper.



## Challenge 0: Setup

This notebook uses only standard Python + NumPy + Matplotlib (PyTorch is only needed for the appendix template).


In [None]:

# If running on Google Colab, you can optionally ensure dependencies:
# !pip install --quiet numpy matplotlib torch torchvision
import sys, math, time, random
import numpy as np
import matplotlib.pyplot as plt

# Global plot defaults for readability (no seaborn, no custom colors)
%matplotlib inline



## Challenge 1: Visual intuition—what is a saddle point?

We'll start with a classic saddle function:  
\[ f(x, y) = x^2 - y^2. \]

- At the origin (0,0): gradient is zero → a stationary point.  
- Curvature is **up** along x and **down** along y → **saddle**.
- Pure **gradient descent** gets **stuck** if started exactly at (0,0) (since the gradient is exactly zero there).


In [None]:

# Plot the saddle surface and a contour view
xs = np.linspace(-2, 2, 200)
ys = np.linspace(-2, 2, 200)
X, Y = np.meshgrid(xs, ys)
Z = X**2 - Y**2

# 3D surface
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
fig = plt.figure(figsize=(7,5))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, alpha=0.85, edgecolor='none')
ax.set_title("Saddle: f(x,y) = x^2 - y^2")
ax.set_xlabel("x"); ax.set_ylabel("y"); ax.set_zlabel("f(x,y)")
plt.show()

# 2D contour
plt.figure(figsize=(6,5))
cs = plt.contour(X, Y, Z, levels=20)
plt.clabel(cs, inline=True, fontsize=8)
plt.title("Contours of f(x,y) = x^2 - y^2 (Saddle)")
plt.xlabel("x"); plt.ylabel("y")
plt.gca().set_aspect('equal', 'box')
plt.show()



## Challenge 2: Minimal PSD-Lite (finite-difference probe; no Hessians)

**Key idea:** When the gradient becomes **small**, we **probe curvature** along random directions using a simple **central difference** of the function values:
\[ q(v) = \frac{f(x + h v) - 2 f(x) + f(x - h v)}{h^2}. \]

- If any probe yields **negative** curvature below a threshold (\( q(v) < -\gamma \)), we **perturb** along that direction and take a short sequence of **gradient steps** (an **escape episode**).
- Otherwise, we declare an approximate second‑order stationary point (no strong negative curvature found).

We'll use a slightly non‑quadratic function with a saddle at the origin, so curvature changes with position (more realistic than a pure quadratic):
\[ f(x,y) = \tfrac{1}{4}x^4 - \tfrac{1}{4}y^4 - \tfrac{1}{2}(x^2 - y^2) + 0.1\,x\,y. \]

> **Why this function?** It has a saddle at (0,0), non‑constant curvature (so finite‑difference probing is meaningful), and simple derivatives for gradient descent.


In [None]:

# Define the example function and its gradient
def f_xy(xy):
    x, y = xy
    return 0.25*x**4 - 0.25*y**4 - 0.5*(x**2 - y**2) + 0.1*x*y

def grad_f_xy(xy):
    x, y = xy
    dfx = x**3 - x + 0.1*y
    dfy = -y**3 + y + 0.1*x
    return np.array([dfx, dfy], dtype=float)

def central_diff_curvature(xy, v, h):
    # q(v) = [f(x + h v) - 2 f(x) + f(x - h v)] / h^2
    v = v / np.linalg.norm(v)
    return (f_xy(xy + h*v) - 2.0*f_xy(xy) + f_xy(xy - h*v)) / (h**2)



### Challenge: PSD-Lite algorithm (educational defaults)

We implement a minimal version with easy parameters:
- Gradient threshold: `eps_g` (enter escape logic when \(\|\nabla f\| \le eps_g\)).  
- Probe radius: `h` (finite‑difference step).  
- Negative‑curvature threshold: `gamma`.  
- Perturbation radius: `r`.  
- Escape episode length: `T` gradient steps with step size `eta`.

> These are **teaching defaults**. For production or theoretical guarantees, you'd calibrate parameters more carefully.


In [None]:

def psd_lite(
    x0=np.array([0.0, 0.0]),
    eps_g=1e-3,
    h=1e-2,
    gamma=1e-2,
    r=5e-2,
    T=30,
    eta=0.1,
    max_iters=2000,
    m_probes=16,
    seed=0,
    verbose=True
):
    rng = np.random.default_rng(seed)
    x = x0.astype(float).copy()
    traj = [x.copy()]
    episodes = 0
    it = 0
    logs = []

    while it < max_iters:
        g = grad_f_xy(x)
        gnorm = np.linalg.norm(g)

        if gnorm > eps_g:
            # plain gradient step
            x = x - eta * g
            traj.append(x.copy())
            it += 1
            continue

        # gradient is small: probe for negative curvature
        found_nc = False
        best_q = float('inf')
        best_v = None
        for _ in range(m_probes):
            v = rng.normal(size=2)
            q = central_diff_curvature(x, v, h)
            if q < best_q:
                best_q, best_v = q, v

        if verbose:
            logs.append(dict(iter=it, gnorm=gnorm, best_q=best_q))

        if best_q < -gamma:
            # Negative curvature detected: do an escape episode
            episodes += 1
            # Random perturbation (small push)
            xi = rng.normal(size=2)
            xi = r * xi / np.linalg.norm(xi)
            y = x + xi

            for _ in range(T):
                y = y - eta * grad_f_xy(y)
                traj.append(y.copy())

            x = y
            it += T
        else:
            # Approximate SOSP
            break

    return np.array(traj), episodes, logs, x



### Challenge 3: Run PSD-Lite vs. plain Gradient Descent (starting at the saddle)

- **GD** starting at (0,0) does **nothing** (gradient is exactly zero).  
- **PSD‑Lite** perturbs and runs a short **escape episode**, then continues with gradient steps.


In [None]:

def run_gd(x0, eta=0.1, steps=60):
    x = x0.astype(float).copy()
    traj = [x.copy()]
    for _ in range(steps):
        x = x - eta * grad_f_xy(x)
        traj.append(x.copy())
    return np.array(traj)

x0 = np.array([0.0, 0.0])  # exactly at the saddle

# Plain GD
gd_traj = run_gd(x0, eta=0.1, steps=60)

# PSD-Lite
traj, episodes, logs, x_final = psd_lite(
    x0=x0, eps_g=1e-6, h=1e-2, gamma=1e-2, r=5e-2, T=30, eta=0.1,
    max_iters=2000, m_probes=16, seed=42, verbose=True
)

print(f"Episodes triggered: {episodes}")
print(f"Final point (PSD-Lite): {x_final}, f = {f_xy(x_final):.6f}")

# Plot contours + trajectories
xs = np.linspace(-1.5, 1.5, 300)
ys = np.linspace(-1.5, 1.5, 300)
X, Y = np.meshgrid(xs, ys)
Z = 0.25*X**4 - 0.25*Y**4 - 0.5*(X**2 - Y**2) + 0.1*X*Y

plt.figure(figsize=(6,5))
cs = plt.contour(X, Y, Z, levels=30)
plt.clabel(cs, inline=True, fontsize=8)
plt.plot(gd_traj[:,0], gd_traj[:,1], marker='.', linewidth=1, label='GD from (0,0)')
plt.plot(traj[:,0], traj[:,1], marker='.', linewidth=1, label='PSD-Lite from (0,0)')
plt.scatter([0],[0], s=60, marker='x', label='Start (saddle)')
plt.title('GD vs PSD-Lite (contours)')
plt.xlabel('x'); plt.ylabel('y')
plt.gca().set_aspect('equal', 'box')
plt.legend()
plt.show()



### Challenge 4: Inspect escape logs
Look at when the gradient got tiny and the probe found negative curvature (most negative `q`).


In [None]:

# Show a few log lines
for i, rec in enumerate(logs[:10]):
    print(f"iter={rec['iter']:>4}  ||grad||={rec['gnorm']:.2e}  best_q={rec['best_q']:.3e}")


---

## Optional Challenge: Integrate a curvature probe into a DL training loop (template)

Below is a **template** for adding a finite‑difference curvature probe to a PyTorch training loop.
It samples random parameter directions, uses a small central‑difference step `h` on the **loss**,
and if it detects strong negative curvature (`q < -gamma`), it applies a small **parameter perturbation**.

> This is deliberately **minimal** and intended as a learning scaffold — not a drop‑in optimizer.


In [None]:

# (Optional) PyTorch template for adding a curvature probe to training
# Note: This is a template; it will run, but it's configured for tiny demos / unit tests.

import torch
import torch.nn as nn
import torch.nn.functional as F

def flatten_params(model):
    return torch.cat([p.detach().flatten() for p in model.parameters()])

def assign_flat_params_(model, flat, like_params):
    # Write a flat vector into model parameters (in-place)
    idx = 0
    with torch.no_grad():
        for p in like_params:
            num = p.numel()
            p.copy_(flat[idx:idx+num].view_as(p))
            idx += num

def loss_closure(model, x, y):
    # Simple supervised loss
    logits = model(x)
    return F.cross_entropy(logits, y)

@torch.no_grad()
def probe_negative_curvature(model, data_batch, h=1e-3, m_dirs=4, gamma=1e-3, device='cpu'):
    # Central-difference probe on the loss along random parameter directions
    model.eval()
    (x, y) = data_batch
    x = x.to(device); y = y.to(device)

    # Flatten parameters
    like_params = [p for p in model.parameters() if p.requires_grad]
    base = flatten_params(model).to(device)

    # Evaluate base loss
    base_loss = loss_closure(model, x, y).item()

    best_q = float('inf')
    best_dir = None

    for _ in range(m_dirs):
        # Random direction with unit norm
        d = torch.randn_like(base)
        d = d / (d.norm() + 1e-12)

        # Central difference on loss
        assign_flat_params_(model, base + h*d, like_params)
        lp = loss_closure(model, x, y).item()
        assign_flat_params_(model, base - h*d, like_params)
        lm = loss_closure(model, x, y).item()
        assign_flat_params_(model, base, like_params)

        q = (lp - 2*base_loss + lm) / (h*h)
        if q < best_q:
            best_q, best_dir = q, d

    return best_q, best_dir  # if best_q < -gamma, consider perturbation



### Optional Challenge: Minimal training loop sketch with probe (for small toy models)

- Every few steps, if the gradient norm is small, we call `probe_negative_curvature`.
- If `q < -gamma`, we **perturb parameters** a little along `best_dir`, then continue training.


In [None]:

# Sketch of usage (CPU; tiny batch; for illustration only).
# You can adapt this into your own training script.

class TinyMLP(nn.Module):
    def __init__(self, in_dim=2, hid=16, out_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hid)
        self.fc2 = nn.Linear(hid, out_dim)
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

device = 'cpu'
model = TinyMLP().to(device)
opt = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.0)

# Fake data (replace with real dataloader)
x = torch.randn(64, 2).to(device)
y = (x[:,0] > x[:,1]).long().to(device)  # toy labels
batch = (x, y)

eps_g = 1e-4
gamma = 1e-3
h = 1e-3
r = 1e-2   # perturbation magnitude in parameter space
check_every = 10

for step in range(60):
    opt.zero_grad()
    loss = loss_closure(model, x, y)
    loss.backward()
    gnorm = torch.sqrt(sum((p.grad**2).sum() for p in model.parameters())).item()
    opt.step()

    if step % check_every == 0 and gnorm < eps_g:
        q, d = probe_negative_curvature(model, batch, h=h, m_dirs=4, gamma=gamma, device=device)
        if q < -gamma and d is not None:
            # Apply a small parameter perturbation along d
            base = flatten_params(model)
            base = base + r * d
            assign_flat_params_(model, base, [p for p in model.parameters() if p.requires_grad])
            print(f"[step {step}] Escape episode triggered: q={q:.3e}, gnorm={gnorm:.2e}")
        else:
            print(f"[step {step}] Likely near SOSP: q={q:.3e}, gnorm={gnorm:.2e}")

print("Done. (This was a tiny sketch; adapt to your real model and data.)")
