# Python Numerical Computing Speedup Comparison

This notebook compares the performance of different Python libraries for numerical computing:

- **NumPy**: Standard vectorized operations
- **NumExpr**: Multi-threaded expression evaluation
- **JAX**: JIT compilation with GPU support
- **Numba**: JIT compilation for CPU (single-threaded and parallel)

We'll use a simple but representative mathematical expression to demonstrate realistic speedups.

In [None]:
import numpy as np
import numexpr as ne
import jax.numpy as jnp
import jax
from jax import jit
import numba
from numba import njit, prange
import time
import warnings
warnings.filterwarnings('ignore')

# Check available devices and settings
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")
print(f"NumExpr threads: {ne.nthreads}")
print(f"NumPy threads: {np.__config__.show() if hasattr(np.__config__, 'show') else 'N/A'}")

## Test Setup

We'll use a mathematical expression that's common in scientific computing:
```
result = sin(x) * exp(-y**2) + sqrt(abs(x * z)) - log1p(abs(y))
```

This expression includes:
- Trigonometric functions (sin)
- Exponential functions (exp)
- Power operations (**2)
- Element-wise operations (* +  -)
- Mathematical functions (sqrt, log1p, abs)

In [None]:
# Create test data - using same data for all methods to ensure fair comparison
n = 10_000_000  # 10 million elements
np.random.seed(42)

# Create arrays in float32 for consistency across all libraries
x = np.random.randn(n).astype(np.float32)
y = np.random.randn(n).astype(np.float32) 
z = np.random.randn(n).astype(np.float32)

# JAX arrays (same data)
x_jax = jnp.array(x)
y_jax = jnp.array(y)
z_jax = jnp.array(z)

print(f"Array size: {n:,} elements ({x.nbytes/1024**2:.1f} MB each)")
print(f"Total memory: {3 * x.nbytes/1024**2:.1f} MB")
print(f"Data type: {x.dtype}")

## Implementation of Each Method

In [None]:
# 1. NumPy implementation (baseline)
def numpy_compute(x, y, z):
    """Standard NumPy vectorized operations"""
    return np.sin(x) * np.exp(-y**2) + np.sqrt(np.abs(x * z)) - np.log1p(np.abs(y))

print("NumPy implementation ready")

In [None]:
# 2. NumExpr implementation (multi-threaded evaluation)
def numexpr_compute(x, y, z):
    """NumExpr multi-threaded expression evaluation"""
    return ne.evaluate("sin(x) * exp(-y**2) + sqrt(abs(x * z)) - log1p(abs(y))")

print("NumExpr implementation ready")

In [None]:
# 3. JAX implementation (JIT compiled)
def jax_compute(x, y, z):
    """JAX computation with XLA optimization"""
    return jnp.sin(x) * jnp.exp(-y**2) + jnp.sqrt(jnp.abs(x * z)) - jnp.log1p(jnp.abs(y))

# Create JIT-compiled version
jax_compute_jit = jit(jax_compute)

# Warmup JIT compilation
print("Warming up JAX JIT compilation...")
for _ in range(3):
    _ = jax_compute_jit(x_jax, y_jax, z_jax).block_until_ready()

print("JAX implementation ready")

In [None]:
# 4. Numba implementations (JIT compiled for CPU)
@njit
def numba_compute(x, y, z):
    """Numba JIT compiled single-threaded"""
    result = np.empty_like(x)
    for i in range(len(x)):
        result[i] = (np.sin(x[i]) * np.exp(-y[i]**2) + 
                    np.sqrt(np.abs(x[i] * z[i])) - 
                    np.log1p(np.abs(y[i])))
    return result

@njit(parallel=True)
def numba_parallel_compute(x, y, z):
    """Numba JIT compiled multi-threaded with prange"""
    result = np.empty_like(x)
    for i in prange(len(x)):
        result[i] = (np.sin(x[i]) * np.exp(-y[i]**2) + 
                    np.sqrt(np.abs(x[i] * z[i])) - 
                    np.log1p(np.abs(y[i])))
    return result

# Warmup Numba compilation
print("Warming up Numba JIT compilation...")
small_x, small_y, small_z = x[:1000], y[:1000], z[:1000]
_ = numba_compute(small_x, small_y, small_z)
_ = numba_parallel_compute(small_x, small_y, small_z)

print("Numba implementations ready")

## Performance Benchmark

We'll run each method multiple times and report average performance with standard deviation.

In [None]:
def benchmark_method(func, name, *args, n_runs=5, **kwargs):
    """Benchmark a function with multiple runs"""
    times = []
    results = []
    
    for i in range(n_runs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        
        # Handle JAX async execution
        if hasattr(result, 'block_until_ready'):
            result.block_until_ready()
            
        elapsed = time.perf_counter() - start
        times.append(elapsed)
        results.append(result)
    
    avg_time = np.mean(times)
    std_time = np.std(times)
    
    print(f"{name:<20}: {avg_time:.4f}s ± {std_time:.4f}s")
    return avg_time, std_time, results[-1]

print("Running performance benchmark...\n")

# Benchmark all methods
numpy_time, numpy_std, numpy_result = benchmark_method(numpy_compute, "NumPy", x, y, z)
numexpr_time, numexpr_std, numexpr_result = benchmark_method(numexpr_compute, "NumExpr", x, y, z)
jax_time, jax_std, jax_result = benchmark_method(jax_compute_jit, "JAX (JIT)", x_jax, y_jax, z_jax)
numba_time, numba_std, numba_result = benchmark_method(numba_compute, "Numba (single)", x, y, z)
numba_par_time, numba_par_std, numba_par_result = benchmark_method(numba_parallel_compute, "Numba (parallel)", x, y, z)

print("\nBenchmark completed.")

## Results Analysis

In [None]:
# Calculate speedups relative to NumPy
print("\n=== SPEEDUP ANALYSIS ===")
print(f"{'Method':<20} {'Time (s)':<12} {'Speedup':<10} {'vs NumPy'}")
print("-" * 50)

methods = [
    ("NumPy", numpy_time, numpy_std),
    ("NumExpr", numexpr_time, numexpr_std),
    ("JAX (JIT)", jax_time, jax_std),
    ("Numba (single)", numba_time, numba_std),
    ("Numba (parallel)", numba_par_time, numba_par_std)
]

for name, avg_time, std_time in methods:
    speedup = numpy_time / avg_time
    print(f"{name:<20} {avg_time:.4f}±{std_time:.4f}  {speedup:.1f}x")

print("\n=== NUMERICAL ACCURACY ===")
# Verify all methods produce the same results (within numerical precision)
results = {
    "NumPy": numpy_result,
    "NumExpr": numexpr_result,
    "JAX": np.array(jax_result),
    "Numba (single)": numba_result,
    "Numba (parallel)": numba_par_result
}

reference = numpy_result
for name, result in results.items():
    max_diff = np.max(np.abs(reference - result))
    matches = np.allclose(reference, result, rtol=1e-6, atol=1e-7)
    print(f"{name:<20}: max diff = {max_diff:.2e}")

## Summary

**Typical Performance Rankings:**

1. **NumPy** (baseline): Standard vectorized operations, single-threaded
2. **NumExpr** (2-4x faster): Multi-threaded evaluation, reduced memory allocation
3. **JAX** (varies): GPU acceleration can give 5-20x speedup, CPU similar to NumPy
4. **Numba single** (2-5x faster): JIT compilation eliminates Python overhead
5. **Numba parallel** (4-10x faster): Multi-threaded JIT compilation

**Key Insights:**
- NumExpr excels for CPU-bound array operations
- JAX shines with GPU acceleration
- Numba provides excellent CPU performance with explicit parallelization
- All methods produce numerically identical results
- Performance gains depend on array size, operation complexity, and hardware

In [None]:
%timeit np.sqrt(np.random.rand(512*512))

In [None]:
import pyfastflow as pff

noise = pff.noise.red_noise(4096,4096)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(noise)

In [None]:
import cProfile

In [None]:
cProfile.run("pff.noise.red_noise(4096,4096)", sort='tottime')