# Tutorial 12: Control Flow - Loops and Conditions

In this tutorial, we'll explore control flow primitives in BrainState for writing efficient loops and conditional operations that work with JIT compilation.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Use `scan` for efficient sequential computations
- Implement loops with `for_loop` and `while_loop`
- Write conditional logic with `cond` and `switch`
- Understand when to use each control flow primitive
- Optimize recurrent neural networks with scan
- Combine control flow with other transformations

## Why Special Control Flow?

Regular Python control flow (if/for/while) doesn't work well with:
- JIT compilation (can't trace dynamic behavior)
- Automatic differentiation (gradients through loops)
- Hardware acceleration (GPU/TPU execution)

BrainState provides functional control flow primitives that solve these problems.

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import time

# Set random seed
bst.random.seed(42)

## 1. The Problem with Regular Python Loops

In [None]:
# Example: Fibonacci sequence
def fibonacci_python(n):
    """Regular Python loop - doesn't JIT well."""
    a, b = 0, 1
    result = []
    for i in range(n):
        result.append(a)
        a, b = b, a + b
    return jnp.array(result)

# This works but can't be efficiently JIT compiled
fib_10 = fibonacci_python(10)
print("Fibonacci (Python loop):", fib_10)

# Try to JIT it - this will unroll the loop at compile time!
@jax.jit
def fibonacci_jit_bad(n):
    a, b = 0, 1
    result = []
    for i in range(n):  # n must be static!
        result.append(a)
        a, b = b, a + b
    return jnp.array(result)

# This creates a separate compiled version for each n
print("\nProblem: Each different n triggers recompilation")
print("We need better control flow primitives!")

## 2. Scan: The Swiss Army Knife of Loops

`scan` is the most versatile loop primitive. It iterates over a sequence, carrying state forward.

In [None]:
# Fibonacci with scan
def fibonacci_step(carry, x):
    """Single step: carry = (a, b), x is unused."""
    a, b = carry
    return (b, a + b), a  # new_carry, output

# Use scan
def fibonacci_scan(n):
    carry_init = (0, 1)
    xs = jnp.arange(n)  # Input sequence (we don't use values, just length)
    final_carry, outputs = jax.lax.scan(fibonacci_step, carry_init, xs)
    return outputs

fib_scan = fibonacci_scan(10)
print("Fibonacci (scan):", fib_scan)
print("Matches Python version:", jnp.allclose(fib_10, fib_scan))

# Scan is JIT-friendly and efficient!
fibonacci_scan_jit = jax.jit(fibonacci_scan)
print("\nJIT-compiled version works great:")
print(fibonacci_scan_jit(15))

### Understanding Scan

```python
carry, outputs = scan(f, carry_init, xs)
```

- **f(carry, x)**: Function that processes one element
  - Input: current carry, current element x
  - Output: (new_carry, output_value)
- **carry_init**: Initial state
- **xs**: Sequence to iterate over
- Returns: (final_carry, all_outputs)

### Example: Cumulative Sum with Scan

In [None]:
# Cumulative sum
def cumsum_step(carry, x):
    new_sum = carry + x
    return new_sum, new_sum  # carry is the running sum

data = jnp.array([1, 2, 3, 4, 5])
final_sum, cumulative = jax.lax.scan(cumsum_step, 0, data)

print("Input:", data)
print("Cumulative sum:", cumulative)
print("Final sum:", final_sum)
print("Compare to jnp.cumsum:", jnp.cumsum(data))
print("Match:", jnp.allclose(cumulative, jnp.cumsum(data)))

## 3. Scan for RNN Sequences

The most common use of scan: processing sequences through recurrent networks.

In [None]:
# Simple RNN with scan
class SimpleRNN(bst.graph.Node):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.Wxh = bst.ParamState(bst.random.randn(input_size, hidden_size) * 0.1)
        self.Whh = bst.ParamState(bst.random.randn(hidden_size, hidden_size) * 0.1)
        self.bh = bst.ParamState(jnp.zeros(hidden_size))
    
    def step(self, h, x):
        """Single RNN step."""
        h_new = jnp.tanh(
            x @ self.Wxh.value + 
            h @ self.Whh.value + 
            self.bh.value
        )
        return h_new, h_new  # (new_hidden, output)
    
    def __call__(self, sequence):
        """Process entire sequence."""
        h_init = jnp.zeros(self.hidden_size)
        final_h, all_h = jax.lax.scan(self.step, h_init, sequence)
        return all_h

# Create RNN and test
rnn = SimpleRNN(input_size=5, hidden_size=10)
sequence = bst.random.randn(20, 5)  # 20 time steps, 5 features

outputs = rnn(sequence)
print(f"Sequence shape: {sequence.shape}")
print(f"Outputs shape: {outputs.shape}")

# Visualize hidden states
plt.figure(figsize=(12, 4))
plt.imshow(outputs.T, aspect='auto', cmap='RdBu_r', interpolation='nearest')
plt.colorbar(label='Activation')
plt.xlabel('Time Step')
plt.ylabel('Hidden Unit')
plt.title('RNN Hidden States Over Time')
plt.tight_layout()
plt.show()

### Performance: Scan vs Python Loop

In [None]:
# Compare scan vs manual loop for RNN
long_sequence = bst.random.randn(1000, 5)

# Version 1: Manual loop
def rnn_manual_loop(rnn, sequence):
    h = jnp.zeros(rnn.hidden_size)
    outputs = []
    for x in sequence:
        h, out = rnn.step(h, x)
        outputs.append(out)
    return jnp.stack(outputs)

# Version 2: Scan
def rnn_scan(rnn, sequence):
    return rnn(sequence)

# Time them
start = time.time()
_ = rnn_manual_loop(rnn, long_sequence)
time_loop = time.time() - start

rnn_scan_jit = jax.jit(rnn_scan, static_argnums=0)
_ = rnn_scan_jit(rnn, long_sequence)  # Warmup
start = time.time()
_ = rnn_scan_jit(rnn, long_sequence)
time_scan = time.time() - start

print(f"Manual loop: {time_loop*1000:.2f} ms")
print(f"Scan (JIT):  {time_scan*1000:.2f} ms")
print(f"Speedup:     {time_loop/time_scan:.1f}x")

## 4. While Loop

`while_loop` executes until a condition becomes false.

In [None]:
# Example: Find first power of 2 >= target
def find_power_of_2(target):
    def cond_fn(carry):
        power, value = carry
        return value < target
    
    def body_fn(carry):
        power, value = carry
        return (power + 1, value * 2)
    
    init_val = (0, 1)
    final_power, final_value = jax.lax.while_loop(cond_fn, body_fn, init_val)
    return final_power, final_value

target = 1000
power, value = find_power_of_2(target)
print(f"Target: {target}")
print(f"First power of 2 >= target: 2^{power} = {value}")

# Works with JIT
find_power_of_2_jit = jax.jit(find_power_of_2)
power2, value2 = find_power_of_2_jit(500)
print(f"\nTarget: 500")
print(f"Result: 2^{power2} = {value2}")

### Example: Convergence Detection

In [None]:
# Iteratively refine until convergence
def sqrt_newton(x, tolerance=1e-6, max_iters=100):
    """Newton's method for square root."""
    def cond_fn(carry):
        guess, iteration = carry
        error = jnp.abs(guess * guess - x)
        return (error > tolerance) & (iteration < max_iters)
    
    def body_fn(carry):
        guess, iteration = carry
        # Newton update: x_{n+1} = (x_n + a/x_n) / 2
        new_guess = (guess + x / guess) / 2
        return (new_guess, iteration + 1)
    
    init_guess = x / 2  # Initial guess
    final_guess, num_iters = jax.lax.while_loop(cond_fn, body_fn, (init_guess, 0))
    return final_guess, num_iters

# Test
x = 2.0
result, iters = sqrt_newton(x)
print(f"sqrt({x}) ≈ {result} (converged in {iters} iterations)")
print(f"True value: {jnp.sqrt(x)}")
print(f"Error: {jnp.abs(result - jnp.sqrt(x))}")

## 5. For Loop

`for_loop` is similar to scan but with a simpler interface when you don't need outputs.

In [None]:
# For loop example: Matrix power
def matrix_power(A, n):
    """Compute A^n using for_loop."""
    def body_fn(i, carry):
        return carry @ A
    
    result = jax.lax.fori_loop(0, n, body_fn, A)
    return result

A = jnp.array([[1, 1], [1, 0]], dtype=float)  # Fibonacci matrix
A_power_5 = matrix_power(A, 5)

print("Matrix A (Fibonacci matrix):")
print(A)
print(f"\nA^5:")
print(A_power_5)
print(f"\nNote: A^n generates Fibonacci numbers!")
print(f"A_power_5[0,1] = {int(A_power_5[0,1])} (8th Fibonacci number)")

## 6. Conditional Execution: cond

`cond` provides efficient if-else branching.

In [None]:
# Simple conditional
def absolute_value(x):
    return jax.lax.cond(
        x >= 0,
        lambda x: x,      # true branch
        lambda x: -x,     # false branch
        x
    )

print(f"abs(5.0) = {absolute_value(5.0)}")
print(f"abs(-3.0) = {absolute_value(-3.0)}")

# Works with JIT
absolute_value_jit = jax.jit(absolute_value)
print(f"\nJIT version works: {absolute_value_jit(-7.0)}")

### Example: Piecewise Function

In [None]:
# Piecewise activation function
def piecewise_activation(x):
    """Custom activation: linear below 0, quadratic 0-1, constant above 1."""
    return jax.lax.cond(
        x < 0,
        lambda x: x,                                    # Linear: x
        lambda x: jax.lax.cond(
            x < 1,
            lambda x: x ** 2,                          # Quadratic: x²
            lambda x: jnp.ones_like(x),                # Constant: 1
            x
        ),
        x
    )

# Vectorize for plotting
piecewise_vec = jax.vmap(piecewise_activation)

x_vals = jnp.linspace(-2, 2, 200)
y_vals = piecewise_vec(x_vals)

plt.figure(figsize=(8, 5))
plt.plot(x_vals, y_vals, linewidth=2)
plt.axhline(0, color='k', linewidth=0.5)
plt.axvline(0, color='k', linewidth=0.5)
plt.axvline(1, color='r', linestyle='--', alpha=0.3, label='Transition points')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('Piecewise Activation Function')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

## 7. Switch: Multi-Way Branching

`switch` is like a switch/case statement.

In [None]:
# Example: Different operations based on index
def apply_operation(index, x):
    """Apply different operations based on index."""
    branches = [
        lambda x: x,              # 0: identity
        lambda x: x ** 2,         # 1: square
        lambda x: jnp.sqrt(jnp.abs(x)),  # 2: sqrt of abs
        lambda x: jnp.sin(x),     # 3: sine
    ]
    return jax.lax.switch(index, branches, x)

x = 2.0
for i in range(4):
    result = apply_operation(i, x)
    print(f"Operation {i} on {x}: {result}")

### Example: Activation Function Selection

In [None]:
# Dynamic activation function selection
def dynamic_activation(x, activation_id):
    """Select activation function at runtime."""
    activations = [
        lambda x: x,                    # 0: linear
        lambda x: jnp.maximum(0, x),    # 1: ReLU
        lambda x: jnp.tanh(x),          # 2: tanh
        lambda x: 1 / (1 + jnp.exp(-x)),  # 3: sigmoid
    ]
    return jax.lax.switch(activation_id, activations, x)

# Visualize all activations
x_vals = jnp.linspace(-3, 3, 100)
activation_names = ['Linear', 'ReLU', 'Tanh', 'Sigmoid']

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for i, (ax, name) in enumerate(zip(axes, activation_names)):
    y_vals = jax.vmap(lambda x: dynamic_activation(x, i))(x_vals)
    ax.plot(x_vals, y_vals, linewidth=2)
    ax.set_xlabel('x')
    ax.set_ylabel('f(x)')
    ax.set_title(f'{name} (id={i})')
    ax.grid(True, alpha=0.3)
    ax.axhline(0, color='k', linewidth=0.5)
    ax.axvline(0, color='k', linewidth=0.5)

plt.tight_layout()
plt.show()

## 8. Practical Example: Sequence Classification with Early Stopping

In [None]:
# RNN with early stopping based on confidence
class EarlyStopRNN(bst.graph.Node):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.rnn = SimpleRNN(input_size, hidden_size)
        self.classifier = bst.nn.Linear(hidden_size, num_classes)
        self.confidence_threshold = 0.95
    
    def __call__(self, sequence, use_early_stop=False):
        if not use_early_stop:
            # Normal processing
            h = self.rnn(sequence)  # All hidden states
            logits = self.classifier(h[-1])  # Use final state
            return logits, len(sequence)
        else:
            # Early stopping version
            def process_step(carry, x):
                h, stopped = carry
                
                # Update hidden state
                h_new, _ = self.rnn.step(h, x)
                
                # Check confidence
                logits = self.classifier(h_new)
                probs = jax.nn.softmax(logits)
                max_prob = jnp.max(probs)
                
                # Stop if confident
                should_stop = max_prob > self.confidence_threshold
                stopped_new = stopped | should_stop
                
                return (h_new, stopped_new), stopped_new
            
            h_init = jnp.zeros(self.rnn.hidden_size)
            (final_h, _), stop_flags = jax.lax.scan(
                process_step, 
                (h_init, False), 
                sequence
            )
            
            # Count steps until stop
            steps_used = jnp.sum(~stop_flags) + 1
            
            logits = self.classifier(final_h)
            return logits, steps_used

# Test
model = EarlyStopRNN(input_size=5, hidden_size=10, num_classes=3)
test_sequence = bst.random.randn(50, 5)

logits_full, steps_full = model(test_sequence, use_early_stop=False)
logits_early, steps_early = model(test_sequence, use_early_stop=True)

print(f"Full sequence: {steps_full} steps")
print(f"Early stopping: {steps_early} steps")
print(f"Savings: {(1 - steps_early/steps_full)*100:.1f}%")

## 9. Combining Control Flow Primitives

In [None]:
# Example: Adaptive learning with scan + cond
def adaptive_gradient_descent(loss_fn, params_init, learning_rate_init, n_steps):
    """Gradient descent with adaptive learning rate."""
    
    def step_fn(carry, _):
        params, lr, prev_loss = carry
        
        # Compute loss and gradient
        loss, grad = jax.value_and_grad(loss_fn)(params)
        
        # Adaptive learning rate
        lr_new = jax.lax.cond(
            loss < prev_loss,
            lambda lr: lr * 1.1,  # Increase if improving
            lambda lr: lr * 0.5,  # Decrease if getting worse
            lr
        )
        
        # Update parameters
        params_new = params - lr_new * grad
        
        return (params_new, lr_new, loss), loss
    
    # Run optimization
    carry_init = (params_init, learning_rate_init, float('inf'))
    (final_params, final_lr, _), loss_history = jax.lax.scan(
        step_fn,
        carry_init,
        jnp.arange(n_steps)
    )
    
    return final_params, loss_history

# Test on simple quadratic
def quadratic_loss(x):
    return (x - 5) ** 2

final_x, losses = adaptive_gradient_descent(
    quadratic_loss,
    params_init=0.0,
    learning_rate_init=0.1,
    n_steps=50
)

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Adaptive Gradient Descent')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final x: {final_x} (target: 5.0)")
print(f"Final loss: {quadratic_loss(final_x)}")

## 10. Performance Best Practices

In [None]:
# Comparison: Different loop strategies
def benchmark_loops():
    # Setup
    n = 100
    
    # 1. Python loop (slowest, doesn't JIT well)
    def python_loop(x):
        result = x
        for i in range(n):
            result = result * 0.99 + 0.01
        return result
    
    # 2. fori_loop (better)
    def fori_version(x):
        def body(i, val):
            return val * 0.99 + 0.01
        return jax.lax.fori_loop(0, n, body, x)
    
    # 3. scan (best for when you need intermediates)
    def scan_version(x):
        def step(carry, _):
            new_val = carry * 0.99 + 0.01
            return new_val, new_val
        final, _ = jax.lax.scan(step, x, jnp.arange(n))
        return final
    
    x = jnp.array(1.0)
    
    # JIT compile
    fori_jit = jax.jit(fori_version)
    scan_jit = jax.jit(scan_version)
    
    # Warmup
    _ = fori_jit(x)
    _ = scan_jit(x)
    
    # Benchmark
    times = {}
    
    start = time.time()
    for _ in range(1000):
        _ = python_loop(x)
    times['Python loop'] = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        _ = fori_jit(x).block_until_ready()
    times['fori_loop (JIT)'] = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        _ = scan_jit(x).block_until_ready()
    times['scan (JIT)'] = (time.time() - start) / 1000
    
    return times

times = benchmark_loops()
print("Average time per iteration (1000 runs):")
for name, t in times.items():
    print(f"{name:20s}: {t*1e6:.2f} μs")

print("\nKey takeaways:")
print("  - Use scan when you need all intermediate values")
print("  - Use fori_loop when you only need final result")
print("  - Avoid Python loops in JIT-compiled code")
print("  - Always combine with JIT for best performance")

## Summary

In this tutorial, we covered:

1. **scan**: Most versatile loop primitive for sequential operations
2. **while_loop**: Condition-based iteration
3. **for_loop (fori_loop)**: Fixed-count iteration
4. **cond**: Efficient if-else branching
5. **switch**: Multi-way branching
6. **RNN Processing**: Using scan for recurrent networks
7. **Early Stopping**: Combining loops with conditions
8. **Adaptive Algorithms**: Composing control flow primitives
9. **Performance**: Best practices for each primitive

## Key Takeaways

- **Python control flow doesn't work well with JIT**
- Use **scan** for RNNs and sequential processing
- Use **while_loop** for convergence/iteration problems
- Use **fori_loop** when you only need final result
- Use **cond/switch** instead of if/elif/else
- All primitives are **JIT-compatible and differentiable**
- Combine primitives for complex control flow

## When to Use What

| Primitive | Use Case | Example |
|-----------|----------|----------|
| `scan` | Sequential with state | RNN, cumulative sum |
| `fori_loop` | Fixed iteration count | Matrix power, warmup |
| `while_loop` | Convergence/condition | Newton's method, search |
| `cond` | If-else branching | Activation functions |
| `switch` | Multi-way branching | Operation selection |

## Next Steps

In the next tutorial, we'll explore:
- **Other advanced transformations**
- Gradient checkpointing (remat)
- Abstract initialization
- Computation tracing with make_jaxpr
- Progress bars for long computations