# Loops and Conditionals in BrainState

This tutorial covers state-aware control flow primitives in `brainstate.transform`. These APIs provide JAX-compatible loops and conditionals while safely handling `State` objects.

We'll explore three categories of control flow:

1. **Loop Transformations**: `scan`, `checkpointed_scan`, `for_loop`, `checkpointed_for_loop`
2. **While Loops**: `while_loop`, `bounded_while_loop`
3. **Conditional Control Flow**: `cond`, `switch`, `ifelse`

Each API is designed to work seamlessly with BrainState's state management system while maintaining JAX's functional programming paradigm.

## Imports and Setup

In [32]:
import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import (
    scan,
    checkpointed_scan,
    for_loop,
    checkpointed_for_loop,
    while_loop,
    bounded_while_loop,
    cond,
    switch,
    ifelse,
)

In [33]:
# Import ProgressBar
from brainstate.transform import ProgressBar

## 1. Loop Transformations

Loop transformations provide efficient iteration over sequences with state tracking. They compile to a single JAX primitive, reducing compilation overhead.

### 1.1 `scan`: Stateful Scanning with Carry

`scan` is the fundamental loop primitive that:
- Iterates over a sequence along the leading axis
- Maintains a "carry" value that threads through iterations
- Collects outputs at each step
- Properly handles `State` objects

**Function signature:**
```python
scan(
    f: Callable[[Carry, X], Tuple[Carry, Y]],
    init: Carry,
    xs: X,
    length: int | None = None,
    reverse: bool = False,
    unroll: int | bool = 1,
    pbar: ProgressBar | int | None = None,
) -> Tuple[Carry, Y]
```

**Parameters:**
- `f`: Function of type `(carry, x) -> (new_carry, output)`
- `init`: Initial carry value
- `xs`: Sequence to iterate over (along axis 0)
- `length`: Optional iteration count (inferred from `xs` if not provided)
- `reverse`: If True, iterate in reverse order
- `unroll`: Number of iterations to unroll (1=no unrolling, True=full unrolling)
- `pbar`: Optional progress bar

In [34]:
# Example 1: Basic scan with carry
def cumsum_body(carry, x):
    """Accumulate sum and return both new carry and current sum."""
    new_carry = carry + x
    return new_carry, new_carry


xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
final_sum, cumulative_sums = scan(cumsum_body, init=0.0, xs=xs)

print("Input sequence:", xs)
print("Final sum:", final_sum)
print("Cumulative sums:", cumulative_sums)

Input sequence: [1. 2. 3. 4. 5.]
Final sum: 15.0
Cumulative sums: [ 1.  3.  6. 10. 15.]


In [35]:
# Example 2: Scan with stateful computation
class RunningStats(brainstate.nn.Module):
    """Maintain running mean and variance."""

    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.array(0))
        self.mean = brainstate.ShortTermState(jnp.array(0.0))
        self.m2 = brainstate.ShortTermState(jnp.array(0.0))  # sum of squared differences

    def update(self, x):
        """Update statistics with new value using Welford's algorithm."""
        self.count.value = self.count.value + 1
        delta = x - self.mean.value
        self.mean.value = self.mean.value + delta / self.count.value
        delta2 = x - self.mean.value
        self.m2.value = self.m2.value + delta * delta2

        variance = self.m2.value / self.count.value
        return {'mean': self.mean.value, 'var': variance}


stats = RunningStats()


def stats_body(carry, x):
    result = stats.update(x)
    return carry, result


data = jnp.array([2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0])
_, history = scan(stats_body, init=None, xs=data)

print("Data:", data)
print("\nRunning mean:", history['mean'])
print("Running variance:", history['var'])
print("\nFinal statistics:")
print(f"  Count: {stats.count.value}")
print(f"  Mean: {stats.mean.value}")
print(f"  Variance: {stats.m2.value / stats.count.value}")

Data: [2. 4. 4. 4. 5. 5. 7. 9.]

Running mean: [2.        3.        3.3333333 3.5       3.8       4.        4.428571
 5.       ]
Running variance: [0.         1.         0.8888889  0.75       0.96000004 1.
 1.9591838  4.        ]

Final statistics:
  Count: 8
  Mean: 5.0
  Variance: 4.0


In [36]:
# Example 3: Reverse scan
def reverse_cumsum(carry, x):
    new_carry = carry + x
    return new_carry, new_carry


xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
_, forward_sums = scan(reverse_cumsum, 0.0, xs, reverse=False)
_, backward_sums = scan(reverse_cumsum, 0.0, xs, reverse=True)

print("Input:", xs)
print("Forward cumsum:", forward_sums)
print("Backward cumsum:", backward_sums)

Input: [1. 2. 3. 4. 5.]
Forward cumsum: [ 1.  3.  6. 10. 15.]
Backward cumsum: [15. 14. 12.  9.  5.]


#### Progress Bar with `scan`

The `pbar` parameter enables progress tracking during long-running scans. You can:
- Pass a `ProgressBar` instance for full control over display options
- Pass an integer for quick setup (updates every N iterations)
- Customize the description with static or dynamic messages

In [37]:
# Example 4: Progress bar with scan - simple integer freq
print("\n=== Simple progress bar (update every 20 iterations) ===")


def expensive_computation(carry, x):
    """Simulate expensive computation."""
    # Some computation
    result = carry + jnp.sin(x) * jnp.cos(x)
    return result, result


# Create long sequence
long_sequence = jnp.linspace(0, 10 * jnp.pi, 100)

# Use integer for simple progress bar (updates every 20 iterations)
final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=20)
print(f"\nFinal result: {final}")


=== Simple progress bar (update every 20 iterations) ===


  0%|          | 0/100 [00:00<?, ?it/s]


Final result: -4.0076361074170563e-07


In [38]:
# Example 5: Progress bar with custom ProgressBar instance
print("\n=== Custom progress bar with ProgressBar ===")

# Create ProgressBar with custom settings
pbar = ProgressBar(freq=10, desc="Processing sequence")

final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=pbar)
print(f"\nCompleted! Final result: {final}")


=== Custom progress bar with ProgressBar ===


  0%|          | 0/100 [00:00<?, ?it/s]


Completed! Final result: -4.0076361074170563e-07


In [39]:
# Example 6: Dynamic progress bar description based on loop state
print("\n=== Dynamic progress bar with loop state ===")


class OptimizationTracker(brainstate.nn.Module):
    """Track optimization progress."""

    def __init__(self):
        super().__init__()
        self.best_loss = brainstate.ShortTermState(jnp.array(float('inf')))

    def step(self, params, x):
        # Compute loss
        loss = jnp.sum((params - x) ** 2)
        # Update best
        self.best_loss.value = jnp.minimum(self.best_loss.value, loss)
        # Update parameters
        new_params = params - 0.1 * 2 * (params - x)
        return new_params, loss


tracker = OptimizationTracker()


def scan_body_with_tracking(params, x):
    return tracker.step(params, x)


# Define dynamic description
def format_progress(data):
    """Format progress with current loss and best loss."""
    return {
        "iter": data["i"],
        "loss": data["y"],
        "best": tracker.best_loss.value
    }


pbar_dynamic = ProgressBar(
    freq=15,
    desc=("Iter {iter:3d} | Loss: {loss:.4f} | Best: {best:.4f}", format_progress)
)

targets = jax.random.normal(jax.random.PRNGKey(42), (100,))
init_params = jnp.array(0.0)

final_params, loss_history = scan(
    scan_body_with_tracking,
    init=init_params,
    xs=targets,
    pbar=pbar_dynamic
)

print(f"\nOptimization completed!")
print(f"Final parameters: {final_params}")
print(f"Final loss: {loss_history[-1]}")
print(f"Best loss achieved: {tracker.best_loss.value}")


=== Dynamic progress bar with loop state ===


  0%|          | 0/100 [00:00<?, ?it/s]


Optimization completed!
Final parameters: -0.04794257506728172
Final loss: 1.41942298412323
Best loss achieved: 3.858334093820304e-05


### 1.2 `checkpointed_scan`: Memory-Efficient Scanning

`checkpointed_scan` is a memory-optimized version of `scan` that uses gradient checkpointing. This is crucial for:
- Long sequences where storing all intermediate activations is memory-prohibitive **during gradient computation**
- Trading computation time for memory during backpropagation
- **Memory efficiency is achieved by only storing checkpoints at regular intervals during the forward pass, then recomputing intermediate values during the backward pass when needed**

**Function signature:**
```python
checkpointed_scan(
    f: Callable[[Carry, X], Tuple[Carry, Y]],
    init: Carry,
    xs: X,
    length: Optional[int] = None,
    base: int = 16,
    pbar: Optional[ProgressBar | int] = None,
) -> Tuple[Carry, Y]
```

**Key parameter:**
- `base`: Checkpointing base (default=16). Smaller values save more memory but increase recomputation during backward pass. The implementation uses a hierarchical checkpointing scheme where `max_steps = base^k` for some k.

**Memory savings during gradient computation:**
- Regular `scan`: Stores **all** intermediate activations → O(n) memory for sequence length n
- `checkpointed_scan`: Stores only checkpoints → O(log_base(n)) memory
- During backward pass: Recomputes intermediate values between checkpoints as needed

In [40]:
# Example: Memory-efficient scan for gradient computation
class RecurrentCell(brainstate.nn.Module):
    """Simple recurrent cell with hidden state."""

    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.weight = brainstate.ParamState(jax.random.normal(
            jax.random.PRNGKey(0), (hidden_size, hidden_size)
        ))

    def step(self, hidden, x):
        """Single recurrent step."""
        new_hidden = jnp.tanh(jnp.dot(self.weight.value, hidden) + x)
        return new_hidden


# Create a cell and input sequence
cell = RecurrentCell(hidden_size=32)
sequence_length = 100
inputs = jax.random.normal(jax.random.PRNGKey(1), (sequence_length, 32))


def rnn_body(hidden, x):
    new_hidden = cell.step(hidden, x)
    return new_hidden, new_hidden


# Use checkpointed scan for memory efficiency during gradient computation
init_hidden = jnp.zeros(32)
final_hidden, all_hiddens = checkpointed_scan(
    rnn_body,
    init=init_hidden,
    xs=inputs,
    base=8  # Checkpoint every 8 steps
)

print(f"Sequence length: {sequence_length}")
print(f"Hidden size: {cell.hidden_size}")
print(f"Final hidden shape: {final_hidden.shape}")
print(f"All hiddens shape: {all_hiddens.shape}")
print(f"\nCheckpointing configuration:")
print(f"  Base: 8 (stores checkpoint every 8 steps)")
print(f"  Memory saved: Stores ~{sequence_length // 8} checkpoints instead of {sequence_length} activations")
print(f"  During backprop: Recomputes activations between checkpoints as needed")

Sequence length: 100
Hidden size: 32
Final hidden shape: (32,)
All hiddens shape: (100, 32)

Checkpointing configuration:
  Base: 8 (stores checkpoint every 8 steps)
  Memory saved: Stores ~12 checkpoints instead of 100 activations
  During backprop: Recomputes activations between checkpoints as needed


#### Progress Bar with `checkpointed_scan`

`checkpointed_scan` also supports progress bars, which is especially useful for very long sequences where you want to monitor progress.

In [41]:
# Example: Progress bar with checkpointed_scan
print("\n=== Checkpointed scan with progress bar ===")


class LongRunningComputation(brainstate.nn.Module):
    """Simulate a long-running computation."""

    def __init__(self):
        super().__init__()
        self.total_ops = brainstate.ShortTermState(jnp.array(0))

    def process(self, state, x):
        self.total_ops.value = self.total_ops.value + 1
        # Some computation
        new_state = state + jnp.tanh(x)
        output = jnp.sin(new_state) * jnp.cos(x)
        return new_state, output


long_comp = LongRunningComputation()


def body(state, x):
    return long_comp.process(state, x)


# Long sequence
very_long_sequence = jnp.linspace(0, 20 * jnp.pi, 500)

# Progress bar that updates every 50 iterations
pbar_checkpointed = ProgressBar(
    freq=50,
    desc="Checkpointed scan progress"
)

final_state, results = checkpointed_scan(
    body,
    init=0.0,
    xs=very_long_sequence,
    base=10,
    pbar=pbar_checkpointed
)

print(f"\nProcessed {long_comp.total_ops.value} operations")
print(f"Final state: {final_state}")
print(f"Results shape: {results.shape}")


=== Checkpointed scan with progress bar ===


  0%|          | 0/500 [00:00<?, ?it/s]


Processed 500 operations
Final state: 493.9846496582031
Results shape: (500,)


### 1.3 `for_loop`: Simplified Loop Without Carry

`for_loop` provides a simpler interface when you don't need an explicit carry value. It:
- Accepts variadic arguments that are sliced along axis 0
- **Collects and returns outputs from each iteration** - the return value from your function at each timestep is saved and stacked into the final output array
- Internally uses `scan` with `None` as the carry

**Function signature:**
```python
for_loop(
    f: Callable[..., Y],
    *xs,
    length: Optional[int] = None,
    reverse: bool = False,
    unroll: int | bool = 1,
    pbar: Optional[ProgressBar | int] = None
) -> Y
```

**Key differences from scan:**
- Function signature is `(*xs) -> output` instead of `(carry, x) -> (carry, output)`
- No carry value to manage
- **Important**: The return value at **each iteration** is collected and stacked along axis 0 to form the final output. This means if your function returns a scalar at each step, `for_loop` returns a 1D array; if it returns a vector of shape `(d,)`, the output will be shape `(n, d)` where `n` is the number of iterations.

In [42]:
# Example 1: Understanding output collection in for_loop
def compute(x, y, z):
    """Combine three inputs."""
    return x * y + z


xs = jnp.array([1.0, 2.0, 3.0, 4.0])
ys = jnp.array([2.0, 3.0, 4.0, 5.0])
zs = jnp.array([0.5, 1.0, 1.5, 2.0])

# for_loop collects the output from EACH iteration
results = for_loop(compute, xs, ys, zs)

print("x:", xs)
print("y:", ys)
print("z:", zs)
print("x * y + z:", results)
print(f"\nNotice: for_loop collected {len(results)} outputs (one per iteration)")
print(f"Each element results[i] = xs[i] * ys[i] + zs[i]")
print(f"Output shape: {results.shape} (stacked along axis 0)")

x: [1. 2. 3. 4.]
y: [2. 3. 4. 5.]
z: [0.5 1.  1.5 2. ]
x * y + z: [ 2.5  7.  13.5 22. ]

Notice: for_loop collected 4 outputs (one per iteration)
Each element results[i] = xs[i] * ys[i] + zs[i]
Output shape: (4,) (stacked along axis 0)


In [43]:
# Example 2: Stateful for_loop
class Accumulator(brainstate.nn.Module):
    """Simple accumulator that tracks total and count."""

    def __init__(self):
        super().__init__()
        self.total = brainstate.ShortTermState(jnp.array(0.0))
        self.count = brainstate.ShortTermState(jnp.array(0))

    def process(self, x):
        self.total.value = self.total.value + x
        self.count.value = self.count.value + 1
        return self.total.value / self.count.value  # running average


acc = Accumulator()

data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
running_averages = for_loop(acc.process, data)

print("Data:", data)
print("Running averages:", running_averages)
print(f"\nFinal state: total={acc.total.value}, count={acc.count.value}")
print(f"Final average: {acc.total.value / acc.count.value}")

Data: [1. 2. 3. 4. 5. 6.]
Running averages: [1.  1.5 2.  2.5 3.  3.5]

Final state: total=21.0, count=6
Final average: 3.5


#### Progress Bar with `for_loop`

`for_loop` also supports progress bars. This is particularly useful when processing large batches of data.

In [44]:
# Example 3: Progress bar with for_loop - simple case
print("\n=== For loop with progress bar ===")


class DataProcessor(brainstate.nn.Module):
    """Process data with progress tracking."""

    def __init__(self):
        super().__init__()
        self.processed_count = brainstate.ShortTermState(jnp.array(0))
        self.sum_val = brainstate.ShortTermState(jnp.array(0.0))

    def process_item(self, x):
        self.processed_count.value = self.processed_count.value + 1
        self.sum_val.value = self.sum_val.value + x
        # Simulate some processing
        result = jnp.exp(x) / (1 + jnp.exp(x))  # sigmoid
        return result


processor = DataProcessor()

# Create dataset
dataset = jax.random.normal(jax.random.PRNGKey(123), (200,))

# Use simple integer for progress updates
processed = for_loop(processor.process_item, dataset, pbar=25)

print(f"\nProcessed {processor.processed_count.value} items")
print(f"Sum of inputs: {processor.sum_val.value}")
print(f"Processed data shape: {processed.shape}")


=== For loop with progress bar ===


  0%|          | 0/200 [00:00<?, ?it/s]


Processed 200 items
Sum of inputs: 9.437446594238281
Processed data shape: (200,)


In [45]:
# Example 4: For loop with dynamic progress description
print("\n=== For loop with dynamic progress description ===")


class BatchProcessor(brainstate.nn.Module):
    """Process batches with statistics."""

    def __init__(self):
        super().__init__()
        self.mean = brainstate.ShortTermState(jnp.array(0.0))
        self.variance = brainstate.ShortTermState(jnp.array(0.0))
        self.count = brainstate.ShortTermState(jnp.array(0))

    def update(self, x):
        self.count.value = self.count.value + 1
        delta = x - self.mean.value
        self.mean.value = self.mean.value + delta / self.count.value
        self.variance.value = self.variance.value + delta * (x - self.mean.value)
        return x ** 2


batch_proc = BatchProcessor()


def format_batch_progress(data):
    """Show current statistics."""
    return {
        "n": data["i"],
        "mean": batch_proc.mean.value,
        "var": batch_proc.variance.value / jnp.maximum(batch_proc.count.value, 1)
    }


pbar_batch = ProgressBar(
    freq=20,
    desc=("Batch {n:3d} | Mean: {mean:+.3f} | Var: {var:.3f}", format_batch_progress)
)

batch_data = jax.random.normal(jax.random.PRNGKey(456), (150,)) * 2.0 + 1.0

squared = for_loop(batch_proc.update, batch_data, pbar=pbar_batch)

print(f"\nFinal statistics:")
print(f"  Mean: {batch_proc.mean.value}")
print(f"  Variance: {batch_proc.variance.value / batch_proc.count.value}")
print(f"  Count: {batch_proc.count.value}")


=== For loop with dynamic progress description ===


  0%|          | 0/150 [00:00<?, ?it/s]


Final statistics:
  Mean: 0.9223809242248535
  Variance: 3.488231897354126
  Count: 150


### 1.4 `checkpointed_for_loop`: Memory-Efficient For Loop

The checkpointed version of `for_loop` combines the simplicity of `for_loop` with the memory efficiency of checkpointing **during gradient computation**.

**Function signature:**
```python
checkpointed_for_loop(
    f: Callable[..., Y],
    *xs: X,
    length: Optional[int] = None,
    base: int = 16,
    pbar: Optional[ProgressBar | int] = None,
) -> Y
```

**Memory efficiency during gradient computation:**
- Like `checkpointed_scan`, this variant significantly reduces memory usage during backpropagation
- Essential for training models with very long sequences where storing all intermediate activations would cause out-of-memory errors
- The `base` parameter controls the memory/computation tradeoff: smaller base = less memory but more recomputation

In [46]:
# Example: Processing long sequence with state
class ExpMovingAverage(brainstate.nn.Module):
    """Exponential moving average."""

    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.ema = brainstate.ShortTermState(jnp.array(0.0))
        self.initialized = brainstate.ShortTermState(jnp.array(False))

    def update(self, x):
        # Initialize with first value
        self.ema.value = jnp.where(
            self.initialized.value,
            self.alpha * x + (1 - self.alpha) * self.ema.value,
            x
        )
        self.initialized.value = True
        return self.ema.value


ema = ExpMovingAverage(alpha=0.3)

# Generate noisy signal
signal = jnp.sin(jnp.linspace(0, 4 * jnp.pi, 200)) + 0.2 * brainstate.random.normal(size=(200,))

# Process with checkpointed for_loop
smoothed = checkpointed_for_loop(ema.update, signal, base=10)

print(f"Signal length: {len(signal)}")
print(f"Smoothed signal shape: {smoothed.shape}")
print(f"Original signal range: [{signal.min():.3f}, {signal.max():.3f}]")
print(f"Smoothed signal range: [{smoothed.min():.3f}, {smoothed.max():.3f}]")

Signal length: 200
Smoothed signal shape: (200,)
Original signal range: [-1.373, 1.384]
Smoothed signal range: [-1.167, 1.081]


#### Progress Bar with `checkpointed_for_loop`

`checkpointed_for_loop` supports progress bars to help track processing of very long sequences.

In [47]:
# Example: Progress bar with checkpointed_for_loop
print("\n=== Checkpointed for loop with progress bar ===")


class StreamProcessor(brainstate.nn.Module):
    """Process streaming data."""

    def __init__(self, momentum=0.9):
        super().__init__()
        self.momentum = momentum
        self.running_avg = brainstate.ShortTermState(jnp.array(0.0))

    def process(self, x):
        # Update exponential moving average
        self.running_avg.value = (
            self.momentum * self.running_avg.value + (1 - self.momentum) * x
        )
        return self.running_avg.value


stream_proc = StreamProcessor(momentum=0.95)

# Generate long data stream
data_stream = jax.random.normal(jax.random.PRNGKey(789), (1000,))

# Progress bar with count parameter (updates exactly 10 times)
pbar_stream = ProgressBar(
    count=10,
    desc="Processing data stream"
)

smoothed_stream = checkpointed_for_loop(
    stream_proc.process,
    data_stream,
    base=20,
    pbar=pbar_stream
)

print(f"\nStream processed!")
print(f"Final running average: {stream_proc.running_avg.value}")
print(f"Smoothed stream shape: {smoothed_stream.shape}")
print(f"First 5 values: {smoothed_stream[:5]}")
print(f"Last 5 values: {smoothed_stream[-5:]}")


=== Checkpointed for loop with progress bar ===


  0%|          | 0/1000 [00:00<?, ?it/s]


Stream processed!
Final running average: 0.050981856882572174
Smoothed stream shape: (1000,)
First 5 values: [ 0.03755957 -0.01195641  0.0011875   0.03787947  0.04933156]
Last 5 values: [ 0.10296391  0.07546234 -0.00500105  0.06068815  0.05098186]


### 1.5 Comparison: `scan` vs `for_loop`

When to use each:

**Use `scan` when:**
- You need to thread a carry value through iterations
- Implementing recurrent patterns (RNNs, state machines)
- You want explicit control over the accumulator

**Use `for_loop` when:**
- No carry value is needed
- Processing independent items with side effects (state updates)
- Simpler, more Pythonic syntax is preferred

In [48]:
# Comparison example: Computing powers of 2

# Using scan: carry explicitly tracks the power
def scan_version(n):
    def body(carry, _):
        return carry * 2, carry

    _, powers = scan(body, init=1, xs=jnp.arange(n))
    return powers


# Using for_loop with state: state tracks the power
class PowerTracker(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.current = brainstate.ShortTermState(jnp.array(1))

    def next_power(self, _):
        result = self.current.value
        self.current.value = self.current.value * 2
        return result


def forloop_version(n):
    tracker = PowerTracker()
    return for_loop(tracker.next_power, jnp.arange(n))


n = 10
print(f"Powers of 2 (first {n} values):")
print("scan result:    ", scan_version(n))
print("for_loop result:", forloop_version(n))

Powers of 2 (first 10 values):
scan result:     [  1   2   4   8  16  32  64 128 256 512]
for_loop result: [  1   2   4   8  16  32  64 128 256 512]


## 2. While Loops

While loops provide conditional iteration where the number of iterations is not known in advance.

### 2.1 `while_loop`: Dynamic Conditional Iteration

`while_loop` executes a body function repeatedly while a condition remains true. This is the stateful version of `jax.lax.while_loop`.

**Function signature:**
```python
while_loop(
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T
) -> T
```

**Parameters:**
- `cond_fun`: Function that returns True to continue looping
- `body_fun`: Function that updates the loop value
- `init_val`: Initial loop value

**Important constraints:**
- `cond_fun` cannot modify state (read-only)
- Loop value must maintain fixed shape and dtype
- Not reverse-mode differentiable (use `bounded_while_loop` instead)

In [49]:
# Example 1: Simple while loop - find first power of 2 above threshold
def find_power_of_2_above(threshold):
    def cond_fn(val):
        return val < threshold

    def body(val):
        return val * 2

    return while_loop(cond_fn, body, init_val=1)


threshold = 1000
result = find_power_of_2_above(threshold)
print(f"First power of 2 above {threshold}: {result}")

First power of 2 above 1000: 1024


In [50]:
# Example 2: Stateful while loop - iterative refinement
class IterativeRefiner(brainstate.nn.Module):
    """Iteratively refine an estimate using Newton's method."""

    def __init__(self, target):
        super().__init__()
        self.target = target
        self.iterations = brainstate.ShortTermState(jnp.array(0))

    def refine(self, x):
        """Newton's method step for computing sqrt(target)."""
        self.iterations.value = self.iterations.value + 1
        return 0.5 * (x + self.target / x)


# Compute square root of 2 using Newton's method
refiner = IterativeRefiner(target=2.0)


def cond_f(x):
    # Continue until error is small enough
    return jnp.abs(x * x - refiner.target) > 1e-6


def body(x):
    return refiner.refine(x)


result = while_loop(cond_f, body, init_val=1.0)

print(f"Computing sqrt(2)...")
print(f"Result: {result}")
print(f"Actual sqrt(2): {jnp.sqrt(2.0)}")
print(f"Error: {jnp.abs(result - jnp.sqrt(2.0))}")
print(f"Iterations: {refiner.iterations.value}")

Computing sqrt(2)...
Result: 1.4142135381698608
Actual sqrt(2): 1.4142135381698608
Error: 0.0
Iterations: 4


In [51]:
# Example 3: Complex loop value (pytree)
class Collatz(brainstate.nn.Module):
    """Track Collatz sequence statistics."""

    def __init__(self):
        super().__init__()
        self.max_value = brainstate.ShortTermState(jnp.array(0))

    def step(self, n):
        self.max_value.value = jnp.maximum(self.max_value.value, n)
        return jnp.where(n % 2 == 0, n // 2, 3 * n + 1)


collatz = Collatz()


def collatz_cond(state):
    n, steps = state
    return n > 1


def collatz_body(state):
    n, steps = state
    return collatz.step(n), steps + 1


start_value = 27
final_n, total_steps = while_loop(
    collatz_cond,
    collatz_body,
    init_val=(start_value, 0)
)

print(f"Collatz sequence starting from {start_value}:")
print(f"  Converged to: {final_n}")
print(f"  Steps taken: {total_steps}")
print(f"  Maximum value reached: {collatz.max_value.value}")

Collatz sequence starting from 27:
  Converged to: 1
  Steps taken: 111
  Maximum value reached: 9232


### 2.2 `bounded_while_loop`: While Loop with Maximum Steps

`bounded_while_loop` adds a maximum iteration limit to while loops. This is important for:
- Preventing infinite loops
- Enabling reverse-mode differentiation (unlike `while_loop`)
- Providing compilation time guarantees

**Function signature:**
```python
bounded_while_loop(
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    max_steps: int,
    base: int = 16,
)
```

**Key parameters:**
- `max_steps`: Maximum number of iterations before termination
- `base`: Compilation/runtime tradeoff (default=16)
  - Larger base = faster compilation, slightly slower runtime
  - Smaller base = slower compilation, faster runtime
  - Compile time scales with `math.ceil(math.log(max_steps, base))`

In [52]:
# Example 1: Gradient descent with bounded iterations
class GradientDescent(brainstate.nn.Module):
    """Simple gradient descent optimizer."""

    def __init__(self, learning_rate=0.1):
        super().__init__()
        self.lr = learning_rate
        self.steps = brainstate.ShortTermState(jnp.array(0))

    def step(self, x):
        # Gradient of f(x) = (x - 3)^2
        grad = 2 * (x - 3.0)
        self.steps.value = self.steps.value + 1
        return x - self.lr * grad


optimizer = GradientDescent(learning_rate=0.1)


def converged(x):
    # Continue if far from optimum
    return jnp.abs(x - 3.0) > 1e-4


result = bounded_while_loop(
    converged,
    optimizer.step,
    init_val=0.0,
    max_steps=100
)

print(f"Minimizing f(x) = (x - 3)^2")
print(f"Starting from x = 0.0")
print(f"Final x: {result}")
print(f"Target x: 3.0")
print(f"Error: {jnp.abs(result - 3.0)}")
print(f"Iterations used: {optimizer.steps.value} / 100")

Minimizing f(x) = (x - 3)^2
Starting from x = 0.0
Final x: 3.001088857650757
Target x: 3.0
Error: 0.001088857650756836
Iterations used: 252 / 100


In [53]:
# Example 2: Comparing different base values
class Counter(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.array(0))

    def increment(self, x):
        self.count.value = self.count.value + 1
        return x + 1


def compare_base_values():
    max_steps = 100

    for base in [2, 8, 16]:
        counter = Counter()

        result = bounded_while_loop(
            lambda x: x < 50,
            counter.increment,
            init_val=0,
            max_steps=max_steps,
            base=base
        )

        recursion_depth = jnp.ceil(jnp.log(max_steps) / jnp.log(base))
        print(f"Base {base:2d}: result={result}, iterations={counter.count.value}, "
              f"recursion_depth≈{int(recursion_depth)}")


compare_base_values()

Base  2: result=206, iterations=50, recursion_depth≈7
Base  8: result=3746, iterations=50, recursion_depth≈3
Base 16: result=3346, iterations=50, recursion_depth≈2


In [54]:
# Example 3: Differentiable bounded_while_loop
def smooth_threshold(x, threshold=5.0, lr=0.5, max_steps=20):
    """Smoothly approach threshold using gradient descent."""

    def cond_fn(val):
        return val < threshold - 0.1

    def body(val):
        # Gradient of loss = (val - threshold)^2
        grad = 2 * (val - threshold)
        return val - lr * grad

    return bounded_while_loop(cond_fn, body, x, max_steps=max_steps)


# Compute gradient
x = 0.0
value, grad = jax.value_and_grad(smooth_threshold)(x)

print(f"Input: {x}")
print(f"Output: {value}")
print(f"Gradient: {grad}")
print(f"\nbounded_while_loop is differentiable!")

Input: 0.0
Output: 4085.0
Gradient: 0.0

bounded_while_loop is differentiable!


### 2.3 Comparison: `while_loop` vs `bounded_while_loop`

**Use `while_loop` when:**
- Number of iterations is truly unknown
- Not computing gradients
- Want standard JAX while loop semantics

**Use `bounded_while_loop` when:**
- Need gradient computation
- Want safety against infinite loops
- Can provide reasonable upper bound on iterations
- Need predictable compilation characteristics

## 3. Conditional Control Flow

Conditional primitives enable branching logic that compiles efficiently and handles state properly.

### 3.1 `cond`: Binary Conditional (If/Else)

`cond` selectively executes one of two branches based on a boolean predicate. This is the stateful version of `jax.lax.cond`.

**Function signature:**
```python
cond(
    pred,
    true_fun: Callable,
    false_fun: Callable,
    *operands
)
```

**Parameters:**
- `pred`: Boolean scalar (or numeric, where non-zero is True)
- `true_fun`: Function called when `pred` is True
- `false_fun`: Function called when `pred` is False
- `*operands`: Arguments passed to the selected function

**Key properties:**
- Only the selected branch is executed (lazy evaluation)
- Both branches must return the same pytree structure
- State modifications in branches are properly tracked

In [55]:
# Example 1: Simple conditional
def positive_branch(x):
    return x ** 2


def negative_branch(x):
    return -x


for value in [-5.0, 3.0, 0.0]:
    result = cond(value >= 0, positive_branch, negative_branch, value)
    print(f"cond({value} >= 0): {result}")

cond(-5.0 >= 0): 5.0
cond(3.0 >= 0): 9.0
cond(0.0 >= 0): 0.0


In [56]:
# Example 2: Stateful conditional
class BranchTracker(brainstate.nn.Module):
    """Track which branches were taken."""

    def __init__(self):
        super().__init__()
        self.true_count = brainstate.ShortTermState(jnp.array(0))
        self.false_count = brainstate.ShortTermState(jnp.array(0))

    def true_branch(self, x):
        self.true_count.value = self.true_count.value + 1
        return x * 2

    def false_branch(self, x):
        self.false_count.value = self.false_count.value + 1
        return x / 2


tracker = BranchTracker()

# Test multiple values
values = jnp.array([1.0, -2.0, 3.0, -4.0, 5.0])
results = []

for v in values:
    result = cond(v > 0, tracker.true_branch, tracker.false_branch, v)
    results.append(result)

print("Values:", values)
print("Results:", jnp.array(results))
print(f"\nBranch statistics:")
print(f"  True branch taken: {tracker.true_count.value} times")
print(f"  False branch taken: {tracker.false_count.value} times")

Values: [ 1. -2.  3. -4.  5.]
Results: [ 2. -1.  6. -2. 10.]

Branch statistics:
  True branch taken: 3 times
  False branch taken: 2 times


In [57]:
# Example 3: Nested conditionals
class Classifier(brainstate.nn.Module):
    """Classify numbers into categories."""

    def __init__(self):
        super().__init__()
        self.classification_counts = brainstate.ShortTermState({
            'large_positive': jnp.array(0),
            'small_positive': jnp.array(0),
            'small_negative': jnp.array(0),
            'large_negative': jnp.array(0),
        })

    def classify_positive(self, x):
        def large(x):
            counts = self.classification_counts.value
            counts['large_positive'] = counts['large_positive'] + 1
            self.classification_counts.value = counts
            return 'large_positive'

        def small(x):
            counts = self.classification_counts.value
            counts['small_positive'] = counts['small_positive'] + 1
            self.classification_counts.value = counts
            return 'small_positive'

        return cond(x > 5.0, large, small, x)

    def classify_negative(self, x):
        def small(x):
            counts = self.classification_counts.value
            counts['small_negative'] = counts['small_negative'] + 1
            self.classification_counts.value = counts
            return 'small_negative'

        def large(x):
            counts = self.classification_counts.value
            counts['large_negative'] = counts['large_negative'] + 1
            self.classification_counts.value = counts
            return 'large_negative'

        return cond(x > -5.0, small, large, x)

    def classify(self, x):
        return cond(x >= 0, self.classify_positive, self.classify_negative, x)


classifier = Classifier()

with jax.disable_jit():
    test_values = jnp.array([10.0, 2.0, -3.0, -8.0, 7.0, -1.0])
    classifications = [classifier.classify(v) for v in test_values]

    print("Values:", test_values)
    print("Classifications:", classifications)
    print("\nCategory counts:")
    for category, count in classifier.classification_counts.value.items():
        print(f"  {category}: {count}")

Values: [10.  2. -3. -8.  7. -1.]
Classifications: ['large_positive', 'small_positive', 'small_negative', 'large_negative', 'large_positive', 'small_negative']

Category counts:
  large_positive: 2
  small_positive: 1
  small_negative: 2
  large_negative: 1


### 3.2 `switch`: Multi-Way Branching

`switch` generalizes `cond` to multiple branches, similar to a switch/case statement.

**Function signature:**
```python
switch(
    index,
    branches: Sequence[Callable],
    *operands
)
```

**Parameters:**
- `index`: Integer scalar selecting which branch to execute
- `branches`: Sequence of callables (at least 1)
- `*operands`: Arguments passed to the selected branch

**Index handling:**
- Out-of-bounds indices are clamped to `[0, len(branches) - 1]`
- Negative indices are clamped to 0
- Indices >= len(branches) are clamped to len(branches) - 1

In [58]:
# Example 1: Simple multi-way branch
def operation_0(x):
    return x + 1


def operation_1(x):
    return x * 2


def operation_2(x):
    return x ** 2


def operation_3(x):
    return -x


operations = [operation_0, operation_1, operation_2, operation_3]

x = 5.0
for i in range(len(operations)):
    result = switch(i, operations, x)
    print(f"Operation {i} on {x}: {result}")

# Test clamping
print(f"\nOut of bounds (index={len(operations)}): {switch(len(operations), operations, x)}")
print(f"Out of bounds (index={-1}): {switch(-1, operations, x)}")

Operation 0 on 5.0: 6.0
Operation 1 on 5.0: 10.0
Operation 2 on 5.0: 25.0
Operation 3 on 5.0: -5.0

Out of bounds (index=4): -5.0
Out of bounds (index=-1): 6.0


In [59]:
# Example 2: Stateful switch - activation function selector
class ActivationSelector(brainstate.nn.Module):
    """Select and apply different activation functions."""

    def __init__(self):
        super().__init__()
        self.usage_counts = brainstate.ShortTermState(jnp.zeros(5, dtype=jnp.int32))

    def _track_usage(self, index):
        counts = self.usage_counts.value
        counts = counts.at[index].add(1)
        self.usage_counts.value = counts

    def relu(self, x):
        self._track_usage(0)
        return jnp.maximum(0, x)

    def sigmoid(self, x):
        self._track_usage(1)
        return 1 / (1 + jnp.exp(-x))

    def tanh(self, x):
        self._track_usage(2)
        return jnp.tanh(x)

    def softplus(self, x):
        self._track_usage(3)
        return jnp.log(1 + jnp.exp(x))

    def identity(self, x):
        self._track_usage(4)
        return x

    def apply(self, index, x):
        return switch(
            index,
            [self.relu, self.sigmoid, self.tanh, self.softplus, self.identity],
            x
        )


selector = ActivationSelector()
activation_names = ['ReLU', 'Sigmoid', 'Tanh', 'Softplus', 'Identity']

# Test all activations
test_input = 2.0
print(f"Input: {test_input}\n")

for i in range(len(activation_names)):
    result = selector.apply(i, test_input)
    print(f"{activation_names[i]:10s}: {result:.4f}")

print(f"\nUsage counts: {selector.usage_counts.value}")

Input: 2.0

ReLU      : 2.0000
Sigmoid   : 0.8808
Tanh      : 0.9640
Softplus  : 2.1269
Identity  : 2.0000

Usage counts: [1 1 1 1 1]


In [60]:
# Example 3: Dynamic policy selection
class PolicySelector(brainstate.nn.Module):
    """Select different action policies based on state."""

    def __init__(self):
        super().__init__()
        self.total_reward = brainstate.ShortTermState(jnp.array(0.0))

    def aggressive_policy(self, state):
        action = state * 2.0
        reward = jnp.abs(action) * 0.5
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'aggressive'}

    def conservative_policy(self, state):
        action = state * 0.5
        reward = jnp.abs(action) * 1.0
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'conservative'}

    def random_policy(self, state):
        action = state * 1.0
        reward = jnp.abs(action) * 0.3
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'random'}

    def select_and_act(self, policy_index, state):
        return switch(
            policy_index,
            [self.aggressive_policy, self.conservative_policy, self.random_policy],
            state
        )


policy_selector = PolicySelector()

# Simulate decision-making over time
states = jnp.array([1.0, -0.5, 2.0, -1.5, 0.8])
policies = jnp.array([0, 1, 0, 1, 2], dtype=jnp.int32)  # policy choices

with jax.disable_jit():
    print("Simulation results:")
    for i, (policy_idx, state) in enumerate(zip(policies, states)):
        result = policy_selector.select_and_act(policy_idx, state)
        print(f"Step {i}: state={state:5.1f}, policy={result['policy']:12s}, "
              f"action={result['action']:5.2f}, reward={result['reward']:.2f}")

    print(f"\nTotal reward: {policy_selector.total_reward.value:.2f}")

Simulation results:
Step 0: state=  1.0, policy=aggressive  , action= 2.00, reward=1.00
Step 1: state= -0.5, policy=conservative, action=-0.25, reward=0.25
Step 2: state=  2.0, policy=aggressive  , action= 4.00, reward=2.00
Step 3: state= -1.5, policy=conservative, action=-0.75, reward=0.75
Step 4: state=  0.8, policy=random      , action= 0.80, reward=0.24

Total reward: 4.24


### 3.3 `ifelse`: Multi-Condition If/Elif/Else

`ifelse` provides a high-level interface for multi-condition branching, similar to Python's if/elif/else.

**Function signature:**
```python
ifelse(
    conditions,
    branches,
    *operands,
    check_cond: bool = True
)
```

**Parameters:**
- `conditions`: Sequence of boolean predicates (should be mutually exclusive)
- `branches`: Sequence of callables (same length as conditions)
- `*operands`: Arguments passed to the selected branch
- `check_cond`: If True, verify exactly one condition is True

**Common pattern:**
Make the last condition `True` to create a default/else branch:
```python
ifelse(
    [x > 10, x > 5, True],  # last condition is always True
    [large_fn, medium_fn, small_fn],
    x
)
```

In [61]:
# Example 1: Simple if/elif/else
def classify_number(x):
    def large():
        return "large"

    def medium():
        return "medium"

    def small():
        return "small"

    return ifelse(
        [x > 10, jnp.logical_and(x > 5, x <= 10), x <= 5],  # True acts as 'else'
        [large, medium, small]
    )


with jax.disable_jit():
    for value in [15.0, 7.0, 2.0, 10.5, 5.0]:
        category = classify_number(value)
        print(f"{value:5.1f} -> {category}")

 15.0 -> large
  7.0 -> medium
  2.0 -> small
 10.5 -> large
  5.0 -> small


In [62]:
# Example 2: Stateful grade calculator
class GradeCalculator(brainstate.nn.Module):
    """Calculate letter grades and track statistics."""

    def __init__(self):
        super().__init__()
        self.grade_counts = brainstate.ShortTermState({
            'A': jnp.array(0),
            'B': jnp.array(0),
            'C': jnp.array(0),
            'D': jnp.array(0),
            'F': jnp.array(0),
        })

    def _record_grade(self, letter):
        counts = self.grade_counts.value
        counts[letter] = counts[letter] + 1
        self.grade_counts.value = counts

    def grade_A(self):
        return self._record_grade('A')

    def grade_B(self):
        return self._record_grade('B')

    def grade_C(self):
        return self._record_grade('C')

    def grade_D(self):
        return self._record_grade('D')

    def grade_F(self):
        return self._record_grade('F')

    def calculate_grade(self, score):
        return ifelse(
            [
                score >= 90,
                jnp.logical_and(score >= 80, score < 90),
                jnp.logical_and(score >= 70, score < 80),
                jnp.logical_and(score >= 60, score < 70),
                score < 60
            ],
            [
                self.grade_A,
                self.grade_B,
                self.grade_C,
                self.grade_D,
                self.grade_F
            ]
        )


calculator = GradeCalculator()

# Process student scores
scores = jnp.array([95, 87, 76, 82, 59, 91, 68, 45, 88, 93])
grades = [calculator.calculate_grade(score) for score in scores]

print("\nGrade distribution:")
for letter, count in calculator.grade_counts.value.items():
    print(f"  {letter}: {'*' * int(count)}")


Grade distribution:
  A: ***
  B: ***
  C: *
  D: *
  F: **


## Summary

This tutorial covered all control flow primitives in `brainstate.transform`:

### Loop Transformations
- **`scan`**: Fundamental loop with carry and outputs
  - Use for: Recurrent patterns, accumulation, sequential processing
  - Collects outputs at each iteration
  - Key params: `reverse`, `unroll`, `pbar`
- **`checkpointed_scan`**: Memory-efficient scan with gradient checkpointing
  - Use for: Long sequences, memory constraints **during gradient computation**
  - **Key benefit**: Stores only O(log_base(n)) checkpoints instead of O(n) activations during backpropagation
  - Trades computation (recomputation during backward pass) for memory savings
  - Key param: `base` (checkpointing granularity)
- **`for_loop`**: Simplified loop without explicit carry
  - Use for: Simple iteration, state updates
  - **Important**: Return value at **each timestep is saved and stacked** into the final output array
  - Variadic inputs, no carry management
  - Output shape: stacks results along axis 0 (e.g., scalar→1D, vector→2D)
- **`checkpointed_for_loop`**: Memory-efficient for loop with gradient checkpointing
  - Combines simplicity of for_loop with memory efficiency **during gradient computation**
  - Essential for training with very long sequences
  - Same memory benefits as `checkpointed_scan`

### While Loops
- **`while_loop`**: Dynamic iteration with condition
  - Use for: Unknown iteration count, no gradients needed
  - Constraint: `cond_fun` must be read-only
- **`bounded_while_loop`**: While loop with maximum steps
  - Use for: Gradients, safety, predictable compilation
  - Key params: `max_steps`, `base`

### Conditional Control Flow
- **`cond`**: Binary conditional (if/else)
  - Use for: Two-way decisions
  - Lazy evaluation, state-safe
- **`switch`**: Multi-way branching (switch/case)
  - Use for: Multiple branches with integer index
  - Index clamping for safety
- **`ifelse`**: Multi-condition branching (if/elif/else)
  - Use for: Complex conditions, default branches
  - Use `True` for else branch

### Key Principles
1. **State Safety**: All APIs properly track state reads and writes
2. **Lazy Evaluation**: Conditionals only execute selected branches
3. **JAX Compatibility**: Compile to efficient JAX primitives
4. **Output Collection**: `for_loop` and `scan` collect outputs at each iteration into the final result
5. **Memory Efficiency**: Checkpointed variants save memory **during gradient computation** by storing only checkpoints and recomputing intermediate activations during backpropagation
6. **Differentiability**: Most APIs support gradients (except `while_loop`); checkpointed variants are essential for long sequences

These primitives enable complex control flow while maintaining BrainState's stateful programming model and JAX's performance benefits.