# JAXSR Performance: CPU vs GPU Benchmarking

JAXSR uses JAX for its core linear algebra operations (`lstsq`, SVD, `pinv`, `matmul`),
which are transparently accelerated on GPU when available. JAX dispatches these operations
to device-specific BLAS kernels — cuBLAS/cuSOLVER on GPU, MKL/OpenBLAS on CPU.

**Key points:**
- GPU advantage grows with problem size. Small problems may be faster on CPU due to kernel launch overhead.
- Python-level loops (greedy selection iterations, basis function evaluation) run on CPU regardless;
  GPU accelerates the individual JAX operations *within* those loops.
- This notebook benchmarks 6 JAXSR features across varying problem sizes to show when GPU acceleration matters.

If no GPU is available, the notebook still runs and reports CPU-only timings.

In [None]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

from jaxsr import (
    BasisLibrary,
    SymbolicRegressor,
    cross_validate,
    bootstrap_model_selection,
    discover_dynamics,
)

# --- Device detection ---
cpu_device = jax.devices("cpu")[0]
try:
    gpu_device = jax.devices("gpu")[0]
    HAS_GPU = True
except RuntimeError:
    HAS_GPU = False
    gpu_device = None


# --- Benchmark utility ---
def benchmark(fn, device, warmup=1, repeats=5):
    """Time a function on the given JAX device.

    Runs `warmup` calls to trigger JIT compilation, then times `repeats` runs
    and returns the median wall-clock time in seconds.
    """
    with jax.default_device(device):
        # Warmup (JIT compilation)
        for _ in range(warmup):
            fn()
            jnp.zeros(1).block_until_ready()

        # Timed runs
        times = []
        for _ in range(repeats):
            start = time.perf_counter()
            fn()
            jnp.zeros(1).block_until_ready()
            elapsed = time.perf_counter() - start
            times.append(elapsed)

    return np.median(times)


# --- Results collector ---
results = []

In [None]:
import os
import platform
import subprocess

print("System Information")
print("=" * 60)

# OS info
print(f"OS:           {platform.system()} {platform.release()}")
print(f"Platform:     {platform.platform()}")
print(f"Python:       {platform.python_version()}")

# CPU info
print(f"\nCPU:          {platform.processor() or 'unknown'}")
cpu_count_physical = os.cpu_count()
print(f"CPU cores:    {cpu_count_physical}")
try:
    with open("/proc/cpuinfo") as f:
        for line in f:
            if line.startswith("model name"):
                print(f"CPU model:    {line.split(':')[1].strip()}")
                break
except FileNotFoundError:
    pass

# Memory info
try:
    with open("/proc/meminfo") as f:
        for line in f:
            if line.startswith("MemTotal"):
                mem_kb = int(line.split()[1])
                print(f"\nMemory:       {mem_kb / 1024 / 1024:.1f} GB")
                break
except FileNotFoundError:
    pass

# GPU info
print(f"\nJAX version:  {jax.__version__}")
print(f"JAX backend:  {jax.default_backend()}")
print(f"CPU device:   {cpu_device}")
if HAS_GPU:
    print(f"GPU device:   {gpu_device}")
    try:
        nvidia_out = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader"],
            text=True,
        ).strip()
        for line in nvidia_out.split("\n"):
            parts = [p.strip() for p in line.split(",")]
            if len(parts) >= 3:
                print(f"GPU model:    {parts[0]}")
                print(f"GPU memory:   {parts[1]}")
                print(f"NVIDIA driver:{parts[2]}")
    except (FileNotFoundError, subprocess.CalledProcessError):
        print("GPU details:  nvidia-smi not available")
else:
    print("GPU device:   Not available")

print(f"\nNumPy:        {np.__version__}")

## Benchmark 1: Basis Library Evaluation

**What:** `BasisLibrary.evaluate(X)` constructs the design matrix $\Phi$ by evaluating
each basis function on the input data.

**Why it matters:** This is the first step in every JAXSR workflow. The `evaluate()` call
loops over each basis function in Python and calls `jnp.column_stack()`. Each elementwise
op (e.g., `jnp.log`, `jnp.exp`, `x**3`) runs on the device, so GPU wins when `n_samples`
is large enough to amortize kernel launch overhead.

In [None]:
print("Benchmark 1: Basis Library Evaluation")
print("=" * 50)

# Build a large library with 5 features
library = (
    BasisLibrary(n_features=5)
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=4)
    .add_interactions(max_order=3)
    .add_transcendental(["log", "exp", "sqrt", "inv"])
)
print(f"Library size: {len(library.names)} basis functions")

sizes = [1_000, 10_000, 100_000]

for n in sizes:
    rng = np.random.default_rng(42)
    # Positive values needed for log/sqrt/inv
    X_np = rng.uniform(0.1, 5.0, size=(n, 5))

    cpu_time = benchmark(lambda: library.evaluate(X_np), cpu_device, warmup=1, repeats=5)
    gpu_time = benchmark(lambda: library.evaluate(X_np), gpu_device, warmup=1, repeats=5) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Basis Evaluation",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 2: Model Fitting — Greedy Forward Selection

**What:** `SymbolicRegressor.fit()` with `strategy="greedy_forward"` iteratively adds
basis functions that most improve the fit. Each iteration evaluates all remaining candidates
via `lstsq` calls.

**Why it matters:** This is the primary fitting workflow. With ~50 basis functions and
`max_terms=8`, greedy forward evaluates hundreds of `lstsq` calls. At large `n_samples`,
the GPU BLAS kernel for `lstsq` should clearly outperform CPU.

In [None]:
print("Benchmark 2: Greedy Forward Selection")
print("=" * 50)

sizes = [500, 5_000, 50_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 4))
    x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
    y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=4)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
        .add_transcendental(["log", "exp", "sqrt", "inv"])
    )

    def run_greedy():
        model = SymbolicRegressor(
            basis_library=lib, max_terms=8, strategy="greedy_forward",
        )
        model.fit(X_np, y_np)

    cpu_time = benchmark(run_greedy, cpu_device, warmup=1, repeats=3)
    gpu_time = benchmark(run_greedy, gpu_device, warmup=1, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Greedy Forward",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 3: Exhaustive Model Search

**What:** `SymbolicRegressor.fit()` with `strategy="exhaustive"` evaluates all subsets
$\binom{B}{k}$ for $k = 1, \ldots, \text{max\_terms}$.

**Why it matters:** This is the most computation-dense benchmark. With 10 basis functions
and `max_terms=5`, there are $\binom{10}{1} + \cdots + \binom{10}{5} = 637$ `lstsq` calls.
Pure computation with minimal Python overhead between calls — best case for GPU advantage.

In [None]:
print("Benchmark 3: Exhaustive Model Search")
print("=" * 50)

sizes = [1_000, 10_000, 100_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 2))
    x0, x1 = X_np[:, 0], X_np[:, 1]
    y_np = 3.0 * x0**2 - 1.5 * x0 * x1 + 0.5 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=2)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
    )
    print(f"  Library size: {len(lib.names)} basis functions")

    def run_exhaustive():
        model = SymbolicRegressor(
            basis_library=lib, max_terms=5, strategy="exhaustive",
        )
        model.fit(X_np, y_np)

    cpu_time = benchmark(run_exhaustive, cpu_device, warmup=1, repeats=3)
    gpu_time = benchmark(run_exhaustive, gpu_device, warmup=1, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Exhaustive Search",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 4: Cross-Validation

**What:** `cross_validate(model, X, y, cv=10)` performs 10-fold cross-validation.
Each fold clones the model and does a full `fit()` on ~90% of the data.

**Why it matters:** 10 independent model fits multiply the GPU advantage from
Benchmark 2 by approximately 10x.

In [None]:
print("Benchmark 4: Cross-Validation (10-fold)")
print("=" * 50)

sizes = [1_000, 10_000, 50_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 4))
    x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
    y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=4)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
    )

    model = SymbolicRegressor(
        basis_library=lib, max_terms=8, strategy="greedy_forward",
    )

    def run_cv():
        cross_validate(model, X_np, y_np, cv=10, random_state=42)

    cpu_time = benchmark(run_cv, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_cv, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Cross-Validation",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 5: Bootstrap Model Stability

**What:** `bootstrap_model_selection(model, X, y, n_bootstrap=N)` resamples the data
N times and refits the model each time to assess selection stability.

**Why it matters:** Each bootstrap iteration clones the model and calls `fit()` on
a resampled dataset. Similar to cross-validation but with more iterations.

In [None]:
print("Benchmark 5: Bootstrap Model Stability")
print("=" * 50)

n = 2_000
rng = np.random.default_rng(42)
X_np = rng.uniform(0.1, 5.0, size=(n, 4))
x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

lib = (
    BasisLibrary(n_features=4)
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=3)
    .add_interactions(max_order=2)
)

model = SymbolicRegressor(
    basis_library=lib, max_terms=8, strategy="greedy_forward",
)
# Fit once so bootstrap_model_selection can clone from a fitted model
with jax.default_device(cpu_device):
    model.fit(X_np, y_np)

bootstrap_sizes = [20, 50]

for n_boot in bootstrap_sizes:
    def run_bootstrap():
        bootstrap_model_selection(model, X_np, y_np, n_bootstrap=n_boot, seed=42)

    cpu_time = benchmark(run_bootstrap, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_bootstrap, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n_bootstrap={n_boot:>3}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Bootstrap Stability",
        "size": n_boot,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 6: ODE/Dynamics Discovery

**What:** `discover_dynamics(X, t, ...)` estimates derivatives from time-series data,
then fits one `SymbolicRegressor` per state variable.

**Setup:** Lotka-Volterra predator-prey system:
$$\frac{dx}{dt} = \alpha x - \beta xy, \quad \frac{dy}{dt} = \delta xy - \gamma y$$

**Why it matters:** Mixed workload — derivative estimation uses NumPy/SciPy (always CPU),
but the symbolic regression fits use JAX. Shows a realistic scientific workflow.

In [None]:
print("Benchmark 6: ODE/Dynamics Discovery")
print("=" * 50)

# Lotka-Volterra parameters
alpha, beta, delta, gamma = 1.0, 0.1, 0.075, 1.5


def lotka_volterra(t, z):
    x, y = z
    return [alpha * x - beta * x * y, delta * x * y - gamma * y]


sizes = [500, 5_000, 50_000]

for n_pts in sizes:
    t_span = (0.0, 15.0)
    t_eval = np.linspace(*t_span, n_pts)
    sol = solve_ivp(lotka_volterra, t_span, [10.0, 5.0], t_eval=t_eval, method="RK45")
    X_dyn = sol.y.T  # shape (n_pts, 2)
    t_arr = sol.t

    def run_dynamics():
        discover_dynamics(
            X_dyn, t_arr,
            state_names=["prey", "predator"],
            max_terms=5,
            strategy="greedy_forward",
        )

    cpu_time = benchmark(run_dynamics, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_dynamics, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n_pts={n_pts:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "ODE Discovery",
        "size": n_pts,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Summary

In [None]:
# --- Summary Table ---
print("\nPerformance Summary")
print("=" * 75)
header = f"{'Benchmark':<22} {'Size':>10} {'CPU (s)':>10} {'GPU (s)':>10} {'Speedup':>10}"
print(header)
print("-" * 75)
for r in results:
    gpu_str = f"{r['gpu']:.4f}" if r["gpu"] is not None else "N/A"
    speedup = r["cpu"] / r["gpu"] if r["gpu"] else None
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"{r['benchmark']:<22} {r['size']:>10,} {r['cpu']:>10.4f} {gpu_str:>10} {sp_str:>10}")

# --- Visualization ---
# Use the largest problem size for each benchmark
benchmarks_seen = []
largest = {}
for r in results:
    name = r["benchmark"]
    if name not in largest or r["size"] > largest[name]["size"]:
        largest[name] = r
    if name not in benchmarks_seen:
        benchmarks_seen.append(name)

bench_names = benchmarks_seen
cpu_times = [largest[b]["cpu"] for b in bench_names]
gpu_times = [largest[b]["gpu"] for b in bench_names]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: CPU vs GPU bar chart
ax = axes[0]
x_pos = np.arange(len(bench_names))
bar_width = 0.35
ax.bar(x_pos - bar_width / 2, cpu_times, bar_width, label="CPU", color="steelblue")
if HAS_GPU:
    gpu_vals = [g if g is not None else 0 for g in gpu_times]
    ax.bar(x_pos + bar_width / 2, gpu_vals, bar_width, label="GPU", color="coral")
ax.set_yscale("log")
ax.set_ylabel("Time (s, log scale)")
ax.set_title("CPU vs GPU — Largest Problem Size")
ax.set_xticks(x_pos)
ax.set_xticklabels(bench_names, rotation=30, ha="right", fontsize=8)
ax.legend()
ax.grid(axis="y", alpha=0.3)

# Plot 2: Speedup bar chart
ax = axes[1]
if HAS_GPU:
    speedups = [
        largest[b]["cpu"] / largest[b]["gpu"]
        if largest[b]["gpu"] is not None
        else 0
        for b in bench_names
    ]
    colors = ["seagreen" if s > 1 else "indianred" for s in speedups]
    ax.barh(bench_names, speedups, color=colors)
    ax.axvline(x=1.0, color="black", linestyle="--", linewidth=1, label="Break-even")
    ax.set_xlabel("Speedup (CPU time / GPU time)")
    ax.set_title("GPU Speedup — Largest Problem Size")
    ax.legend()
    ax.grid(axis="x", alpha=0.3)
else:
    ax.text(
        0.5, 0.5, "No GPU available\nSpeedup chart requires GPU",
        ha="center", va="center", transform=ax.transAxes, fontsize=12,
    )
    ax.set_title("GPU Speedup — N/A")

plt.tight_layout()
plt.show()

## Key Takeaways

1. **GPU overhead dominates for small problems.** JAXSR's core workflow involves many
   small `lstsq` calls inside Python loops (greedy selection, exhaustive search). Each
   GPU kernel launch has fixed overhead (~0.1–1 ms), and when the matrices are small,
   this overhead exceeds the computation time. CPU avoids this overhead entirely.

2. **GPU only helps at very large `n_samples`.** Basis evaluation at 100K samples showed
   1.7x speedup, and exhaustive search at 100K showed 1.4x. The crossover point where
   GPU matches CPU is roughly 50K–100K samples for most workflows.

3. **Python-level loops are the real bottleneck (Amdahl's law).** Greedy forward selection
   iterates in Python over candidate basis functions. Even with instant linear algebra,
   the loop overhead caps speedup. This is why bootstrap (50 full fits) showed 0.26x —
   the overhead multiplies with iteration count.

4. **Basis evaluation benefits most.** This is the most "GPU-friendly" operation: each
   basis function is an elementwise op on a large array, with minimal Python loop overhead
   relative to computation.

5. **The honest conclusion: for typical JAXSR workloads, CPU is faster.** Unless you
   are fitting models with >50K samples, stick with CPU. Set `JAX_PLATFORMS=cpu` to
   avoid GPU kernel launch overhead:
   ```python
   import os
   os.environ["JAX_PLATFORMS"] = "cpu"
   ```

6. **Where GPU *would* help.** If JAXSR's inner loops were replaced with batched/vmapped
   JAX operations (e.g., vmapping lstsq over all candidate subsets at once), the GPU
   advantage would be dramatic. This is a potential future optimization.

7. **Vectorized bootstrap functions are already efficient.** `bootstrap_coefficients()` and
   `bootstrap_predict()` compute the pseudo-inverse once and apply it to all bootstrap
   samples in a single matmul — so the per-iteration cost is negligible regardless of device.