# Introduction to JAX: from warp-speed NumPy to neural networks

JAX is NumPy on rocket fuel: it runs blazingly fast on CPUs, GPUs, and TPUs by automatically compiling your computational graph to [XLA](https://github.com/openxla/xla). Oh, and it also has autograd and some other cool tricks.

**What makes JAX special:**
- **Speed**: Your code runs where it's fastest (CPU/GPU/TPU)
- **Autograd**: Calculate derivatives of pretty much any function
- **Compilation**: Transforms Python into optimized machine code
- **Functional core**: Pure functions guarantee reproducibility
- **Easy parallelization**: Write code for a single item and run it it over multiple dimensions, data structures, and devices

Whether you're training massive language models, solving differential equations, or just trying to make your data science workflow faster, JAX delivers performance without sacrificing readability.

**TL;DR**: It's NumPy, but with superpowers.

## What about PyTorch?

### When to use JAX:
- When PyTorch is too slow
  - Lots of tiny GPU-operations: the compiler will fuse them all together
  - Speeding up (large) neural networks, such as [LLMs](https://github.com/xai-org/grok-1)
- Projects requiring perfect reproducability and determinism
- When working with TPUs *(JAX was literally made for this)*

### When to use PyTorch:
- Quick & dirty prototyping (i.e., most research code)
- When using lots of random numbers that don't need to be reproducable
- When you need easy debugging
- Projects requiring extensive ecosystem support
- Production-grade deep learning (often exported to ONNX)

## Let's get started with JAX
First: some handy imports!

Notice `jnp = jax.numpy` and `jrand = jax.random`.

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrand
from jax import grad, jit, vmap, lax
import numpy as np
import time
from typing import Any
_counter = 123
print(f"JAX version: {jax.__version__}")

## JAX = Accelerated NumPy

JAX provides a NumPy-like interface that runs on accelerated hardware. Most NumPy operations have direct JAX equivalents.

**TL;DR**: Write NumPy code, change `np` into `jnp`, profit.

In [None]:
# NumPy array - lives a peaceful life on your CPU
np_array = np.array([[1, 2], [3, 4]])

# JAX array - same syntax, but secretly lives on your GPU/TPU (if available)
jax_array = jnp.array([[1, 2], [3, 4]])

# Operations look similar
print("NumPy sum:", np_array.sum())
print("JAX sum:", jax_array.sum())

# But JAX arrays have a special home
print(f"JAX array device: {jax_array.device}")

### Automatic device detection in JAX
Every 15 seconds, somewhere in the world, a PyTorch dev encounters this error:\
`RuntimeError: Expected all tensors to be on the same device, but found at least two devices`

With JAX: never again!

JAX automatically picks the 'best' available device for you. It follows a simple priority hierarchy: TPU > GPU > CPU.\
No more need for tedious `.to(device)` calls. Hallelujah!

**TL;DR**: JAX handles device placement for you. You can forget it even exists.

**Pro Tip:** Always check that JAX detects your GPU, or it'll use the CPU and never tell you. *(ask me how I know)*

In [None]:
print("Running JAX on", jax.lib.xla_bridge.get_backend().platform)

## The foundation of JAX: functional programming
In JAX, there cannot be global states. Every function must return the same output when given the same input. This makes the code more predictable and easier to interpret.

**TL;DR**: Functions should depend only on their inputs, not on the phase of the moon.

**Pro Tip**: In desparate need of a global variable? Provide it as an extra input to the function and return it as an extra output.
\
**Pro Tip #2**: Global variables *are* allowed as long as they remain constant throughout the entire program. This can be useful for setting configs and hyperparameters.

In [None]:
# Bad: Global state makes your code unpredictable
def bad_increment(x):
  """Global states are like a box of chocolates.
  You never know what you're gonna get..."""
  return x + _counter

In [None]:
# Any predictions for what you'll see here?
bad_increment(4)

In [None]:
# Good: Pure function, like a reliable vending machine
def good_increment(x, counter):
  return x + counter, counter + 1

In [None]:
# Any predictions for what you'll see here?
good_increment(4, 1)

### Random numbers without global seeds

One global state we all like to ignore, is the **Pseudo-Random Number Generator seed**.

Since JAX forbids global states, you need to *manually* set this right! Every single time. Anywhere in your code. Yes, this sucks. But it must be done.

**TL;DR**: Random numbers need explicit keys. Always split, never reuse.

**Pro Tip**: Create one master key at the start of your program, then split it whenever you need randomness. This keeps your code reproducible.

#### Creating random numbers in NumPy
Easy peasy, but impossible to predict.
Run this cell twice and you'll get different results.

In [None]:
a_np = np.random.randn(1, 2, 3)
print(a_np)

#### Creating random numbers in JAX
A major hassle, but completely deterministic.
Run these cells twice and you'll get the same results.

In [None]:
# 1. Summon the ancient PRNG key exactly *once*
key = jrand.PRNGKey(0)

# 2. Never use the ancient key directly! Split it like a coconut
key, subkey = jrand.split(key) # = two fresh coconut halves

# 3. Use the subkey to generate random numbers
a_jax = jrand.normal(subkey, (1, 2, 3))
print(a_jax)

# 4. Repeat for new random vector
key, use_key = jrand.split(key)
b_jax = jrand.normal(subkey, (1, 2, 3))
print(b_jax)

# What definitely NOT to do:
subkey = jrand.PRNGKey(0)
numbers1 = jrand.normal(subkey, (3,))  # This will always be the same...
numbers2 = jrand.normal(subkey, (3,))  # ...as this. Oops!
print(numbers1, numbers2)

### Array immutability, and how to deal with it
In-place array modifications are also a sort of global state operations. Therefore, in JAX, arrays are immutable.

**TL;DR**: Don't try to change arrays in-place; create new ones instead.

**Pro Tip**: Use `at+set` to create modified copies of arrays.

In [None]:
x = jnp.array([1, 2, 3])

# This creates a new array
y = x.at[0].set(10)

print("Original array:", x)
print("New array:", y)

## JIT Compilation
JIT (Just-In-Time) compilation is JAX's secret sauce. It fuses operations together, removes dead or redundant code, and optimizes your code for the available hardware.\
Think of JIT as meal-prepping for the week vs. cooking each meal from scratch — there's an upfront cost, but huge savings afterward.

**TL;DR**: JIT makes your code fast, but needs a warmup run. It's worth it for any code you'll run more than once.

**Pro Tip**: When timing JAX code, use `.block_until_ready()`. Otherwise, [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html) will skew your results.

In [None]:
def coffee_maker(beans):
    """Simulates a complex morning coffee routine"""
    return jnp.sum(4 * beans + 2 * jnp.sin(beans) ** 2)

# First, let's time it without JIT
morning_beans = jnp.ones((1000,))
print("Slooooooowwwww:")
%timeit coffee_maker(morning_beans).block_until_ready()

# Now with JIT - first cup is slow (compilation)
print("Quite slow:")
%timeit jit(coffee_maker)(morning_beans).block_until_ready()

# But every cup after that is lightning fast!
jitted_coffee = jit(coffee_maker)  # or use @jit if you're fancy
_ = jitted_coffee(morning_beans)  # First slow cup
print("Super fast:")
%timeit jitted_coffee(morning_beans).block_until_ready()  # Fast cups forever

### Control Flow in JIT

Python is loved for its dynamic control flow. But JAX's static compiler doesn't like that. If you want to use `jit`, you'll have to write your code the JAX way.

**TL;DR**: Don't use Python's `if/else` and `for/while`. Use their JAX equivalents.

#### If / else ⟹ cond / switch / where
Use `lax.cond`, `lax.switch` or `jnp.where` instead of `if/else`.

In [None]:
# Cheap coffee on Wednesdays!
def discount_bad(day_num, price):
    return price * 0.2 if day_num == 3 else 0

# This fails with JIT - conditional depends on runtime value
try:
  jit(discount_bad)(1, 25.0)
except jax.errors.TracerBoolConversionError as e:
  print("BIG ERROR: DON'T USE IF IN JIT")
  print(e)

In [None]:
# The JAX way: always calculable paths
def discount_good(day_num, price):
    return jnp.where(day_num == 3, price * 0.2, 0)

# This works, but still no discounts on Mondays :'(
jit(discount_good)(1, 25.0)

#### Loops ⟹ scan
Loops are allowed, but will be unrolled at compilation. This will make the computation graph unnecessarily large and cause compilation times to explode (~4th-order power law). Using `scan` avoids this.

That's why DeepMind's slogan is: ["Always scan when you can!"](https://github.com/jax-ml/jax/discussions/3850#discussioncomment-44785)

In [None]:
# Grind beans with multiple steps
def grind_bad(beans):
    result = beans
    for i in range(4096):  # Yes, I like my beans reduced to dust
        result = result * 0.8 + 2  # Each grinding step
    return result

grind_bad = jit(grind_bad)
# For loops work but take long to compile
grind_bad(jnp.array([4.0, 8.0]))

In [None]:
# Only the compilation itself is slow. Afterwards, it's super fast.
grind_bad(jnp.array([5.0, 7.0]))

In [None]:
# Better: Using scan for iterative processes
# (but unfortunately it's a bit ugly)
def grind_good(beans):
    def grind_step(carry, _):
        return carry * 0.8 + 2, None
    final_brew, _ = lax.scan(grind_step, beans, None, length=4096)
    return final_brew

grind_good = jit(grind_good)
grind_good(jnp.array([4.0, 8.0]))

In [None]:
# Ironically, fori / scan is a teeny-tiny bit slower than unrolled for-loops
grind_good(jnp.array([5.0, 7.0]))

Sometimes unrolling loops is actually beneficial! That is because the compiler is not smart enough yet to optimize a scan beyond its boundaries. Unrolling creates a larger computational graph allowing better operator fusion.

For example, if your loop starts with `y = x**2` and ends with `x = b**2`, unrolling lets the compiler optimize to `y = b**4`.

**Pro Tip**: Use `scan(..., unroll=True)` to get this optimization when needed, though be prepared for slower compilation times.

#### Summary

| Control Structure | JAX Equivalent | Use When |
|------------------|----------------|-----------|
| if/else | lax.cond, jnp.where | Simple conditionals |
| for loops | lax.scan | Iterative operations |
| while loops | lax.while_loop | Dynamic iteration |

## Effortless vectorization with `vmap`
PyTorch is very much a "batch first" sort of framework. For everything you do, you need to keep into account that there will be an extra batch dimension added to the data.
Sometimes, this can be a hassle. If you've ever struggled with broadcasting dimensions in NumPy/PyTorch, you're in for a treat.

JAX's `vmap` simplifies batching. Write code for one item, and `vmap` scales it to many — just like that.

It's like having a sous-chef that perfectly scales your recipe from serving 1 person to serving 100 without changing how you write the recipe.

**TL;DR**: Single-item code, batch-ready with `vmap`.

**Pro Tip**: Use `vmap`'s `in_axes` parameter to control which arguments get vectorized. It's like telling your sous "100 servings, but no need for 100 different pepper mills".\
**Pro Tip #2**: Abuse the parallel device reduction `pmean` and `psum` to reduce over your vmapped axis.

In [None]:
def brew_coffee(beans, water):
    """Brew a cup of coffee with given beans and water ratio"""
    return jnp.dot(beans, water)

# Make a single cup
beans = jnp.array([1, 2, 3])  # different bean types (arabica, robusta, etc.)
water = jnp.array([0.5, 0.3, 0.2])  # water temperature/pressure ratios
print("Single cup:", brew_coffee(beans, water))

Let's say you want to apply a function to a whole batch of inputs.

In PyTorch, you'd have to bend your head over the resulting matmul dimensions (or use [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html)).

In JAX, it's as easy as you'd hope:

In [None]:
# Make coffee for the whole family
batched_brewing = vmap(brew_coffee)
family_beans = jnp.stack([beans, beans * 2])  # Some want stronger coffee
family_water = jnp.stack([water, water / 2])  # Different water profiles
print("Family coffee:", batched_brewing(family_beans, family_water))

# Make different coffees with same water settings
cafe_brewing = vmap(brew_coffee, in_axes=(0, None))
print("Different beans, same water:", cafe_brewing(family_beans, water))

But wait, what if you want to do a reduction after your vmap? (e.g., sum / mean)

No worries, JAX has got you covered:

In [None]:
from functools import partial

# Create a batch of coffee beans and water profiles
batch_beans = jnp.array([[1, 2], [3, 4]])
batch_water = jnp.array([[0.5, 0.5], [0.7, 0.3]])

# To be sure about which axis to reduce over, pmean needs an axis_name arg.
my_axis_name: str = "coffee_batch"

@partial(vmap, axis_name=my_axis_name)
def mean_coffee_quality(beans, water):
    """Calculate mean quality across a batch of coffee brews"""
    brews = brew_coffee(beans, water)
    return lax.pmean(brews, axis_name=my_axis_name)

print(mean_coffee_quality(batch_beans, batch_water)) # note that the dims were preserved

# Adding out_axes=None tells vmap to collapse this unnecessary dim into a scalar
@partial(vmap, axis_name=my_axis_name, out_axes=None)
def mean_coffee_quality_scalar(beans, water):
    """Also mean quality, but as a scalar this time"""
    brews = brew_coffee(beans, water)
    return lax.pmean(brews, axis_name=my_axis_name)

print(mean_coffee_quality_scalar(batch_beans, batch_water)) # collapsed to scalar

**Disclaimer:** `pmean` and `psum` are originally meant for reductions over `pmap` operations (hence the `p`). `pmap` **p**arallellizes a big computation over multiple devices. Think of it as having several sous-chefs working together to cook a single huge buffet!

## Automatic Differentiation with `grad`

JAX makes calculating derivatives as easy as making a cup of coffee. Actually, it's much easier: JAX does all the work for you! Just wrap your function with `grad` and you're done.

**TL;DR**: `grad` turns any function into its derivative. No calculus required!

In [None]:
# Let's differentiate a coffee pricing model
def coffee_price(beans, milk):
    """Compute price of a coffee based on ingredients"""
    return 0.5 * beans**2 + 2.0 * milk**2  # Fancy quadratic pricing

# Get derivative wrt beans
d_beans = grad(coffee_price)  # Default: differentiate wrt first arg
print("If beans go up by $1, coffee price changes by $",
      d_beans(5.0, 3.0))  # Evaluate at beans=$5, milk=$3

# Get both derivatives at once
d_both = grad(coffee_price, argnums=(0,1))
print("Price sensitivity to ingredients:", d_both(5.0, 3.0))

## Power combo: stacking transformations

One of JAX's most powerful features is how seamlessly you can combine its fundamental building blocks. Like LEGOs, you can stack `jit`, `vmap`, `grad`, and more to create complex functions with minimal code.

**Pro Tip:** use `@` decorators above a function definition to keep your code clean.

In [None]:
# Let's do a fancy market analysis

# Stack transformations to analyze price elasticity
@jit  # Make it fast
@vmap  # Apply to multiple prices at once
@grad  # Get price sensitivity
def profit_sensitivity(beans_price):
    """How sensitive is profit to bean price (relative to milk price)?"""
    return coffee_price(beans_price, jnp.ones_like(beans_price))

# Analyze for 5 different bean costs
bean_prices = jnp.array([4.0, 5.0, 6.0, 7.0, 8.0])
sensitivity = profit_sensitivity(bean_prices)
print(f"Profit sensitivity to bean prices: {sensitivity}")

# Want second derivatives? Just add another grad!
@jit
@vmap
@grad
@grad
def profit_elasticity(beans_price):
    """How quickly does sensitivity change with price?"""
    return coffee_price(beans_price, jnp.ones_like(beans_price))

elasticity = profit_elasticity(bean_prices)
print(f"Profit elasticity: {elasticity}")

## PyTrees: Python data structures on steroids

PyTrees are JAX's way of handling complex data structures. Think of them as Python containers (dicts, lists, tuples) that JAX can magically understand and operate on.

**TL;DR**: PyTrees let you use JAX with structured data, not just arrays.

**Pro Tip**: Use `jax.tree.map` to apply functions to every array in a PyTree. It's like `vmap`, but for nested structures instead of batches.

In [None]:
from typing import NamedTuple

class CoffeeOrder(NamedTuple):
    """A very sophisticated coffee order"""
    drinks: dict[str, list[tuple[jax.Array, jax.Array]]]  # (beans, milk) per cup
    extras: tuple[dict[str, float]]  # sugar, cinnamon, etc.

# A complex order for a very picky coffee shop
my_coffee_order = CoffeeOrder(
    drinks = {
        "espresso": [(jnp.array([20.0]), jnp.array([0.0]))],  # No milk!
        "latte": [(jnp.array([15.0]), jnp.array([150.0]))],  # Lots of milk
    },
    extras = ({"sugar": 2.0, "cinnamon": 0.5},)
)
print("Your complex order:", my_coffee_order)

# Calculate prices for everything in one go
def apply_coffee_price(bean_milk_tuple):
    beans, milk = bean_milk_tuple
    return coffee_price(beans, milk)

# Get prices for the drinks only (I refuse to pay for extras)
# Also, tell tree.map to expect a tuple as input (default: jax.Array)
print("Prices:", jax.tree.map(apply_coffee_price,
                              my_coffee_order.drinks,
                              is_leaf=lambda x: isinstance(x, tuple)))

This feature is incredibly useful. You can define a function on a simple array and expect it to extrapolate to arbitrary PyTrees. Powerful stuff!

Doing this in PyTorch would probably require a whole bunch of nested for loops.

## Neural Networks with Equinox

JAX is at its strongest we have some big computation repeating itself for many iterations. Oh, and also if we need gradients and batches.

So obviously, there's one great application for JAX: **neural networks**

But there's a problem. Remember how JAX hates state? Well, neural networks are basically big bags of state (weights, biases, batch statistics...).

This is where [Equinox](https://github.com/patrick-kidger/equinox) comes in: it converts all that state into a *single* PyTree that JAX can handle.

**TL;DR:** Use Equinox when you want PyTorch-like simplicity with JAX-like speed.

**Pro Tip:** Use Equinox's `filter_{jit,grad,vmap}` instead of the JAX building blocks. It'll save you some headaches.

In [None]:
!pip install --no-deps equinox jaxtyping

In [None]:
import equinox as eqx

class MLP(eqx.Module):
    """A PyTree pretending to be a Multi-Layer Perceptron"""
    layers: list

    def __init__(self, key):
        # Split our key into three (like a responsible adult)
        keys = jrand.split(key, 3)

        # Use keys for random weight init
        self.layers = [
            eqx.nn.Linear(784, 512, key=keys[0]),  # Big layer
            eqx.nn.Linear(512, 256, key=keys[1]),  # Medium layer
            eqx.nn.Linear(256, 10, key=keys[2])    # Small layer
        ]

    def __call__(self, x):
        # Feed forward, with extra GELU sauce
        for i, layer in enumerate(self.layers[:-1]):
            x = jax.nn.gelu(layer(x))
        return self.layers[-1](x)  # Final layer, no activation

# Let's see what this PyTree looks like
model = MLP(jrand.PRNGKey(42))  # 42 is the answer to everything
print("Your model's family tree:", model)

## MNIST Training Example
Alright, let's put everything together and train an MLP on MNIST!

### Data
**Plot twist:** JAX has no official data loading library.

You have to load *another* Deep Learning library, just for data :'(

In [None]:
# You'll need to borrow PyTorch's dataloaders (yes, this takes a while to load)
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

def load_mnist(batch_size=64):
    # Normalize and flatten
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: x.reshape(-1)),
    ])

    # Load datasets
    train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

### Useful training functions
**Pro Tip**: With nested functions, only the outermost JIT counts. Any other JITs get ignored.

In [None]:
# Not an outer function, so no need for JIT here (though it couldn't hurt)
def loss_fn(model, x, y):
    """A Cross-Entropy Loss that needs to go down"""
    pred = vmap(model)(x)
    y_onehot = jax.nn.one_hot(y, num_classes=10)
    return jnp.mean(optax.softmax_cross_entropy(pred, y_onehot))

@eqx.filter_jit()  # We want this crazy fast!
def update_step(model, optimizer, opt_state, x, y):
    """You give me model, I give you better model"""
    grads = eqx.filter_grad(loss_fn)(model, x, y)  # Only grads for first arg
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state

### Useful evaluation functions

In [None]:
@eqx.filter_jit()
def get_loss_and_acc(model, x, y):
    """Combined loss+acc calculation for speedy JIT"""
    pred = vmap(model)(x)  # JIT is smart enough to reuse this in loss_fn
    loss = loss_fn(model, x, y)
    acc = jnp.mean(jnp.argmax(pred, axis=1) == y)
    return loss, acc

def evaluate(model, data_loader):
    """Check how good a model is on a given dataset"""
    total_loss, total_acc, count = 0.0, 0.0, 0

    for data, target in data_loader:
        x, y = jnp.array(data.numpy()), jnp.array(target.numpy())

        loss, acc = get_loss_and_acc(model, x, y)

        # Accumulate metrics
        total_loss += loss
        total_acc += acc
        count += 1

    return total_loss/count, total_acc/count

### Training loop

In [None]:
import optax
from tqdm.auto import tqdm

def train_mnist(key, num_epochs=3):
    # Setup
    model = MLP(key)
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    train_loader, test_loader = load_mnist()

    # Training loop
    for epoch in range(num_epochs):
        # Train
        for data, target in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            x, y = jnp.array(data.numpy()), jnp.array(target.numpy())
            model, opt_state = update_step(model, optimizer, opt_state, x, y)

        # Evaluate
        _, train_acc = evaluate(model, train_loader)
        _, test_acc = evaluate(model, test_loader)

        print(f"Epoch {epoch+1}: Train acc={train_acc:.4f}, Test acc={test_acc:.4f}")

### Running the training

In [None]:
# Step 1: Choose an initial key (the responsible thing to do, remember?)
key = jrand.PRNGKey(0)

# Step 2: Train the model (AKA coffee break time!)
model = train_mnist(key)

# Step 3: World domination

## Wrapping up: how to master JAX

Together, we've explored the basics of JAX, from random keys to training a neural network.
JAX's learning curve might feel steep at first, but the performance gains are worth it once you get comfortable with the paradigm shift.

**So, what's next?**
- **Get a coffee:** You've earned it. #priorities
- **Practice the JAX mindset:** Functional programming with immutable data takes time to internalize. Start by converting some PyTorch scripts to JAX using your favorite LLM, and play around with the possiblities.
- **Master the key transformations:** The `jit`, `grad`, `vmap` trio is insanely powerful. Once you get used to it, you can't live without.
- **Explore the ecosystem:**
  - **[Equinox](https://github.com/patrick-kidger/equinox)** (obviously)
  - **[Optax](https://github.com/deepmind/optax)** for optimizers and loss functions
  - **[Diffrax](https://github.com/patrick-kidger/diffrax)** for differential equations
  - **[Jumanji](https://github.com/instadeepai/jumanji)** for reinforcement learning
  - **[And many many more...](https://github.com/n2cholas/awesome-jax)**

**Common pitfalls to avoid:**
- Reusing PRNG keys (leads to identical "random" numbers)
- In-place mutations (doesn't work on JAX arrays)
- Relying on Python control flow inside JIT (use JAX-specific alternatives)
- Expecting JIT to speed up single-use functions (first run is always slow)

**Remember:** JAX shines brightest on complex, performance-critical tasks. For quick prototyping, PyTorch might still be your best friend. Choose the right tool for your job!

**Fall in love:** Once you get to know JAX and see your first 10x speedup, you'll understand why it's worth the effort. Happy JAXing!