# JAX Performance Demonstration

JAX is a high-performance machine learning library that provides:

- **Just-in-time (JIT) compilation** using XLA compiler
- **Automatic vectorization** for SIMD operations
- **GPU/TPU acceleration** with the same code
- **Automatic differentiation** for gradients
- **Function transformations** like vmap for vectorization

This notebook demonstrates JAX's performance advantages over NumPy for numerical computations through JIT compilation and vectorization.

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit, vmap
import time
import warnings
warnings.filterwarnings('ignore')

# Check JAX backend and devices
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")
print(f"JAX version: {jax.__version__}")

# Set random seeds for reproducibility
np.random.seed(42)
key = jax.random.PRNGKey(42)

## Array Creation and Basic Operations

JAX arrays (DeviceArrays) live on accelerators by default. We'll compare array operations between NumPy and JAX:

In [None]:
# Create large arrays for performance testing
n = 10_000_000  # 10 million elements
print(f"Array size: {n:,} elements")

# Create NumPy arrays first with consistent seed
np.random.seed(42)
np_a = np.random.randn(n).astype(np.float32)
np_b = np.random.randn(n).astype(np.float32)
np_c = np.random.randn(n).astype(np.float32)

# Create JAX arrays using the SAME data for fair comparison
jax_a = jnp.array(np_a)  # Same data as NumPy
jax_b = jnp.array(np_b)  # Same data as NumPy  
jax_c = jnp.array(np_c)  # Same data as NumPy

print(f"NumPy array type: {type(np_a)}")
print(f"JAX array type: {type(jax_a)}")
print(f"JAX array device: {jax_a.device}")

# Verify arrays contain identical data
print(f"Arrays contain same data: {np.array_equal(np_a, jax_a)}")
print(f"Data precision: {np_a.dtype} (NumPy) vs {jax_a.dtype} (JAX)")

## JIT Compilation Benefits

JAX's key advantage comes from JIT compilation, but realistic expectations are important:

**Expected performance gains:**
- **2-10x speedup** for compute-heavy operations
- **Larger gains** on GPU vs CPU
- **Best performance** after warmup (compilation overhead)
- **Diminishing returns** for simple operations

**Important caveats:**
- First call includes compilation time (slower)
- GPU kernels have launch overhead
- Memory transfer costs (CPU ↔ GPU) 
- float32 vs float64 precision differences

In [None]:
# Define computation functions
def numpy_computation(a, b, c):
    """
    Standard NumPy computation - interpreted Python with vectorized operations.
    Each operation is dispatched individually to optimized C code.
    """
    # Complex mathematical expression
    result = np.tanh(a) * np.exp(-b**2) + np.sin(c) * np.cos(a)
    result = result + np.sqrt(np.abs(a * b)) - np.log1p(np.abs(c))
    return result

def jax_computation(a, b, c):
    """
    JAX computation - can be JIT compiled for optimization.
    Operations are fused and optimized by XLA compiler.
    """
    # Same mathematical expression as NumPy version
    result = jnp.tanh(a) * jnp.exp(-b**2) + jnp.sin(c) * jnp.cos(a)
    result = result + jnp.sqrt(jnp.abs(a * b)) - jnp.log1p(jnp.abs(c))
    return result

# Create JIT-compiled version
jax_computation_jit = jit(jax_computation)

print("Functions defined - ready for benchmarking")

In [None]:
# PROPER BENCHMARKING with synchronization and warmup

print("Starting proper benchmark with identical input data...")

# Warmup runs for JIT compilation
print("Warming up JIT compilation...")
for _ in range(3):
    _ = jax_computation_jit(jax_a, jax_b, jax_c).block_until_ready()

# Multiple runs for stable timing
n_runs = 5
numpy_times = []
jax_times = []
jax_jit_times = []

print(f"Running {n_runs} iterations for stable timing...")

for i in range(n_runs):
    # Benchmark NumPy
    start = time.perf_counter()
    numpy_result = numpy_computation(np_a, np_b, np_c)
    numpy_time = time.perf_counter() - start
    numpy_times.append(numpy_time)
    
    # Benchmark JAX without JIT
    start = time.perf_counter()
    jax_result = jax_computation(jax_a, jax_b, jax_c)
    jax_result.block_until_ready()  # Critical: wait for GPU computation
    jax_time = time.perf_counter() - start
    jax_times.append(jax_time)
    
    # Benchmark JAX with JIT (already compiled)
    start = time.perf_counter()
    jax_jit_result = jax_computation_jit(jax_a, jax_b, jax_c)
    jax_jit_result.block_until_ready()  # Critical: wait for GPU computation  
    jax_jit_time = time.perf_counter() - start
    jax_jit_times.append(jax_jit_time)

# Calculate statistics
numpy_avg = np.mean(numpy_times)
numpy_std = np.std(numpy_times)
jax_avg = np.mean(jax_times)
jax_std = np.std(jax_times)
jax_jit_avg = np.mean(jax_jit_times)
jax_jit_std = np.std(jax_jit_times)

print(f"\\nPerformance Results (average ± std over {n_runs} runs):")
print(f"NumPy time:           {numpy_avg:.4f}s ± {numpy_std:.4f}s")
print(f"JAX time (no JIT):    {jax_avg:.4f}s ± {jax_std:.4f}s") 
print(f"JAX time (JIT):       {jax_jit_avg:.4f}s ± {jax_jit_std:.4f}s")

print(f"\\nRealistic Speedup Analysis:")
speedup_jax = numpy_avg / jax_avg if jax_avg > 0 else 0
speedup_jit = numpy_avg / jax_jit_avg if jax_jit_avg > 0 else 0
jit_improvement = jax_avg / jax_jit_avg if jax_jit_avg > 0 else 0

print(f"JAX vs NumPy (no JIT):  {speedup_jax:.1f}x")
print(f"JAX vs NumPy (JIT):     {speedup_jit:.1f}x") 
print(f"JIT improvement:        {jit_improvement:.1f}x")

# Verify numerical accuracy with proper tolerances
max_diff = np.max(np.abs(numpy_result - np.array(jax_jit_result)))
print(f"\\nNumerical verification (using IDENTICAL input data):")
print(f"Max difference: {max_diff:.2e}")
results_match = np.allclose(numpy_result, jax_jit_result, rtol=1e-6, atol=1e-7)
print(f"Results match: {results_match}")

if not results_match:
    print(f"Float32 machine epsilon: {np.finfo(np.float32).eps:.2e}")
    print(f"Relative error: {max_diff/np.max(np.abs(numpy_result)):.2e}")

print(f"\\nNote: Using identical input data ensures numerical accuracy verification")

In [None]:
# DETAILED NUMERICAL VERIFICATION - Let's debug the differences

print("=== DEBUGGING NUMERICAL DIFFERENCES ===")

# First, let's check if the issue is data type conversion
print("\\n1. Data type analysis:")
print(f"NumPy array dtype: {np_a.dtype}")
print(f"JAX array dtype: {jax_a.dtype}")

# Let's use a small subset for detailed analysis
n_small = 1000
np_small_a = np_a[:n_small]
np_small_b = np_b[:n_small] 
np_small_c = np_c[:n_small]
jax_small_a = jax_a[:n_small]
jax_small_b = jax_b[:n_small]
jax_small_c = jax_c[:n_small]

print(f"\\n2. Small subset verification (first {n_small} elements):")

# Step-by-step computation verification
print("\\nStep-by-step verification:")

# Step 1: tanh(a) * exp(-b^2)
np_step1 = np.tanh(np_small_a) * np.exp(-np_small_b**2)
jax_step1 = jnp.tanh(jax_small_a) * jnp.exp(-jax_small_b**2)
diff1 = np.max(np.abs(np_step1 - np.array(jax_step1)))
print(f"Step 1 max diff: {diff1:.2e}")

# Step 2: sin(c) * cos(a)  
np_step2 = np.sin(np_small_c) * np.cos(np_small_a)
jax_step2 = jnp.sin(jax_small_c) * jnp.cos(jax_small_a)
diff2 = np.max(np.abs(np_step2 - np.array(jax_step2)))
print(f"Step 2 max diff: {diff2:.2e}")

# Step 3: sqrt(abs(a * b))
np_step3 = np.sqrt(np.abs(np_small_a * np_small_b))
jax_step3 = jnp.sqrt(jnp.abs(jax_small_a * jax_small_b))
diff3 = np.max(np.abs(np_step3 - np.array(jax_step3)))
print(f"Step 3 max diff: {diff3:.2e}")

# Step 4: log1p(abs(c))
np_step4 = np.log1p(np.abs(np_small_c))
jax_step4 = jnp.log1p(jnp.abs(jax_small_c))
diff4 = np.max(np.abs(np_step4 - np.array(jax_step4)))
print(f"Step 4 max diff: {diff4:.2e}")

# Full computation on small subset
np_small_result = numpy_computation(np_small_a, np_small_b, np_small_c)
jax_small_result = jax_computation(jax_small_a, jax_small_b, jax_small_c)

diff_total = np.abs(np_small_result - np.array(jax_small_result))
print(f"\\nTotal computation differences:")
print(f"Max diff: {np.max(diff_total):.2e}")
print(f"Mean diff: {np.mean(diff_total):.2e}")
print(f"Std diff: {np.std(diff_total):.2e}")

# Check if differences are reasonable for float32
expected_eps = np.finfo(np.float32).eps * 10  # Allow 10x machine epsilon
print(f"\\nFloat32 machine epsilon: {np.finfo(np.float32).eps:.2e}")
print(f"Expected tolerance (10x eps): {expected_eps:.2e}")
print(f"All differences within tolerance: {np.all(diff_total < expected_eps)}")

# Look at worst cases
worst_indices = np.argsort(diff_total)[-5:]
print(f"\\nWorst 5 differences:")
for i, idx in enumerate(worst_indices):
    print(f"{i+1}. Index {idx}: np={np_small_result[idx]:.6f}, jax={jax_small_result[idx]:.6f}, diff={diff_total[idx]:.2e}")

## Automatic Vectorization with vmap

JAX's `vmap` transformation automatically vectorizes functions, eliminating explicit loops and providing better performance:

In [None]:
# Create batch data for vectorization demo
batch_size = 1000
vector_size = 10000

# Create NumPy batch data first
np.random.seed(123)  # Different seed for batch demo
np_batch = np.random.randn(batch_size, vector_size).astype(np.float32)
np_weights = np.random.randn(vector_size).astype(np.float32)

# Create JAX batch data using SAME data
jax_batch = jnp.array(np_batch)
jax_weights = jnp.array(np_weights)

def single_computation(x, w):
    """Computation on a single vector"""
    return jnp.tanh(jnp.dot(x, w)) + jnp.sum(jnp.sin(x))

# NumPy approach: explicit loop
def numpy_batch_computation(batch, weights):
    """Process batch with explicit Python loop"""
    results = []
    for i in range(batch.shape[0]):
        x = batch[i]
        result = np.tanh(np.dot(x, weights)) + np.sum(np.sin(x))
        results.append(result)
    return np.array(results)

# JAX approach: vmap for automatic vectorization
jax_batch_computation = vmap(single_computation, in_axes=(0, None))
jax_batch_computation_jit = jit(jax_batch_computation)

print(f"Batch processing: {batch_size} vectors of size {vector_size}")
print(f"Total elements: {batch_size * vector_size:,}")
print(f"Using identical input data: {np.array_equal(np_batch, jax_batch)}")

In [None]:
# PROPER VMAP BENCHMARKING with synchronization and warmup

print("Proper vmap benchmarking with identical input data...")

# Warmup JIT compilation
print("Warming up vmap JIT compilation...")
for _ in range(3):
    _ = jax_batch_computation_jit(jax_batch, jax_weights).block_until_ready()

# Multiple runs for reliable timing
n_runs = 5
numpy_batch_times = []
jax_vmap_times = []
jax_vmap_jit_times = []

print(f"Running {n_runs} iterations...")

for i in range(n_runs):
    # NumPy with explicit loop
    start = time.perf_counter()
    numpy_batch_result = numpy_batch_computation(np_batch, np_weights)
    numpy_batch_time = time.perf_counter() - start
    numpy_batch_times.append(numpy_batch_time)
    
    # JAX with vmap (no JIT)  
    start = time.perf_counter()
    jax_batch_result = jax_batch_computation(jax_batch, jax_weights)
    jax_batch_result.block_until_ready()  # Critical: wait for computation
    jax_vmap_time = time.perf_counter() - start
    jax_vmap_times.append(jax_vmap_time)
    
    # JAX with vmap + JIT (already compiled)
    start = time.perf_counter()
    jax_batch_jit_result = jax_batch_computation_jit(jax_batch, jax_weights)
    jax_batch_jit_result.block_until_ready()  # Critical: wait for computation
    jax_vmap_jit_time = time.perf_counter() - start
    jax_vmap_jit_times.append(jax_vmap_jit_time)

# Calculate averages and standard deviations
numpy_avg = np.mean(numpy_batch_times)
numpy_std = np.std(numpy_batch_times)
vmap_avg = np.mean(jax_vmap_times)
vmap_std = np.std(jax_vmap_times)
vmap_jit_avg = np.mean(jax_vmap_jit_times)
vmap_jit_std = np.std(jax_vmap_jit_times)

print(f"\\nBatch Processing Results (average ± std over {n_runs} runs):")
print(f"NumPy (explicit loop):    {numpy_avg:.4f}s ± {numpy_std:.4f}s")
print(f"JAX vmap (no JIT):        {vmap_avg:.4f}s ± {vmap_std:.4f}s")
print(f"JAX vmap (JIT):           {vmap_jit_avg:.4f}s ± {vmap_jit_std:.4f}s")

print(f"\\nRealistic Vectorization Speedup:")
speedup_vmap = numpy_avg / vmap_avg if vmap_avg > 0 else 0
speedup_vmap_jit = numpy_avg / vmap_jit_avg if vmap_jit_avg > 0 else 0

print(f"vmap vs NumPy loop:       {speedup_vmap:.1f}x")
print(f"vmap+JIT vs NumPy loop:   {speedup_vmap_jit:.1f}x")

# Verify numerical accuracy with IDENTICAL input data
print(f"\\nNumerical verification (using IDENTICAL input data):")
print(f"Results shape: {jax_batch_jit_result.shape}")

max_diff = np.max(np.abs(numpy_batch_result - np.array(jax_batch_jit_result)))
print(f"Max difference: {max_diff:.2e}")

results_match = np.allclose(numpy_batch_result, jax_batch_jit_result, rtol=1e-6, atol=1e-7)
print(f"Results match: {results_match}")

if not results_match:
    print(f"Mean difference: {np.mean(np.abs(numpy_batch_result - np.array(jax_batch_jit_result))):.2e}")
    print(f"Float32 machine epsilon: {np.finfo(np.float32).eps:.2e}")

print(f"\\nNote: Identical inputs ensure accurate numerical comparison")

## Memory Efficiency and Device Management

JAX manages memory efficiently and can work seamlessly across different devices:

In [None]:
# Memory efficiency demonstration
print("Memory and Device Analysis:")
print(f"NumPy array location: CPU memory")
print(f"JAX array location: {jax_a.device}")  # Fixed: device is a property

# Show memory usage patterns
small_size = 1000
np_small = np.random.randn(small_size)
jax_small = jax.random.normal(key, (small_size,))

print(f"\nSmall array sizes: {small_size} elements")
print(f"NumPy itemsize: {np_small.itemsize} bytes")
print(f"JAX itemsize: {jax_small.itemsize} bytes")

# Device transfer overhead
print(f"\nDevice transfer (if applicable):")
start = time.perf_counter()
jax_to_numpy = np.array(jax_small)  # Device to CPU
transfer_time = time.perf_counter() - start
print(f"JAX to NumPy transfer: {transfer_time:.6f}s")

start = time.perf_counter()
numpy_to_jax = jnp.array(np_small)  # CPU to device
transfer_time2 = time.perf_counter() - start
print(f"NumPy to JAX transfer: {transfer_time2:.6f}s")

## Practical Usage Guidelines

**When to use JAX:**
- Computationally intensive numerical operations
- Machine learning and scientific computing
- When you need automatic differentiation
- Batch processing with consistent operations
- GPU/TPU acceleration available

**When to stick with NumPy:**
- Simple, one-off computations
- Extensive use of NumPy ecosystem (scipy, pandas)
- Frequent interaction with Python objects
- When compilation overhead outweighs benefits
- I/O heavy workloads

In [None]:
# Analysis of when JIT compilation pays off
sizes = [1000, 10000, 100000, 1000000]
compile_times = []
execution_speedups = []

print("JIT Compilation Payoff Analysis:")
print(f"{'Size':<10} {'Compile(ms)':<12} {'Speedup':<10} {'Break-even':<12}")
print("-" * 50)

def test_function(x):
    return jnp.sin(x) * jnp.cos(x) + jnp.exp(-x**2)

test_function_jit = jit(test_function)

for size in sizes:
    # Create test array
    test_arr = jax.random.normal(key, (size,))
    
    # Time regular JAX
    start = time.perf_counter()
    regular_result = test_function(test_arr)
    regular_result.block_until_ready()
    regular_time = time.perf_counter() - start
    
    # Time JIT compilation + execution
    start = time.perf_counter()
    jit_result = test_function_jit(test_arr)
    jit_result.block_until_ready()
    compile_time = time.perf_counter() - start
    
    # Time JIT execution only
    start = time.perf_counter()
    jit_result2 = test_function_jit(test_arr)
    jit_result2.block_until_ready()
    execution_time = time.perf_counter() - start
    
    speedup = regular_time / execution_time
    compilation_overhead = compile_time - execution_time
    break_even_calls = compilation_overhead / (regular_time - execution_time) if execution_time < regular_time else float('inf')
    
    compile_times.append(compilation_overhead * 1000)  # Convert to ms
    execution_speedups.append(speedup)
    
    print(f"{size:<10} {compilation_overhead*1000:<12.2f} {speedup:<10.1f}x {break_even_calls:<12.0f}")

print(f"\nKey Insights:")
print(f"- Compilation overhead decreases relative importance with larger arrays")
print(f"- JIT is most beneficial for repeated calls on same-sized data")
print(f"- Break-even point typically 2-10 calls depending on complexity")

## Advanced JAX Features

JAX provides additional transformations that can further improve performance:

In [None]:
from jax import grad, jacfwd

# Automatic differentiation example
def complex_function(x):
    """Complex function for differentiation demo"""
    return jnp.sum(x**3 - 2*x**2 + jnp.sin(x))

# Get gradient function
grad_fn = grad(complex_function)
grad_fn_jit = jit(grad_fn)

# Test on sample data
x_test = jax.random.normal(key, (1000,))

# Compare gradient computation times
start = time.perf_counter()
gradient = grad_fn(x_test)
gradient.block_until_ready()
grad_time = time.perf_counter() - start

start = time.perf_counter()
gradient_jit = grad_fn_jit(x_test)
gradient_jit.block_until_ready()
grad_jit_time = time.perf_counter() - start

print("Automatic Differentiation Performance:")
print(f"Gradient (no JIT): {grad_time:.6f}s")
print(f"Gradient (JIT):    {grad_jit_time:.6f}s")
print(f"AD + JIT speedup:  {grad_time/grad_jit_time:.1f}x")

# Demonstrate pmap for parallel computation (if multiple devices available)
if len(jax.devices()) > 1:
    from jax import pmap
    print(f"\nMultiple devices available: {len(jax.devices())}")
    print("pmap can distribute computation across devices")
else:
    print(f"\nSingle device setup - pmap would replicate across CPU cores")

print(f"\nJAX ecosystem advantages:")
print(f"- Composable transformations: jit(vmap(grad(...)))")
print(f"- Same code runs on CPU/GPU/TPU")
print(f"- Functional programming paradigm")
print(f"- Research-friendly with cutting-edge optimizations")