# JAXSR Performance: CPU vs GPU Benchmarking

JAXSR uses JAX for its core linear algebra operations (`lstsq`, SVD, `pinv`, `matmul`),
which are transparently accelerated on GPU when available. JAX dispatches these operations
to device-specific BLAS kernels — cuBLAS/cuSOLVER on GPU, MKL/OpenBLAS on CPU.

**Key points:**
- GPU advantage grows with problem size. Small problems may be faster on CPU due to kernel launch overhead.
- Python-level loops (greedy selection iterations, basis function evaluation) run on CPU regardless;
  GPU accelerates the individual JAX operations *within* those loops.
- This notebook benchmarks 6 JAXSR features across varying problem sizes to show when GPU acceleration matters.

If no GPU is available, the notebook still runs and reports CPU-only timings.

In [None]:
i
m
p
o
r
t
 
t
i
m
e


i
m
p
o
r
t
 
j
a
x

i
m
p
o
r
t
 
j
a
x
.
n
u
m
p
y
 
a
s
 
j
n
p

i
m
p
o
r
t
 
m
a
t
p
l
o
t
l
i
b
.
p
y
p
l
o
t
 
a
s
 
p
l
t

i
m
p
o
r
t
 
n
u
m
p
y
 
a
s
 
n
p

f
r
o
m
 
s
c
i
p
y
.
i
n
t
e
g
r
a
t
e
 
i
m
p
o
r
t
 
s
o
l
v
e
_
i
v
p


f
r
o
m
 
j
a
x
s
r
 
i
m
p
o
r
t
 
(

 
 
 
 
B
a
s
i
s
L
i
b
r
a
r
y
,

 
 
 
 
S
y
m
b
o
l
i
c
R
e
g
r
e
s
s
o
r
,

 
 
 
 
c
r
o
s
s
_
v
a
l
i
d
a
t
e
,

 
 
 
 
b
o
o
t
s
t
r
a
p
_
m
o
d
e
l
_
s
e
l
e
c
t
i
o
n
,

 
 
 
 
d
i
s
c
o
v
e
r
_
d
y
n
a
m
i
c
s
,

)


#
 
-
-
-
 
D
e
v
i
c
e
 
d
e
t
e
c
t
i
o
n
 
-
-
-

c
p
u
_
d
e
v
i
c
e
 
=
 
j
a
x
.
d
e
v
i
c
e
s
(
"
c
p
u
"
)
[
0
]

t
r
y
:

 
 
 
 
g
p
u
_
d
e
v
i
c
e
 
=
 
j
a
x
.
d
e
v
i
c
e
s
(
"
g
p
u
"
)
[
0
]

 
 
 
 
H
A
S
_
G
P
U
 
=
 
T
r
u
e

e
x
c
e
p
t
 
R
u
n
t
i
m
e
E
r
r
o
r
:

 
 
 
 
H
A
S
_
G
P
U
 
=
 
F
a
l
s
e

 
 
 
 
g
p
u
_
d
e
v
i
c
e
 
=
 
N
o
n
e



#
 
-
-
-
 
B
e
n
c
h
m
a
r
k
 
u
t
i
l
i
t
y
 
-
-
-

d
e
f
 
b
e
n
c
h
m
a
r
k
(
f
n
,
 
d
e
v
i
c
e
,
 
w
a
r
m
u
p
=
1
,
 
r
e
p
e
a
t
s
=
5
)
:

 
 
 
 
"
"
"
T
i
m
e
 
a
 
f
u
n
c
t
i
o
n
 
o
n
 
t
h
e
 
g
i
v
e
n
 
J
A
X
 
d
e
v
i
c
e
.


 
 
 
 
R
u
n
s
 
`
w
a
r
m
u
p
`
 
c
a
l
l
s
 
t
o
 
t
r
i
g
g
e
r
 
J
I
T
 
c
o
m
p
i
l
a
t
i
o
n
,
 
t
h
e
n
 
t
i
m
e
s
 
`
r
e
p
e
a
t
s
`
 
r
u
n
s

 
 
 
 
a
n
d
 
r
e
t
u
r
n
s
 
t
h
e
 
m
e
d
i
a
n
 
w
a
l
l
-
c
l
o
c
k
 
t
i
m
e
 
i
n
 
s
e
c
o
n
d
s
.

 
 
 
 
"
"
"

 
 
 
 
w
i
t
h
 
j
a
x
.
d
e
f
a
u
l
t
_
d
e
v
i
c
e
(
d
e
v
i
c
e
)
:

 
 
 
 
 
 
 
 
#
 
W
a
r
m
u
p
 
(
J
I
T
 
c
o
m
p
i
l
a
t
i
o
n
)

 
 
 
 
 
 
 
 
f
o
r
 
_
 
i
n
 
r
a
n
g
e
(
w
a
r
m
u
p
)
:

 
 
 
 
 
 
 
 
 
 
 
 
f
n
(
)

 
 
 
 
 
 
 
 
 
 
 
 
j
n
p
.
z
e
r
o
s
(
1
)
.
b
l
o
c
k
_
u
n
t
i
l
_
r
e
a
d
y
(
)


 
 
 
 
 
 
 
 
#
 
T
i
m
e
d
 
r
u
n
s

 
 
 
 
 
 
 
 
t
i
m
e
s
 
=
 
[
]

 
 
 
 
 
 
 
 
f
o
r
 
_
 
i
n
 
r
a
n
g
e
(
r
e
p
e
a
t
s
)
:

 
 
 
 
 
 
 
 
 
 
 
 
s
t
a
r
t
 
=
 
t
i
m
e
.
p
e
r
f
_
c
o
u
n
t
e
r
(
)

 
 
 
 
 
 
 
 
 
 
 
 
f
n
(
)

 
 
 
 
 
 
 
 
 
 
 
 
j
n
p
.
z
e
r
o
s
(
1
)
.
b
l
o
c
k
_
u
n
t
i
l
_
r
e
a
d
y
(
)

 
 
 
 
 
 
 
 
 
 
 
 
e
l
a
p
s
e
d
 
=
 
t
i
m
e
.
p
e
r
f
_
c
o
u
n
t
e
r
(
)
 
-
 
s
t
a
r
t

 
 
 
 
 
 
 
 
 
 
 
 
t
i
m
e
s
.
a
p
p
e
n
d
(
e
l
a
p
s
e
d
)


 
 
 
 
r
e
t
u
r
n
 
n
p
.
m
e
d
i
a
n
(
t
i
m
e
s
)



#
 
-
-
-
 
R
e
s
u
l
t
s
 
c
o
l
l
e
c
t
o
r
 
-
-
-

r
e
s
u
l
t
s
 
=
 
[
]

In [None]:
i
m
p
o
r
t
 
o
s

i
m
p
o
r
t
 
p
l
a
t
f
o
r
m

i
m
p
o
r
t
 
s
u
b
p
r
o
c
e
s
s


p
r
i
n
t
(
"
S
y
s
t
e
m
 
I
n
f
o
r
m
a
t
i
o
n
"
)

p
r
i
n
t
(
"
=
"
 
*
 
6
0
)


#
 
O
S
 
i
n
f
o

p
r
i
n
t
(
f
"
O
S
:
 
 
 
 
 
 
 
 
 
 
 
{
p
l
a
t
f
o
r
m
.
s
y
s
t
e
m
(
)
}
 
{
p
l
a
t
f
o
r
m
.
r
e
l
e
a
s
e
(
)
}
"
)

p
r
i
n
t
(
f
"
P
l
a
t
f
o
r
m
:
 
 
 
 
 
{
p
l
a
t
f
o
r
m
.
p
l
a
t
f
o
r
m
(
)
}
"
)

p
r
i
n
t
(
f
"
P
y
t
h
o
n
:
 
 
 
 
 
 
 
{
p
l
a
t
f
o
r
m
.
p
y
t
h
o
n
_
v
e
r
s
i
o
n
(
)
}
"
)


#
 
C
P
U
 
i
n
f
o

p
r
i
n
t
(
f
"
\
n
C
P
U
:
 
 
 
 
 
 
 
 
 
 
{
p
l
a
t
f
o
r
m
.
p
r
o
c
e
s
s
o
r
(
)
 
o
r
 
'
u
n
k
n
o
w
n
'
}
"
)

c
p
u
_
c
o
u
n
t
_
p
h
y
s
i
c
a
l
 
=
 
o
s
.
c
p
u
_
c
o
u
n
t
(
)

p
r
i
n
t
(
f
"
C
P
U
 
c
o
r
e
s
:
 
 
 
 
{
c
p
u
_
c
o
u
n
t
_
p
h
y
s
i
c
a
l
}
"
)

t
r
y
:

 
 
 
 
w
i
t
h
 
o
p
e
n
(
"
/
p
r
o
c
/
c
p
u
i
n
f
o
"
)
 
a
s
 
f
:

 
 
 
 
 
 
 
 
f
o
r
 
l
i
n
e
 
i
n
 
f
:

 
 
 
 
 
 
 
 
 
 
 
 
i
f
 
l
i
n
e
.
s
t
a
r
t
s
w
i
t
h
(
"
m
o
d
e
l
 
n
a
m
e
"
)
:

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
p
r
i
n
t
(
f
"
C
P
U
 
m
o
d
e
l
:
 
 
 
 
{
l
i
n
e
.
s
p
l
i
t
(
'
:
'
)
[
1
]
.
s
t
r
i
p
(
)
}
"
)

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b
r
e
a
k

e
x
c
e
p
t
 
F
i
l
e
N
o
t
F
o
u
n
d
E
r
r
o
r
:

 
 
 
 
p
a
s
s


#
 
M
e
m
o
r
y
 
i
n
f
o

t
r
y
:

 
 
 
 
w
i
t
h
 
o
p
e
n
(
"
/
p
r
o
c
/
m
e
m
i
n
f
o
"
)
 
a
s
 
f
:

 
 
 
 
 
 
 
 
f
o
r
 
l
i
n
e
 
i
n
 
f
:

 
 
 
 
 
 
 
 
 
 
 
 
i
f
 
l
i
n
e
.
s
t
a
r
t
s
w
i
t
h
(
"
M
e
m
T
o
t
a
l
"
)
:

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
m
e
m
_
k
b
 
=
 
i
n
t
(
l
i
n
e
.
s
p
l
i
t
(
)
[
1
]
)

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
p
r
i
n
t
(
f
"
\
n
M
e
m
o
r
y
:
 
 
 
 
 
 
 
{
m
e
m
_
k
b
 
/
 
1
0
2
4
 
/
 
1
0
2
4
:
.
1
f
}
 
G
B
"
)

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b
r
e
a
k

e
x
c
e
p
t
 
F
i
l
e
N
o
t
F
o
u
n
d
E
r
r
o
r
:

 
 
 
 
p
a
s
s


#
 
G
P
U
 
i
n
f
o

p
r
i
n
t
(
f
"
\
n
J
A
X
 
v
e
r
s
i
o
n
:
 
 
{
j
a
x
.
_
_
v
e
r
s
i
o
n
_
_
}
"
)

p
r
i
n
t
(
f
"
J
A
X
 
b
a
c
k
e
n
d
:
 
 
{
j
a
x
.
d
e
f
a
u
l
t
_
b
a
c
k
e
n
d
(
)
}
"
)

p
r
i
n
t
(
f
"
C
P
U
 
d
e
v
i
c
e
:
 
 
 
{
c
p
u
_
d
e
v
i
c
e
}
"
)

i
f
 
H
A
S
_
G
P
U
:

 
 
 
 
p
r
i
n
t
(
f
"
G
P
U
 
d
e
v
i
c
e
:
 
 
 
{
g
p
u
_
d
e
v
i
c
e
}
"
)

 
 
 
 
t
r
y
:

 
 
 
 
 
 
 
 
n
v
i
d
i
a
_
o
u
t
 
=
 
s
u
b
p
r
o
c
e
s
s
.
c
h
e
c
k
_
o
u
t
p
u
t
(

 
 
 
 
 
 
 
 
 
 
 
 
[
"
n
v
i
d
i
a
-
s
m
i
"
,
 
"
-
-
q
u
e
r
y
-
g
p
u
=
n
a
m
e
,
m
e
m
o
r
y
.
t
o
t
a
l
,
d
r
i
v
e
r
_
v
e
r
s
i
o
n
"
,
 
"
-
-
f
o
r
m
a
t
=
c
s
v
,
n
o
h
e
a
d
e
r
"
]
,

 
 
 
 
 
 
 
 
 
 
 
 
t
e
x
t
=
T
r
u
e
,

 
 
 
 
 
 
 
 
)
.
s
t
r
i
p
(
)

 
 
 
 
 
 
 
 
f
o
r
 
l
i
n
e
 
i
n
 
n
v
i
d
i
a
_
o
u
t
.
s
p
l
i
t
(
"
\
n
"
)
:

 
 
 
 
 
 
 
 
 
 
 
 
p
a
r
t
s
 
=
 
[
p
.
s
t
r
i
p
(
)
 
f
o
r
 
p
 
i
n
 
l
i
n
e
.
s
p
l
i
t
(
"
,
"
)
]

 
 
 
 
 
 
 
 
 
 
 
 
i
f
 
l
e
n
(
p
a
r
t
s
)
 
>
=
 
3
:

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
p
r
i
n
t
(
f
"
G
P
U
 
m
o
d
e
l
:
 
 
 
 
{
p
a
r
t
s
[
0
]
}
"
)

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
p
r
i
n
t
(
f
"
G
P
U
 
m
e
m
o
r
y
:
 
 
 
{
p
a
r
t
s
[
1
]
}
"
)

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
p
r
i
n
t
(
f
"
N
V
I
D
I
A
 
d
r
i
v
e
r
:
{
p
a
r
t
s
[
2
]
}
"
)

 
 
 
 
e
x
c
e
p
t
 
(
F
i
l
e
N
o
t
F
o
u
n
d
E
r
r
o
r
,
 
s
u
b
p
r
o
c
e
s
s
.
C
a
l
l
e
d
P
r
o
c
e
s
s
E
r
r
o
r
)
:

 
 
 
 
 
 
 
 
p
r
i
n
t
(
"
G
P
U
 
d
e
t
a
i
l
s
:
 
 
n
v
i
d
i
a
-
s
m
i
 
n
o
t
 
a
v
a
i
l
a
b
l
e
"
)

e
l
s
e
:

 
 
 
 
p
r
i
n
t
(
"
G
P
U
 
d
e
v
i
c
e
:
 
 
 
N
o
t
 
a
v
a
i
l
a
b
l
e
"
)


p
r
i
n
t
(
f
"
\
n
N
u
m
P
y
:
 
 
 
 
 
 
 
 
{
n
p
.
_
_
v
e
r
s
i
o
n
_
_
}
"
)

## Benchmark 1: Basis Library Evaluation

**What:** `BasisLibrary.evaluate(X)` constructs the design matrix $\Phi$ by evaluating
each basis function on the input data.

**Why it matters:** This is the first step in every JAXSR workflow. The `evaluate()` call
loops over each basis function in Python and calls `jnp.column_stack()`. Each elementwise
op (e.g., `jnp.log`, `jnp.exp`, `x**3`) runs on the device, so GPU wins when `n_samples`
is large enough to amortize kernel launch overhead.

In [None]:
print("Benchmark 1: Basis Library Evaluation")
print("=" * 50)

# Build a large library with 5 features
library = (
    BasisLibrary(n_features=5)
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=4)
    .add_interactions(max_order=3)
    .add_transcendental(["log", "exp", "sqrt", "inv"])
)
print(f"Library size: {len(library.names)} basis functions")

sizes = [1_000, 10_000, 100_000]

for n in sizes:
    rng = np.random.default_rng(42)
    # Positive values needed for log/sqrt/inv
    X_np = rng.uniform(0.1, 5.0, size=(n, 5))

    cpu_time = benchmark(lambda: library.evaluate(X_np), cpu_device, warmup=1, repeats=5)
    gpu_time = benchmark(lambda: library.evaluate(X_np), gpu_device, warmup=1, repeats=5) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Basis Evaluation",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 2: Model Fitting — Greedy Forward Selection

**What:** `SymbolicRegressor.fit()` with `strategy="greedy_forward"` iteratively adds
basis functions that most improve the fit. Each iteration evaluates all remaining candidates
via `lstsq` calls.

**Why it matters:** This is the primary fitting workflow. With ~50 basis functions and
`max_terms=8`, greedy forward evaluates hundreds of `lstsq` calls. At large `n_samples`,
the GPU BLAS kernel for `lstsq` should clearly outperform CPU.

In [None]:
print("Benchmark 2: Greedy Forward Selection")
print("=" * 50)

sizes = [500, 5_000, 50_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 4))
    x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
    y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=4)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
        .add_transcendental(["log", "exp", "sqrt", "inv"])
    )

    def run_greedy():
        model = SymbolicRegressor(
            basis_library=lib, max_terms=8, strategy="greedy_forward",
        )
        model.fit(X_np, y_np)

    cpu_time = benchmark(run_greedy, cpu_device, warmup=1, repeats=3)
    gpu_time = benchmark(run_greedy, gpu_device, warmup=1, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Greedy Forward",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 3: Exhaustive Model Search

**What:** `SymbolicRegressor.fit()` with `strategy="exhaustive"` evaluates all subsets
$\binom{B}{k}$ for $k = 1, \ldots, \text{max\_terms}$.

**Why it matters:** This is the most computation-dense benchmark. With 10 basis functions
and `max_terms=5`, there are $\binom{10}{1} + \cdots + \binom{10}{5} = 637$ `lstsq` calls.
Pure computation with minimal Python overhead between calls — best case for GPU advantage.

In [None]:
print("Benchmark 3: Exhaustive Model Search")
print("=" * 50)

sizes = [1_000, 10_000, 100_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 2))
    x0, x1 = X_np[:, 0], X_np[:, 1]
    y_np = 3.0 * x0**2 - 1.5 * x0 * x1 + 0.5 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=2)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
    )
    print(f"  Library size: {len(lib.names)} basis functions")

    def run_exhaustive():
        model = SymbolicRegressor(
            basis_library=lib, max_terms=5, strategy="exhaustive",
        )
        model.fit(X_np, y_np)

    cpu_time = benchmark(run_exhaustive, cpu_device, warmup=1, repeats=3)
    gpu_time = benchmark(run_exhaustive, gpu_device, warmup=1, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Exhaustive Search",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 4: Cross-Validation

**What:** `cross_validate(model, X, y, cv=10)` performs 10-fold cross-validation.
Each fold clones the model and does a full `fit()` on ~90% of the data.

**Why it matters:** 10 independent model fits multiply the GPU advantage from
Benchmark 2 by approximately 10x.

In [None]:
print("Benchmark 4: Cross-Validation (10-fold)")
print("=" * 50)

sizes = [1_000, 10_000, 50_000]

for n in sizes:
    rng = np.random.default_rng(42)
    X_np = rng.uniform(0.1, 5.0, size=(n, 4))
    x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
    y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

    lib = (
        BasisLibrary(n_features=4)
        .add_constant()
        .add_linear()
        .add_polynomials(max_degree=3)
        .add_interactions(max_order=2)
    )

    model = SymbolicRegressor(
        basis_library=lib, max_terms=8, strategy="greedy_forward",
    )

    def run_cv():
        cross_validate(model, X_np, y_np, cv=10, random_state=42)

    cpu_time = benchmark(run_cv, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_cv, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n={n:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Cross-Validation",
        "size": n,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 5: Bootstrap Model Stability

**What:** `bootstrap_model_selection(model, X, y, n_bootstrap=N)` resamples the data
N times and refits the model each time to assess selection stability.

**Why it matters:** Each bootstrap iteration clones the model and calls `fit()` on
a resampled dataset. Similar to cross-validation but with more iterations.

In [None]:
print("Benchmark 5: Bootstrap Model Stability")
print("=" * 50)

n = 2_000
rng = np.random.default_rng(42)
X_np = rng.uniform(0.1, 5.0, size=(n, 4))
x0, x1, x2, x3 = X_np[:, 0], X_np[:, 1], X_np[:, 2], X_np[:, 3]
y_np = 2.0 * x0 + 1.5 * x1**2 - 0.8 * x2 * x3 + 0.3 + rng.normal(0, 0.1, n)

lib = (
    BasisLibrary(n_features=4)
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=3)
    .add_interactions(max_order=2)
)

model = SymbolicRegressor(
    basis_library=lib, max_terms=8, strategy="greedy_forward",
)
# Fit once so bootstrap_model_selection can clone from a fitted model
with jax.default_device(cpu_device):
    model.fit(X_np, y_np)

bootstrap_sizes = [20, 50]

for n_boot in bootstrap_sizes:
    def run_bootstrap():
        bootstrap_model_selection(model, X_np, y_np, n_bootstrap=n_boot, seed=42)

    cpu_time = benchmark(run_bootstrap, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_bootstrap, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n_bootstrap={n_boot:>3}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "Bootstrap Stability",
        "size": n_boot,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Benchmark 6: ODE/Dynamics Discovery

**What:** `discover_dynamics(X, t, ...)` estimates derivatives from time-series data,
then fits one `SymbolicRegressor` per state variable.

**Setup:** Lotka-Volterra predator-prey system:
$$\frac{dx}{dt} = \alpha x - \beta xy, \quad \frac{dy}{dt} = \delta xy - \gamma y$$

**Why it matters:** Mixed workload — derivative estimation uses NumPy/SciPy (always CPU),
but the symbolic regression fits use JAX. Shows a realistic scientific workflow.

In [None]:
print("Benchmark 6: ODE/Dynamics Discovery")
print("=" * 50)

# Lotka-Volterra parameters
alpha, beta, delta, gamma = 1.0, 0.1, 0.075, 1.5


def lotka_volterra(t, z):
    x, y = z
    return [alpha * x - beta * x * y, delta * x * y - gamma * y]


sizes = [500, 5_000, 50_000]

for n_pts in sizes:
    t_span = (0.0, 15.0)
    t_eval = np.linspace(*t_span, n_pts)
    sol = solve_ivp(lotka_volterra, t_span, [10.0, 5.0], t_eval=t_eval, method="RK45")
    X_dyn = sol.y.T  # shape (n_pts, 2)
    t_arr = sol.t

    def run_dynamics():
        discover_dynamics(
            X_dyn, t_arr,
            state_names=["prey", "predator"],
            max_terms=5,
            strategy="greedy_forward",
        )

    cpu_time = benchmark(run_dynamics, cpu_device, warmup=0, repeats=3)
    gpu_time = benchmark(run_dynamics, gpu_device, warmup=0, repeats=3) if HAS_GPU else None

    speedup = cpu_time / gpu_time if gpu_time else None
    gpu_str = f"{gpu_time:.4f}s" if gpu_time else "N/A"
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"  n_pts={n_pts:>7,}: CPU={cpu_time:.4f}s  GPU={gpu_str}  Speedup={sp_str}")

    results.append({
        "benchmark": "ODE Discovery",
        "size": n_pts,
        "cpu": cpu_time,
        "gpu": gpu_time,
    })

## Summary

In [None]:
# --- Summary Table ---
print("\nPerformance Summary")
print("=" * 75)
header = f"{'Benchmark':<22} {'Size':>10} {'CPU (s)':>10} {'GPU (s)':>10} {'Speedup':>10}"
print(header)
print("-" * 75)
for r in results:
    gpu_str = f"{r['gpu']:.4f}" if r["gpu"] is not None else "N/A"
    speedup = r["cpu"] / r["gpu"] if r["gpu"] else None
    sp_str = f"{speedup:.2f}x" if speedup else "N/A"
    print(f"{r['benchmark']:<22} {r['size']:>10,} {r['cpu']:>10.4f} {gpu_str:>10} {sp_str:>10}")

# --- Visualization ---
# Use the largest problem size for each benchmark
benchmarks_seen = []
largest = {}
for r in results:
    name = r["benchmark"]
    if name not in largest or r["size"] > largest[name]["size"]:
        largest[name] = r
    if name not in benchmarks_seen:
        benchmarks_seen.append(name)

bench_names = benchmarks_seen
cpu_times = [largest[b]["cpu"] for b in bench_names]
gpu_times = [largest[b]["gpu"] for b in bench_names]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: CPU vs GPU bar chart
ax = axes[0]
x_pos = np.arange(len(bench_names))
bar_width = 0.35
ax.bar(x_pos - bar_width / 2, cpu_times, bar_width, label="CPU", color="steelblue")
if HAS_GPU:
    gpu_vals = [g if g is not None else 0 for g in gpu_times]
    ax.bar(x_pos + bar_width / 2, gpu_vals, bar_width, label="GPU", color="coral")
ax.set_yscale("log")
ax.set_ylabel("Time (s, log scale)")
ax.set_title("CPU vs GPU — Largest Problem Size")
ax.set_xticks(x_pos)
ax.set_xticklabels(bench_names, rotation=30, ha="right", fontsize=8)
ax.legend()
ax.grid(axis="y", alpha=0.3)

# Plot 2: Speedup bar chart
ax = axes[1]
if HAS_GPU:
    speedups = [
        largest[b]["cpu"] / largest[b]["gpu"]
        if largest[b]["gpu"] is not None
        else 0
        for b in bench_names
    ]
    colors = ["seagreen" if s > 1 else "indianred" for s in speedups]
    ax.barh(bench_names, speedups, color=colors)
    ax.axvline(x=1.0, color="black", linestyle="--", linewidth=1, label="Break-even")
    ax.set_xlabel("Speedup (CPU time / GPU time)")
    ax.set_title("GPU Speedup — Largest Problem Size")
    ax.legend()
    ax.grid(axis="x", alpha=0.3)
else:
    ax.text(
        0.5, 0.5, "No GPU available\nSpeedup chart requires GPU",
        ha="center", va="center", transform=ax.transAxes, fontsize=12,
    )
    ax.set_title("GPU Speedup — N/A")

plt.tight_layout()
plt.show()

## Key Takeaways

1. **GPU overhead dominates for small problems.** JAXSR's core workflow involves many
   small `lstsq` calls inside Python loops (greedy selection, exhaustive search). Each
   GPU kernel launch has fixed overhead (~0.1–1 ms), and when the matrices are small,
   this overhead exceeds the computation time. CPU avoids this overhead entirely.

2. **GPU only helps at very large `n_samples`.** Basis evaluation at 100K samples showed
   1.7x speedup, and exhaustive search at 100K showed 1.4x. The crossover point where
   GPU matches CPU is roughly 50K–100K samples for most workflows.

3. **Python-level loops are the real bottleneck (Amdahl's law).** Greedy forward selection
   iterates in Python over candidate basis functions. Even with instant linear algebra,
   the loop overhead caps speedup. This is why bootstrap (50 full fits) showed 0.26x —
   the overhead multiplies with iteration count.

4. **Basis evaluation benefits most.** This is the most "GPU-friendly" operation: each
   basis function is an elementwise op on a large array, with minimal Python loop overhead
   relative to computation.

5. **The honest conclusion: for typical JAXSR workloads, CPU is faster.** Unless you
   are fitting models with >50K samples, stick with CPU. Set `JAX_PLATFORMS=cpu` to
   avoid GPU kernel launch overhead:
   ```python
   import os
   os.environ["JAX_PLATFORMS"] = "cpu"
   ```

6. **Where GPU *would* help.** If JAXSR's inner loops were replaced with batched/vmapped
   JAX operations (e.g., vmapping lstsq over all candidate subsets at once), the GPU
   advantage would be dramatic. This is a potential future optimization.

7. **Vectorized bootstrap functions are already efficient.** `bootstrap_coefficients()` and
   `bootstrap_predict()` compute the pseudo-inverse once and apply it to all bootstrap
   samples in a single matmul — so the per-iteration cost is negligible regardless of device.