# Other BrainState Transforms

This tutorial covers essential utilities in `brainstate.transform` for introspection, optimization, and debugging:

1. **`checkpoint`**: Memory-efficient gradient computation through rematerialization
2. **`make_jaxpr` and `StatefulFunction`**: Inspect and understand compiled computation graphs
3. **`jax.debug.print`**: Runtime debugging in JIT-compiled code

All examples demonstrate state-aware features that distinguish BrainState from vanilla JAX.

## Imports and Setup

In [1]:
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import checkpoint, make_jaxpr, StatefulFunction

## 1. `checkpoint`: Memory-Efficient Gradient Computation

`checkpoint` (also known as rematerialization or gradient checkpointing) is crucial for training deep neural networks and processing long sequences. It trades computation for memory during backpropagation.

### How Gradient Computation Works

**Without checkpointing:**
- Forward pass: Computes outputs and stores **all intermediate activations**
- Backward pass: Uses stored activations to compute gradients
- Memory usage: O(n) where n is the number of layers/steps

**With checkpointing:**
- Forward pass: Computes outputs, stores **only inputs** at checkpoints
- Backward pass: **Recomputes** intermediate activations from checkpoints as needed
- Memory usage: O(√n) with optimal checkpointing
- Computation: ~2x forward passes (recomputation during backward)

**Key principle: Trade extra computation for reduced memory**

### 1.1 Basic Usage with Gradient Computation

In [2]:
# Example 1: Memory-efficient gradient computation
print("=== Example 1: Basic checkpoint usage ===")

# Without checkpoint: stores all intermediate activations
def expensive_forward(x):
    """Chain of expensive operations."""
    y = jnp.sin(x)
    z = jnp.exp(y)
    w = jnp.tanh(z)
    return jnp.sum(w ** 2)

# With checkpoint: only stores inputs, recomputes during backward
@checkpoint
def checkpointed_forward(x):
    """Same computation, but memory-efficient."""
    y = jnp.sin(x)
    z = jnp.exp(y)
    w = jnp.tanh(z)
    return jnp.sum(w ** 2)

x = jnp.linspace(0, 10, 1000)

# Both produce same results
value1, grad1 = jax.value_and_grad(expensive_forward)(x)
value2, grad2 = jax.value_and_grad(checkpointed_forward)(x)

print(f"Values match: {jnp.allclose(value1, value2)}")
print(f"Gradients match: {jnp.allclose(grad1, grad2)}")
print(f"\nMemory: checkpoint saves ~3x intermediate activations")
print(f"Cost: checkpoint does ~2x forward computations")

=== Example 1: Basic checkpoint usage ===
Values match: True
Gradients match: True

Memory: checkpoint saves ~3x intermediate activations
Cost: checkpoint does ~2x forward computations


### 1.2 Checkpointing Stateful Computations

BrainState's `checkpoint` properly handles `State` objects during gradient computation.

In [3]:
# Example 2: Checkpoint with stateful neural network
print("\n=== Example 2: Checkpointed neural network ===")

class DeepNetwork(brainstate.nn.Module):
    """Deep network with many layers."""
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = []
        for i in range(len(layer_sizes) - 1):
            self.layers.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (layer_sizes[i], layer_sizes[i+1])
                ))
            )
    
    def forward(self, x, use_checkpoint=False):
        """Forward pass through all layers."""
        def layer_fn(x):
            h = x
            for W in self.layers[:-1]:
                h = jnp.tanh(h @ W.value)
            # Output layer (no activation)
            return h @ self.layers[-1].value
        
        if use_checkpoint:
            return checkpoint(layer_fn)(x)
        else:
            return layer_fn(x)

# Create a deep network: 10 layers
net = DeepNetwork([128, 256, 256, 256, 256, 256, 256, 256, 256, 128, 10])
x_batch = jax.random.normal(jax.random.PRNGKey(42), (32, 128))

# Define loss function
def loss_fn(use_checkpoint):
    y_pred = net.forward(x_batch, use_checkpoint=use_checkpoint)
    return jnp.mean(y_pred ** 2)

# Get parameters
params = net.states(brainstate.ParamState)

# Compute gradients with and without checkpoint
grads_normal = brainstate.transform.grad(lambda: loss_fn(False), params)()
grads_checkpointed = brainstate.transform.grad(lambda: loss_fn(True), params)()

# Compare
print(f"Number of layers: {len(net.layers)}")
print(f"Gradient shapes match: {jax.tree.map(lambda a, b: a.shape == b.shape, grads_normal, grads_checkpointed)}")
print(f"\nWithout checkpoint: Stores ~10 layer activations")
print(f"With checkpoint: Recomputes activations during backward")
print(f"Memory saved: ~10x for deep networks")


=== Example 2: Checkpointed neural network ===
Number of layers: 10
Gradient shapes match: {('layers', 0): True, ('layers', 1): True, ('layers', 2): True, ('layers', 3): True, ('layers', 4): True, ('layers', 5): True, ('layers', 6): True, ('layers', 7): True, ('layers', 8): True, ('layers', 9): True}

Without checkpoint: Stores ~10 layer activations
With checkpoint: Recomputes activations during backward
Memory saved: ~10x for deep networks


### 1.3 Sequential Layer Checkpointing

For very deep networks, checkpoint individual layers or groups of layers.

In [5]:
# Example 3: Per-layer checkpointing
print("\n=== Example 3: Granular checkpointing ===")

class CheckpointedDeepNetwork(brainstate.nn.Module):
    """Network with per-layer checkpointing."""
    def __init__(self, layer_sizes, checkpoint_every=2):
        super().__init__()
        self.checkpoint_every = checkpoint_every
        self.weights = []
        for i in range(len(layer_sizes) - 1):
            self.weights.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (layer_sizes[i], layer_sizes[i+1])
                ) * 0.1)
            )
    
    def __call__(self, x):
        h = x
        for i, W in enumerate(self.weights):
            # Define layer computation
            def layer_forward(h):
                return jnp.tanh(h @ W.value)
            
            # Checkpoint every N layers
            if (i + 1) % self.checkpoint_every == 0:
                h = checkpoint(layer_forward)(h)
            else:
                h = layer_forward(h)
        return h

# Create network: checkpoint every 2 layers
ckpt_net = CheckpointedDeepNetwork(
    [64, 128, 128, 128, 128, 128, 32],  # 6 layers
    checkpoint_every=2
)

x_in = jax.random.normal(jax.random.PRNGKey(123), (16, 64))

# Forward and backward
def forward_loss():
    return jnp.sum(ckpt_net(x_in) ** 2)

grads, value = brainstate.transform.grad(
    forward_loss, 
    ckpt_net.states(brainstate.ParamState),
    return_value=True
)()

print(f"Network depth: {len(ckpt_net.weights)} layers")
print(f"Checkpoint frequency: every {ckpt_net.checkpoint_every} layers")
print(f"Checkpoints created: {len(ckpt_net.weights) // ckpt_net.checkpoint_every}")
print(f"Loss: {value:.4f}")
print(f"\nMemory usage: O(checkpoints) instead of O(layers)")


=== Example 3: Granular checkpointing ===
Network depth: 6 layers
Checkpoint frequency: every 2 layers
Checkpoints created: 3
Loss: 85.1143

Memory usage: O(checkpoints) instead of O(layers)


### 1.4 Memory-Computation Tradeoff

Understand when to use checkpointing.

In [6]:
# Example 5: Measuring the tradeoff
print("\n=== Example 5: When to use checkpoint ===")

import time

class BenchmarkNet(brainstate.nn.Module):
    def __init__(self, n_layers, hidden_size):
        super().__init__()
        self.layers = []
        for i in range(n_layers):
            self.layers.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (hidden_size, hidden_size)
                ) * 0.1)
            )
    
    def forward_normal(self, x):
        h = x
        for W in self.layers:
            h = jnp.tanh(h @ W.value)
        return jnp.sum(h)
    
    def forward_checkpointed(self, x):
        def layer_block(h):
            for W in self.layers:
                h = jnp.tanh(h @ W.value)
            return jnp.sum(h)
        return checkpoint(layer_block)(x)

# Small network: checkpoint overhead not worth it
small_net = BenchmarkNet(n_layers=3, hidden_size=64)
x_small = jax.random.normal(jax.random.PRNGKey(0), (64,))

# Large network: checkpoint saves significant memory
large_net = BenchmarkNet(n_layers=20, hidden_size=512)
x_large = jax.random.normal(jax.random.PRNGKey(0), (512,))

print("Small network (3 layers, 64 hidden):")
print("  → Normal gradient: Fast, low memory")
print("  → Checkpoint: Overhead not justified\n")

print("Large network (20 layers, 512 hidden):")
print("  → Normal gradient: Stores ~20 activations (high memory)")
print("  → Checkpoint: Recomputes activations (saves memory)")
print("  → Recommended: Use checkpoint for deep/wide networks\n")

print("Rule of thumb:")
print("  Use checkpoint when: depth > 10 OR width > 256")
print("  Skip checkpoint when: shallow networks (< 5 layers)")


=== Example 5: When to use checkpoint ===
Small network (3 layers, 64 hidden):
  → Normal gradient: Fast, low memory
  → Checkpoint: Overhead not justified

Large network (20 layers, 512 hidden):
  → Normal gradient: Stores ~20 activations (high memory)
  → Checkpoint: Recomputes activations (saves memory)
  → Recommended: Use checkpoint for deep/wide networks

Rule of thumb:
  Use checkpoint when: depth > 10 OR width > 256
  Skip checkpoint when: shallow networks (< 5 layers)


## 2. `make_jaxpr` and `StatefulFunction`: Inspecting Compiled Code

`make_jaxpr` converts a function into its JAX intermediate representation (Jaxpr), which reveals how JAX compiles and optimizes your code. `StatefulFunction` is the underlying mechanism that enables state-aware transformations.

### What is Jaxpr?

Jaxpr is JAX's intermediate representation based on a simply-typed first-order lambda calculus with let-bindings. It shows:
- Primitive operations (add, mul, sin, etc.)
- Data dependencies
- How state reads/writes are handled
- Memory layout and optimizations

### 2.1 Basic Jaxpr Inspection

In [7]:
# Example 1: Simple function jaxpr
print("=== Example 1: Basic jaxpr ===")

def simple_fn(x):
    y = jnp.sin(x)
    z = jnp.cos(y)
    return z * 2

# Create jaxpr
jaxpr_fn = make_jaxpr(simple_fn)
jaxpr, states = jaxpr_fn(3.0)

print("Function: z = cos(sin(x)) * 2")
print("\nJaxpr representation:")
print(jaxpr)
print(f"\nStates used: {len(states)} (none for this simple function)")

=== Example 1: Basic jaxpr ===
Function: z = cos(sin(x)) * 2

Jaxpr representation:
{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = sin a
    c[35m:f32[][39m = cos b
    d[35m:f32[][39m = mul c 2.0:f32[]
  [34;1min [39;22m(d,) }

States used: 0 (none for this simple function)


### 2.2 Stateful Jaxpr: Tracking State Reads and Writes

BrainState's `make_jaxpr` reveals how states are accessed.

In [8]:
# Example 2: Jaxpr with states
print("\n=== Example 2: Stateful jaxpr ===")

# Create states
counter = brainstate.ShortTermState(jnp.array(0))
accumulator = brainstate.ShortTermState(jnp.array(0.0))

def stateful_fn(x):
    # Read states
    count = counter.value
    accum = accumulator.value
    
    # Update states
    counter.value = count + 1
    accumulator.value = accum + x
    
    return accumulator.value / counter.value

# Inspect jaxpr
jaxpr_fn = make_jaxpr(stateful_fn)
jaxpr, states = jaxpr_fn(5.0)

print("Function: running average tracker")
print(f"\nStates accessed: {len(states)}")
for i, state in enumerate(states):
    print(f"  [{i}] {type(state).__name__}: {state.value}")

print("\nJaxpr (state operations visible):")
print(jaxpr)
print("\nNote: Jaxpr shows state reads as inputs, writes as outputs")


=== Example 2: Stateful jaxpr ===
Function: running average tracker

States accessed: 2
  [0] ShortTermState: 0
  [1] ShortTermState: 0.0

Jaxpr (state operations visible):
{ [34;1mlambda [39;22m; a[35m:f32[][39m b[35m:i32[][39m c[35m:f32[][39m. [34;1mlet
    [39;22md[35m:i32[][39m = add b 1:i32[]
    e[35m:f32[][39m = add c a
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=True] d
    g[35m:f32[][39m = div e f
  [34;1min [39;22m(g, d, e) }

Note: Jaxpr shows state reads as inputs, writes as outputs


### 2.3 Understanding `StatefulFunction`

`StatefulFunction` is the core abstraction that enables all BrainState transformations. It:
1. **Identifies states** accessed during function execution
2. **Compiles to Jaxpr** with explicit state inputs/outputs
3. **Manages state values** before and after execution
4. **Caches compilations** for efficient repeated calls

In [9]:
# Example 3: Using StatefulFunction directly
print("\n=== Example 3: StatefulFunction mechanics ===")

# Create a module with state
class NeuralCell(brainstate.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W = brainstate.ParamState(jax.random.normal(
            jax.random.PRNGKey(0), (input_size, hidden_size)
        ))
        self.h = brainstate.ShortTermState(jnp.zeros(hidden_size))
    
    def __call__(self, x):
        # Update hidden state
        self.h.value = jnp.tanh(x @ self.W.value + self.h.value)
        return self.h.value

cell = NeuralCell(input_size=10, hidden_size=20)

# Wrap in StatefulFunction
sf = StatefulFunction(cell)

# Example input
x = jax.random.normal(jax.random.PRNGKey(1), (10,))

# Step 1: Compile and inspect
sf.make_jaxpr(x)
print("Step 1: Compilation")
print(f"  Compiled for input shape: {x.shape}")

# Step 2: Get tracked states
states = sf.get_states(x)
read_states = sf.get_read_states(x)
write_states = sf.get_write_states(x)

print(f"\nStep 2: State identification")
print(f"  Total states: {len(states)}")
print(f"  Read states: {len(read_states)}")
for s in read_states:
    print(f"    - {type(s).__name__}: shape {s.value.shape}")
print(f"  Write states: {len(write_states)}")
for s in write_states:
    print(f"    - {type(s).__name__}: shape {s.value.shape}")

# Step 3: Get jaxpr
jaxpr = sf.get_jaxpr(x)
print(f"\nStep 3: Jaxpr compilation")
print(f"  Jaxpr variables: {len(jaxpr.jaxpr.invars)} inputs, {len(jaxpr.jaxpr.outvars)} outputs")
print(f"  Jaxpr equations: {len(jaxpr.jaxpr.eqns)} operations")

# Step 4: Execute
output = sf(x)
print(f"\nStep 4: Execution")
print(f"  Output shape: {output.shape}")
print(f"  Hidden state updated: {cell.h.value.shape}")


=== Example 3: StatefulFunction mechanics ===
Step 1: Compilation
  Compiled for input shape: (10,)

Step 2: State identification
  Total states: 2
  Read states: 1
    - ParamState: shape (10, 20)
  Write states: 1
    - ShortTermState: shape (20,)

Step 3: Jaxpr compilation
  Jaxpr variables: 3 inputs, 2 outputs
  Jaxpr equations: 3 operations

Step 4: Execution
  Output shape: (20,)
  Hidden state updated: (20,)


### 2.4 Jaxpr for Gradient Computation

Inspect how autodiff transforms your code.

In [10]:
# Example 4: Gradient jaxpr
print("\n=== Example 4: Gradient computation jaxpr ===")

# Simple loss function
params = brainstate.ParamState(jnp.array([1.0, 2.0, 3.0]))

def loss_fn(x):
    return jnp.sum((params.value - x) ** 2)

# Original function jaxpr
print("Original function jaxpr:")
jaxpr_orig, _ = make_jaxpr(loss_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_orig)

# Gradient function jaxpr
print("\nGradient function jaxpr:")
grad_fn = brainstate.transform.grad(loss_fn, params)
jaxpr_grad, _ = make_jaxpr(grad_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_grad)

print("\nNote: Gradient jaxpr includes:")
print("  - Forward pass operations")
print("  - Backward pass (VJP) operations")
print("  - Much more complex than original")


=== Example 4: Gradient computation jaxpr ===
Original function jaxpr:
{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3][39m. [34;1mlet
    [39;22mc[35m:f32[3][39m = sub b a
    d[35m:f32[3][39m = integer_pow[y=2] c
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
  [34;1min [39;22m(e, b) }

Gradient function jaxpr:
{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3][39m. [34;1mlet
    [39;22mc[35m:f32[3][39m = sub b a
    d[35m:f32[3][39m = integer_pow[y=2] c
    e[35m:f32[3][39m = integer_pow[y=1] c
    f[35m:f32[3][39m = mul 2.0:f32[] e
    _[35m:f32[][39m = reduce_sum[axes=(0,)] d
    g[35m:f32[3][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(3,)
      sharding=None
    ] 1.0:f32[]
    h[35m:f32[3][39m = mul g f
  [34;1min [39;22m(h, b) }

Note: Gradient jaxpr includes:
  - Forward pass operations
  - Backward pass (VJP) operations
  - Much more complex than original


### 2.5 Jaxpr for Transformed Functions

See how transformations affect the compiled code.

In [11]:
# Example 5: Transformation jaxpr
print("\n=== Example 5: Transformed function jaxpr ===")

def simple_fn(x):
    return x ** 2

# Original
print("Original function:")
jaxpr1, _ = make_jaxpr(simple_fn)(jnp.array([1.0, 2.0, 3.0]))
print(jaxpr1)

# Vmapped version
print("\nVmapped function:")
vmapped_fn = brainstate.transform.vmap2(simple_fn)
jaxpr2, _ = make_jaxpr(vmapped_fn)(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
print(jaxpr2)

print("\nNote: vmap adds batching dimensions to operations")


=== Example 5: Transformed function jaxpr ===
Original function:
{ [34;1mlambda [39;22m; a[35m:f32[3][39m. [34;1mlet[39;22m b[35m:f32[3][39m = integer_pow[y=2] a [34;1min [39;22m(b,) }

Vmapped function:
{ [34;1mlambda [39;22m; a[35m:f32[2,2][39m. [34;1mlet
    [39;22mb[35m:key<fry>[][39m = random_seed[impl=fry] 0:i32[]
    c[35m:u32[2][39m = random_unwrap b
    d[35m:key<fry>[][39m = random_wrap[impl=fry] c
    e[35m:key<fry>[2][39m = random_split[shape=(2,)] d
    _[35m:u32[2,2][39m = random_unwrap e
    _[35m:f32[2,2][39m = integer_pow[y=2] a
    f[35m:f32[2,2][39m = integer_pow[y=2] a
  [34;1min [39;22m(f,) }

Note: vmap adds batching dimensions to operations


### 2.6 StatefulFunction Caching

`StatefulFunction` caches compiled jaxprs for efficiency.

In [12]:
# Example 6: Understanding compilation caching
print("\n=== Example 6: Compilation caching ===")

state = brainstate.ShortTermState(jnp.array(0.0))

def cached_fn(x):
    state.value = state.value + jnp.sum(x)
    return state.value

sf = StatefulFunction(cached_fn)

# First call: compile
sf.make_jaxpr(jnp.array([1.0, 2.0]))
stats1 = sf.get_cache_stats()
print("After first compilation:")
print(f"  Jaxpr cache: {stats1['jaxpr_cache']}")

# Same shape: cache hit
sf.make_jaxpr(jnp.array([3.0, 4.0]))
stats2 = sf.get_cache_stats()
print("\nAfter same-shape call:")
print(f"  Jaxpr cache: {stats2['jaxpr_cache']}")
print(f"  Hit rate: {stats2['jaxpr_cache']['hit_rate']:.1f}%")

# Different shape: new compilation
sf.make_jaxpr(jnp.array([1.0, 2.0, 3.0]))
stats3 = sf.get_cache_stats()
print("\nAfter different-shape call:")
print(f"  Jaxpr cache: {stats3['jaxpr_cache']}")
print(f"  Cache size: {stats3['jaxpr_cache']['size']} entries")

print("\nCaching strategy:")
print("  - Different shapes → new compilation")
print("  - Same shapes → cache reuse")
print("  - Cache size limited to 128 entries (LRU)")


=== Example 6: Compilation caching ===
After first compilation:
  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}

After same-shape call:
  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
  Hit rate: 0.0%

After different-shape call:
  Jaxpr cache: {'size': 2, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
  Cache size: 2 entries

Caching strategy:
  - Different shapes → new compilation
  - Same shapes → cache reuse
  - Cache size limited to 128 entries (LRU)


## 3. Debugging with `jax.debug.print`

`jax.debug.print` enables runtime debugging in JIT-compiled code. Unlike regular `print`, it:
- Executes during runtime (not tracing)
- Works inside `@jit`, `vmap`, `grad`, etc.
- Supports formatted output
- Can print array values and shapes

### Key principle: Debug prints happen at execution time, not trace time

### 3.1 Basic Debug Printing

In [13]:
# Example 1: Basic debug printing in JIT
print("=== Example 1: Debug printing in JIT ===")

@brainstate.transform.jit
def compute_with_debug(x):
    jax.debug.print("Input: {x}", x=x)
    y = x ** 2
    jax.debug.print("After square: {y}", y=y)
    z = jnp.sum(y)
    jax.debug.print("Sum: {z}", z=z)
    return z

result = compute_with_debug(jnp.array([1.0, 2.0, 3.0]))
print(f"\nFinal result: {result}")
print("\nNote: Debug prints appear during execution, not compilation")

=== Example 1: Debug printing in JIT ===
Input: [1. 2. 3.]
After square: [1. 4. 9.]
Sum: 14.0

Final result: 14.0

Note: Debug prints appear during execution, not compilation


### 3.2 Debugging State Updates

In [14]:
# Example 2: Debugging stateful computations
print("\n=== Example 2: Debug state updates ===")

class DebuggableCell(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.state = brainstate.ShortTermState(jnp.zeros(size))
        self.weight = brainstate.ParamState(jax.random.normal(jax.random.PRNGKey(0), (size, size)))
    
    def step(self, x):
        jax.debug.print("Before update - state: {s}", s=self.state.value)
        
        # Update
        new_state = jnp.tanh(x @ self.weight.value + self.state.value)
        jax.debug.print("Computed new state: {s}", s=new_state)
        
        self.state.value = new_state
        jax.debug.print("After update - state: {s}", s=self.state.value)
        
        return new_state

cell = DebuggableCell(size=3)

@brainstate.transform.jit
def update_step(x):
    return cell.step(x)

x = jnp.array([1.0, 0.0, -1.0])
output = update_step(x)
print(f"\nOutput: {output}")


=== Example 2: Debug state updates ===
Before update - state: [0. 0. 0.]
Computed new state: [ 0.97147846  0.9105761  -0.79975927]
After update - state: [ 0.97147846  0.9105761  -0.79975927]

Output: [ 0.97147846  0.9105761  -0.79975927]


### 3.3 Debugging Gradients

In [15]:
# Example 3: Debug gradient computation
print("\n=== Example 3: Debug gradients ===")

param = brainstate.ParamState(jnp.array([2.0, 3.0]))

def loss_with_debug(x):
    jax.debug.print("Forward - param: {p}, input: {x}", p=param.value, x=x)
    
    pred = param.value * x
    jax.debug.print("Forward - prediction: {pred}", pred=pred)
    
    loss = jnp.sum(pred ** 2)
    jax.debug.print("Forward - loss: {loss}", loss=loss)
    
    return loss

# Gradient computation
x = jnp.array([0.5, 1.0])
grad_fn = brainstate.transform.grad(loss_with_debug, param)

print("\nComputing gradients:")
grads = grad_fn(x)
print(f"\nGradients: {grads}")
print("\nNote: Debug prints show forward pass values during gradient computation")


=== Example 3: Debug gradients ===

Computing gradients:
Forward - param: [2. 3.], input: [0.5 1. ]
Forward - prediction: [1. 3.]
Forward - loss: 10.0

Gradients: [1. 6.]

Note: Debug prints show forward pass values during gradient computation


### 3.4 Debugging Vectorized Code

In [16]:
# Example 4: Debug vmap
print("\n=== Example 4: Debug vectorized code ===")

def process_item(x, index):
    jax.debug.print("Processing item {i}: {x}", i=index, x=x)
    return x ** 2

# Vmap over both arguments
vmapped_fn = brainstate.transform.vmap2(process_item)

batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
indices = jnp.arange(len(batch_x))

print("\nProcessing batch:")
results = vmapped_fn(batch_x, indices)
print(f"\nResults: {results}")
print("\nNote: Debug prints execute for each element in the batch")


=== Example 4: Debug vectorized code ===

Processing batch:
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0

Results: [ 1.  4.  9. 16.]

Note: Debug prints execute for each element in the batch


### 3.5 Conditional Debugging

In [18]:
# Example 5: Conditional debug prints
print("\n=== Example 5: Conditional debugging ===")

iteration = brainstate.ShortTermState(jnp.array(0))

def training_step_with_debug(x, debug_every=5):
    # Update iteration
    iteration.value = iteration.value + 1
    
    # Conditional debug print
    jax.debug.print(
        "Iteration {iter}: x={x}",
        iter=iteration.value,
        x=x,
    )
    
    loss = jnp.sum(x ** 2)
    return loss

@brainstate.transform.jit
def train_step(x):
    return training_step_with_debug(x, debug_every=3)

print("\nRunning 10 training steps:")
for i in range(10):
    x = jax.random.normal(jax.random.PRNGKey(i), (5,))
    loss = train_step(x)

print("\nNote: Debug prints only at iterations 3, 6, 9")


=== Example 5: Conditional debugging ===

Running 10 training steps:
Iteration 1: x=[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909 ]
Iteration 2: x=[-0.15443718  0.08470728 -0.13598049 -0.15503626  1.2666674 ]
Iteration 3: x=[ 0.36057416  1.2849895  -0.73873436  1.1830745  -0.20641916]
Iteration 4: x=[-1.446257    1.539381    0.38250625  1.9707018  -0.5876674 ]
Iteration 5: x=[ 1.1777242   0.73848104 -1.0801564   0.3344669   0.00339968]
Iteration 6: x=[-0.08437306  1.4110229   0.63048154 -1.3100973   1.3689315 ]
Iteration 7: x=[ 0.3864717  -0.57079715 -1.678261   -1.203193    1.0770401 ]
Iteration 8: x=[ 0.45123515  1.9534509  -0.51623946 -0.1409403   0.6154967 ]
Iteration 9: x=[-0.55150557 -1.369112    2.7549403   0.5639917  -1.0112009 ]
Iteration 10: x=[-1.7417272   1.8461128  -0.20227258 -1.27005    -0.7593621 ]

Note: Debug prints only at iterations 3, 6, 9


### 3.6 Advanced: Custom Debug Callbacks

In [None]:
# Example 6: Custom debugging with callbacks
print("\n=== Example 6: Custom debug callbacks ===")

def custom_debug_callback(name, value):
    """Custom callback for detailed debugging."""
    print(f"[DEBUG {name}]:")
    print(f"  Shape: {value.shape}")
    print(f"  Dtype: {value.dtype}")
    print(f"  Min: {jnp.min(value):.4f}")
    print(f"  Max: {jnp.max(value):.4f}")
    print(f"  Mean: {jnp.mean(value):.4f}")
    print(f"  Std: {jnp.std(value):.4f}")

@brainstate.transform.jit
def compute_with_callback(x):
    # Use debug callback for detailed inspection
    jax.debug.callback(custom_debug_callback, "input", x)
    
    y = jnp.tanh(x)
    jax.debug.callback(custom_debug_callback, "after_tanh", y)
    
    z = y @ y.T
    jax.debug.callback(custom_debug_callback, "output", z)
    
    return z

x = jax.random.normal(jax.random.PRNGKey(42), (5, 5))
print("\nExecuting with custom debug callbacks:")
result = compute_with_callback(x)
print(f"\nFinal result shape: {result.shape}")

## Summary

This tutorial covered three essential BrainState utilities:

### 1. `checkpoint`: Memory-Efficient Gradients
- **Purpose**: Reduce memory usage during gradient computation
- **Mechanism**: Recompute activations during backward pass instead of storing them
- **Tradeoff**: ~2x computation for significant memory savings (O(√n) vs O(n))
- **When to use**: Deep networks (>10 layers), wide networks (>256 hidden), long sequences
- **Advanced**: Custom policies control what to save vs. recompute
- **State-aware**: Works seamlessly with BrainState's `State` objects

### 2. `make_jaxpr` and `StatefulFunction`: Code Inspection
- **Purpose**: Understand how JAX compiles and optimizes your code
- **Jaxpr**: JAX's intermediate representation showing primitive operations and data flow
- **StatefulFunction**: Core mechanism enabling all BrainState transformations
  - Identifies state reads and writes
  - Compiles to Jaxpr with explicit state handling
  - Caches compilations for efficiency (LRU cache, 128 entries)
  - Manages state values automatically
- **Use cases**: Debugging compilation issues, understanding transformations, optimization analysis

### 3. `jax.debug.print`: Runtime Debugging
- **Purpose**: Debug JIT-compiled code during execution
- **Key features**:
  - Prints at runtime (not trace time)
  - Works inside `@jit`, `vmap`, `grad`, etc.
  - Supports formatted output and array inspection
- **Best practices**:
  - Use debug flags to enable/disable
  - Print statistics, not full arrays
  - Check for NaN/Inf in critical ops
  - Use callbacks for complex debugging
  - Disable in production

### Integration with BrainState
All three tools are **state-aware**:
- `checkpoint` preserves state semantics during rematerialization
- `make_jaxpr` reveals state reads/writes in compiled code
- `jax.debug.print` can inspect state values during execution

These utilities are essential for developing, optimizing, and debugging complex stateful models in BrainState.