# Tutorial 25: Performance Optimization Guide

In this tutorial, we'll explore techniques to optimize BrainState models for maximum performance.

## Learning Objectives

By the end of this tutorial, you will:
- Profile and analyze model performance
- Optimize memory usage
- Speed up computation with JAX transformations
- Implement parallel processing strategies
- Understand JIT compilation best practices
- Reduce memory footprint
- Apply advanced optimization techniques

## Introduction

Performance optimization is crucial for:
- **Training speed**: Faster iterations and experimentation
- **Memory efficiency**: Handling larger models and datasets
- **Inference latency**: Real-time applications
- **Scalability**: Deploying to production systems

BrainState/JAX provides powerful tools for optimization through functional transformations.

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import Dict, Callable, Any, Tuple
from functools import partial

bst.random.seed(42)

## 1. Performance Profiling

### 1.1 Basic Timing

In [None]:
class Timer:
    """Simple context manager for timing code blocks."""
    
    def __init__(self, name: str = "Block"):
        self.name = name
        self.start_time = None
        self.elapsed = None
    
    def __enter__(self):
        self.start_time = time.time()
        return self
    
    def __exit__(self, *args):
        self.elapsed = time.time() - self.start_time
        print(f"{self.name}: {self.elapsed:.4f} seconds")

# Example usage
def expensive_operation():
    x = bst.random.randn(1000, 1000)
    return jnp.linalg.svd(x)

with Timer("SVD computation"):
    result = expensive_operation()

### 1.2 Benchmarking Functions

In [None]:
def benchmark(func: Callable, *args, num_runs: int = 100, warmup: int = 10, **kwargs) -> Dict[str, float]:
    """
    Benchmark a function.
    
    Args:
        func: Function to benchmark
        *args: Positional arguments
        num_runs: Number of runs for benchmarking
        warmup: Number of warmup runs
        **kwargs: Keyword arguments
        
    Returns:
        Dictionary with timing statistics
    """
    # Warmup
    for _ in range(warmup):
        _ = func(*args, **kwargs)
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        start = time.time()
        _ = func(*args, **kwargs)
        times.append(time.time() - start)
    
    times = np.array(times)
    return {
        'mean': np.mean(times),
        'std': np.std(times),
        'min': np.min(times),
        'max': np.max(times),
        'median': np.median(times),
        'p95': np.percentile(times, 95),
        'p99': np.percentile(times, 99),
    }

# Example: Compare matrix multiplication implementations
def matmul_numpy(a, b):
    return np.matmul(a, b)

def matmul_jax(a, b):
    return jnp.matmul(a, b)

a_np = np.random.randn(500, 500)
b_np = np.random.randn(500, 500)
a_jax = jnp.array(a_np)
b_jax = jnp.array(b_np)

print("Matrix Multiplication Benchmark (500x500)")
print("=" * 60)

numpy_stats = benchmark(matmul_numpy, a_np, b_np, num_runs=50)
print(f"NumPy:  {numpy_stats['mean']*1000:.2f} ± {numpy_stats['std']*1000:.2f} ms")

jax_stats = benchmark(matmul_jax, a_jax, b_jax, num_runs=50)
print(f"JAX:    {jax_stats['mean']*1000:.2f} ± {jax_stats['std']*1000:.2f} ms")

speedup = numpy_stats['mean'] / jax_stats['mean']
print(f"\nSpeedup: {speedup:.2f}x")

### 1.3 Profiling Model Forward Pass

In [None]:
class SimpleModel(bst.graph.Node):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x):
        x = jax.nn.relu(self.fc1(x))
        x = jax.nn.relu(self.fc2(x))
        return self.fc3(x)

# Create model
model = SimpleModel(784, 512, 10)
x = bst.random.randn(128, 784)
_ = model(x)  # Initialize

# Profile different configurations
print("Model Forward Pass Profiling")
print("=" * 60)

batch_sizes = [1, 16, 32, 64, 128, 256]
results = []

for bs in batch_sizes:
    x_test = bst.random.randn(bs, 784)
    stats = benchmark(model, x_test, num_runs=50)
    results.append(stats['mean'])
    throughput = bs / stats['mean']
    print(f"Batch size {bs:3d}: {stats['mean']*1000:6.2f} ms, "
          f"Throughput: {throughput:8.1f} samples/sec")

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(batch_sizes, [r*1000 for r in results], 'o-', linewidth=2)
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Latency (ms)')
ax1.set_title('Forward Pass Latency')
ax1.grid(True, alpha=0.3)

throughputs = [bs/t for bs, t in zip(batch_sizes, results)]
ax2.plot(batch_sizes, throughputs, 's-', linewidth=2, color='green')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Throughput (samples/sec)')
ax2.set_title('Throughput vs Batch Size')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. JIT Compilation Optimization

### 2.1 JIT vs Non-JIT Comparison

In [None]:
print("JIT Compilation Impact")
print("=" * 60)

# Create a moderately complex function
def complex_computation(x):
    for _ in range(10):
        x = jnp.tanh(x)
        x = jnp.matmul(x, x.T)
        x = jnp.maximum(x, 0)
    return jnp.sum(x)

# JIT version
complex_computation_jit = jax.jit(complex_computation)

# Test data
x_test = bst.random.randn(100, 100)

# Benchmark
print("\nBenchmarking (100x100 matrix):")
no_jit_stats = benchmark(complex_computation, x_test, num_runs=20)
print(f"Without JIT: {no_jit_stats['mean']*1000:.2f} ± {no_jit_stats['std']*1000:.2f} ms")

jit_stats = benchmark(complex_computation_jit, x_test, num_runs=20)
print(f"With JIT:    {jit_stats['mean']*1000:.2f} ± {jit_stats['std']*1000:.2f} ms")

speedup = no_jit_stats['mean'] / jit_stats['mean']
print(f"\nSpeedup: {speedup:.1f}x")
print(f"\n⚡ JIT compilation can provide {speedup:.1f}x speedup!")

### 2.2 Static vs Dynamic Arguments

In [None]:
print("Static vs Dynamic Arguments in JIT")
print("=" * 60)

# Function with both static and dynamic args
def power_sum(x, n):
    """Compute sum of x raised to power n."""
    return jnp.sum(x ** n)

# WRONG: Treating n as dynamic causes recompilation
power_sum_jit_wrong = jax.jit(power_sum)

# CORRECT: Mark n as static
power_sum_jit_correct = jax.jit(power_sum, static_argnums=(1,))

x = bst.random.randn(1000, 1000)

print("\nWrong approach (n as dynamic):")
for n in [2, 3, 4]:
    with Timer(f"  n={n}"):
        _ = power_sum_jit_wrong(x, n)
        # Each call with different n causes recompilation!

print("\nCorrect approach (n as static):")
for n in [2, 3, 4]:
    with Timer(f"  n={n}"):
        _ = power_sum_jit_correct(x, n)
        # Recompiles once per unique n value, then caches

print("\n✓ Use static_argnums for arguments that determine control flow")

### 2.3 JIT Compilation Best Practices

In [None]:
print("JIT Best Practices")
print("=" * 60)

# 1. JIT the outermost function
print("\n1. JIT the outermost function:")

# ❌ BAD: JIT inner functions
def bad_approach(x):
    @jax.jit
    def inner1(y):
        return jnp.sin(y)
    
    @jax.jit
    def inner2(y):
        return jnp.cos(y)
    
    return inner1(x) + inner2(x)

# ✓ GOOD: JIT the outer function
@jax.jit
def good_approach(x):
    def inner1(y):
        return jnp.sin(y)
    
    def inner2(y):
        return jnp.cos(y)
    
    return inner1(x) + inner2(x)

x = bst.random.randn(1000)
bad_stats = benchmark(bad_approach, x, num_runs=50)
good_stats = benchmark(good_approach, x, num_runs=50)

print(f"Bad approach:  {bad_stats['mean']*1000:.3f} ms")
print(f"Good approach: {good_stats['mean']*1000:.3f} ms")
print(f"Speedup: {bad_stats['mean']/good_stats['mean']:.1f}x")

# 2. Avoid Python control flow inside JIT
print("\n2. Avoid Python control flow (use JAX control flow):")
print("❌ BAD: if/for with dynamic values")
print("✓ GOOD: jax.lax.cond, jax.lax.fori_loop")

# 3. Minimize data transfer
print("\n3. Minimize host-device data transfer:")
print("✓ Keep data on device as long as possible")
print("✓ Use jnp arrays, not numpy arrays inside JIT")

## 3. Memory Optimization

### 3.1 Memory Profiling

In [None]:
def estimate_model_memory(model: bst.graph.Node) -> Dict[str, float]:
    """
    Estimate memory usage of a model.
    
    Returns:
        Dictionary with memory estimates in MB
    """
    params = model.states(bst.ParamState)
    
    # Calculate parameter memory
    param_memory = sum(p.value.nbytes for p in params.values())
    
    # Calculate state memory (non-param states)
    all_states = model.states()
    state_memory = sum(
        s.value.nbytes for s in all_states.values() 
        if s not in params.values()
    )
    
    total_memory = param_memory + state_memory
    
    return {
        'parameters_mb': param_memory / 1024 / 1024,
        'states_mb': state_memory / 1024 / 1024,
        'total_mb': total_memory / 1024 / 1024,
        'num_parameters': sum(p.value.size for p in params.values())
    }

# Analyze model memory
model = SimpleModel(784, 512, 10)
x = bst.random.randn(1, 784)
_ = model(x)

mem_info = estimate_model_memory(model)
print("Model Memory Usage")
print("=" * 60)
print(f"Parameters:        {mem_info['parameters_mb']:.2f} MB")
print(f"States:            {mem_info['states_mb']:.2f} MB")
print(f"Total:             {mem_info['total_mb']:.2f} MB")
print(f"Number of params:  {mem_info['num_parameters']:,}")

### 3.2 Gradient Checkpointing

In [None]:
print("Gradient Checkpointing for Memory Efficiency")
print("=" * 60)

# Without checkpointing: stores all intermediate activations
def deep_network_no_checkpoint(x, weights):
    """Deep network without checkpointing."""
    for w in weights:
        x = jnp.tanh(jnp.matmul(x, w))
    return jnp.sum(x)

# With checkpointing: recomputes activations during backward pass
def deep_network_with_checkpoint(x, weights):
    """Deep network with checkpointing."""
    for w in weights:
        # Use checkpoint to save memory
        x = jax.checkpoint(lambda y, w: jnp.tanh(jnp.matmul(y, w)))(x, w)
    return jnp.sum(x)

# Create deep network (many layers)
num_layers = 50
layer_size = 100
weights = [bst.random.randn(layer_size, layer_size) * 0.01 for _ in range(num_layers)]
x_input = bst.random.randn(10, layer_size)

print(f"\nDeep network: {num_layers} layers, {layer_size}x{layer_size} each")
print("\nWithout checkpointing:")
print("  Memory: High (stores all activations)")
print("  Speed: Fast (forward and backward)")

print("\nWith checkpointing:")
print("  Memory: Low (recomputes activations)")
print("  Speed: Slower (recomputation overhead)")

# Benchmark gradient computation
grad_no_cp = jax.grad(lambda x: deep_network_no_checkpoint(x, weights))
grad_with_cp = jax.grad(lambda x: deep_network_with_checkpoint(x, weights))

print("\nGradient computation time:")
stats_no_cp = benchmark(grad_no_cp, x_input, num_runs=10)
print(f"  No checkpoint:   {stats_no_cp['mean']*1000:.2f} ms")

stats_cp = benchmark(grad_with_cp, x_input, num_runs=10)
print(f"  With checkpoint: {stats_cp['mean']*1000:.2f} ms")

print(f"\nTradeoff: {stats_cp['mean']/stats_no_cp['mean']:.1f}x slower, but uses less memory")

### 3.3 Mixed Precision Training

In [None]:
print("Mixed Precision Training")
print("=" * 60)

# Full precision (float32)
def model_fp32(x, w):
    return jnp.matmul(x.astype(jnp.float32), w.astype(jnp.float32))

# Half precision (float16) - faster but less precise
def model_fp16(x, w):
    result = jnp.matmul(x.astype(jnp.float16), w.astype(jnp.float16))
    return result.astype(jnp.float32)  # Cast back to fp32

# Test data
x = bst.random.randn(1000, 1000)
w = bst.random.randn(1000, 1000)

# Memory comparison
fp32_memory = x.astype(jnp.float32).nbytes + w.astype(jnp.float32).nbytes
fp16_memory = x.astype(jnp.float16).nbytes + w.astype(jnp.float16).nbytes

print("\nMemory Usage:")
print(f"FP32: {fp32_memory / 1024 / 1024:.2f} MB")
print(f"FP16: {fp16_memory / 1024 / 1024:.2f} MB")
print(f"Memory savings: {(1 - fp16_memory/fp32_memory)*100:.1f}%")

# Speed comparison
print("\nSpeed Comparison:")
fp32_stats = benchmark(model_fp32, x, w, num_runs=20)
print(f"FP32: {fp32_stats['mean']*1000:.2f} ms")

fp16_stats = benchmark(model_fp16, x, w, num_runs=20)
print(f"FP16: {fp16_stats['mean']*1000:.2f} ms")
print(f"Speedup: {fp32_stats['mean']/fp16_stats['mean']:.2f}x")

# Accuracy comparison
result_fp32 = model_fp32(x, w)
result_fp16 = model_fp16(x, w)
error = jnp.abs(result_fp32 - result_fp16).max()
print(f"\nMax error: {error:.2e}")

## 4. Vectorization with vmap

### 4.1 vmap vs Loop Performance

In [None]:
print("Vectorization with vmap")
print("=" * 60)

# Function that operates on single example
def process_single(x, w):
    return jnp.tanh(jnp.dot(x, w))

# Process batch with Python loop
def process_batch_loop(batch, w):
    results = []
    for x in batch:
        results.append(process_single(x, w))
    return jnp.stack(results)

# Process batch with vmap
process_batch_vmap = jax.vmap(process_single, in_axes=(0, None))

# Test data
batch = bst.random.randn(100, 128)
w = bst.random.randn(128, 64)

print("\nProcessing 100 samples:")

loop_stats = benchmark(process_batch_loop, batch, w, num_runs=50)
print(f"Loop:  {loop_stats['mean']*1000:.2f} ms")

vmap_stats = benchmark(process_batch_vmap, batch, w, num_runs=50)
print(f"vmap:  {vmap_stats['mean']*1000:.2f} ms")

speedup = loop_stats['mean'] / vmap_stats['mean']
print(f"\nSpeedup: {speedup:.1f}x")
print(f"\n⚡ vmap provides automatic vectorization!")

### 4.2 Nested vmap for Batch Processing

In [None]:
print("Nested vmap Example")
print("=" * 60)

# Function for single pair of vectors
def pairwise_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

# Compute all pairwise distances
# First vmap over x, second vmap over y
pairwise_distances = jax.vmap(
    jax.vmap(pairwise_distance, in_axes=(None, 0)),
    in_axes=(0, None)
)

# Test
X = bst.random.randn(50, 10)  # 50 points in 10D
Y = bst.random.randn(30, 10)  # 30 points in 10D

distances = pairwise_distances(X, Y)
print(f"Input shapes: X={X.shape}, Y={Y.shape}")
print(f"Distance matrix shape: {distances.shape}")
print(f"Expected: (50, 30) ✓")

# Benchmark
stats = benchmark(pairwise_distances, X, Y, num_runs=50)
print(f"\nComputed {distances.size} distances in {stats['mean']*1000:.2f} ms")

## 5. Parallel Processing Strategies

### 5.1 pmap for Multi-Device Parallelism

In [None]:
print("Parallel Processing with pmap")
print("=" * 60)

# Check available devices
devices = jax.devices()
print(f"\nAvailable devices: {len(devices)}")
for i, device in enumerate(devices):
    print(f"  Device {i}: {device}")

if len(devices) > 1:
    print("\nUsing pmap for multi-device parallelism:")
    
    # Function to process on single device
    def train_step(batch):
        return jnp.sum(batch ** 2)
    
    # Parallelize across devices
    train_step_pmap = jax.pmap(train_step)
    
    # Create data sharded across devices
    num_devices = len(devices)
    batch_per_device = 32
    data = bst.random.randn(num_devices, batch_per_device, 100)
    
    # Process in parallel
    results = train_step_pmap(data)
    print(f"Results shape: {results.shape}")
    print(f"One result per device: {len(results)} results")
else:
    print("\nSingle device detected. pmap would replicate computation.")
    print("In production with multiple GPUs/TPUs, pmap enables:")
    print("  - Data parallelism")
    print("  - Model parallelism")
    print("  - Pipeline parallelism")

### 5.2 Batch Size Optimization

In [None]:
print("Finding Optimal Batch Size")
print("=" * 60)

model = SimpleModel(784, 256, 10)
x_init = bst.random.randn(1, 784)
_ = model(x_init)

# Test different batch sizes
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
throughputs = []
latencies = []

print("\nBatch Size Analysis:")
print(f"{'Batch Size':<12} {'Latency (ms)':<15} {'Throughput':<20} {'Efficiency'}")
print("-" * 70)

for bs in batch_sizes:
    x = bst.random.randn(bs, 784)
    stats = benchmark(model, x, num_runs=30, warmup=5)
    
    latency = stats['mean'] * 1000  # ms
    throughput = bs / stats['mean']  # samples/sec
    efficiency = throughput / bs  # normalized throughput
    
    latencies.append(latency)
    throughputs.append(throughput)
    
    print(f"{bs:<12} {latency:<15.2f} {throughput:<20.1f} {efficiency:.3f}")

# Find optimal batch size (highest throughput)
optimal_idx = np.argmax(throughputs)
optimal_bs = batch_sizes[optimal_idx]
print(f"\n✓ Optimal batch size: {optimal_bs}")

# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(batch_sizes, throughputs, 'o-', linewidth=2, markersize=8)
ax.axvline(optimal_bs, color='red', linestyle='--', alpha=0.7, 
           label=f'Optimal: {optimal_bs}')
ax.set_xlabel('Batch Size')
ax.set_ylabel('Throughput (samples/sec)')
ax.set_title('Throughput vs Batch Size')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

## 6. Advanced Optimization Techniques

### 6.1 Fused Operations

In [None]:
print("Fused Operations for Better Performance")
print("=" * 60)

x = bst.random.randn(1000, 1000)
w = bst.random.randn(1000, 1000)
b = bst.random.randn(1000)

# Unfused: multiple operations
def unfused_linear(x, w, b):
    temp = jnp.matmul(x, w)
    temp = temp + b
    temp = jax.nn.relu(temp)
    return temp

# Fused: combined into single operation
def fused_linear(x, w, b):
    return jax.nn.relu(jnp.matmul(x, w) + b)

# JIT will automatically fuse operations
unfused_jit = jax.jit(unfused_linear)
fused_jit = jax.jit(fused_linear)

print("\nOperation fusion (with JIT):")
unfused_stats = benchmark(unfused_jit, x, w, b, num_runs=50)
print(f"Unfused: {unfused_stats['mean']*1000:.2f} ms")

fused_stats = benchmark(fused_jit, x, w, b, num_runs=50)
print(f"Fused:   {fused_stats['mean']*1000:.2f} ms")

print("\n✓ JIT compiler automatically fuses operations")
print("  Write clean code; let JIT optimize!")

### 6.2 Efficient Scanning with lax.scan

In [None]:
print("Efficient Loops with lax.scan")
print("=" * 60)

# Inefficient: Python loop (can't JIT well)
def rnn_loop_python(inputs, h0, w_h, w_x):
    h = h0
    outputs = []
    for x in inputs:
        h = jnp.tanh(jnp.dot(h, w_h) + jnp.dot(x, w_x))
        outputs.append(h)
    return jnp.stack(outputs), h

# Efficient: lax.scan
def rnn_scan(inputs, h0, w_h, w_x):
    def step(h, x):
        h_new = jnp.tanh(jnp.dot(h, w_h) + jnp.dot(x, w_x))
        return h_new, h_new
    
    final_h, outputs = jax.lax.scan(step, h0, inputs)
    return outputs, final_h

# Test data
seq_length = 100
hidden_size = 128
input_size = 64

inputs = bst.random.randn(seq_length, input_size)
h0 = bst.random.randn(hidden_size)
w_h = bst.random.randn(hidden_size, hidden_size) * 0.01
w_x = bst.random.randn(input_size, hidden_size) * 0.01

print(f"\nProcessing sequence of length {seq_length}:")

# Note: Python loop version is slow to compile with JIT
scan_jit = jax.jit(rnn_scan, static_argnums=())

# Warmup
_ = scan_jit(inputs, h0, w_h, w_x)

scan_stats = benchmark(scan_jit, inputs, h0, w_h, w_x, num_runs=30)
print(f"lax.scan: {scan_stats['mean']*1000:.2f} ms")

print("\n✓ lax.scan is the idiomatic way to write loops in JAX")
print("  Benefits:")
print("  - JIT-friendly")
print("  - Constant memory (doesn't accumulate)")
print("  - Fast compilation")

## 7. Performance Checklist

### Summary of Optimization Techniques

In [None]:
print("Performance Optimization Checklist")
print("=" * 80)

checklist = [
    ("✓", "Use @jax.jit or @bst.transform.jit for performance-critical code"),
    ("✓", "Mark control-flow arguments as static (static_argnums)"),
    ("✓", "Use vmap instead of Python loops for batch processing"),
    ("✓", "Use lax.scan for sequential operations"),
    ("✓", "Profile code to identify bottlenecks"),
    ("✓", "Choose appropriate batch sizes (benchmark different sizes)"),
    ("✓", "Consider mixed precision (FP16) for memory/speed"),
    ("✓", "Use gradient checkpointing for very deep networks"),
    ("✓", "Minimize host-device data transfers"),
    ("✓", "Use pmap for multi-device parallelism when available"),
    ("✓", "Avoid unnecessary data copies (use views when possible)"),
    ("✓", "Fuse operations (let JIT compiler optimize)"),
    ("✓", "Reuse compiled functions (avoid recompilation)"),
    ("✓", "Monitor memory usage and optimize state management"),
]

for check, item in checklist:
    print(f"{check} {item}")

print("\n" + "=" * 80)
print("Performance Tips by Priority:")
print("=" * 80)

priorities = [
    ("High", "JIT compilation", "5-100x speedup"),
    ("High", "Vectorization (vmap)", "10-100x speedup"),
    ("High", "Batch size optimization", "2-5x speedup"),
    ("Medium", "Mixed precision", "1.5-2x speedup, 50% memory"),
    ("Medium", "Gradient checkpointing", "50% memory, -20% speed"),
    ("Medium", "lax.scan for loops", "2-10x speedup"),
    ("Low", "Multi-device (pmap)", "Nx speedup (N devices)"),
]

print(f"{'Priority':<10} {'Technique':<30} {'Impact':<20}")
print("-" * 60)
for priority, technique, impact in priorities:
    print(f"{priority:<10} {technique:<30} {impact:<20}")

## Summary

In this tutorial, we covered:

1. **Performance Profiling**:
   - Timing and benchmarking functions
   - Model memory estimation
   - Identifying bottlenecks

2. **JIT Compilation**:
   - JIT vs non-JIT comparison
   - Static vs dynamic arguments
   - Best practices for JIT

3. **Memory Optimization**:
   - Memory profiling
   - Gradient checkpointing
   - Mixed precision training

4. **Vectorization**:
   - vmap for batch processing
   - Nested vmap
   - Performance gains

5. **Parallel Processing**:
   - pmap for multi-device
   - Batch size optimization
   - Throughput analysis

6. **Advanced Techniques**:
   - Fused operations
   - Efficient scanning with lax.scan
   - Performance checklist

### Key Takeaways:

- **Always use JIT** for production code (5-100x speedup)
- **Vectorize with vmap** instead of Python loops
- **Profile before optimizing** to find real bottlenecks
- **Batch size matters** - benchmark to find optimal
- **Trade memory for speed** with checkpointing when needed
- **Use lax.scan** for sequential operations

## Next Steps

- Profile your own models
- Experiment with different optimization techniques
- Monitor performance in production
- Learn advanced JAX transformations

For more information:
- [JAX Performance Tips](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)
- [BrainState Documentation](https://brainstate.readthedocs.io/)