# Parallelization with ``pmap``

In this tutorial, we'll explore vectorization in BrainState using `vmap` (vectorized map) and `pmap` (parallel map) to write efficient, scalable code.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Understand the concept of automatic vectorization
- Use `vmap` to vectorize functions automatically
- Specify which axes to map over
- Handle batched operations efficiently
- Use `pmap` for multi-device parallelism
- Work with `StatefulMapping` for stateful transformations
- Apply vectorization to neural networks

## Why Vectorization?

Instead of writing explicit loops, vectorization allows you to:
- Write code for single examples that automatically works for batches
- Achieve better performance through parallelization
- Keep code clean and readable
- Leverage hardware acceleration (GPU/TPU)

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import time

# Set random seed
bst.random.seed(42)

## 1. The Problem: Manual Batching

Let's start by seeing why vectorization is useful.

In [None]:
# Function that works on a single vector
def compute_norm_squared(x):
    """Compute squared L2 norm of a vector."""
    return jnp.sum(x ** 2)

# Single example
single_x = jnp.array([1.0, 2.0, 3.0])
result = compute_norm_squared(single_x)
print(f"Single example: {single_x}")
print(f"Result: {result}")

# Batch of examples - manual loop (slow!)
batch_x = jnp.array([[1.0, 2.0, 3.0],
                     [4.0, 5.0, 6.0],
                     [7.0, 8.0, 9.0]])

# Bad approach: Python loop
results_loop = []
for i in range(len(batch_x)):
    results_loop.append(compute_norm_squared(batch_x[i]))
results_loop = jnp.array(results_loop)

print(f"\nBatch results (manual loop): {results_loop}")
print("Problem: Python loops are slow and don't parallelize well!")

## 2. Solution: vmap (Vectorized Map)

`vmap` automatically transforms a function to work on batches.

In [None]:
# Vectorize the function
compute_norm_squared_batched = bst.transform.vmap(compute_norm_squared)

# Now it works on batches automatically!
results_vmap = compute_norm_squared_batched(batch_x)
print(f"Batch results (vmap): {results_vmap}")
print(f"Results match: {jnp.allclose(results_loop, results_vmap)}")

# Performance comparison
large_batch = bst.random.randn(1000, 100)

# Time manual loop
start = time.time()
_ = jnp.array([compute_norm_squared(x) for x in large_batch])
loop_time = time.time() - start

# Time vmap
vmapped_fn = bst.transform.vmap(compute_norm_squared)
vmapped_fn = jax.jit(vmapped_fn)  # JIT for fair comparison
_ = vmapped_fn(large_batch)  # Warmup
start = time.time()
_ = vmapped_fn(large_batch)
vmap_time = time.time() - start

print(f"\nPerformance comparison (1000 samples):")
print(f"Manual loop: {loop_time*1000:.2f} ms")
print(f"vmap (JIT):  {vmap_time*1000:.2f} ms")
print(f"Speedup: {loop_time/vmap_time:.1f}x")

## 3. Controlling the Mapped Axis

Use `in_axes` and `out_axes` to specify which axes to map over.

In [None]:
# Function that takes two arguments
def weighted_sum(x, weight):
    """Compute weighted sum: sum(x * weight)"""
    return jnp.sum(x * weight)

# Batch of vectors
batch_x = jnp.array([[1.0, 2.0],
                     [3.0, 4.0],
                     [5.0, 6.0]])  # Shape: (3, 2)

# Single weight vector (shared across batch)
weight = jnp.array([0.5, 1.5])  # Shape: (2,)

# Map over first argument only (in_axes=(0, None))
batched_weighted_sum = bst.transform.vmap(weighted_sum, in_axes=(0, None))
results = batched_weighted_sum(batch_x, weight)

print("Batch:", batch_x)
print("Weight:", weight)
print("Results:", results)
print("Expected:", jnp.array([1.0*0.5 + 2.0*1.5,
                               3.0*0.5 + 4.0*1.5,
                               5.0*0.5 + 6.0*1.5]))

### Different Axes Example

In [None]:
# Data with batch dimension in different positions
x_batch_first = jnp.arange(12).reshape(3, 4)  # (batch=3, features=4)
x_batch_last = jnp.arange(12).reshape(4, 3)   # (features=4, batch=3)

def sum_features(x):
    """Sum all features for one sample."""
    return jnp.sum(x)

# Map over axis 0 (default)
vmap_axis0 = bst.transform.vmap(sum_features, in_axes=0)
result_axis0 = vmap_axis0(x_batch_first)

# Map over axis 1
vmap_axis1 = bst.transform.vmap(sum_features, in_axes=1)
result_axis1 = vmap_axis1(x_batch_last)

print("Batch-first shape:", x_batch_first.shape)
print("Results (axis 0):", result_axis0)
print("\nBatch-last shape:", x_batch_last.shape)
print("Results (axis 1):", result_axis1)
print("\nResults match:", jnp.allclose(result_axis0, result_axis1))

## 4. Nested vmap

You can nest `vmap` calls for multi-dimensional operations.

In [None]:
# Function for single pair of vectors
def dot_product(x, y):
    return jnp.dot(x, y)

# Compute all pairwise dot products
vectors_a = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])  # 3 vectors
vectors_b = jnp.array([[1.0, 1.0], [1.0, -1.0]])             # 2 vectors

# First vmap over vectors_a, then over vectors_b
vmap_outer = bst.transform.vmap(lambda a: bst.transform.vmap(lambda b: dot_product(a, b))(vectors_b))
pairwise_dots = vmap_outer(vectors_a)

print("Vectors A (3x2):")
print(vectors_a)
print("\nVectors B (2x2):")
print(vectors_b)
print("\nPairwise dot products (3x2):")
print(pairwise_dots)
print("\nInterpretation: pairwise_dots[i,j] = dot(vectors_a[i], vectors_b[j])")

### Simpler Syntax for Nested vmap

In [None]:
# Alternative: use in_axes with None to broadcast
# vmap over first arg, then vmap over second arg
compute_all_dots = bst.transform.vmap(
    bst.transform.vmap(dot_product, in_axes=(None, 0)),
    in_axes=(0, None)
)

pairwise_dots_v2 = compute_all_dots(vectors_a, vectors_b)
print("Same result:", jnp.allclose(pairwise_dots, pairwise_dots_v2))
print(pairwise_dots_v2)

## 5. vmap with Neural Networks

Let's see how `vmap` works with stateful modules in BrainState.

In [None]:
# Simple MLP
class MLP(bst.graph.Node):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = bst.nn.Linear(input_dim, hidden_dim)
        self.layer2 = bst.nn.Linear(hidden_dim, output_dim)
    
    def __call__(self, x):
        x = self.layer1(x)
        x = jnp.tanh(x)
        x = self.layer2(x)
        return x

model = MLP(input_dim=4, hidden_dim=8, output_dim=2)

# Single input
single_input = bst.random.randn(4)
single_output = model(single_input)
print(f"Single input shape: {single_input.shape}")
print(f"Single output shape: {single_output.shape}")

# Batch input - the model already handles batches!
# BrainState layers are designed to work with batched inputs
batch_input = bst.random.randn(32, 4)
batch_output = model(batch_input)
print(f"\nBatch input shape: {batch_input.shape}")
print(f"Batch output shape: {batch_output.shape}")

### Using vmap for Per-Example Gradients

In [None]:
# Compute per-example gradients (useful for privacy, adversarial training, etc.)
def loss_single(x, y):
    """Loss for a single example."""
    pred = model(x)
    return jnp.sum((pred - y) ** 2)

# Generate batch data
x_batch = bst.random.randn(4, 4)  # 4 samples
y_batch = bst.random.randn(4, 2)

# Get per-example gradients
def compute_per_example_grads(x, y):
    # Use vector_grad which is optimized for this use case
    grad_fn = bst.transform.vector_grad(
        lambda xs, ys: jnp.array([loss_single(xs[i], ys[i]) for i in range(len(xs))])
    )
    return grad_fn(model.states(bst.ParamState), x, y)

# Note: This is a simplified example. In practice, you'd use more sophisticated approaches
print("Per-example gradients allow fine-grained analysis of training dynamics")
print("Useful for:")
print("  - Differential privacy")
print("  - Identifying influential examples")
print("  - Adversarial robustness")

## 6. StatefulMapping for Stateful Transformations

When you need to maintain state across mapped computations, use `StatefulMapping`.

In [None]:
# RNN cell that maintains hidden state
class RNNCell(bst.graph.Node):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.Wxh = bst.ParamState(bst.random.randn(input_size, hidden_size) * 0.1)
        self.Whh = bst.ParamState(bst.random.randn(hidden_size, hidden_size) * 0.1)
        self.bh = bst.ParamState(jnp.zeros(hidden_size))
        self.h = bst.ShortTermState(jnp.zeros(hidden_size))
    
    def __call__(self, x):
        # Update hidden state
        h_new = jnp.tanh(
            x @ self.Wxh.value + 
            self.h.value @ self.Whh.value + 
            self.bh.value
        )
        self.h.value = h_new
        return h_new
    
    def reset(self):
        self.h.value = jnp.zeros(self.hidden_size)

# Create RNN cell
rnn_cell = RNNCell(input_size=3, hidden_size=5)

# Process a sequence
sequence = bst.random.randn(10, 3)  # 10 time steps, 3 features

rnn_cell.reset()
outputs = []
for t in range(len(sequence)):
    output = rnn_cell(sequence[t])
    outputs.append(output)

outputs = jnp.stack(outputs)
print(f"Sequence length: {len(sequence)}")
print(f"Outputs shape: {outputs.shape}")
print(f"Hidden state evolves over time, maintained in rnn_cell.h")

## 7. Parallel Map (pmap) for Multi-Device Execution

`pmap` distributes computation across multiple devices (GPUs/TPUs).

In [None]:
# Check available devices
devices = jax.devices()
print(f"Available devices: {len(devices)}")
for i, device in enumerate(devices):
    print(f"  Device {i}: {device}")

# Example: parallel computation across devices
def expensive_computation(x):
    """Some expensive operation."""
    # Matrix multiplication chain
    result = x
    for _ in range(10):
        result = result @ result.T @ result
    return jnp.mean(result)

# Create data with first dimension = number of devices
n_devices = len(devices)
data = bst.random.randn(n_devices, 50, 50)

if n_devices > 1:
    # Parallelize across devices
    pmap_fn = jax.pmap(expensive_computation)
    results_parallel = pmap_fn(data)
    print(f"\nParallel results shape: {results_parallel.shape}")
    print(f"Results computed on {n_devices} devices simultaneously")
else:
    print("\nOnly 1 device available. pmap would work like vmap.")
    print("On multi-GPU/TPU systems, pmap provides true parallelism.")

# vmap for comparison (sequential)
vmap_fn = bst.transform.vmap(expensive_computation)
results_sequential = vmap_fn(data)
print(f"\nSequential (vmap) also works: {results_sequential.shape}")

## 8. Practical Example: Ensemble Predictions

Use `vmap` to efficiently run multiple models in parallel.

In [None]:
# Create an ensemble of models with different initializations
class SimpleClassifier(bst.graph.Node):
    def __init__(self, seed):
        super().__init__()
        bst.random.seed(seed)
        self.layer1 = bst.nn.Linear(2, 4)
        self.layer2 = bst.nn.Linear(4, 1)
    
    def __call__(self, x):
        x = jnp.tanh(self.layer1(x))
        x = self.layer2(x)
        return jnp.squeeze(x)

# Create ensemble
ensemble_size = 5
models = [SimpleClassifier(seed=i) for i in range(ensemble_size)]

# Test data
test_x = bst.random.randn(100, 2)

# Get predictions from all models
predictions = jnp.array([model(test_x) for model in models])

# Ensemble prediction: average
ensemble_pred = jnp.mean(predictions, axis=0)
ensemble_std = jnp.std(predictions, axis=0)

print(f"Predictions shape: {predictions.shape}")  # (5, 100)
print(f"Ensemble prediction shape: {ensemble_pred.shape}")  # (100,)

# Visualize uncertainty
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Individual model predictions
for i in range(ensemble_size):
    ax1.scatter(test_x[:, 0], predictions[i], alpha=0.3, s=10, label=f'Model {i+1}')
ax1.set_xlabel('Feature 1')
ax1.set_ylabel('Prediction')
ax1.set_title('Individual Model Predictions')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Ensemble prediction with uncertainty
ax2.scatter(test_x[:, 0], ensemble_pred, c=ensemble_std, cmap='viridis', s=20)
ax2.set_xlabel('Feature 1')
ax2.set_ylabel('Ensemble Prediction')
ax2.set_title('Ensemble Prediction (color = uncertainty)')
plt.colorbar(ax2.collections[0], ax=ax2, label='Std Dev')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Advanced: Monte Carlo Dropout with vmap

In [None]:
# Network with dropout
class BayesianNet(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.layer1 = bst.nn.Linear(5, 10)
        self.dropout1 = bst.nn.Dropout(0.3)
        self.layer2 = bst.nn.Linear(10, 1)
    
    def __call__(self, x):
        x = jnp.tanh(self.layer1(x))
        x = self.dropout1(x)
        x = self.layer2(x)
        return x

model = BayesianNet()
test_input = bst.random.randn(20, 5)

# Multiple forward passes with dropout enabled (Monte Carlo sampling)
n_samples = 50
mc_predictions = []

with bst.environ.context(fit=True):  # Keep dropout enabled
    for _ in range(n_samples):
        pred = model(test_input)
        mc_predictions.append(pred)

mc_predictions = jnp.array(mc_predictions)  # Shape: (50, 20, 1)

# Compute mean and uncertainty
mc_mean = jnp.mean(mc_predictions, axis=0)
mc_std = jnp.std(mc_predictions, axis=0)

print(f"MC predictions shape: {mc_predictions.shape}")
print(f"Mean prediction shape: {mc_mean.shape}")
print(f"Uncertainty (std) range: [{mc_std.min():.3f}, {mc_std.max():.3f}]")
print("\nMonte Carlo dropout provides Bayesian uncertainty estimates")

## 10. Performance Tips and Best Practices

In [None]:
# Tip 1: Combine vmap with jit for best performance
def slow_function(x):
    return jnp.sum(jnp.sin(x) ** 2 + jnp.cos(x) ** 2)

data = bst.random.randn(1000, 100)

# Version 1: Just vmap
vmap_only = bst.transform.vmap(slow_function)
_ = vmap_only(data)  # Warmup
start = time.time()
_ = vmap_only(data)
time_vmap = time.time() - start

# Version 2: vmap + jit
vmap_jit = jax.jit(bst.transform.vmap(slow_function))
_ = vmap_jit(data)  # Warmup (compilation)
start = time.time()
_ = vmap_jit(data)
time_vmap_jit = time.time() - start

print("Performance comparison:")
print(f"vmap only:     {time_vmap*1000:.2f} ms")
print(f"vmap + jit:    {time_vmap_jit*1000:.2f} ms")
print(f"Speedup:       {time_vmap/time_vmap_jit:.1f}x")
print("\nAlways combine vmap with jit for production code!")

### Tip 2: Choose the Right Batch Dimension

In [None]:
# Batch-first (standard): (batch, features)
batch_first = bst.random.randn(1000, 128)

# Feature-first: (features, batch)
feature_first = bst.random.randn(128, 1000)

def process_sample(x):
    return jnp.mean(x ** 2)

# For batch-first (standard)
vmap_batch_first = jax.jit(bst.transform.vmap(process_sample, in_axes=0))

# For feature-first
vmap_feature_first = jax.jit(bst.transform.vmap(process_sample, in_axes=1))

# Both work, but batch-first is more common
_ = vmap_batch_first(batch_first)
_ = vmap_feature_first(feature_first)

print("Recommendation: Use batch-first layout (batch, ...) as the standard")
print("  - More intuitive")
print("  - Matches PyTorch/TensorFlow conventions")
print("  - Better for distributed training")

## Summary

In this tutorial, we covered:

1. **vmap Basics**: Automatic vectorization of functions
2. **Axis Control**: Using `in_axes` and `out_axes` for flexible mapping
3. **Nested vmap**: Handling multi-dimensional operations
4. **Neural Networks**: Batched operations with BrainState modules
5. **Per-Example Gradients**: Fine-grained gradient computation
6. **StatefulMapping**: Maintaining state across mapped operations
7. **pmap**: Multi-device parallelism
8. **Ensemble Methods**: Efficient multi-model predictions
9. **Monte Carlo Dropout**: Uncertainty estimation
10. **Performance Tips**: Combining vmap with JIT
