![JAX Slide](img/02_JAX/Folie1.PNG)

![JAX Slide](img/02_JAX/Folie2.PNG)

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import time
from functools import partial

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")



JAX version: 0.7.0
JAX devices: [CpuDevice(id=0)]
JAX default backend: cpu


![JAX Slide](img/02_JAX/Folie3.PNG)

In [2]:
# Create arrays in NumPy and JAX
np_array = np.array([1, 2, 3, 4, 5])
jax_array = jnp.array([1, 2, 3, 4, 5])

print(f"NumPy array: {np_array}, type: {type(np_array)}")
print(f"JAX array: {jax_array}, type: {type(jax_array)}")

# Converting between them
np_from_jax = np.array(jax_array)
jax_from_np = jnp.array(np_array)

print(f"\nConversion works seamlessly")
print(f"JAX->NumPy: {np_from_jax}")
print(f"NumPy->JAX: {jax_from_np}")

NumPy array: [1 2 3 4 5], type: <class 'numpy.ndarray'>
JAX array: [1 2 3 4 5], type: <class 'jaxlib._jax.ArrayImpl'>

Conversion works seamlessly
JAX->NumPy: [1 2 3 4 5]
NumPy->JAX: [1 2 3 4 5]


In [3]:
# Demonstrate immutability
x = jnp.array([1, 2, 3])
print(f"Original: {x}")

# This creates a NEW array
y = x.at[0].set(999)
print(f"After x.at[0].set(999): original x = {x}, new y = {y}")

# Compare with NumPy (mutable)
np_x = np.array([1, 2, 3])
np_x[0] = 999  # Modifies in-place
print(f"NumPy in-place modification: {np_x}")

Original: [1 2 3]
After x.at[0].set(999): original x = [1 2 3], new y = [999   2   3]
NumPy in-place modification: [999   2   3]


![JAX Slide](img/02_JAX/Folie4.PNG)

![JAX Slide](img/02_JAX/Folie5.PNG)

![JAX Slide](img/02_JAX/Folie6.PNG)

In [4]:
# Simple function without JIT
def slow_function(x):
    return jnp.sum(x**2) + jnp.mean(x**3) - jnp.std(x)

# Same function with JIT
@jax.jit
def fast_function(x):
    return jnp.sum(x**2) + jnp.mean(x**3) - jnp.std(x)

# Create test data
data = jnp.array(np.random.rand(100_000))

# Warm up the JIT function (compilation happens here)
_ = fast_function(data)

print("Performance comparison:")
%timeit slow_function(data).block_until_ready()
%timeit fast_function(data).block_until_ready()

# Verify results are the same
print(f"\nResults match: {jnp.allclose(slow_function(data), fast_function(data))}")

Performance comparison:
284 μs ± 14 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
252 μs ± 111 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Results match: True


### Understanding `.block_until_ready()`

JAX executes asynchronously by default. For accurate timing, we need to wait for computation to complete.

In [5]:
# Demonstrate asynchronous execution
large_data = jnp.array(np.random.rand(10_000_000))

print("Without .block_until_ready():")
start = time.time()
result = jnp.sum(large_data**2)  # Returns immediately
end = time.time()
print(f"Time: {end - start:.6f} seconds (misleading!)")

print("\nWith .block_until_ready():")
start = time.time()
result = jnp.sum(large_data**2).block_until_ready()  # Waits for completion
end = time.time()
print(f"Time: {end - start:.6f} seconds (accurate)")

Without .block_until_ready():
Time: 0.126603 seconds (misleading!)

With .block_until_ready():
Time: 0.028470 seconds (accurate)


## 1.3 Automatic Differentiation

JAX provides automatic differentiation with `grad()`, `jacrev()`, and `jacfwd()`. This is crucial for optimization and machine learning.

In [6]:
# Define a simple function
def quadratic(x):
    return x**2 + 3*x + 2

# Get its derivative
quadratic_grad = jax.grad(quadratic)

# Test values
x = 2.0
print(f"f({x}) = {quadratic(x)}")
print(f"f'({x}) = {quadratic_grad(x)}")
print(f"Analytical f'(x) = 2x + 3 = {2*x + 3}")

# Higher-order derivatives
quadratic_hessian = jax.grad(jax.grad(quadratic))
print(f"f''({x}) = {quadratic_hessian(x)}")
print(f"Analytical f''(x) = 2")

f(2.0) = 12.0
f'(2.0) = 7.0
Analytical f'(x) = 2x + 3 = 7.0
f''(2.0) = 2.0
Analytical f''(x) = 2


In [7]:
# More complex example: gradient of a loss function
def mse_loss(params, x, y):
    """Mean squared error loss for linear regression"""
    predictions = params['w'] * x + params['b']
    return jnp.mean((predictions - y)**2)

# Generate synthetic data
key = jax.random.PRNGKey(42)
x_data = jax.random.normal(key, (100,))
true_w, true_b = 2.5, -1.0
y_data = true_w * x_data + true_b + 0.1 * jax.random.normal(key, (100,))

# Initial parameters
params = {'w': 0.0, 'b': 0.0}

# Compute gradients
loss_grad = jax.grad(mse_loss)
grads = loss_grad(params, x_data, y_data)

print(f"Initial loss: {mse_loss(params, x_data, y_data):.4f}")
print(f"Gradients: {grads}")
print(f"True parameters: w={true_w}, b={true_b}")

Initial loss: 6.3319
Gradients: {'b': Array(1.8523035, dtype=float32, weak_type=True), 'w': Array(-4.158302, dtype=float32, weak_type=True)}
True parameters: w=2.5, b=-1.0


### Exercise 1: Simple Gradient Descent

Implement gradient descent optimization using JAX autodiff:

In [8]:
# Exercise 1: Implement gradient descent
@jax.jit
def gradient_step(params, x, y, learning_rate):
    """Single gradient descent step"""
    # TODO: Compute gradients and update parameters
    # Hint: use jax.grad(mse_loss)(params, x, y)
    
    grads = jax.grad(mse_loss)(params, x, y)
    
    # Update parameters
    new_params = {
        'w': params['w'] - learning_rate * grads['w'],
        'b': params['b'] - learning_rate * grads['b']
    }
    
    return new_params

# Run optimization
params = {'w': 0.0, 'b': 0.0}
learning_rate = 0.1

print("Optimization progress:")
for i in range(100):
    if i % 20 == 0:
        loss = mse_loss(params, x_data, y_data)
        print(f"Step {i:3d}: loss={loss:.6f}, w={params['w']:.4f}, b={params['b']:.4f}")
    
    params = gradient_step(params, x_data, y_data, learning_rate)

final_loss = mse_loss(params, x_data, y_data)
print(f"\nFinal: loss={final_loss:.6f}, w={params['w']:.4f}, b={params['b']:.4f}")
print(f"True parameters: w={true_w:.4f}, b={true_b:.4f}")

Optimization progress:
Step   0: loss=6.331944, w=0.0000, b=0.0000
Step  20: loss=0.005290, w=2.5212, b=-0.9815
Step  40: loss=0.000005, w=2.5976, b=-0.9996
Step  60: loss=0.000000, w=2.5999, b=-1.0000
Step  80: loss=0.000000, w=2.6000, b=-1.0000

Final: loss=0.000000, w=2.6000, b=-1.0000
True parameters: w=2.5000, b=-1.0000


## 2.1 Function Transformations

JAX's strength lies in its composable function transformations: `jit`, `grad`, `vmap`, `pmap`, etc.

In [9]:
# vmap: Vectorization over batch dimensions
def compute_distance(point1, point2):
    """Euclidean distance between two points"""
    return jnp.sqrt(jnp.sum((point1 - point2)**2))

# Create test data: 1000 points in 3D
key = jax.random.key(0)
points = jax.random.normal(key, (1000, 3))
origin = jnp.zeros(3)

# Method 1: Loop (slow)
def distances_loop(points, origin):
    distances = []
    for point in points:
        distances.append(compute_distance(point, origin))
    return jnp.array(distances)

# Method 2: vmap (fast)
distances_vectorized = jax.vmap(compute_distance, in_axes=(0, None))

# Compare performance
print("Performance comparison:")
%timeit distances_loop(points, origin).block_until_ready()
%timeit distances_vectorized(points, origin).block_until_ready()

# Verify results match
result1 = distances_loop(points[:10], origin)
result2 = distances_vectorized(points[:10], origin)
print(f"\nResults match: {jnp.allclose(result1, result2)}")

Performance comparison:
254 ms ± 56.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.14 ms ± 222 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Results match: True


### Understanding `vmap` Arguments

- `in_axes`: Which axes to vectorize over for each input
- `out_axes`: Which axis to use for outputs
- `None`: Don't vectorize this input (broadcast it)

In [10]:
# Different vmap configurations
A = jax.random.normal(key, (5, 3))
B = jax.random.normal(key, (4, 3))

print(f"A shape: {A.shape}, B shape: {B.shape}")

# Compute distances between all pairs
# vmap over first axis of A, don't vmap B (broadcast)
dist_A_to_all_B = jax.vmap(lambda a: jax.vmap(lambda b: compute_distance(a, b))(B))(A)
print(f"All pairwise distances shape: {dist_A_to_all_B.shape}")

# Alternative: nest vmap calls
pairwise_distance = jax.vmap(jax.vmap(compute_distance, in_axes=(None, 0)), in_axes=(0, None))
result = pairwise_distance(A, B)
print(f"Nested vmap result shape: {result.shape}")
print(f"Results match: {jnp.allclose(dist_A_to_all_B, result)}")

A shape: (5, 3), B shape: (4, 3)
All pairwise distances shape: (5, 4)
Nested vmap result shape: (5, 4)
Results match: True


## 2.2 JAX and Random Numbers

JAX uses explicit random keys for reproducibility and parallel safety.

In [11]:
# JAX requires explicit random state management
key = jax.random.key(42)
print(f"Initial key: {key}")

# Split key to generate independent streams
key, subkey1, subkey2 = jax.random.split(key, 3)
print(f"After split: key={key}")
print(f"Subkey1: {subkey1}")
print(f"Subkey2: {subkey2}")

# Use subkeys for random generation
sample1 = jax.random.normal(subkey1, (5,))
sample2 = jax.random.normal(subkey2, (5,))

print(f"\nSample 1: {sample1}")
print(f"Sample 2: {sample2}")

# Using same key gives same results (reproducibility)
sample1_repeat = jax.random.normal(subkey1, (5,))
print(f"\nRepeating subkey1: {sample1_repeat}")
print(f"Identical to sample1: {jnp.allclose(sample1, sample1_repeat)}")

Initial key: Array((), dtype=key<fry>) overlaying:
[ 0 42]
After split: key=Array((), dtype=key<fry>) overlaying:
[1832780943  270669613]
Subkey1: Array((), dtype=key<fry>) overlaying:
[  64467757 2916123636]
Subkey2: Array((), dtype=key<fry>) overlaying:
[2465931498  255383827]

Sample 1: [ 0.60576403  0.7990441  -0.908927   -0.63525754 -1.2226585 ]
Sample 2: [ 0.4323065   0.5872638  -1.1416743  -0.37379906 -0.19910173]

Repeating subkey1: [ 0.60576403  0.7990441  -0.908927   -0.63525754 -1.2226585 ]
Identical to sample1: True


### Exercise 2: Monte Carlo π Estimation

Use JAX to estimate π using Monte Carlo sampling:

In [12]:
# Exercise 2: Monte Carlo estimation of π
def estimate_pi(key, n_samples):
    """Estimate π by sampling random points in unit square"""
    # TODO: Generate random points in [-1, 1] x [-1, 1]
    # Count how many fall inside unit circle
    # π ≈ 4 * (points inside circle) / (total points)
    
    # Generate random points
    points = jax.random.uniform(key, (n_samples, 2), minval=-1.0, maxval=1.0)
    
    # Check which points are inside unit circle
    distances_squared = jnp.sum(points**2, axis=1)
    inside_circle = distances_squared <= 1.0
    
    # Estimate π
    pi_estimate = 4.0 * jnp.mean(inside_circle)
    
    return pi_estimate

# Test with different sample sizes
key = jax.random.key(123)
sample_sizes = [1000, 10000, 100000, 1000000]

print("Monte Carlo π estimation:")
for n in sample_sizes:
    key, subkey = jax.random.split(key)
    pi_est = estimate_pi(subkey, n)
    error = abs(pi_est - jnp.pi)
    print(f"N={n:7d}: π≈{pi_est:.6f}, error={error:.6f}")

print(f"\nTrue π: {jnp.pi:.6f}")

Monte Carlo π estimation:
N=   1000: π≈3.060000, error=0.081593
N=  10000: π≈3.162800, error=0.021207
N= 100000: π≈3.139520, error=0.002073
N=1000000: π≈3.142648, error=0.001055

True π: 3.141593


![JAX Slide](img/02_JAX/Folie7.PNG)

## 2.3 Optimized Linear Algebra

JAX provides optimized linear algebra operations through XLA compilation.

In [13]:
# Compare JAX vs NumPy for matrix operations
size = 2000
np_A = np.random.rand(size, size).astype(np.float32)
np_B = np.random.rand(size, size).astype(np.float32)

jax_A = jnp.array(np_A)
jax_B = jnp.array(np_B)

print(f"Matrix size: {size}×{size} ({np_A.nbytes/1e6:.1f} MB each)")

print("\nMatrix multiplication performance:")
print("NumPy:")
%timeit np_A @ np_B

print("JAX (JIT):")
jax_matmul = jax.jit(lambda a, b: a @ b)
# Warm up
_ = jax_matmul(jax_A, jax_B).block_until_ready()
%timeit jax_matmul(jax_A, jax_B).block_until_ready()

# Verify results match
np_result = np_A @ np_B
jax_result = jax_matmul(jax_A, jax_B)
print(f"\nResults match: {np.allclose(np_result, jax_result, rtol=1e-5)}")

Matrix size: 2000×2000 (16.0 MB each)

Matrix multiplication performance:
NumPy:
41.5 ms ± 3.84 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
JAX (JIT):
25.1 ms ± 763 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Results match: True


## 3.1 Advanced Example: K-Means Clustering

Let's implement K-means clustering to showcase JAX's capabilities:

In [14]:
def kmeans_step(points, centroids):
    """Single step of K-means: assign points and update centroids"""
    # Compute distances from each point to each centroid
    # Using broadcasting: (N, 1, D) - (1, K, D) -> (N, K, D)
    distances = jnp.linalg.norm(
        points[:, None, :] - centroids[None, :, :], 
        axis=2
    )
    
    # Assign each point to closest centroid
    assignments = jnp.argmin(distances, axis=1)
    
    # Update centroids
    new_centroids = jnp.array([
        jnp.mean(points[assignments == k], axis=0)
        for k in range(len(centroids))
    ])
    
    return new_centroids, assignments

# Generate synthetic clustered data
key = jax.random.key(42)
n_points, n_dims, n_clusters = 5000, 2, 3

# Create 3 clusters
cluster_centers = jnp.array([[2, 2], [-2, 2], [0, -3]])
points = []
for center in cluster_centers:
    key, subkey = jax.random.split(key)
    cluster_points = center + 0.5 * jax.random.normal(subkey, (n_points//n_clusters, 2))
    points.append(cluster_points)

points = jnp.vstack(points)
print(f"Generated {len(points)} points in {n_dims}D with {n_clusters} clusters")

# Initialize centroids randomly
key, subkey = jax.random.split(key)
centroids = jax.random.normal(subkey, (n_clusters, n_dims))

# Run K-means
print("\nRunning K-means:")
for i in range(20):
    new_centroids, assignments = kmeans_step(points, centroids)
    
    # Check convergence
    centroid_shift = jnp.max(jnp.linalg.norm(new_centroids - centroids, axis=1))
    centroids = new_centroids
    
    if i % 5 == 0 or centroid_shift < 1e-4:
        print(f"Iteration {i:2d}: max centroid shift = {centroid_shift:.6f}")
    
    if centroid_shift < 1e-4:
        print(f"Converged after {i} iterations")
        break

print(f"\nFinal centroids:")
for i, centroid in enumerate(centroids):
    print(f"Cluster {i}: [{centroid[0]:6.3f}, {centroid[1]:6.3f}]")

print(f"\nTrue centers:")
for i, center in enumerate(cluster_centers):
    print(f"Cluster {i}: [{center[0]:6.3f}, {center[1]:6.3f}]")

Generated 4998 points in 2D with 3 clusters

Running K-means:
Iteration  0: max centroid shift = 2.402447
Iteration  3: max centroid shift = 0.000000
Converged after 3 iterations

Final centroids:
Cluster 0: [ 1.995,  1.996]
Cluster 1: [-1.995,  2.002]
Cluster 2: [-0.009, -2.994]

True centers:
Cluster 0: [ 2.000,  2.000]
Cluster 1: [-2.000,  2.000]
Cluster 2: [ 0.000, -3.000]


### Exercise 3: Optimize K-means

The current implementation has a Python loop in the centroid update. Make it fully vectorized:

In [15]:
# Exercise 3: Vectorized K-means centroid update
def kmeans_step_optimized(points, centroids):
    """Fully vectorized K-means step"""
    # Compute distances (same as before)
    distances = jnp.linalg.norm(
        points[:, None, :] - centroids[None, :, :], 
        axis=2
    )
    
    # Assign points to clusters
    assignments = jnp.argmin(distances, axis=1)
    
    # TODO: Vectorized centroid update
    # Hint: Use jnp.eye() to create one-hot encoding of assignments
    # Then use matrix operations to compute means
    
    # One-hot encode assignments: (N, K)
    one_hot = jnp.eye(len(centroids))[assignments]
    
    # Count points per cluster
    cluster_sizes = jnp.sum(one_hot, axis=0, keepdims=True)  # (1, K)
    
    # Sum points per cluster: (D, K)
    cluster_sums = points.T @ one_hot  # (D, N) @ (N, K) = (D, K)
    
    # Compute new centroids: (K, D)
    # Avoid division by zero
    cluster_sizes = jnp.maximum(cluster_sizes, 1)
    new_centroids = (cluster_sums / cluster_sizes).T
    
    return new_centroids, assignments

# Test both versions
centroids_test = jax.random.normal(jax.random.key(0), (n_clusters, n_dims))

# Verify they give same results
centroids1, assignments1 = kmeans_step(points, centroids_test)
centroids2, assignments2 = kmeans_step_optimized(points, centroids_test)

print(f"Results match: {jnp.allclose(centroids1, centroids2, rtol=1e-5)}")
print(f"Assignments match: {jnp.array_equal(assignments1, assignments2)}")

# Performance comparison
print("\nPerformance comparison:")
print("Original version:")
%timeit kmeans_step(points, centroids_test)

print("Optimized version:")
%timeit kmeans_step_optimized(points, centroids_test)

Results match: True
Assignments match: True

Performance comparison:
Original version:
2.15 ms ± 47.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Optimized version:
1.05 ms ± 13.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## 3.2 Performance Analysis and Best Practices

### JIT Compilation Overhead

In [16]:
# Demonstrate JIT compilation overhead
def simple_computation(x):
    return jnp.sum(x**2) + jnp.mean(x)

simple_computation_jit = jax.jit(simple_computation)

data = jnp.array(np.random.rand(1000))

# First call includes compilation time
print("First call (includes compilation):")
%timeit -n1 -r1 simple_computation_jit(data).block_until_ready()

# Subsequent calls are fast
print("\nSubsequent calls (compiled):")
%timeit simple_computation_jit(data).block_until_ready()

# Compare with non-JIT version
print("\nNon-JIT version:")
%timeit simple_computation(data).block_until_ready()

First call (includes compilation):
25.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Subsequent calls (compiled):
10.2 μs ± 320 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Non-JIT version:
85.4 μs ± 4.38 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### When JIT Helps Most

In [17]:
# JIT is most beneficial for:
# 1. Complex computations with many operations
# 2. Repeated function calls
# 3. Operations that can be fused

def complex_computation(x):
    """Many operations that can be fused by XLA"""
    y = jnp.sin(x)
    z = jnp.cos(x**2)
    w = jnp.exp(-x)
    return jnp.sum(y * z * w) + jnp.mean(y + z + w)

def simple_operation(x):
    """Single operation - less benefit from JIT"""
    return jnp.sum(x)

# Create JIT versions
complex_jit = jax.jit(complex_computation)
simple_jit = jax.jit(simple_operation)

data = jnp.array(np.random.rand(100_000))

# Warm up
_ = complex_jit(data)
_ = simple_jit(data)

print("Complex computation (many operations):")
%timeit complex_computation(data).block_until_ready()
%timeit complex_jit(data).block_until_ready()

print("\nSimple operation (single operation):")
%timeit simple_operation(data).block_until_ready()
%timeit simple_jit(data).block_until_ready()

Complex computation (many operations):
887 μs ± 7.46 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
785 μs ± 3.86 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Simple operation (single operation):
24.2 μs ± 345 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
21.4 μs ± 278 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## 3.3 Memory Management and Arrays

Understanding JAX's memory model is crucial for performance.

In [18]:
# JAX arrays are immutable - operations create new arrays
x = jnp.array([1, 2, 3, 4, 5])
print(f"Original: {x}")

# These create NEW arrays
y = x + 1
z = x * 2

print(f"x + 1: {y}")
print(f"x * 2: {z}")
print(f"Original x unchanged: {x}")

# For updates, use .at[] syntax
x_updated = x.at[0].set(999)
print(f"\nUpdated at index 0: {x_updated}")
print(f"Original still: {x}")

# Multiple updates
x_multi = x.at[1:3].set(jnp.array([888, 777]))
print(f"Updated slice: {x_multi}")

# Addition to existing values
x_added = x.at[2].add(100)
print(f"Added to index 2: {x_added}")

Original: [1 2 3 4 5]
x + 1: [2 3 4 5 6]
x * 2: [ 2  4  6  8 10]
Original x unchanged: [1 2 3 4 5]

Updated at index 0: [999   2   3   4   5]
Original still: [1 2 3 4 5]
Updated slice: [  1 888 777   4   5]
Added to index 2: [  1   2 103   4   5]


## Summary: JAX Best Practices

### ✅ Do:
- Use `@jax.jit` for repeated complex computations
- Use `vmap` instead of Python loops
- Manage random keys explicitly with `split()`
- Use `.block_until_ready()` for accurate timing
- Prefer functional programming patterns
- Use autodiff (`grad`) for optimization

### ❌ Don't:
- JIT simple single operations (overhead > benefit)
- Use Python loops for array operations
- Rely on global random state
- Mutate arrays (use `.at[]` syntax)
- Ignore compilation overhead in timing

### JAX vs NumPy Performance:
- **Simple operations**: NumPy often faster (less overhead)
- **Complex operations**: JAX wins with fusion and optimization
- **Repeated computations**: JAX dominates after JIT compilation

In [19]:
# Final performance showcase
print("JAX Performance Showcase")
print("=" * 40)

# Large-scale computation
n = 10000
key = jax.random.key(42)
X = jax.random.normal(key, (n, 50))

@jax.jit
def complex_analysis(X):
    # Covariance matrix
    cov = (X.T @ X) / len(X)
    # Eigenvalues
    eigenvals = jnp.linalg.eigvals(cov)
    # Statistics
    return {
        'mean': jnp.mean(X, axis=0),
        'std': jnp.std(X, axis=0),
        'max_eigenval': jnp.max(eigenvals),
        'condition_number': jnp.max(eigenvals) / jnp.min(eigenvals)
    }

# Warm up
_ = complex_analysis(X)

# Time the analysis
import time
start = time.time()
result = complex_analysis(X)
for key in result:
    _ = result[key].block_until_ready()
end = time.time()

print(f"Analyzed {n}×50 matrix in {end-start:.4f} seconds")
print(f"Condition number: {result['condition_number']:.2f}")
print(f"Max eigenvalue: {result['max_eigenval']:.4f}")

JAX Performance Showcase
Analyzed 10000×50 matrix in 0.0100 seconds
Condition number: 1.31+0.00j
Max eigenvalue: 1.1396+0.0000j
