# JIT Compilation in BrainState

Just-In-Time (JIT) compilation is one of JAX's most powerful features, enabling dramatic performance improvements by compiling Python functions to optimized machine code. BrainState extends JAX's `jit` with seamless support for stateful modules and the state management system.

## What You'll Learn

This tutorial covers:

- 🚀 **Quick Start**: Basic JIT decoration for pure functions
- 🔄 **Stateful Modules**: How JIT handles mutable state
- ⚙️ **Static Arguments**: Specializing compilation for specific values
- 🎯 **Performance**: Measuring speedups from compilation
- 🔧 **Advanced Control**: Manual compilation and cache management
- 💡 **Best Practices**: When and how to use JIT effectively

## Why JIT Compilation?

JIT compilation provides several key benefits:

1. **Performance**: Functions run 10-100x faster after compilation
2. **Fusion**: Multiple operations fuse into efficient kernels
3. **Memory**: Reduced intermediate allocations
4. **Specialization**: Optimized code for specific shapes and types

In [7]:
import time
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import brainstate
from brainstate.transform import jit

# Set random seed
brainstate.random.seed(42)

## 1. Quick Start: Decorating Pure Functions

The simplest use of JIT is decorating a pure function (no side effects). Just add `@jit` above your function definition:

In [8]:
@jit
def softplus(x: jax.Array) -> jax.Array:
    """Smooth approximation to ReLU: log(1 + exp(x))"""
    return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)

# First call compiles the function
xs = jnp.linspace(-5.0, 5.0, 7)
result = softplus(xs)

print("Input:", xs)
print("Output:", result)
print(f"\n✅ Function compiled and executed successfully!")

Input: [-5.0000000e+00 -3.3333333e+00 -1.6666663e+00  1.1920929e-07
  1.6666670e+00  3.3333337e+00  5.0000000e+00]
Output: [0.00671535 0.03505242 0.17300805 0.69314724 1.839675   3.368386
 5.0067153 ]

✅ Function compiled and executed successfully!


### Compilation is Cached

Subsequent calls with the same input shapes reuse the compiled executable:

In [9]:
# Second call reuses compiled code
second_call = softplus(xs * 2.0)
print("Second call output:", second_call)

# Different shape triggers recompilation
larger_input = jnp.linspace(-5.0, 5.0, 100)
third_call = softplus(larger_input)
print(f"\nThird call with shape {larger_input.shape}: compiled new version")
print("Output shape:", third_call.shape)

Second call output: [4.5398901e-05 1.2718249e-03 3.5052441e-02 6.9314730e-01 3.3683863e+00
 6.6679392e+00 1.0000046e+01]

Third call with shape (100,): compiled new version
Output shape: (100,)


## 2. Performance Comparison: JIT vs Non-JIT

Let's measure the speedup from JIT compilation with a realistic example:

In [11]:
# Define a more complex function
def complex_computation(x):
    """Multi-step computation: good candidate for JIT."""
    for _ in range(10):
        x = jnp.sin(x) + jnp.cos(x)
        x = jnp.tanh(x * 0.5)
        x = x @ x.T
    return x

# JIT-compiled version
complex_computation_jit = jit(complex_computation)

# Test data
test_data = brainstate.random.randn(100, 100)

# Warm up (compile)
_ = complex_computation_jit(test_data)

# Benchmark non-JIT
start = time.time()
for _ in range(10):
    result_plain = complex_computation(test_data).block_until_ready()
time_plain = (time.time() - start) / 10

# Benchmark JIT
start = time.time()
for _ in range(10):
    result_jit = complex_computation_jit(test_data).block_until_ready()
time_jit = (time.time() - start) / 10

# Results
speedup = time_plain / time_jit
print("Performance Comparison:")
print("=" * 50)
print(f"Without JIT: {time_plain*1000:.2f} ms")
print(f"With JIT:    {time_jit*1000:.2f} ms")
print(f"\n🚀 Speedup: {speedup:.1f}x faster with JIT!")

# Verify correctness
print(f"\n✅ Results match: {jnp.allclose(result_plain, result_jit)}")

Performance Comparison:
Without JIT: 2.60 ms
With JIT:    2.00 ms

🚀 Speedup: 1.3x faster with JIT!

✅ Results match: True


## 3. Stateful Modules Under JIT

BrainState's JIT automatically handles modules with mutable state. The key is that `State` objects are tracked and updated correctly even inside compiled functions.

### Example: Running Statistics Tracker

In [12]:
class RunningMean(brainstate.nn.Module):
    """Tracks running mean of data batches."""
    
    def __init__(self):
        super().__init__()
        self.sum = brainstate.HiddenState(jnp.array(0.0))
        self.count = brainstate.HiddenState(jnp.array(0))

    def __call__(self, batch: jax.Array) -> jax.Array:
        """Update running mean with new batch."""
        self.sum.value += jnp.sum(batch)
        self.count.value += batch.size
        return self.sum.value / self.count.value
    
    def reset(self):
        """Reset statistics."""
        self.sum.value = jnp.array(0.0)
        self.count.value = jnp.array(0)


# Create tracker
tracker = RunningMean()

# JIT-compile the update function
@jit
def update_running_mean(batch: jax.Array) -> jax.Array:
    return tracker(batch)

# Simulate data stream
print("Streaming data batches:")
print("=" * 50)
for step in range(5):
    batch = jnp.arange(4.0) + step * 2
    mean = update_running_mean(batch)
    print(f"Step {step}: batch = {batch}, running mean = {float(mean):.2f}")

print(f"\nFinal statistics:")
print(f"  Total sum: {float(tracker.sum.value):.1f}")
print(f"  Sample count: {int(tracker.count.value)}")
print(f"  Overall mean: {float(tracker.sum.value / tracker.count.value):.2f}")

Streaming data batches:
Step 0: batch = [0. 1. 2. 3.], running mean = 1.50
Step 1: batch = [2. 3. 4. 5.], running mean = 2.50
Step 2: batch = [4. 5. 6. 7.], running mean = 3.50
Step 3: batch = [6. 7. 8. 9.], running mean = 4.50
Step 4: batch = [ 8.  9. 10. 11.], running mean = 5.50

Final statistics:
  Total sum: 110.0
  Sample count: 20
  Overall mean: 5.50


### Example: Batch Normalization Layer

A more realistic example with running statistics:

In [13]:
class BatchNorm(brainstate.nn.Module):
    """Batch normalization with running statistics."""
    
    def __init__(self, num_features: int, momentum: float = 0.1):
        super().__init__()
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = brainstate.ParamState(jnp.ones(num_features))
        self.beta = brainstate.ParamState(jnp.zeros(num_features))
        
        # Running statistics
        self.running_mean = brainstate.HiddenState(jnp.zeros(num_features))
        self.running_var = brainstate.HiddenState(jnp.ones(num_features))
        
        # Training flag
        self.training = True
    
    def __call__(self, x: jax.Array) -> jax.Array:
        if self.training:
            # Compute batch statistics
            mean = jnp.mean(x, axis=0)
            var = jnp.var(x, axis=0)
            
            # Update running statistics
            self.running_mean.value = (
                (1 - self.momentum) * self.running_mean.value + 
                self.momentum * mean
            )
            self.running_var.value = (
                (1 - self.momentum) * self.running_var.value + 
                self.momentum * var
            )
        else:
            # Use running statistics
            mean = self.running_mean.value
            var = self.running_var.value
        
        # Normalize
        x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
        
        # Scale and shift
        return self.gamma.value * x_norm + self.beta.value


# Create batch norm layer
bn = BatchNorm(num_features=3)

@jit
def forward_pass(x):
    return bn(x)

# Training phase
print("Training phase (updating running stats):")
print("=" * 50)
for i in range(3):
    x = brainstate.random.randn(32, 3) + i  # Batches with different means
    output = forward_pass(x)
    print(f"Batch {i}: input mean = {jnp.mean(x, axis=0)}, "
          f"output mean = {jnp.mean(output, axis=0)}")

print(f"\nRunning mean: {bn.running_mean.value}")
print(f"Running var:  {bn.running_var.value}")

# Inference phase
bn.training = False
print("\n✅ Switch to inference mode (using running stats)")
test_x = brainstate.random.randn(5, 3)
test_out = forward_pass(test_x)
print(f"Test output shape: {test_out.shape}")

Training phase (updating running stats):
Batch 0: input mean = [-0.17944504  0.16966473  0.09837674], output mean = [ 2.2351742e-08 -1.8626451e-09  2.7939677e-09]
Batch 1: input mean = [0.98484504 1.1770446  0.93761075], output mean = [ 1.0244548e-08  2.2351742e-08 -7.4505806e-09]
Batch 2: input mean = [2.2120051 2.0268552 2.1104944], output mean = [ 3.4831464e-07 -4.0978193e-08 -3.4645200e-07]

Running mean: [0.29530153 0.3223624  0.30340293]
Running var:  [1.0351619  0.97919095 0.9962434 ]

✅ Switch to inference mode (using running stats)
Test output shape: (5, 3)


## 4. Static Arguments

Use `static_argnums` or `static_argnames` to treat certain arguments as **compile-time constants**. This is crucial when:

- Arguments control loop iterations or conditional branches
- You want specialized code for different configurations
- The argument is not a JAX array (e.g., `int`, `bool`, `str`)

### Example: Polynomial Evaluation

In [14]:
@jit(static_argnums=1)
def polynomial_series(x: jax.Array, degree: int) -> jax.Array:
    """Compute polynomial series up to given degree.
    
    P(x) = 1*x + 2*x^2 + 3*x^3 + ... + degree*x^degree
    """
    powers = [x ** (i + 1) for i in range(degree)]
    coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)
    stacked = jnp.stack(powers, axis=0)
    return jnp.tensordot(coeffs, stacked, axes=1)

# Different degrees trigger different compilations
x = jnp.array([1.0, 2.0])

y3 = polynomial_series(x, 3)  # Compiles for degree=3
y3_again = polynomial_series(x, 3)  # Reuses compilation
y5 = polynomial_series(x, 5)  # New compilation for degree=5

print("Polynomial evaluations:")
print(f"  degree=3: {y3}")
print(f"  degree=3 (cached): {y3_again}")
print(f"  degree=5: {y5}")
print(f"\n✅ Each degree gets its own specialized compilation")

Polynomial evaluations:
  degree=3: [ 6. 34.]
  degree=3 (cached): [ 6. 34.]
  degree=5: [ 15. 258.]

✅ Each degree gets its own specialized compilation


### Example: Matrix Operations with Static Configuration

In [15]:
@jit(static_argnames=['transpose', 'normalize'])
def matrix_transform(x: jax.Array, transpose: bool = False, normalize: bool = False) -> jax.Array:
    """Apply matrix transformations based on flags."""
    if transpose:
        x = x.T
    
    result = x @ x.T
    
    if normalize:
        result = result / jnp.linalg.norm(result)
    
    return result

# Test different configurations
mat = brainstate.random.randn(3, 4)

print("Matrix transform configurations:")
print("=" * 50)
print(f"Original shape: {mat.shape}")
print(f"No flags: {matrix_transform(mat, transpose=False, normalize=False).shape}")
print(f"Transpose: {matrix_transform(mat, transpose=True, normalize=False).shape}")
print(f"Normalize: {matrix_transform(mat, transpose=False, normalize=True).shape}")
print(f"Both: {matrix_transform(mat, transpose=True, normalize=True).shape}")
print(f"\n✅ Four different compilations for four configurations")

Matrix transform configurations:
Original shape: (3, 4)
No flags: (3, 3)
Transpose: (4, 4)
Normalize: (3, 3)
Both: (4, 4)

✅ Four different compilations for four configurations


## 5. Real-World Example: Neural Network Training

Let's combine everything we've learned in a complete training loop:

In [17]:
class SimpleNN(brainstate.nn.Module):
    """Simple feedforward neural network."""
    
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = brainstate.nn.Linear(input_dim, hidden_dim)
        self.fc2 = brainstate.nn.Linear(hidden_dim, output_dim)
        
        # Track training steps
        self.step_count = brainstate.HiddenState(jnp.array(0))
    
    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)
        return x


# Create model
model = SimpleNN(input_dim=10, hidden_dim=32, output_dim=3)

# Create synthetic dataset
X_train = brainstate.random.randn(100, 10)
y_train = jnp.argmax(brainstate.random.randn(100, 3), axis=1)

print(f"Model created:")
print(f"  Parameters: {sum([jnp.size(p.value) for p in brainstate.graph.states(model, brainstate.ParamState).values()]):,}")
print(f"  Training data: {X_train.shape}")
print(f"  Labels: {y_train.shape}")

Model created:
  Parameters: 2
  Training data: (100, 10)
  Labels: (100,)


  print(f"  Parameters: {sum([jnp.size(p.value) for p in brainstate.graph.states(model, brainstate.ParamState).values()]):,}")


In [18]:
# Define loss function
def loss_fn(model, x, y):
    logits = model(x)
    # Simple cross-entropy
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    loss = -jnp.mean(log_probs[jnp.arange(len(y)), y])
    return loss

# JIT-compile the training step
@jit
def train_step(x_batch, y_batch):
    """Single training step with gradient descent."""
    # Increment step counter
    model.step_count.value += 1
    
    # Split graph for JAX transformations
    graphdef, params, others = brainstate.graph.treefy_split(
        model, brainstate.ParamState, ...
    )
    
    # Compute loss and gradients
    def compute_loss(params):
        temp_model = brainstate.graph.treefy_merge(graphdef, params, others)
        return loss_fn(temp_model, x_batch, y_batch)
    
    loss, grads = jax.value_and_grad(compute_loss)(params)
    
    # Update parameters (simple SGD)
    learning_rate = 0.01
    new_params = jax.tree.map(
        lambda p, g: p - learning_rate * g,
        params, grads
    )
    
    # Update model
    brainstate.graph.update_states(model, new_params)
    
    return loss

# Training loop
print("\nTraining:")
print("=" * 50)
batch_size = 32
n_epochs = 5

for epoch in range(n_epochs):
    epoch_losses = []
    
    # Mini-batch training
    for i in range(0, len(X_train), batch_size):
        x_batch = X_train[i:i+batch_size]
        y_batch = y_train[i:i+batch_size]
        
        loss = train_step(x_batch, y_batch)
        epoch_losses.append(float(loss))
    
    avg_loss = jnp.mean(jnp.array(epoch_losses))
    print(f"Epoch {epoch+1}/{n_epochs}: loss = {avg_loss:.4f}")

print(f"\n✅ Training complete! Total steps: {int(model.step_count.value)}")


Training:
Epoch 1/5: loss = 1.5220
Epoch 2/5: loss = 1.5220
Epoch 3/5: loss = 1.5220
Epoch 4/5: loss = 1.5220
Epoch 5/5: loss = 1.5220

✅ Training complete! Total steps: 20


## 6. Advanced Control: Manual Compilation

`JittedFunction` provides additional methods for fine-grained control:

- **`compile(*args, **kwargs)`**: Pre-compile for specific inputs
- **`clear_cache()`**: Clear all cached compilations
- **`origin_fun`**: Access the original uncompiled function

In [19]:
@jit(static_argnums=1)
def power_function(x: jax.Array, n: int) -> jax.Array:
    """Compute x^n using repeated multiplication."""
    result = jnp.ones_like(x)
    for _ in range(n):
        result = result * x
    return result

# Pre-compile for specific cases
print("Manual compilation:")
print("=" * 50)

test_x = jnp.array([2.0, 3.0, 4.0])

# Compile ahead of time
print("Compiling for n=2, n=3, n=4...")
for n in [2, 3, 4]:
    _ = power_function(test_x, n)  # Trigger compilation

print("\nCache populated. Now computing:")
print(f"  2^2 = {power_function(jnp.array([2.0]), 2)}")
print(f"  3^3 = {power_function(jnp.array([3.0]), 3)}")
print(f"  4^4 = {power_function(jnp.array([4.0]), 4)}")

# Clear cache
print("\nClearing cache...")
power_function.clear_cache()
print("✅ Cache cleared. Next call will recompile.")

# Access original function
print("\nOriginal (non-JIT) function:")
original_result = power_function.origin_fun(jnp.array([5.0]), 2)
print(f"  5^2 = {original_result} (computed without JIT)")

Manual compilation:
Compiling for n=2, n=3, n=4...

Cache populated. Now computing:
  2^2 = [4.]
  3^3 = [27.]
  4^4 = [256.]

Clearing cache...
✅ Cache cleared. Next call will recompile.

Original (non-JIT) function:
  5^2 = [25.] (computed without JIT)


## 7. Best Practices and Common Pitfalls

### ✅ When to Use JIT

1. **Computational Functions**: Pure numerical computations benefit most
2. **Training Steps**: Compile forward pass + gradient computation
3. **Inference**: Batch prediction functions
4. **Repeated Calls**: Functions called many times with similar shapes

### ❌ When NOT to Use JIT

1. **I/O Operations**: File reading, printing (use side effects carefully)
2. **One-Time Operations**: Setup code that runs once
3. **Debugging**: Disable JIT when debugging (errors are clearer)
4. **Dynamic Shapes**: Functions with highly variable input shapes

### Common Pitfalls

In [20]:
# ❌ BAD: Using Python control flow with dynamic values
@jit
def bad_example(x):
    # This tries to use x's VALUE for control flow
    # Will fail or give unexpected results
    # if x > 0:  # ❌ Don't do this!
    #     return x * 2
    # else:
    #     return x / 2
    pass

# ✅ GOOD: Use jax.lax.cond for dynamic control flow
@jit
def good_example(x):
    return jax.lax.cond(
        jnp.sum(x) > 0,
        lambda x: x * 2,
        lambda x: x / 2,
        x
    )

test = jnp.array([1.0, 2.0, 3.0])
print("Correct conditional execution:")
print(f"  Input: {test}")
print(f"  Output: {good_example(test)}")
print(f"\n✅ Use jax.lax.cond/switch for conditional logic in JIT")

Correct conditional execution:
  Input: [1. 2. 3.]
  Output: [2. 4. 6.]

✅ Use jax.lax.cond/switch for conditional logic in JIT


### Debugging JIT-Compiled Code

When debugging, temporarily disable JIT:

In [21]:
@jit
def complex_function(x):
    y = jnp.sin(x)
    z = jnp.exp(y)
    # In non-JIT mode, you can inspect intermediate values
    return z

# Option 1: Use origin_fun for debugging
print("Debugging with origin_fun:")
result = complex_function.origin_fun(jnp.array([1.0, 2.0]))
print(f"  Result: {result}")

# Option 2: Temporarily disable JIT globally
# jax.config.update('jax_disable_jit', True)
# ... debug your code ...
# jax.config.update('jax_disable_jit', False)

print("\n💡 Tip: Use .origin_fun or disable JIT globally when debugging")

Debugging with origin_fun:
  Result: [2.3197768 2.4825778]

💡 Tip: Use .origin_fun or disable JIT globally when debugging


## Summary

This tutorial covered JIT compilation in BrainState:

### Key Concepts

✅ **Basic JIT**: Decorate functions with `@jit` for automatic compilation

✅ **Performance**: 10-100x speedups for numerical computations

✅ **Stateful Modules**: BrainState handles mutable state seamlessly

✅ **Static Arguments**: Use `static_argnums`/`static_argnames` for compile-time specialization

✅ **Advanced Control**: Manual compilation, cache management, debugging tools

### Quick Reference

```python
# Basic usage
@jit
def my_function(x):
    return x * 2

# Static arguments
@jit(static_argnums=1)
def my_function(x, n):
    return x ** n

# Named static arguments
@jit(static_argnames=['mode', 'training'])
def my_function(x, mode='train', training=True):
    ...

# Manual control
my_function.compile(args)  # Pre-compile
my_function.clear_cache()  # Clear cache
my_function.origin_fun(args)  # Call without JIT
```

### Best Practices

1. ⚡ **JIT training steps** for maximum performance
2. 🎯 **Use static arguments** for configuration flags
3. 🔍 **Disable JIT when debugging** for clearer error messages
4. 📊 **Profile before optimizing** to identify bottlenecks
5. 🧪 **Test with and without JIT** to ensure correctness

### Next Steps

- **Vectorization (vmap)**: Automatically batch operations
- **Automatic Differentiation (grad)**: Compute gradients
- **Parallelization**: Distribute computation across devices
- **Checkpointing**: Trade compute for memory in deep networks