# SLSQP-JAX: CPU vs GPU Benchmark

This notebook benchmarks the [slsqp-jax](https://github.com/lucianopaz/slsqp-jax) constrained optimizer across problem sizes to compare **CPU** and **GPU** execution times.

We solve constrained quadratic problems at dimensions **n = 5, 20, 100, 500** and
report wall-clock times for both devices. Finally, we tackle a **10,000-dimensional constrained portfolio allocation** on GPU to demonstrate
large-scale capability.

> **Requirements** — Run this notebook on Google Colab with a **GPU runtime**.\
> Go to **Runtime → Change runtime type → T4 GPU** (or better).

In [None]:
# Install slsqp-jax from GitHub
!pip install "slsqp-jax @ git+https://github.com/lucianopaz/slsqp-jax.git" -q

In [None]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optimistix as optx

from slsqp_jax import SLSQP

# Enable 64-bit precision for numerical accuracy
jax.config.update("jax_enable_x64", True)

# --- Hardware verification ---
print(f"JAX version : {jax.__version__}")
print(f"All devices : {jax.devices()}")

try:
    gpu_devices = jax.devices("gpu")
    print(f"\u2705 GPU detected: {gpu_devices}")
    gpu_available = True
except RuntimeError:
    print("\u274c No GPU found. Enable GPU in Runtime > Change runtime type.")
    gpu_available = False

cpu_device = jax.devices("cpu")[0]
gpu_device = jax.devices("gpu")[0] if gpu_available else cpu_device

## CPU vs GPU Benchmark

For each problem dimension **n**, we create a constrained weighted quadratic:

$$
\min_x \sum_{i=1}^n w_i\,(x_i - 1)^2
\quad\text{s.t.}\quad
\sum_i x_i = n,\;\; x_i \ge 0
$$

where $w_i$ are linearly spaced in $[1,\,10]$. The analytical solution is
$x^* = (1,\dots,1)$.

Each benchmark:
1. JIT-compiles the full `optimistix.minimise` call (all SLSQP iterations compiled into a single XLA program).
2. Runs a **warmup** pass to trigger compilation.
3. Times multiple post-compilation runs and reports the average.

In [None]:
def make_benchmark_problem(n):
    """Create a constrained quadratic optimisation problem of dimension n.

    Returns (solver, objective, x0).
    """
    weights = jnp.linspace(1.0, 10.0, n)

    def objective(x, args):
        return jnp.sum(weights * (x - 1.0) ** 2), None

    def eq_constraint(x, args):
        return jnp.array([jnp.sum(x) - float(n)])

    def ineq_constraint(x, args):
        return x  # x_i >= 0

    solver = SLSQP(
        rtol=1e-6,
        atol=1e-6,
        max_steps=200,
        eq_constraint_fn=eq_constraint,
        n_eq_constraints=1,
        ineq_constraint_fn=ineq_constraint,
        n_ineq_constraints=n,
        lbfgs_memory=10,
    )

    x0 = jnp.full(n, 2.0)
    return solver, objective, x0


def benchmark_on_device(solve_fn, x0, device, n_loops=10):
    """Benchmark a JIT-compiled solve function on *device*.

    Returns (avg_seconds, std_seconds, solution_array).
    """
    x0_dev = jax.device_put(x0, device)

    # Warmup — triggers JIT compilation for this device
    result = solve_fn(x0_dev)
    result.block_until_ready()

    # Timed runs
    times = []
    for _ in range(n_loops):
        t0 = time.perf_counter()
        result = solve_fn(x0_dev)
        result.block_until_ready()  # critical for async GPU dispatch
        times.append(time.perf_counter() - t0)

    return float(np.mean(times)), float(np.std(times)), np.asarray(result)

In [None]:
# ---- Run benchmarks --------------------------------------------------------
problem_sizes = [5, 20, 100, 500]
n_loops = 10

all_results = {}

for n in problem_sizes:
    print(f"\n{'=' * 55}")
    print(f" Problem dimension  n = {n}")
    print(f"{'=' * 55}")

    solver, objective, x0 = make_benchmark_problem(n)

    # Build a single JIT-compiled function for this problem.
    # Passing solver/objective via default-arg capture so each
    # iteration of the loop gets its own compiled function.
    @jax.jit
    def solve(x0, _s=solver, _o=objective):
        sol = optx.minimise(
            _o,
            _s,
            x0,
            has_aux=True,
            max_steps=200,
            throw=False,
        )
        return sol.value

    entry = {}

    # --- CPU ---
    print("  Compiling for CPU ...", end=" ", flush=True)
    cpu_avg, cpu_std, cpu_sol = benchmark_on_device(
        solve,
        x0,
        cpu_device,
        n_loops,
    )
    print(f"done  \u2192  {cpu_avg * 1000:8.2f} \u00b1 {cpu_std * 1000:.2f} ms")
    entry["cpu"] = {"avg": cpu_avg, "std": cpu_std, "sol": cpu_sol}

    # --- GPU ---
    if gpu_available:
        print("  Compiling for GPU ...", end=" ", flush=True)
        gpu_avg, gpu_std, gpu_sol = benchmark_on_device(
            solve,
            x0,
            gpu_device,
            n_loops,
        )
        print(f"done  \u2192  {gpu_avg * 1000:8.2f} \u00b1 {gpu_std * 1000:.2f} ms")
        entry["gpu"] = {"avg": gpu_avg, "std": gpu_std, "sol": gpu_sol}
        speedup = cpu_avg / gpu_avg
        print(f"  Speedup: {speedup:.2f}x")

    # Quick sanity check — solution should be all-ones
    err = float(jnp.max(jnp.abs(cpu_sol - 1.0)))
    print(f"  Solution check: max|x* - 1| = {err:.2e}")

    all_results[n] = entry

In [None]:
# ---- Summary table ---------------------------------------------------------
print()
header = f"{'n':>6} | {'CPU (ms)':>14} | {'GPU (ms)':>14} | {'Speedup':>10}"
print(header)
print("-" * len(header))

for n in problem_sizes:
    cpu_ms = all_results[n]["cpu"]["avg"] * 1000
    row = f"{n:>6} | {cpu_ms:>10.2f} ms"
    if gpu_available and "gpu" in all_results[n]:
        gpu_ms = all_results[n]["gpu"]["avg"] * 1000
        sp = all_results[n]["cpu"]["avg"] / all_results[n]["gpu"]["avg"]
        row += f" | {gpu_ms:>10.2f} ms | {sp:>8.2f}x"
    else:
        row += f" | {'N/A':>14} | {'N/A':>10}"
    print(row)

In [None]:
# ---- Plots -----------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(14, 5), layout="constrained")
sizes = np.array(problem_sizes)

cpu_times = np.array([all_results[n]["cpu"]["avg"] * 1000 for n in problem_sizes])
cpu_stds = np.array([all_results[n]["cpu"]["std"] * 1000 for n in problem_sizes])

# -- Left panel: execution times --
ax = axes[0]
ax.errorbar(
    sizes,
    cpu_times,
    yerr=cpu_stds,
    marker="o",
    capsize=5,
    linewidth=2,
    markersize=8,
    label="CPU",
)
if gpu_available:
    gpu_times = np.array([all_results[n]["gpu"]["avg"] * 1000 for n in problem_sizes])
    gpu_stds = np.array([all_results[n]["gpu"]["std"] * 1000 for n in problem_sizes])
    ax.errorbar(
        sizes,
        gpu_times,
        yerr=gpu_stds,
        marker="s",
        capsize=5,
        linewidth=2,
        markersize=8,
        label="GPU",
    )

ax.set_xlabel("Problem dimension (n)", fontsize=12)
ax.set_ylabel("Execution time (ms)", fontsize=12)
ax.set_title("SLSQP-JAX: CPU vs GPU Execution Time", fontsize=13)
ax.legend(fontsize=11)
ax.set_xscale("log")
ax.set_yscale("log")
ax.grid(True, alpha=0.3)
ax.set_xticks(problem_sizes)
ax.set_xticklabels(problem_sizes)

# -- Right panel: speedup --
ax = axes[1]
if gpu_available:
    speedups = cpu_times / gpu_times
    bars = ax.bar(
        range(len(sizes)),
        speedups,
        tick_label=[str(s) for s in sizes],
        color="#2196F3",
        alpha=0.8,
        edgecolor="#1565C0",
    )
    ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.7, label="Break-even")
    ax.set_ylabel("Speedup (CPU / GPU)", fontsize=12)
    ax.set_title("GPU Speedup Factor", fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, axis="y")
    # Annotate bars
    for bar, sp in zip(bars, speedups):
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.05,
            f"{sp:.2f}x",
            ha="center",
            va="bottom",
            fontsize=11,
            fontweight="bold",
        )
else:
    ax.text(
        0.5,
        0.5,
        "No GPU available",
        ha="center",
        va="center",
        fontsize=14,
        transform=ax.transAxes,
    )
ax.set_xlabel("Problem dimension (n)", fontsize=12)

plt.show()

## Large-Scale: 10,000-Dimensional Portfolio Allocation

We now solve a constrained optimization problem with **n = 10,000** decision
variables, modelling a simplified portfolio allocation over a universe of assets.

| Component | Formula |
|---|---|
| **Objective** | $\min_x\; \sum_i \sigma_i^2 x_i^2 - \sum_i \mu_i x_i$ |
| **Budget** (equality) | $\sum_i x_i = 1$ |
| **Sector caps** (inequality) | $\sum_{i \in S_k} x_i \le 0.15$ for each of $K{=}20$ sectors |

- $\sigma_i \in [0.05, 0.50]$ : per-asset volatility (random)
- $\mu_i \in [0.01, 0.15]$ : expected return (random)
- Each asset is randomly assigned to one of 20 sectors.

This gives **10,001 decision-variable dimensions** with only
**21 constraints** (1 equality + 20 inequality), which keeps the QP
subproblem's active-set projection fast ($21 \times 21$ system) while
exercising the L-BFGS Hessian and projected CG at scale.

In [None]:
N_LARGE = 10_000
K_SECTORS = 20
SECTOR_CAP = 0.15  # max 15 % in any one sector

# --- Generate random market parameters ---
key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)

sector_ids = jax.random.randint(k1, (N_LARGE,), 0, K_SECTORS)
volatilities = jax.random.uniform(k2, (N_LARGE,), minval=0.05, maxval=0.50)
expected_returns = jax.random.uniform(k3, (N_LARGE,), minval=0.01, maxval=0.15)


def portfolio_objective(x, args):
    """Risk-adjusted objective: variance - expected return."""
    variance = jnp.sum(volatilities**2 * x**2)
    exp_return = jnp.sum(expected_returns * x)
    return variance - exp_return, None


def budget_constraint(x, args):
    """Fully-invested: sum(x) = 1."""
    return jnp.array([jnp.sum(x) - 1.0])


def sector_constraint(x, args):
    """Sector caps: sum of allocations in each sector <= SECTOR_CAP.

    Returns cap - sector_sum >= 0 for each sector.
    """
    sector_allocs = jnp.zeros(K_SECTORS).at[sector_ids].add(x)
    return SECTOR_CAP - sector_allocs


solver_10k = SLSQP(
    rtol=1e-5,
    atol=1e-5,
    max_steps=500,
    eq_constraint_fn=budget_constraint,
    n_eq_constraints=1,
    ineq_constraint_fn=sector_constraint,
    n_ineq_constraints=K_SECTORS,
    lbfgs_memory=15,
    qp_max_iter=200,
    qp_max_cg_iter=100,
)

# Equal-weight starting point (feasible: each sector gets ~5 %)
x0_10k = jnp.ones(N_LARGE) / N_LARGE

print("Problem setup")
print(f"  Decision variables      : {N_LARGE:,}")
print("  Equality constraints    : 1")
print(f"  Inequality constraints  : {K_SECTORS}")
print(f"  Total constraints       : {K_SECTORS + 1}")

target_device = gpu_device if gpu_available else cpu_device
device_name = "GPU" if gpu_available else "CPU"
print(f"  Target device           : {device_name}")


@jax.jit
def solve_10k(x0):
    sol = optx.minimise(
        portfolio_objective,
        solver_10k,
        x0,
        has_aux=True,
        max_steps=500,
        throw=False,
    )
    return sol.value


x0_dev = jax.device_put(x0_10k, target_device)

# Compile
print("\nCompiling (this may take a minute or two) ...")
t0 = time.perf_counter()
result_10k = solve_10k(x0_dev)
result_10k.block_until_ready()
compile_time = time.perf_counter() - t0
print(f"Compilation time: {compile_time:.1f} s")

# Timed runs
n_runs = 5
print(f"\nRunning {n_runs} timed solves ...")
times_10k = []
for _ in range(n_runs):
    t0 = time.perf_counter()
    result_10k = solve_10k(x0_dev)
    result_10k.block_until_ready()
    times_10k.append(time.perf_counter() - t0)

avg_10k = np.mean(times_10k)
std_10k = np.std(times_10k)
print(
    f"Average solve time ({device_name}): {avg_10k * 1000:.2f} \u00b1 {std_10k * 1000:.2f} ms"
)

In [None]:
# ---- Analyse the 10 K solution ---------------------------------------------
x_opt = np.asarray(result_10k)

obj_val = float(portfolio_objective(result_10k, None)[0])
budget_viol = float(jnp.abs(jnp.sum(result_10k) - 1.0))
sector_allocs = np.asarray(
    jnp.zeros(K_SECTORS).at[np.asarray(sector_ids)].add(result_10k)
)
max_sector_viol = float(np.max(np.maximum(sector_allocs - SECTOR_CAP, 0.0)))
n_active = int(np.sum(x_opt > 1e-6))

print(f"{'=' * 60}")
print(f"{'10 K PORTFOLIO OPTIMISATION RESULTS':^60}")
print(f"{'=' * 60}")
print(f"  Objective value            : {obj_val:.6f}")
print(f"  Budget violation |sum-1|   : {budget_viol:.2e}")
print(f"  Max sector-cap violation   : {max_sector_viol:.2e}")
print(f"  Non-zero allocations       : {n_active:,} / {N_LARGE:,}")
print(f"  Total portfolio weight     : {float(jnp.sum(result_10k)):.6f}")
print(
    f"  Expected return            : {float(jnp.sum(expected_returns * result_10k)):.6f}"
)
print(
    f"  Portfolio variance         : {float(jnp.sum(volatilities**2 * result_10k**2)):.6f}"
)

# ---- Visualise -------------------------------------------------------------
fig, axes = plt.subplots(1, 3, figsize=(18, 5), layout="constrained")

# 1) Sector allocations vs caps
ax = axes[0]
sector_labels = [f"S{i}" for i in range(K_SECTORS)]
bars = ax.bar(
    sector_labels,
    sector_allocs,
    color="#4CAF50",
    alpha=0.85,
    edgecolor="#388E3C",
    label="Allocation",
)
ax.axhline(
    y=SECTOR_CAP,
    color="red",
    linestyle="--",
    linewidth=1.5,
    label=f"Cap ({SECTOR_CAP})",
)
ax.set_ylabel("Total sector weight", fontsize=12)
ax.set_title("Sector Allocations vs Cap", fontsize=13)
ax.legend(fontsize=10)
ax.tick_params(axis="x", rotation=45)
ax.grid(True, alpha=0.3, axis="y")

# 2) Allocation histogram
ax = axes[1]
nonzero = x_opt[x_opt > 1e-8]
ax.hist(nonzero, bins=50, color="#2196F3", alpha=0.8, edgecolor="#1565C0")
ax.set_xlabel("Allocation weight", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
ax.set_title(f"Distribution of Non-Zero Weights ({len(nonzero):,} assets)", fontsize=13)
ax.grid(True, alpha=0.3)

# 3) Return vs volatility scatter, coloured by allocation
ax = axes[2]
vol_np = np.asarray(volatilities)
ret_np = np.asarray(expected_returns)
sc = ax.scatter(vol_np, ret_np, c=x_opt, cmap="YlOrRd", s=1, alpha=0.6)
plt.colorbar(sc, ax=ax, label="Allocation weight")
ax.set_xlabel("Volatility (\u03c3)", fontsize=12)
ax.set_ylabel("Expected return (\u03bc)", fontsize=12)
ax.set_title("Asset Universe: Allocation by Risk-Return", fontsize=13)
ax.grid(True, alpha=0.3)

plt.show()