# Vectorization with `brainstate.transform.vmap`

Vectorization is a fundamental technique for efficient computation in machine learning and scientific computing. BrainState provides `brainstate.transform.vmap` as a state-aware wrapper around JAX's `jax.vmap`, enabling seamless vectorization of stateful computations.

This tutorial covers:

1. **Basic usage of `vmap`** with detailed parameter explanations and examples
2. **Random number semantics** and how `vmap` automatically handles `RandomState`
3. **Understanding `StatefulMapping`**, the underlying abstraction that powers `vmap`

In [1]:
import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import vmap2 as vmap
from brainstate.util.filter import OfType

## 1. Basic Usage: Understanding `vmap` Parameters

### 1.1 The `in_axes` Parameter

The `in_axes` parameter controls how batch dimensions are mapped over function arguments. It works identically to `jax.vmap`.

In [2]:
# Example 1: Single scalar-to-scalar function
def square(x):
    return x ** 2


# Vectorize over the first axis (default)
vmap_square = vmap(square, in_axes=0)

xs = jnp.array([1.0, 2.0, 3.0, 4.0])
print("Input shape:", xs.shape)
print("Output:", vmap_square(xs))
print("Output shape:", vmap_square(xs).shape)

Input shape: (4,)
Output: [ 1.  4.  9. 16.]
Output shape: (4,)


In [3]:
# Example 2: Multiple arguments with different in_axes
def weighted_sum(x, weight):
    """Compute weighted sum: x * weight"""
    return x * weight


# Vectorize over x (batch), but broadcast weight (single value)
vmap_weighted = vmap(weighted_sum, in_axes=(0, None))

batch_x = jnp.array([1.0, 2.0, 3.0])
single_weight = 2.0

result = vmap_weighted(batch_x, single_weight)
print("Batched x:", batch_x)
print("Single weight:", single_weight)
print("Result:", result)

Batched x: [1. 2. 3.]
Single weight: 2.0
Result: [2. 4. 6.]


In [4]:
# Example 3: Vectorizing along different axes
def matrix_vector_product(matrix, vector):
    return matrix @ vector


# Batch of matrices: shape (batch, m, n)
# Batch of vectors: shape (batch, n)
batch_matrices = jnp.ones((4, 3, 2))  # 4 matrices of shape (3, 2)
batch_vectors = jnp.ones((4, 2))  # 4 vectors of shape (2,)

# Map over the first axis of both arguments
vmap_matmul = vmap(matrix_vector_product, in_axes=(0, 0))
result = vmap_matmul(batch_matrices, batch_vectors)

print("Input shapes:", batch_matrices.shape, batch_vectors.shape)
print("Output shape:", result.shape)  # (4, 3)

Input shapes: (4, 3, 2) (4, 2)
Output shape: (4, 3)


### 1.2 The `out_axes` Parameter

The `out_axes` parameter controls where the batch dimension appears in the output.

In [5]:
def create_vector(scalar):
    """Create a 3D vector from a scalar."""
    return jnp.array([scalar, scalar * 2, scalar * 3])


# Default: batch dimension at axis 0
vmap_default = vmap(create_vector, in_axes=0, out_axes=0)
result_axis0 = vmap_default(jnp.array([1.0, 2.0]))
print("out_axes=0, shape:", result_axis0.shape)  # (2, 3)
print(result_axis0)

# Batch dimension at axis 1
vmap_axis1 = vmap(create_vector, in_axes=0, out_axes=1)
result_axis1 = vmap_axis1(jnp.array([1.0, 2.0]))
print("\nout_axes=1, shape:", result_axis1.shape)  # (3, 2)
print(result_axis1)

out_axes=0, shape: (2, 3)
[[1. 2. 3.]
 [2. 4. 6.]]

out_axes=1, shape: (3, 2)
[[1. 2.]
 [2. 4.]
 [3. 6.]]


### 1.3 The `axis_name` Parameter

The `axis_name` parameter allows you to name the mapped axis, enabling collective operations like `jax.lax.pmean`.

In [6]:
def normalize_batch(x):
    """Normalize by subtracting the batch mean."""
    # Compute mean across the 'batch' axis
    batch_mean = jax.lax.pmean(x, axis_name='batch')
    return x - batch_mean


# Name the mapped axis as 'batch'
vmap_normalize = vmap(normalize_batch, in_axes=0, axis_name='batch')

batch_data = jnp.array([1.0, 2.0, 3.0, 4.0])
normalized = vmap_normalize(batch_data)

print("Input:", batch_data)
print("Batch mean:", jnp.mean(batch_data))
print("Normalized:", normalized)
print("New mean:", jnp.mean(normalized))  # Should be ~0

Input: [1. 2. 3. 4.]
Batch mean: 2.5
Normalized: [-1.5 -0.5  0.5  1.5]
New mean: 0.0


### 1.4 The `axis_size` Parameter

The `axis_size` parameter explicitly specifies the size of the mapped axis. It's optional when the size can be inferred from arguments.

In [7]:
def generate_sequence(unused=None):
    """Generate a sequence (for demonstration)."""
    return jnp.arange(3)


# When all inputs are static (None in in_axes), we must specify axis_size
vmap_generate = vmap(generate_sequence, in_axes=None, axis_size=5)

result = vmap_generate()
print("Generated sequences:")
print(result)
print("Shape:", result.shape)  # (5, 3)

Generated sequences:
[[0 1 2]
 [0 1 2]
 [0 1 2]
 [0 1 2]
 [0 1 2]]
Shape: (5, 3)


### 1.5 State-Aware Parameters: `state_in_axes` and `state_out_axes`

These are BrainState-specific parameters that control how `State` objects are batched.

In [8]:
class Counter(brainstate.nn.Module):
    """A simple counter using ShortTermState."""

    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.zeros(4))

    def __call__(self, delta):
        """Increment counter by delta."""
        self.count.value = self.count.value + delta
        return self.count.value


counter = Counter()

# Vectorize with state batching
vmap_counter = vmap(
    counter,
    in_axes=0,  # Batch over input deltas
    out_axes=0,  # Batch over output counts
    # Batch the counter state along axis 0
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
)

deltas = jnp.array([1.0, 2.0, 3.0, 4.0])
counts = vmap_counter(deltas)

print("Deltas:", deltas)
print("Counts:", counts)
print("Final counter value:", counter.count.value)  # Sum of deltas

Deltas: [1. 2. 3. 4.]
Counts: [1. 2. 3. 4.]
Final counter value: [1. 2. 3. 4.]


### 1.6 Working with Module States

When working with `nn.Module`, states are typically shared (broadcast) across the batch by default.

In [9]:
class LinearLayer(brainstate.nn.Module):
    """Simple linear layer."""

    def __init__(self, in_features, out_features):
        super().__init__()
        # Parameters are ParamState
        self.weight = brainstate.ParamState(jnp.ones((in_features, out_features)))
        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))

    def __call__(self, x):
        return x @ self.weight.value + self.bias.value


layer = LinearLayer(3, 2)

# Vectorize over batch of inputs
# Parameters are shared (broadcast) across the batch
vmap_layer = vmap(layer, in_axes=0, out_axes=0)

batch_inputs = jnp.ones((4, 3))  # Batch of 4 inputs
batch_outputs = vmap_layer(batch_inputs)

print("Input shape:", batch_inputs.shape)  # (4, 3)
print("Output shape:", batch_outputs.shape)  # (4, 2)
print("Output:")
print(batch_outputs)

Input shape: (4, 3)
Output shape: (4, 2)
Output:
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]


### 1.7 The `unexpected_out_state_mapping` Parameter

This parameter controls behavior when a state is written but not covered by `state_out_axes`.

In [10]:
temp_state = brainstate.ShortTermState(jnp.zeros(3))
write_state = brainstate.LongTermState(jnp.asarray(0.))


def update_temp(x):
    """Function that writes to a state."""
    temp_state.value = temp_state.value + x
    write_state.value = temp_state.value
    return temp_state.value


# Example 1: Properly specify state_out_axes
vmap_proper = vmap(
    update_temp,
    in_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
    unexpected_out_state_mapping='raise',  # Default
)

try:
    result = vmap_proper(jnp.array([1.0, 2.0, 3.0]))
except Exception as e:
    print(e)


State
 LongTermState(
  value=ShapedArray(float32[])
) 
 was not expected to be batched on output. Please adjust state_out_axes or set unexpected_out_state_mapping to "warn" or "ignore".


In [11]:

# Example 2: Using 'ignore' to allow unexpected states
temp_state2 = brainstate.ShortTermState(jnp.array(0.0))
write_state2 = brainstate.LongTermState(jnp.asarray(0.))


def update_temp2(x):
    temp_state2.value = temp_state2.value + x
    write_state2.value = temp_state2.value
    return temp_state2.value


print('Before vmapping, original write state value:', write_state2.value)

vmap_ignore = vmap(
    update_temp2,
    in_axes=0,
    # Note: not specifying state_in_axes/state_out_axes
    unexpected_out_state_mapping='ignore',
)

result2 = vmap_ignore(jnp.array([1.0, 2.0, 3.0]))
print("With 'ignore' policy:", result2)
print("With 'ignore' policy, write state value after vmapping:", write_state2.value)

Before vmapping, original write state value: 0.0
With 'ignore' policy: [1. 2. 3.]
With 'ignore' policy, write state value after vmapping: [1. 2. 3.]


## 2. Random Number Semantics

### 2.1 Automatic Key Splitting for `RandomState`

**Important**: `brainstate.transform.vmap` automatically splits PRNG keys for `brainstate.random.RandomState`, ensuring each batch element receives a unique random key.

In [12]:
# Reset random state
brainstate.random.seed(42)


def sample_normal(scale):
    """Sample from a normal distribution."""
    return brainstate.random.normal(0.0, scale)


# Vectorize the sampling function
vmap_sample = vmap(
    sample_normal,
    in_axes=0,
    # RandomState is automatically handled!
    # state_in_axes={0: OfType(brainstate.random.RandomState)},
    # state_out_axes={0: OfType(brainstate.random.RandomState)},
)

scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples = vmap_sample(scales)

print("Scales:", scales)
print("Samples:", samples)
print("\nNote: Each sample is different (independent random key per batch element)")

Scales: [1. 2. 3. 4.]
Samples: [-1.0413289 -1.4796011  2.222502   6.412178 ]

Note: Each sample is different (independent random key per batch element)


In [13]:
# Example 2: Multiple random operations
brainstate.random.seed(123)


def sample_multiple(mean):
    """Sample multiple random numbers."""
    sample1 = brainstate.random.uniform(0.0, 1.0)
    sample2 = brainstate.random.normal(mean, 1.0)
    return sample1 + sample2


vmap_multiple = vmap(sample_multiple, in_axes=0)

means = jnp.array([0.0, 1.0, 2.0])
results = vmap_multiple(means)

print("Means:", means)
print("Results:", results)
print("\nEach batch element uses independent random keys for both operations")

Means: [0. 1. 2.]
Results: [1.063001  2.0858884 3.2780576]

Each batch element uses independent random keys for both operations


### 2.2 Controlling Random Keys: Using JAX's Random API

If you need **shared random keys** across batch elements (same random numbers), use `jax.random` APIs and set `in_axes=None` for the key.

In [14]:
def sample_with_jax_key(key, scale):
    """Sample using JAX's random API."""
    return jax.random.normal(key, ()) * scale


# Shared key across all batch elements
vmap_shared_key = vmap(
    sample_with_jax_key,
    in_axes=(None, 0),  # key is None (broadcast), scale is batched
)

shared_key = jax.random.PRNGKey(0)
scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples_shared = vmap_shared_key(shared_key, scales)

print("Samples with shared key:", samples_shared)
print("Notice: All samples use the same base random number, just scaled differently")


# Compare with unique keys per batch element
def sample_with_unique_keys(key, scale):
    return jax.random.normal(key, ()) * scale


vmap_unique_keys = vmap(
    sample_with_unique_keys,
    in_axes=(0, 0),  # Both key and scale are batched
)

# Split key into batch
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, len(scales))
samples_unique = vmap_unique_keys(keys, scales)

print("\nSamples with unique keys:", samples_unique)
print("Notice: Each sample is independent")

Samples with shared key: [1.6226422 3.2452843 4.8679266 6.4905686]
Notice: All samples use the same base random number, just scaled differently

Samples with unique keys: [ 1.0040143 -4.8849115  3.8869078 -2.4877744]
Notice: Each sample is independent


### 2.3 Practical Example: Dropout with Reproducibility

In [15]:
class Dropout(brainstate.nn.Module):
    """Dropout layer using BrainState random."""

    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate

    def __call__(self, x, training=True):
        if not training:
            return x
        # Each call gets independent random mask
        keep_mask = brainstate.random.uniform(0.0, 1.0, x.shape) > self.rate
        return jnp.where(keep_mask, x / (1 - self.rate), 0.0)


brainstate.random.seed(456)
dropout = Dropout(rate=0.3)

# Vectorize dropout application
vmap_dropout = vmap(
    lambda x: dropout(x, training=True),
    in_axes=0,
)

batch_data = jnp.ones((4, 5))  # 4 samples, 5 features
dropped = vmap_dropout(batch_data)

print("Original data:")
print(batch_data)
print("\nAfter dropout:")
print(dropped)
print("\nNote: Each row has a different dropout pattern")

Original data:
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]

After dropout:
[[0.        1.4285715 1.4285715 0.        1.4285715]
 [1.4285715 1.4285715 1.4285715 1.4285715 0.       ]
 [1.4285715 0.        1.4285715 0.        0.       ]
 [0.        1.4285715 1.4285715 1.4285715 0.       ]]

Note: Each row has a different dropout pattern


## 3. Under the Hood: `StatefulMapping`

`brainstate.transform.vmap` is actually a thin wrapper around `brainstate.transform.StatefulMapping`, which provides the core state-aware mapping functionality.

### 3.1 Understanding the Architecture

`StatefulMapping` performs several key operations:

1. **State Discovery**: Identifies all `State` objects accessed by the function
2. **In/Out Axis Mapping**: Determines which states are batched and along which axes
3. **IR Compilation**: Compiles the function to JAX's intermediate representation (Jaxpr)
4. **State Management**: Manages state values before and after execution

In [16]:
# Example: Inspecting StatefulMapping
accumulator = brainstate.ShortTermState(jnp.zeros(4))


def accumulate(x):
    accumulator.value = accumulator.value + x
    return accumulator.value


# Create a StatefulMapping
mapped_accumulate = vmap(
    accumulate,
    in_axes=0,
    out_axes=0,
    axis_size=4,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
)

# Inspect the StatefulMapping object
print("Type:", type(mapped_accumulate))
print("Origin function:", mapped_accumulate.origin_fun)
print("in_axes:", mapped_accumulate.in_axes)
print("out_axes:", mapped_accumulate.out_axes)
print("state_in_axes:", mapped_accumulate.state_in_axes)
print("state_out_axes:", mapped_accumulate.state_out_axes)
print("axis_name:", mapped_accumulate.axis_name)
print("axis_size:", mapped_accumulate.axis_size)

Type: <class 'brainstate.transform.StatefulMapping'>
Origin function: <function accumulate at 0x000001B087DBF100>
in_axes: 0
out_axes: 0
state_in_axes: {0: OfType(<class 'brainstate.ShortTermState'>)}
state_out_axes: {}
axis_name: None
axis_size: 4


### 3.2 Compilation and Caching

`StatefulMapping` compiles the function and caches:
- The Jaxpr (JAX intermediate representation)
- State traces (which states are accessed)
- Batch axis mappings

This compilation happens lazily on first call.

In [17]:
# Example: Observing compilation
call_count = [0]


def counting_function(x):
    call_count[0] += 1
    return x * 2


vmap_counting = vmap(counting_function, in_axes=0)

# First call: triggers compilation
print("Before first call, count:", call_count[0])
result1 = vmap_counting(jnp.array([1.0, 2.0, 3.0]))
print("After first call, count:", call_count[0], "(compilation trace)")

# Second call: uses cached compilation
call_count[0] = 0
result2 = vmap_counting(jnp.array([4.0, 5.0, 6.0]))
print("After second call, count:", call_count[0], "(no recompilation)")

print("\nResults:")
print("First:", result1)
print("Second:", result2)

Before first call, count: 0
After first call, count: 2 (compilation trace)
After second call, count: 0 (no recompilation)

Results:
First: [2. 4. 6.]
Second: [ 8. 10. 12.]


### 3.3 State Axis Inference

`StatefulMapping` automatically infers which states need to be batched based on:
1. Explicit `state_in_axes` filters
2. State usage patterns during tracing
3. Batch dimensions in state values

In [18]:
# Example: Complex state interactions
class StatefulComputation(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        # Different types of states
        self.temp = brainstate.ShortTermState(jnp.zeros(3))
        self.param = brainstate.ParamState(jnp.array(1.0))

    def __call__(self, x):
        # temp is batched (accumulates per batch element)
        self.temp.value = self.temp.value + x
        # param is shared (broadcast across batch)
        return self.temp.value * self.param.value


model = StatefulComputation()

# Only batch ShortTermState, ParamState is shared
vmap_model = vmap(
    model,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
)

inputs = jnp.array([1.0, 2.0, 3.0])
outputs = vmap_model(inputs)

print("Inputs:", inputs)
print("Outputs:", outputs)
print("Final temp state:", model.temp.value)  # Sum of inputs
print("Param (unchanged):", model.param.value)

Inputs: [1. 2. 3.]
Outputs: [1. 2. 3.]
Final temp state: [1. 2. 3.]
Param (unchanged): 1.0


### 3.4 Direct Use of `StatefulMapping`

Advanced users can instantiate `StatefulMapping` directly for custom mapping primitives.

In [19]:
from brainstate.transform import StatefulMapping
import functools

# Example: Using a custom mapping function
counter_state = brainstate.ShortTermState(jnp.zeros(3))


def increment(delta):
    counter_state.value = counter_state.value + delta
    return counter_state.value


# Create StatefulMapping with custom mapping_fn
# (In this case, we still use jax.vmap, but you could use jax.pmap, etc.)
custom_mapping = StatefulMapping(
    increment,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
    name="custom_increment",
    mapping_fn=functools.partial(jax.vmap, spmd_axis_name=None),
)

deltas = jnp.array([1.0, 2.0, 3.0])
results = custom_mapping(deltas)

print("Custom mapping results:", results)
print("Final counter:", counter_state.value)

Custom mapping results: [1. 2. 3.]
Final counter: [1. 2. 3.]


### 3.5 Understanding the IR (Intermediate Representation)

`StatefulMapping` compiles your function to JAX's Jaxpr (JAX expression), an intermediate representation that:
- Represents the computation as a functional program
- Explicitly tracks all inputs and outputs (including state values)
- Enables optimizations and transformations

In [20]:
# Example: Inspecting the Jaxpr
simple_state = brainstate.State(jnp.array(1.0))


def simple_op(x):
    result = x + simple_state.value
    simple_state.value = result
    return result * 2


# Create a simple mapping
simple_vmap = vmap(
    simple_op,
    in_axes=0,
    state_out_axes={0: OfType(brainstate.State)},
)

# Call once to trigger compilation
test_input = jnp.array([1.0, 2.0])
_ = simple_vmap(test_input)

# Access the compiled Jaxpr
cache_key = simple_vmap.get_arg_cache_key(test_input)
jaxpr = simple_vmap.get_jaxpr_by_cache(cache_key)

print("Compiled Jaxpr:")
print(jaxpr)
print("\nThis represents the function's computation graph at an abstract level")

Compiled Jaxpr:
{ [34;1mlambda [39;22m; a[35m:f32[2][39m b[35m:f32[][39m. [34;1mlet
    [39;22mc[35m:key<fry>[][39m = random_seed[impl=fry] 0:i32[]
    d[35m:u32[2][39m = random_unwrap c
    e[35m:key<fry>[][39m = random_wrap[impl=fry] d
    f[35m:key<fry>[2][39m = random_split[shape=(2,)] e
    _[35m:u32[2,2][39m = random_unwrap f
    g[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    h[35m:f32[2][39m = add a g
    _[35m:f32[2][39m = mul h 2.0:f32[]
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    j[35m:f32[2][39m = add a i
    k[35m:f32[2][39m = mul j 2.0:f32[]
  [34;1min [39;22m(k, j) }

This represents the function's computation graph at an abstract level


## 4. Advanced Patterns and Best Practices

### 4.1 Nested `vmap`

You can nest multiple `vmap` calls for multi-dimensional batching.

In [21]:
def matrix_elem_product(x, y):
    """Element-wise product."""
    return x * y


# First vmap: over rows
vmap_rows = vmap(matrix_elem_product, in_axes=(0, 0))

# Second vmap: over columns
vmap_matrix = vmap(vmap_rows, in_axes=(0, 0))

# Create 2D inputs
matrix_a = jnp.ones((3, 4))
matrix_b = jnp.arange(12).reshape(3, 4)

result = vmap_matrix(matrix_a, matrix_b)

print("Matrix A shape:", matrix_a.shape)
print("Matrix B shape:", matrix_b.shape)
print("Result shape:", result.shape)
print("Result:")
print(result)

Matrix A shape: (3, 4)
Matrix B shape: (3, 4)
Result shape: (3, 4)
Result:
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]


### 4.2 Combining with Other Transforms

`vmap` can be composed with other JAX transforms like `jit` and `grad`.

In [22]:
from brainstate.transform import grad, jit


# Define a loss function
def loss_fn(x, target):
    pred = x ** 2
    return jnp.sum((pred - target) ** 2)


# Compose: jit -> grad -> vmap
batched_grad = vmap(
    grad(loss_fn, argnums=0),
    in_axes=(0, 0),
)
batched_grad_jit = jit(batched_grad)

# Batch of inputs and targets
batch_x = jnp.array([1.0, 2.0, 3.0])
batch_targets = jnp.array([2.0, 4.0, 6.0])

gradients = batched_grad_jit(batch_x, batch_targets)

print("Inputs:", batch_x)
print("Targets:", batch_targets)
print("Gradients:", gradients)

Inputs: [1. 2. 3.]
Targets: [2. 4. 6.]
Gradients: [-4.  0. 36.]


## Summary

In this tutorial, we covered:

### 1. **`vmap` Parameters**
- `in_axes`: Controls how inputs are batched
- `out_axes`: Controls where batch dimension appears in outputs
- `axis_name`: Names the mapped axis for collective operations
- `axis_size`: Explicitly specifies batch size when needed
- `state_in_axes` / `state_out_axes`: Control state batching (BrainState-specific)
- `unexpected_out_state_mapping`: Handles unexpected state writes

### 2. **Random Number Semantics**
- **Automatic key splitting**: `brainstate.random.RandomState` is automatically split per batch element
- **Shared keys**: Use `jax.random` APIs with `in_axes=None` for shared random numbers
- Each batch element gets independent random streams by default

### 3. **`StatefulMapping` Architecture**
- `vmap` is a wrapper around `StatefulMapping`
- Performs state discovery, axis mapping, and IR compilation
- Compiles to Jaxpr (JAX intermediate representation)
- Caches compilations for reuse
- Manages state values before and after execution

### Key Takeaways

- BrainState's `vmap` seamlessly handles stateful computations
- Random states are automatically managed for reproducibility
- The underlying `StatefulMapping` provides powerful abstractions for state-aware transformations
- Understanding the IR compilation helps debug and optimize vectorized code