# Tutorial 13: Advanced Transformations

In this final tutorial on program transformations, we'll explore advanced techniques for optimizing memory, debugging, and monitoring your BrainState programs.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Use gradient checkpointing (remat) to reduce memory usage
- Initialize models abstractly without allocating memory
- Inspect compiled computations with `make_jaxpr`
- Add progress bars to long-running computations
- Understand computational graphs
- Profile and optimize transformation pipelines

## What Are Advanced Transformations?

Beyond the core transformations (JIT, grad, vmap), JAX and BrainState provide specialized tools for:
- Memory optimization
- Debugging and introspection
- User experience improvements
- Performance profiling

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
import time

# Set random seed
bst.random.seed(42)

## 1. Gradient Checkpointing (Rematerialization)

Gradient checkpointing trades computation for memory by recomputing intermediate values during backprop instead of storing them.

### The Memory Problem

In [None]:
# Deep network - stores all activations for backprop
def deep_network(x, n_layers=10):
    """Each layer's activations are stored for gradient computation."""
    for i in range(n_layers):
        x = jnp.tanh(x)  # Activation stored in memory
        x = x @ jnp.eye(x.shape[-1])  # Linear transform
    return jnp.sum(x ** 2)

# Gradient function
grad_fn = jax.grad(deep_network)

x = bst.random.randn(100, 100)
grads = grad_fn(x, n_layers=100)

print("Without checkpointing: All intermediate activations stored in memory")
print("Memory usage grows linearly with network depth")
print(f"For {100} layers, storing ~{100 * 100 * 100 * 4 / 1e6:.1f} MB of activations")

### Solution: Gradient Checkpointing

In [None]:
# Use jax.checkpoint (formerly jax.remat)
def deep_network_checkpointed(x, n_layers=10):
    """Checkpoint every few layers to save memory."""
    checkpoint_every = 10
    
    for i in range(n_layers):
        x = jnp.tanh(x)
        x = x @ jnp.eye(x.shape[-1])
        
        # Checkpoint at intervals
        if (i + 1) % checkpoint_every == 0:
            x = jax.checkpoint(lambda y: y)(x)
    
    return jnp.sum(x ** 2)

# Gradient with checkpointing
grad_fn_checkpoint = jax.grad(deep_network_checkpointed)
grads_checkpoint = grad_fn_checkpoint(x, n_layers=100)

print("With checkpointing: Only checkpoint activations stored")
print("Other activations recomputed during backprop")
print("Trade: 2x computation for ~10x less memory")

### Practical Example: Transformer Block with Checkpointing

In [None]:
# Simplified transformer block
class TransformerBlock(bst.graph.Node):
    def __init__(self, dim, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint
        self.W_q = bst.ParamState(bst.random.randn(dim, dim) * 0.1)
        self.W_k = bst.ParamState(bst.random.randn(dim, dim) * 0.1)
        self.W_v = bst.ParamState(bst.random.randn(dim, dim) * 0.1)
        self.W_o = bst.ParamState(bst.random.randn(dim, dim) * 0.1)
        self.W_ff1 = bst.ParamState(bst.random.randn(dim, 4*dim) * 0.1)
        self.W_ff2 = bst.ParamState(bst.random.randn(4*dim, dim) * 0.1)
    
    def attention(self, x):
        """Self-attention mechanism."""
        Q = x @ self.W_q.value
        K = x @ self.W_k.value
        V = x @ self.W_v.value
        
        scores = Q @ K.T / jnp.sqrt(self.dim)
        attn = jax.nn.softmax(scores, axis=-1)
        out = attn @ V
        return out @ self.W_o.value
    
    def feedforward(self, x):
        """Feedforward network."""
        h = jax.nn.relu(x @ self.W_ff1.value)
        return h @ self.W_ff2.value
    
    def __call__(self, x):
        # Attention with residual
        attn_fn = self.attention
        if self.use_checkpoint:
            attn_fn = jax.checkpoint(self.attention)
        x = x + attn_fn(x)
        
        # Feedforward with residual
        ff_fn = self.feedforward
        if self.use_checkpoint:
            ff_fn = jax.checkpoint(self.feedforward)
        x = x + ff_fn(x)
        
        return x

# Compare memory usage (conceptual)
block_normal = TransformerBlock(dim=64, use_checkpoint=False)
block_checkpoint = TransformerBlock(dim=64, use_checkpoint=True)

x = bst.random.randn(32, 64)  # (seq_len, dim)

print("Transformer Block Comparison:")
print("  Without checkpoint: Stores all attention scores and FF activations")
print("  With checkpoint: Recomputes attention and FF during backprop")
print("  \nRecommendation: Use for deep transformers (>12 layers)")

## 2. Abstract Initialization

Initialize models without allocating actual arrays - useful for inspecting shapes and structure.

In [None]:
# Abstract shapes without concrete values
from jax import ShapeDtypeStruct

def init_model_abstract(input_shape):
    """Initialize model with abstract shapes only."""
    # Create abstract input
    abstract_x = ShapeDtypeStruct(input_shape, jnp.float32)
    
    # Trace model initialization
    def trace_shapes(x):
        # Simulate model
        h1 = jnp.zeros((x.shape[0], 128))  # First layer
        h2 = jnp.zeros((h1.shape[0], 64))  # Second layer
        out = jnp.zeros((h2.shape[0], 10)) # Output layer
        return {'h1': h1.shape, 'h2': h2.shape, 'out': out.shape}
    
    # Get shapes without allocating memory
    shapes = jax.eval_shape(trace_shapes, abstract_x)
    return shapes

# Get model structure
model_shapes = init_model_abstract((32, 784))
print("Model architecture (shapes only, no memory allocated):")
for name, shape in model_shapes.items():
    print(f"  {name}: {shape}")

print("\nUse case: Validate model architecture before training")

### Practical: Shape Inference for Complex Models

In [None]:
# Complex model with dynamic shapes
class DynamicNet(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.conv1 = bst.nn.Conv2d(3, 16, kernel_size=(3, 3))
        self.conv2 = bst.nn.Conv2d(16, 32, kernel_size=(3, 3))
        # Linear layer size depends on input size
    
    def __call__(self, x):
        x = jax.nn.relu(self.conv1(x))
        x = jax.nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = jax.nn.relu(self.conv2(x))
        x = jax.nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape(x.shape[0], -1)  # Flatten
        return x

model = DynamicNet()

# Find output shape for different input sizes
input_sizes = [(1, 3, 32, 32), (1, 3, 64, 64), (1, 3, 128, 128)]

print("Output shapes for different input sizes:")
for input_size in input_sizes:
    abstract_input = ShapeDtypeStruct(input_size, jnp.float32)
    output_shape = jax.eval_shape(model, abstract_input)
    print(f"  Input {input_size} -> Output {output_shape.shape}")
    print(f"    Flattened size: {output_shape.shape[1]} features")

## 3. Make Jaxpr: Inspecting Compiled Computations

`make_jaxpr` shows you the internal representation of your computation.

In [None]:
# Simple function
def simple_fn(x, y):
    return x ** 2 + y ** 2

# Get jaxpr (JAX expression)
jaxpr = jax.make_jaxpr(simple_fn)(3.0, 4.0)
print("Jaxpr for x² + y²:")
print(jaxpr)
print("\nJaxpr shows primitive operations JAX will execute")

### Understanding Jaxpr for Debugging

In [None]:
# More complex example
def neural_layer(x, W, b):
    return jax.nn.relu(x @ W + b)

# Create jaxpr
x_ex = jnp.ones((2, 3))
W_ex = jnp.ones((3, 4))
b_ex = jnp.zeros(4)

jaxpr_layer = jax.make_jaxpr(neural_layer)(x_ex, W_ex, b_ex)
print("Jaxpr for ReLU(xW + b):")
print(jaxpr_layer)
print("\nUse cases:")
print("  - Verify computation is what you expect")
print("  - Debug performance issues")
print("  - Understand how transforms affect code")

### Jaxpr with Transformations

In [None]:
# Compare original vs JIT vs vmap
def square_sum(x):
    return jnp.sum(x ** 2)

x_single = jnp.array([1.0, 2.0, 3.0])
x_batch = jnp.array([[1.0, 2.0, 3.0],
                     [4.0, 5.0, 6.0]])

print("Original function:")
print(jax.make_jaxpr(square_sum)(x_single))

print("\n" + "="*50)
print("With vmap:")
vmapped = jax.vmap(square_sum)
print(jax.make_jaxpr(vmapped)(x_batch))

print("\nNotice how vmap changes the computation graph")

## 4. Progress Bars for Long Computations

Add progress tracking to training loops and long-running operations.

In [None]:
# Manual progress tracking
def train_with_progress(n_epochs):
    """Training loop with progress updates."""
    print("Training Progress:")
    print("[" + " " * 50 + "] 0%", end='\r')
    
    for epoch in range(n_epochs):
        # Simulate training
        time.sleep(0.05)
        
        # Update progress
        progress = (epoch + 1) / n_epochs
        filled = int(50 * progress)
        bar = "#" * filled + " " * (50 - filled)
        print(f"[{bar}] {progress*100:.0f}%", end='\r')
    
    print("\nTraining complete!")

train_with_progress(20)

### Using tqdm for Better Progress Bars

In [None]:
# Try to import tqdm
try:
    from tqdm import tqdm
    has_tqdm = True
except ImportError:
    has_tqdm = False
    print("tqdm not available. Install with: pip install tqdm")

if has_tqdm:
    # Training with tqdm
    def train_with_tqdm(n_epochs):
        losses = []
        
        with tqdm(total=n_epochs, desc="Training") as pbar:
            for epoch in range(n_epochs):
                # Simulate training
                loss = 1.0 / (epoch + 1)  # Decreasing loss
                losses.append(loss)
                
                # Update progress with metrics
                pbar.set_postfix({'loss': f'{loss:.4f}'})
                pbar.update(1)
                
                time.sleep(0.02)
        
        return losses
    
    losses = train_with_tqdm(30)
    
    # Plot training curve
    plt.figure(figsize=(8, 4))
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("Skipping tqdm example")

## 5. Host Callback for Debugging

Use host callbacks to inspect values during JIT execution.

In [ ]:
# Debug with host callback
from jax.experimental import io_callback

def debug_print(x, name):
    """Print value during JIT execution."""
    def _print(x):
        print(f"{name}: {x}")
    
    io_callback(_print, None, x)
    return x

@jax.jit
def computation_with_debug(x):
    x = x * 2
    x = debug_print(x, "After multiply")
    
    x = jnp.tanh(x)
    x = debug_print(x, "After tanh")
    
    return jnp.sum(x)

result = computation_with_debug(jnp.array([1.0, 2.0, 3.0]))
print(f"\nFinal result: {result}")

## 6. Custom Pretty Printing for Models

In [None]:
# Model summary utility
class SummaryNet(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.conv1 = bst.nn.Conv2d(3, 16, kernel_size=(3, 3))
        self.conv2 = bst.nn.Conv2d(16, 32, kernel_size=(3, 3))
        self.linear1 = bst.nn.Linear(32 * 6 * 6, 128)
        self.linear2 = bst.nn.Linear(128, 10)
    
    def __call__(self, x):
        x = jax.nn.relu(self.conv1(x))
        x = jax.nn.max_pool(x, (2, 2), (2, 2))
        x = jax.nn.relu(self.conv2(x))
        x = jax.nn.max_pool(x, (2, 2), (2, 2))
        x = x.reshape(x.shape[0], -1)
        x = jax.nn.relu(self.linear1(x))
        x = self.linear2(x)
        return x

def print_model_summary(model, input_shape):
    """Print model architecture summary."""
    print("=" * 70)
    print(f"Model: {model.__class__.__name__}")
    print("=" * 70)
    print(f"{'Layer':<20} {'Output Shape':<25} {'Param #':<15}")
    print("-" * 70)
    
    # Get all parameters
    total_params = 0
    param_states = model.states(bst.ParamState)
    
    for name, state in param_states.items():
        params = state.value.size
        total_params += params
        print(f"{name:<20} {str(state.value.shape):<25} {params:<15,}")
    
    print("=" * 70)
    print(f"Total parameters: {total_params:,}")
    print(f"Total memory: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")
    print("=" * 70)

# Test
model = SummaryNet()
# Initialize model
_ = model(bst.random.randn(1, 3, 32, 32))
print_model_summary(model, (1, 3, 32, 32))

## 7. Profiling and Performance Analysis

In [None]:
# Simple profiling utility
class Timer:
    def __init__(self, name):
        self.name = name
    
    def __enter__(self):
        self.start = time.time()
        return self
    
    def __exit__(self, *args):
        self.elapsed = time.time() - self.start
        print(f"{self.name}: {self.elapsed*1000:.2f} ms")

# Profile different operations
x = bst.random.randn(1000, 1000)
y = bst.random.randn(1000, 1000)

with Timer("Matrix multiplication (JIT warmup)"):
    @jax.jit
    def matmul(a, b):
        return a @ b
    result = matmul(x, y)
    result.block_until_ready()

with Timer("Matrix multiplication (JIT warm)"):
    result = matmul(x, y)
    result.block_until_ready()

with Timer("Element-wise operations"):
    result = jnp.sin(x) ** 2 + jnp.cos(x) ** 2
    result.block_until_ready()

with Timer("Reduction operations"):
    result = jnp.sum(x, axis=0)
    result.block_until_ready()

### Comparative Benchmark

In [None]:
# Benchmark suite
def benchmark_transforms(size=1000, n_runs=10):
    """Compare performance of different transformation strategies."""
    x = bst.random.randn(size, size)
    
    # Function to test
    def compute(x):
        for _ in range(10):
            x = jnp.tanh(x)
        return jnp.sum(x)
    
    results = {}
    
    # 1. No JIT
    times = []
    for _ in range(n_runs):
        start = time.time()
        _ = compute(x)
        times.append(time.time() - start)
    results['No JIT'] = np.mean(times) * 1000
    
    # 2. With JIT
    compute_jit = jax.jit(compute)
    _ = compute_jit(x)  # Warmup
    times = []
    for _ in range(n_runs):
        start = time.time()
        result = compute_jit(x)
        result.block_until_ready()
        times.append(time.time() - start)
    results['JIT'] = np.mean(times) * 1000
    
    # 3. JIT + vmap (for batched version)
    x_batched = bst.random.randn(10, size, size)
    compute_jit_vmap = jax.jit(jax.vmap(compute))
    _ = compute_jit_vmap(x_batched)  # Warmup
    times = []
    for _ in range(n_runs):
        start = time.time()
        result = compute_jit_vmap(x_batched)
        result.block_until_ready()
        times.append(time.time() - start)
    results['JIT + vmap (10 batches)'] = np.mean(times) * 1000
    
    return results

# Run benchmark
print("Running benchmarks...")
benchmark_results = benchmark_transforms(size=500, n_runs=10)

print("\nBenchmark Results (average time):")
print("-" * 50)
for name, time_ms in benchmark_results.items():
    print(f"{name:<30}: {time_ms:>8.2f} ms")

# Visualize
plt.figure(figsize=(10, 5))
names = list(benchmark_results.keys())
times = list(benchmark_results.values())
plt.bar(names, times)
plt.ylabel('Time (ms)')
plt.title('Transformation Performance Comparison')
plt.xticks(rotation=15, ha='right')
plt.tight_layout()
plt.show()

## 8. Putting It All Together: Production Training Loop

In [None]:
# Production-ready training loop with all optimizations
class ProductionTrainer:
    def __init__(self, model, use_checkpoint=True, show_progress=True):
        self.model = model
        self.use_checkpoint = use_checkpoint
        self.show_progress = show_progress
        
        # Compile training step
        self.train_step_fn = jax.jit(self._train_step)
    
    def _train_step(self, x, y, learning_rate):
        """Single training step."""
        def loss_fn():
            pred = self.model(x)
            if self.use_checkpoint:
                pred = jax.checkpoint(lambda p: p)(pred)
            return jnp.mean((pred - y) ** 2)
        
        # Compute gradients
        loss, grads = bst.augment.grad(
            loss_fn, 
            self.model.states(bst.ParamState),
            return_value=True
        )()
        
        # Update parameters
        for key, grad in grads.items():
            self.model.states()[key].value -= learning_rate * grad
        
        return loss
    
    def train(self, train_data, n_epochs, learning_rate=0.01):
        """Full training loop."""
        history = {'loss': [], 'time': []}
        
        iterator = range(n_epochs)
        if self.show_progress and has_tqdm:
            iterator = tqdm(iterator, desc="Training")
        
        for epoch in iterator:
            epoch_start = time.time()
            
            # Training step
            x, y = train_data
            loss = self.train_step_fn(x, y, learning_rate)
            
            # Track metrics
            history['loss'].append(float(loss))
            history['time'].append(time.time() - epoch_start)
            
            # Update progress
            if self.show_progress and has_tqdm:
                iterator.set_postfix({'loss': f'{loss:.4f}'})
        
        return history

# Demo
class SimpleModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.linear = bst.nn.Linear(10, 5)
    
    def __call__(self, x):
        return self.linear(x)

model = SimpleModel()
trainer = ProductionTrainer(model, use_checkpoint=True, show_progress=has_tqdm)

# Generate dummy data
x_train = bst.random.randn(32, 10)
y_train = bst.random.randn(32, 5)

# Train
history = trainer.train((x_train, y_train), n_epochs=50, learning_rate=0.01)

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['loss'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(history['time'])
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Time (s)')
ax2.set_title('Time per Epoch')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTraining completed!")
print(f"Final loss: {history['loss'][-1]:.4f}")
print(f"Average time per epoch: {np.mean(history['time'])*1000:.2f} ms")

## Summary

In this tutorial, we covered advanced transformations:

1. **Gradient Checkpointing**: Trade computation for memory in deep networks
2. **Abstract Initialization**: Inspect shapes without allocating memory
3. **Make Jaxpr**: Understand compiled computation graphs
4. **Progress Bars**: Improve user experience during training
5. **Host Callbacks**: Debug JIT-compiled code
6. **Model Summaries**: Pretty-print model architecture
7. **Profiling**: Measure and optimize performance
8. **Production Training**: Combine all techniques

## Key Takeaways

- **Gradient checkpointing** is essential for very deep networks
- **Abstract initialization** helps validate architectures efficiently
- **make_jaxpr** is invaluable for debugging transformations
- **Progress bars** greatly improve user experience
- **Profiling** helps identify bottlenecks
- Combine techniques for production-ready code

## Best Practices

1. Use gradient checkpointing for networks with >50 layers
2. Always profile before optimizing
3. Add progress bars for long-running operations
4. Use abstract initialization to validate shapes
5. Combine JIT + vmap for best performance
6. Monitor memory usage in training loops

## Congratulations!

You've completed all the transformation tutorials! You now know:
- JIT compilation and optimization
- Automatic differentiation
- Vectorization with vmap/pmap
- Control flow primitives
- Advanced transformations

These tools form the foundation for building efficient, scalable neural networks with BrainState!