# Tutorial 23: Migration from PyTorch to BrainState

In this tutorial, we'll learn how to migrate models and code from PyTorch to BrainState.

## Learning Objectives

By the end of this tutorial, you will:
- Understand the key differences between PyTorch and BrainState
- Learn API mappings between the two frameworks
- Convert PyTorch models to BrainState
- Migrate training loops and optimizers
- Avoid common pitfalls during migration
- Follow best practices for BrainState development

## Introduction

BrainState and PyTorch share similar goals but have different design philosophies:

**PyTorch**:
- Imperative, eager execution
- Object-oriented with mutable state
- Dynamic computation graphs
- CPU/GPU with manual device management

**BrainState**:
- Built on JAX (functional + transformations)
- Explicit state management with immutable data
- JIT compilation for performance
- Automatic device placement
- Designed for brain modeling and neuroscience

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional, List

# Set random seed
bst.random.seed(42)
np.random.seed(42)

## 1. API Comparison

### 1.1 Module Definition

Let's compare how to define a simple neural network module.

#### PyTorch Version

```python
import torch
import torch.nn as nn

class PyTorchMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Usage
model = PyTorchMLP(784, 128, 10)
output = model(torch.randn(32, 784))
```

#### BrainState Version

In [None]:
class BrainStateMLP(bst.graph.Node):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x):
        x = jax.nn.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Usage
model = BrainStateMLP(784, 128, 10)
output = model(bst.random.randn(32, 784))
print(f"Output shape: {output.shape}")

**Key Differences**:
1. **Base class**: `nn.Module` → `bst.graph.Node`
2. **Method name**: `forward()` → `__call__()`
3. **Activation functions**: `torch.relu()` → `jax.nn.relu()`
4. **Random numbers**: `torch.randn()` → `bst.random.randn()`

### 1.2 Layer Mappings

Common layer translations:

In [None]:
# Create a mapping table for visualization
layer_mappings = [
    ("nn.Linear", "bst.nn.Linear", "Fully connected layer"),
    ("nn.Conv2d", "bst.nn.Conv2d", "2D convolution"),
    ("nn.BatchNorm2d", "bst.nn.BatchNorm2d", "Batch normalization"),
    ("nn.LayerNorm", "bst.nn.LayerNorm", "Layer normalization"),
    ("nn.Dropout", "bst.nn.Dropout", "Dropout regularization"),
    ("nn.RNN", "bst.nn.RNNCell", "Recurrent layer (use Cell)"),
    ("nn.LSTM", "bst.nn.LSTMCell", "LSTM layer (use Cell)"),
    ("nn.GRU", "bst.nn.GRUCell", "GRU layer (use Cell)"),
]

print("PyTorch → BrainState Layer Mappings")
print("=" * 80)
print(f"{'PyTorch':<20} {'BrainState':<25} {'Description':<35}")
print("-" * 80)
for pytorch, brainstate, desc in layer_mappings:
    print(f"{pytorch:<20} {brainstate:<25} {desc:<35}")

### 1.3 Activation Functions

In [None]:
activation_mappings = [
    ("torch.relu", "jax.nn.relu"),
    ("torch.sigmoid", "jax.nn.sigmoid"),
    ("torch.tanh", "jax.nn.tanh"),
    ("torch.softmax", "jax.nn.softmax"),
    ("F.gelu", "jax.nn.gelu"),
    ("F.leaky_relu", "jax.nn.leaky_relu"),
    ("F.elu", "jax.nn.elu"),
    ("F.silu", "jax.nn.silu"),
]

print("\nActivation Function Mappings")
print("=" * 60)
print(f"{'PyTorch':<30} {'BrainState/JAX':<30}")
print("-" * 60)
for pytorch, jax_fn in activation_mappings:
    print(f"{pytorch:<30} {jax_fn:<30}")

# Test activations
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
print("\nExample activations on [-2, -1, 0, 1, 2]:")
print(f"ReLU: {jax.nn.relu(x)}")
print(f"Sigmoid: {jax.nn.sigmoid(x)}")
print(f"Tanh: {jax.nn.tanh(x)}")

### 1.4 Tensor Operations

In [None]:
print("Common Tensor Operations")
print("=" * 80)

# Create sample tensors
x_jax = jnp.array([[1, 2, 3], [4, 5, 6]])

operations = [
    ("Tensor creation", "torch.tensor([1,2,3])", "jnp.array([1,2,3])"),
    ("Random normal", "torch.randn(2, 3)", "bst.random.randn(2, 3)"),
    ("Zeros", "torch.zeros(2, 3)", "jnp.zeros((2, 3))"),
    ("Ones", "torch.ones(2, 3)", "jnp.ones((2, 3))"),
    ("Reshape", "x.view(3, 2)", "x.reshape(3, 2)"),
    ("Transpose", "x.T", "x.T"),
    ("Concatenate", "torch.cat([x, y], dim=0)", "jnp.concatenate([x, y], axis=0)"),
    ("Stack", "torch.stack([x, y], dim=0)", "jnp.stack([x, y], axis=0)"),
    ("Mean", "x.mean()", "jnp.mean(x)"),
    ("Sum", "x.sum(dim=1)", "jnp.sum(x, axis=1)"),
    ("Max", "x.max(dim=1)", "jnp.max(x, axis=1)"),
    ("Argmax", "x.argmax(dim=1)", "jnp.argmax(x, axis=1)"),
]

print(f"{'Operation':<20} {'PyTorch':<30} {'JAX/BrainState':<30}")
print("-" * 80)
for op_name, pytorch, jax_code in operations:
    print(f"{op_name:<20} {pytorch:<30} {jax_code:<30}")

# Demonstrate some operations
print("\nExamples:")
print(f"Original: \n{x_jax}")
print(f"\nReshape (3, 2): \n{x_jax.reshape(3, 2)}")
print(f"\nSum along axis 1: {jnp.sum(x_jax, axis=1)}")
print(f"Argmax along axis 1: {jnp.argmax(x_jax, axis=1)}")

**Important Note**: In JAX/BrainState, use `axis` instead of `dim` for dimension specifications.

## 2. Converting PyTorch Models

Let's convert a complete PyTorch CNN model to BrainState.

### 2.1 PyTorch CNN (Reference)

```python
import torch.nn as nn
import torch.nn.functional as F

class PyTorchCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        # Conv block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        # Conv block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        # Flatten and FC
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
```

### 2.2 BrainState CNN (Converted)

In [None]:
class BrainStateCNN(bst.graph.Node):
    def __init__(self, num_classes=10):
        super().__init__()
        # Convolutional layers
        self.conv1 = bst.nn.Conv2d(1, 32, kernel_size=(3, 3), padding='SAME')
        self.bn1 = bst.nn.BatchNorm2d(32)
        self.conv2 = bst.nn.Conv2d(32, 64, kernel_size=(3, 3), padding='SAME')
        self.bn2 = bst.nn.BatchNorm2d(64)
        self.dropout = bst.nn.Dropout(0.25)
        
        # Fully connected layers
        self.fc1 = bst.nn.Linear(64 * 7 * 7, 128)
        self.fc2 = bst.nn.Linear(128, num_classes)
    
    def __call__(self, x):
        # Conv block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = jax.nn.relu(x)
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # Conv block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = jax.nn.relu(x)
        x = bst.functional.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # Flatten and FC
        x = x.reshape(x.shape[0], -1)  # Flatten
        x = self.dropout(x)
        x = jax.nn.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create and test the model
model = BrainStateCNN(num_classes=10)

# Test with dummy input (batch_size=2, channels=1, height=28, width=28)
dummy_input = bst.random.randn(2, 1, 28, 28)
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {sum(p.value.size for p in model.states(bst.ParamState).values())}")

**Key Conversion Points**:

1. **Base class**: `nn.Module` → `bst.graph.Node`
2. **Method**: `forward()` → `__call__()`
3. **Kernel size**: `kernel_size=3` → `kernel_size=(3, 3)` (tuple required)
4. **Padding**: `padding=1` → `padding='SAME'` (string or explicit padding)
5. **Pooling**: `nn.MaxPool2d(2, 2)` → `bst.functional.max_pool(...)`
6. **Flatten**: `x.view(...)` → `x.reshape(...)`
7. **Activation**: `F.relu(x)` → `jax.nn.relu(x)`

## 3. Training Loop Migration

### 3.1 PyTorch Training Loop (Reference)

```python
import torch.optim as optim

# PyTorch training
model = PyTorchCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for batch_x, batch_y in train_loader:
        # Forward pass
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
```

### 3.2 BrainState Training Loop

In [None]:
class Trainer:
    """BrainState trainer class."""
    
    def __init__(self, model: bst.graph.Node, learning_rate: float = 0.001):
        self.model = model
        self.learning_rate = learning_rate
        
        # Get trainable parameters
        self.params = model.states(bst.ParamState)
        
        # Initialize optimizer state (for Adam)
        self.m = {k: jnp.zeros_like(v.value) for k, v in self.params.items()}  # First moment
        self.v = {k: jnp.zeros_like(v.value) for k, v in self.params.items()}  # Second moment
        self.t = 0  # Time step
        
        # Adam hyperparameters
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.eps = 1e-8
    
    def loss_fn(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        """Compute cross-entropy loss."""
        logits = self.model(x)
        # Cross-entropy loss
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        loss = -jnp.mean(jnp.sum(y * log_probs, axis=-1))
        return loss
    
    def train_step(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
        """
        Single training step.
        
        Args:
            x: Input batch
            y: Target labels (one-hot encoded)
            
        Returns:
            Loss value
        """
        # Compute loss and gradients
        loss, grads = bst.transform.grad(self.loss_fn, grad_states=self.params, return_value=True)(x, y)
        
        # Adam optimizer update
        self.t += 1
        for key in self.params.keys():
            g = grads[key]
            
            # Update biased first moment estimate
            self.m[key] = self.beta1 * self.m[key] + (1 - self.beta1) * g
            
            # Update biased second moment estimate
            self.v[key] = self.beta2 * self.v[key] + (1 - self.beta2) * (g ** 2)
            
            # Compute bias-corrected moment estimates
            m_hat = self.m[key] / (1 - self.beta1 ** self.t)
            v_hat = self.v[key] / (1 - self.beta2 ** self.t)
            
            # Update parameters
            self.params[key].value = self.params[key].value - self.learning_rate * m_hat / (jnp.sqrt(v_hat) + self.eps)
        
        return float(loss)
    
    def train_epoch(self, train_data: List[Tuple[np.ndarray, np.ndarray]]) -> float:
        """
        Train for one epoch.
        
        Args:
            train_data: List of (x, y) tuples
            
        Returns:
            Average loss
        """
        total_loss = 0.0
        for x_batch, y_batch in train_data:
            loss = self.train_step(jnp.array(x_batch), jnp.array(y_batch))
            total_loss += loss
        return total_loss / len(train_data)

# Example usage
model = BrainStateCNN(num_classes=10)
trainer = Trainer(model, learning_rate=0.001)

# Generate dummy training data
num_batches = 5
batch_size = 32
train_data = [
    (np.random.randn(batch_size, 1, 28, 28), 
     np.eye(10)[np.random.randint(0, 10, batch_size)])  # One-hot labels
    for _ in range(num_batches)
]

# Train for a few epochs
for epoch in range(3):
    avg_loss = trainer.train_epoch(train_data)
    print(f"Epoch {epoch + 1}: Average Loss = {avg_loss:.4f}")

**Key Differences in Training**:

1. **No `.backward()`**: Use `bst.transform.grad()` for automatic differentiation
2. **Explicit state management**: Access parameters via `model.states(bst.ParamState)`
3. **Manual optimizer**: Implement optimizer logic explicitly (or use optax library)
4. **Functional approach**: Loss function takes inputs and computes output
5. **One-hot encoding**: Labels should be one-hot encoded for cross-entropy

## 4. Common Pitfalls and Solutions

### 4.1 Mutable vs Immutable State

In [None]:
print("Pitfall 1: Mutable vs Immutable State")
print("=" * 60)

# WRONG: JAX arrays are immutable
print("\n❌ WRONG (PyTorch-style):")
print("""
x = jnp.array([1, 2, 3])
x[0] = 10  # TypeError: JAX arrays are immutable
""")

# CORRECT: Create new array
print("✓ CORRECT (BrainState way):")
x = jnp.array([1, 2, 3])
x_new = x.at[0].set(10)  # Returns new array
print(f"Original: {x}")
print(f"Modified: {x_new}")

# For model parameters, use .value assignment
print("\n✓ For model parameters:")
print("""
# Create a parameter state
param = bst.ParamState(jnp.zeros(10))

# Update parameter value
param.value = jnp.ones(10)  # This is correct
""")

### 4.2 Device Management

In [None]:
print("\nPitfall 2: Device Management")
print("=" * 60)

# PyTorch requires explicit device management
print("\nPyTorch (manual):")
print("""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = x.to(device)
""")

# BrainState/JAX handles devices automatically
print("BrainState/JAX (automatic):")
print("""
# JAX automatically uses GPU if available
model = BrainStateCNN()
x = bst.random.randn(32, 1, 28, 28)
output = model(x)  # Automatically runs on best available device
""")

# Check available devices
devices = jax.devices()
print(f"\nAvailable devices: {devices}")
print(f"Default device: {devices[0].device_kind}")

### 4.3 Dimension/Axis Naming

In [None]:
print("Pitfall 3: Dimension vs Axis")
print("=" * 60)

x = jnp.array([[1, 2, 3], [4, 5, 6]])

# PyTorch uses 'dim'
print("\nPyTorch:")
print("x.sum(dim=1)  # Sum along dimension 1")

# JAX uses 'axis'
print("\nJAX/BrainState:")
result = jnp.sum(x, axis=1)
print(f"jnp.sum(x, axis=1) = {result}")

# Multiple axes
x3d = bst.random.randn(2, 3, 4)
print(f"\nSum over axes (0, 2): {jnp.sum(x3d, axis=(0, 2)).shape}")

### 4.4 In-place Operations

In [None]:
print("Pitfall 4: In-place Operations")
print("=" * 60)

print("\nPyTorch (in-place allowed):")
print("""
x.add_(1)  # In-place addition
x.relu_()  # In-place ReLU
""")

print("JAX/BrainState (no in-place, functional):")
x = jnp.array([1.0, 2.0, 3.0])
x_new = x + 1  # Returns new array
x_relu = jax.nn.relu(x)  # Returns new array
print(f"Original: {x}")
print(f"After +1: {x_new}")
print(f"After ReLU: {x_relu}")

print("\n✓ For efficient updates, JAX has .at syntax:")
x = jnp.array([1, 2, 3, 4, 5])
x_updated = x.at[2:4].set(0)  # Set indices 2:4 to 0
print(f"Updated: {x_updated}")

### 4.5 Random Number Generation

In [None]:
print("Pitfall 5: Random Number Generation")
print("=" * 60)

print("\nPyTorch (global state):")
print("""
torch.manual_seed(42)
x = torch.randn(2, 3)
y = torch.randn(2, 3)  # Different from x
""")

print("\nJAX (explicit PRNG key - more complex):")
print("""
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (2, 3))
""")

print("\n✓ BrainState (simplified):")
bst.random.seed(42)
x = bst.random.randn(2, 3)
y = bst.random.randn(2, 3)
print(f"x:\n{x}")
print(f"y:\n{y}")
print("\nBrainState manages PRNG keys automatically!")

### 4.6 Gradient Computation

In [None]:
print("Pitfall 6: Gradient Computation")
print("=" * 60)

print("\nPyTorch (autograd with computational graph):")
print("""
x = torch.randn(10, requires_grad=True)
y = (x ** 2).sum()
y.backward()
grad = x.grad  # Access gradients
""")

print("\nBrainState (functional gradient):")

# Define a simple function
def f(x):
    return jnp.sum(x ** 2)

# Compute gradient
x = jnp.array([1.0, 2.0, 3.0])
grad_f = jax.grad(f)
gradient = grad_f(x)

print(f"x = {x}")
print(f"f(x) = sum(x^2) = {f(x)}")
print(f"grad f(x) = 2*x = {gradient}")

# For models, use bst.transform.grad
print("\n✓ For models with parameters:")
print("""
params = model.states(bst.ParamState)
loss, grads = bst.transform.grad(
    loss_fn, 
    grad_states=params,
    return_value=True
)(x, y)
""")

## 5. Best Practices for BrainState

### 5.1 Use JIT Compilation for Performance

In [None]:
import time

print("Best Practice 1: JIT Compilation")
print("=" * 60)

# Create a simple model
model = BrainStateMLP(784, 128, 10)
x = bst.random.randn(100, 784)

# Without JIT
start = time.time()
for _ in range(100):
    _ = model(x)
time_no_jit = time.time() - start

# With JIT
@bst.transform.jit
def predict_jit(x):
    return model(x)

# Warmup
_ = predict_jit(x)

start = time.time()
for _ in range(100):
    _ = predict_jit(x)
time_jit = time.time() - start

print(f"Without JIT: {time_no_jit:.4f}s")
print(f"With JIT: {time_jit:.4f}s")
print(f"Speedup: {time_no_jit / time_jit:.2f}x")

### 5.2 Organize State Management

In [None]:
print("\nBest Practice 2: Organize State Management")
print("=" * 60)

class WellOrganizedModel(bst.graph.Node):
    """Example of good state management."""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        # Parameters (trainable)
        self.fc1 = bst.nn.Linear(input_dim, hidden_dim)
        self.fc2 = bst.nn.Linear(hidden_dim, output_dim)
        
        # Short-term state (reset between sequences)
        self.hidden = bst.ShortTermState(jnp.zeros(hidden_dim))
        
        # Long-term state (accumulated statistics)
        self.num_calls = bst.LongTermState(jnp.array(0))
    
    def __call__(self, x):
        # Update statistics
        self.num_calls.value = self.num_calls.value + 1
        
        # Use hidden state
        x = self.fc1(x) + self.hidden.value
        x = jax.nn.relu(x)
        
        # Update hidden state
        self.hidden.value = x
        
        # Output
        return self.fc2(x)
    
    def reset(self):
        """Reset short-term state."""
        self.hidden.value = jnp.zeros_like(self.hidden.value)

# Usage
model = WellOrganizedModel(10, 20, 5)
x = bst.random.randn(1, 10)

for i in range(3):
    y = model(x)
    print(f"Call {i+1}: num_calls={model.num_calls.value}")

print("\nState types:")
print(f"  ParamState: {len(model.states(bst.ParamState))} groups")
print(f"  ShortTermState: {len(model.states(bst.ShortTermState))} states")
print(f"  LongTermState: {len(model.states(bst.LongTermState))} states")

### 5.3 Use Vectorization (vmap)

In [None]:
print("Best Practice 3: Vectorization with vmap")
print("=" * 60)

# Function that processes single sample
def process_single(x):
    return jnp.sum(x ** 2)

# Process batch with loop (slow)
batch = bst.random.randn(100, 10)

start = time.time()
results_loop = jnp.array([process_single(x) for x in batch])
time_loop = time.time() - start

# Process batch with vmap (fast)
process_batch = jax.vmap(process_single)

start = time.time()
results_vmap = process_batch(batch)
time_vmap = time.time() - start

print(f"Loop: {time_loop:.6f}s")
print(f"vmap: {time_vmap:.6f}s")
print(f"Speedup: {time_loop / time_vmap:.2f}x")
print(f"Results match: {jnp.allclose(results_loop, results_vmap)}")

### 5.4 Debugging Tips

In [None]:
print("Best Practice 4: Debugging")
print("=" * 60)

# Enable debugging mode
print("\n1. Check for NaN/Inf:")
x = jnp.array([1.0, 2.0, jnp.nan, 4.0])
print(f"  Has NaN: {jnp.any(jnp.isnan(x))}")
print(f"  Has Inf: {jnp.any(jnp.isinf(x))}")

print("\n2. Use jax.debug for printing inside JIT:")
print("""
@bst.transform.jit
def f(x):
    jax.debug.print("x = {}", x)  # Print inside JIT
    return x ** 2
""")

print("\n3. Disable JIT for debugging:")
print("""
with jax.disable_jit():
    output = model(x)  # Run without JIT compilation
""")

print("\n4. Check shapes:")
model = BrainStateMLP(784, 128, 10)
x = bst.random.randn(32, 784)
y = model(x)
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {y.shape}")
print(f"  Expected: (32, 10)")

## 6. Complete Migration Example

Let's convert a complete PyTorch training script to BrainState.

In [None]:
# Complete BrainState training example
class CompleteTrainingExample:
    """Complete training pipeline in BrainState."""
    
    def __init__(self):
        # Model
        self.model = BrainStateCNN(num_classes=10)
        
        # Initialize with dummy input
        dummy = bst.random.randn(1, 1, 28, 28)
        _ = self.model(dummy)
        
        # Trainer
        self.trainer = Trainer(self.model, learning_rate=0.001)
        
        # Metrics
        self.train_losses = []
        self.val_accuracies = []
    
    def evaluate(self, val_data: List[Tuple[np.ndarray, np.ndarray]]) -> float:
        """Evaluate model accuracy."""
        correct = 0
        total = 0
        
        for x_batch, y_batch in val_data:
            x = jnp.array(x_batch)
            y = jnp.array(y_batch)
            
            logits = self.model(x)
            predictions = jnp.argmax(logits, axis=-1)
            labels = jnp.argmax(y, axis=-1)
            
            correct += jnp.sum(predictions == labels)
            total += len(labels)
        
        return float(correct / total)
    
    def train(self, train_data, val_data, num_epochs=5):
        """Train the model."""
        print("Training started...")
        print("=" * 60)
        
        for epoch in range(num_epochs):
            # Train
            train_loss = self.trainer.train_epoch(train_data)
            self.train_losses.append(train_loss)
            
            # Evaluate
            val_acc = self.evaluate(val_data)
            self.val_accuracies.append(val_acc)
            
            print(f"Epoch {epoch+1}/{num_epochs}: "
                  f"Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")
        
        print("=" * 60)
        print("Training completed!")
    
    def plot_metrics(self):
        """Plot training metrics."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        ax1.plot(self.train_losses, 'o-')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss')
        ax1.grid(True, alpha=0.3)
        
        ax2.plot(self.val_accuracies, 's-', color='green')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Validation Accuracy')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Generate dummy dataset
def generate_dummy_data(num_batches=10, batch_size=32):
    """Generate dummy MNIST-like data."""
    return [
        (np.random.randn(batch_size, 1, 28, 28),
         np.eye(10)[np.random.randint(0, 10, batch_size)])
        for _ in range(num_batches)
    ]

# Run training
train_data = generate_dummy_data(num_batches=10, batch_size=32)
val_data = generate_dummy_data(num_batches=5, batch_size=32)

trainer_example = CompleteTrainingExample()
trainer_example.train(train_data, val_data, num_epochs=5)
trainer_example.plot_metrics()

## Summary

### Key Takeaways for Migration:

1. **API Mappings**:
   - `nn.Module` → `bst.graph.Node`
   - `forward()` → `__call__()`
   - `torch.*` → `jax.* / jnp.*`
   - `dim` → `axis`

2. **Functional Programming**:
   - JAX arrays are immutable
   - Use functional transformations (jit, grad, vmap)
   - Explicit state management with State objects

3. **Training**:
   - Use `bst.transform.grad()` for gradients
   - Implement optimizer manually or use optax
   - Access parameters via `model.states(bst.ParamState)`

4. **Best Practices**:
   - Use JIT compilation for performance
   - Vectorize with vmap when possible
   - Organize state types appropriately
   - Debug without JIT first

5. **Common Pitfalls**:
   - No in-place operations
   - Explicit random key management (simplified in BrainState)
   - Use `axis` not `dim`
   - Tuple kernel sizes for Conv layers

## Next Steps

- Practice converting your own PyTorch models
- Explore BrainState-specific features (dynamics, brain modeling)
- Learn BrainPy integration (next tutorial)
- Optimize with JAX transformations (jit, vmap, pmap)

For more information, visit the [BrainState documentation](https://brainstate.readthedocs.io/).