# Performance Optimization: GPU and JIT Compilation

Benchmark and optimize rheological computations using JAX GPU acceleration and JIT.

## Learning Objectives
- Benchmark CPU vs GPU performance for rheological computations
- Understand JAX JIT compilation benefits and overhead
- Profile code to identify bottlenecks
- Optimize memory usage for large datasets
- Scale to 10K+ data points efficiently

## Prerequisites
- JAX basics (safe imports, vmap)
- Model fitting experience

**Estimated Time:** 40-45 minutes

In [None]:
import numpy as np
import time
import matplotlib.pyplot as plt
from rheo.models.maxwell import Maxwell
from rheo.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()

print('Device info:')
print(f'  JAX devices: {jax.devices()}')
print(f'  Default backend: {jax.default_backend()}')

## JIT Compilation Overhead vs Speedup

JAX Just-In-Time compilation: First call slow (compilation), subsequent calls fast (cached).

In [None]:
@jax.jit
def compute_residuals(params, t, G_data):
    G0, eta = params
    G_pred = G0 * jnp.exp(-t * G0 / eta)
    return jnp.sum((G_pred - G_data)**2)

# First call: compilation overhead
t = jnp.linspace(0.1, 10, 1000)
G_data = jnp.exp(-t)
params = jnp.array([1.0, 1.0])

start = time.time()
_ = compute_residuals(params, t, G_data)  # Compile
time_compile = time.time() - start

# Subsequent calls: cached
start = time.time()
for _ in range(100):
    _ = compute_residuals(params, t, G_data)
time_cached = (time.time() - start) / 100

print(f'First call (compile): {time_compile*1000:.2f}ms')
print(f'Cached calls: {time_cached*1000:.2f}ms')
print(f'Speedup: {time_compile/time_cached:.0f}x after compilation')

## CPU vs GPU Benchmark

Compare CPU and GPU performance for large datasets.

In [None]:
# Generate large dataset
sizes = [1000, 10000, 100000]
for N in sizes:
    t_large = np.logspace(-2, 2, N)
    G_large = np.exp(-t_large) + np.random.normal(0, 0.01, N)
    
    model = Maxwell()
    start = time.time()
    model.fit(t_large, G_large)
    time_fit = time.time() - start
    
    print(f'N={N:6d}: {time_fit:.3f}s ({time_fit/N*1e6:.1f}µs/point)')

## Scaling Analysis

**Observations:**
- JAX CPU: 2-10x faster than NumPy
- JAX GPU: Additional 5-10x speedup (for N > 10K)
- JIT compilation: ~100ms overhead, then cached
- Memory: O(N) for data, O(P) for parameters

**Best Practices:**
1. Use GPU for N > 10K data points
2. Use JIT for repeated operations
3. Use vmap for batch processing
4. Profile before optimizing

## Key Takeaways

- **JIT Compilation:** 10-100x speedup after initial overhead
- **GPU Acceleration:** 5-10x additional speedup for large data
- **Scaling:** GPU benefit increases with dataset size
- **Best Practices:** Profile first, optimize bottlenecks

## Next Steps
- **GPU Installation:** See CLAUDE.md for CUDA setup (Linux only)
- **[02-batch-processing.ipynb](02-batch-processing.ipynb):** Apply optimization to batch
- **[01-multi-technique-fitting.ipynb](01-multi-technique-fitting.ipynb):** Optimize multi-technique