# JAX / XLA Whole-Program Training Optimizer
This notebook demonstrates how to use `jax.jit` to compile an entire training step into a single XLA-optimized kernel, significantly reducing execution latency.

In [None]:
import jax
import jax.numpy as jnp
import time
from core.model import init_params
from core.trainer import step

# Initialize data and parameters exactly as in the project
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (128, 128))
y = jax.random.normal(key, (128, 10))
params = init_params(128, 10, key)

print(f"Devices detected: {jax.devices()}")

In [None]:
# Warmup: The first call triggers the XLA compiler
print("Compiling training step...")
_ = step(params, x, y)

# Performance Benchmark
print("Running benchmark for 50 steps...")
t0 = time.time()
for _ in range(50):
    loss, grads = step(params, x, y)
t1 = time.time()

print(f"Optimization Complete.")
print(f"Average JIT time per step: {(t1-t0)/50*1000:.4f} ms")