# Tutorial 7: Large-Scale Simulations

**Duration:** ~35 minutes | **Prerequisites:** All Basic Tutorials

## Learning Objectives

By the end of this tutorial, you will:

- ✅ Optimize memory usage for large networks
- ✅ Apply JIT compilation best practices
- ✅ Use batching strategies effectively
- ✅ Leverage GPU/TPU acceleration
- ✅ Profile and optimize performance
- ✅ Implement sparse connectivity

## Overview

Scaling neural simulations to thousands or millions of neurons requires careful optimization. BrainPy leverages JAX for high-performance computing on CPUs, GPUs, and TPUs.

**Key concepts:**
- **JIT compilation**: Compile Python code to optimized machine code
- **Memory efficiency**: Minimize state storage and intermediate computations
- **Sparse operations**: Only compute where connections exist
- **Batching**: Process multiple trials simultaneously
- **Device acceleration**: Utilize GPU/TPU parallelism

Let's learn how to build efficient large-scale simulations!

In [None]:
import brainpy as bp
import brainstate
import brainunit as u
import braintools
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import time

# Set random seed
brainstate.random.seed(42)

# Configure environment
brainstate.environ.set(dt=0.1 * u.ms)

# Check available devices
print("🖥️  Available devices:")
print(f"   {jax.devices()}")
print(f"   Default backend: {jax.default_backend()}")

## Part 1: JIT Compilation Basics

Just-In-Time (JIT) compilation converts Python code to optimized machine code. This can provide 10-100× speedups!

**Benefits of JIT:**
- Eliminates Python interpreter overhead
- Enables compiler optimizations (loop fusion, vectorization)
- Required for GPU/TPU execution

**Rules for JIT:**
- Functions must be pure (no side effects)
- Array shapes must be static (known at compile time)
- Avoid Python loops over dynamic ranges

Let's compare JIT vs non-JIT performance!

In [None]:
# Simple network without JIT
class SimpleNetwork(brainstate.nn.Module):
    def __init__(self, n_neurons=1000):
        super().__init__()
        self.neurons = bp.LIF(
            n_neurons,
            V_rest=-65.0*u.mV, V_th=-50.0*u.mV, tau=10.0*u.ms
        )
    
    def update(self, inp):
        self.neurons(inp)
        return self.neurons.get_spike()

# Test without JIT
net_no_jit = SimpleNetwork(n_neurons=1000)
brainstate.nn.init_all_states(net_no_jit)

# Warmup
inp = brainstate.random.rand(1000) * 2.0 * u.nA
_ = net_no_jit(inp)

# Time execution
n_steps = 1000
start = time.time()
for _ in range(n_steps):
    inp = brainstate.random.rand(1000) * 2.0 * u.nA
    _ = net_no_jit(inp)
time_no_jit = time.time() - start

print(f"⏱️  Without JIT: {time_no_jit:.3f} seconds for {n_steps} steps")
print(f"   ({time_no_jit/n_steps*1000:.2f} ms/step)")

In [None]:
# Same network WITH JIT
net_jit = SimpleNetwork(n_neurons=1000)
brainstate.nn.init_all_states(net_jit)

# Apply JIT compilation
@brainstate.compile.jit
def run_step_jit(net, inp):
    return net(inp)

# Warmup (compilation happens here)
inp = brainstate.random.rand(1000) * 2.0 * u.nA
_ = run_step_jit(net_jit, inp)

# Time execution
start = time.time()
for _ in range(n_steps):
    inp = brainstate.random.rand(1000) * 2.0 * u.nA
    _ = run_step_jit(net_jit, inp)
time_jit = time.time() - start

print(f"⏱️  With JIT: {time_jit:.3f} seconds for {n_steps} steps")
print(f"   ({time_jit/n_steps*1000:.2f} ms/step)")
print(f"\n🚀 Speedup: {time_no_jit/time_jit:.1f}×")

## Part 2: Memory Optimization

Large networks require careful memory management. Key strategies:

1. **Use appropriate data types**: Float32 instead of Float64
2. **Minimize state storage**: Only keep necessary variables
3. **Avoid unnecessary copies**: Use in-place updates where possible
4. **Clear intermediate results**: Don't accumulate large histories

Let's compare memory usage for different approaches.

In [None]:
# Estimate memory usage
def estimate_memory_mb(n_neurons, n_synapses, dtype_bytes=4):
    """Estimate memory requirements.
    
    Args:
        n_neurons: Number of neurons
        n_synapses: Number of synaptic connections
        dtype_bytes: Bytes per element (4 for float32, 8 for float64)
    """
    # Neuron states (V, spike)
    neuron_memory = n_neurons * 2 * dtype_bytes
    
    # Synapse states (g, x for plasticity)
    synapse_memory = n_synapses * 2 * dtype_bytes
    
    # Connection weights
    weight_memory = n_synapses * dtype_bytes
    
    total_bytes = neuron_memory + synapse_memory + weight_memory
    total_mb = total_bytes / (1024 * 1024)
    
    return total_mb

# Compare different network sizes
sizes = [100, 1000, 10000, 100000, 1000000]
connectivity = 0.1

print("📊 Memory Requirements (Float32):")
print("="*60)
print(f"{'Neurons':<12} {'Synapses':<15} {'Memory (MB)':<15} {'Memory (GB)'}")
print("="*60)

for n in sizes:
    n_syn = int(n * n * connectivity)
    mem_mb = estimate_memory_mb(n, n_syn, dtype_bytes=4)
    mem_gb = mem_mb / 1024
    print(f"{n:<12,} {n_syn:<15,} {mem_mb:<15.2f} {mem_gb:<.3f}")

print("\n💡 Optimization tips:")
print("   • Use sparse connectivity to reduce synapse count")
print("   • Use float32 instead of float64 (2× memory savings)")
print("   • Don't store full spike history (record only what you need)")
print("   • Process in batches if memory-constrained")

## Part 3: Sparse Connectivity

Biological networks are sparsely connected (~1-10% connectivity). Using sparse matrices dramatically reduces memory and computation.

**Dense vs Sparse:**
- Dense: Store all $N \times N$ connections (even zeros)
- Sparse: Store only non-zero connections

**Memory savings:**
- 10% connectivity → 90% memory reduction
- 1% connectivity → 99% memory reduction

BrainPy's `EventFixedProb` connection automatically uses sparse representations.

In [None]:
# Compare dense vs sparse connectivity
n_pre = 1000
n_post = 1000
prob = 0.05  # 5% connectivity

# Dense connection matrix
dense_matrix = (np.random.rand(n_post, n_pre) < prob).astype(np.float32)
dense_size_mb = dense_matrix.nbytes / (1024 * 1024)

# Sparse representation (only store indices and values)
indices = np.argwhere(dense_matrix > 0)
values = dense_matrix[dense_matrix > 0]
sparse_size_mb = (indices.nbytes + values.nbytes) / (1024 * 1024)

print("🔍 Dense vs Sparse Comparison:")
print(f"   Network size: {n_pre} → {n_post} neurons")
print(f"   Connectivity: {prob*100}%")
print(f"   Actual connections: {len(values):,}")
print()
print(f"   Dense storage: {dense_size_mb:.2f} MB")
print(f"   Sparse storage: {sparse_size_mb:.2f} MB")
print(f"   Memory savings: {(1 - sparse_size_mb/dense_size_mb)*100:.1f}%")
print(f"   Space ratio: {dense_size_mb/sparse_size_mb:.1f}×")

In [None]:
# Build large sparse network
class LargeSparseNetwork(brainstate.nn.Module):
    """Large network with sparse connectivity."""
    
    def __init__(self, n_exc=4000, n_inh=1000, p_conn=0.02):
        super().__init__()
        
        # Neurons
        self.E = bp.LIF(n_exc, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=15.*u.ms)
        self.I = bp.LIF(n_inh, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
        
        # Sparse projections with EventFixedProb
        self.E2E = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=p_conn, weight=0.6*u.mS),
            syn=bp.Expon.desc(n_exc, tau=5.*u.ms),
            out=bp.COBA.desc(E=0.*u.mV),
            post=self.E
        )
        
        self.E2I = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=p_conn, weight=0.6*u.mS),
            syn=bp.Expon.desc(n_inh, tau=5.*u.ms),
            out=bp.COBA.desc(E=0.*u.mV),
            post=self.I
        )
        
        self.I2E = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=p_conn, weight=6.7*u.mS),
            syn=bp.Expon.desc(n_exc, tau=10.*u.ms),
            out=bp.COBA.desc(E=-80.*u.mV),
            post=self.E
        )
        
        self.I2I = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=p_conn, weight=6.7*u.mS),
            syn=bp.Expon.desc(n_inh, tau=10.*u.ms),
            out=bp.COBA.desc(E=-80.*u.mV),
            post=self.I
        )
    
    def update(self, inp_e, inp_i):
        spk_e = self.E.get_spike()
        spk_i = self.I.get_spike()
        
        self.E2E(spk_e)
        self.E2I(spk_e)
        self.I2E(spk_i)
        self.I2I(spk_i)
        
        self.E(inp_e)
        self.I(inp_i)
        
        return spk_e, spk_i

# Create large network
large_net = LargeSparseNetwork(n_exc=4000, n_inh=1000, p_conn=0.02)
brainstate.nn.init_all_states(large_net)

print("✅ Created large sparse network:")
print(f"   Excitatory neurons: 4,000")
print(f"   Inhibitory neurons: 1,000")
print(f"   Total neurons: 5,000")
print(f"   Connectivity: 2%")
print(f"   Approximate connections: {5000*5000*0.02:,.0f}")
print(f"   Estimated memory: ~20 MB (sparse) vs ~400 MB (dense)")

## Part 4: Batching for Parallelism

Running multiple independent simulations (trials) can be done in parallel using batching. This is especially efficient on GPUs.

**Batching benefits:**
- Run multiple trials simultaneously
- Amortize compilation cost
- Better GPU utilization
- Faster parameter sweeps

**How it works:**
- Add batch dimension: `(batch_size, n_neurons)`
- Operations automatically vectorized
- Each trial independent

In [None]:
# Single trial simulation
def simulate_single_trial(n_steps=1000):
    net = SimpleNetwork(n_neurons=1000)
    brainstate.nn.init_all_states(net)
    
    @brainstate.compile.jit
    def step(net, inp):
        return net(inp)
    
    for _ in range(n_steps):
        inp = brainstate.random.rand(1000) * 2.0 * u.nA
        _ = step(net, inp)

# Batched simulation
def simulate_batched_trials(n_trials=10, n_steps=1000):
    net = SimpleNetwork(n_neurons=1000)
    brainstate.nn.init_all_states(net, batch_size=n_trials)
    
    @brainstate.compile.jit
    def step(net, inp):
        return net(inp)
    
    for _ in range(n_steps):
        inp = brainstate.random.rand(n_trials, 1000) * 2.0 * u.nA
        _ = step(net, inp)

# Compare timing
n_trials = 10

# Sequential trials
start = time.time()
for _ in range(n_trials):
    simulate_single_trial(n_steps=100)
time_sequential = time.time() - start

# Batched trials
start = time.time()
simulate_batched_trials(n_trials=n_trials, n_steps=100)
time_batched = time.time() - start

print(f"⏱️  Sequential (10 trials): {time_sequential:.3f} seconds")
print(f"⏱️  Batched (10 trials): {time_batched:.3f} seconds")
print(f"\n🚀 Batching speedup: {time_sequential/time_batched:.1f}×")
print(f"\n💡 Batching is especially effective on GPUs!")

## Part 5: GPU Acceleration

GPUs excel at parallel operations on large arrays. BrainPy automatically uses GPUs when available via JAX.

**GPU benefits:**
- Massive parallelism (1000s of cores)
- High memory bandwidth
- Fast matrix operations
- 10-100× speedup for large networks

**Best practices:**
- Use large batch sizes
- Minimize CPU-GPU data transfer
- Keep data on GPU between operations
- Use JIT compilation

In [None]:
# Check if GPU is available
try:
    gpu_device = jax.devices('gpu')[0]
    has_gpu = True
    print("✅ GPU detected:", gpu_device)
except:
    has_gpu = False
    print("ℹ️  No GPU detected, using CPU")

if has_gpu:
    # Compare CPU vs GPU for large operation
    n = 10000
    
    # CPU
    with jax.default_device(jax.devices('cpu')[0]):
        x = jax.random.normal(jax.random.PRNGKey(0), (n, n))
        
        start = time.time()
        y = jnp.dot(x, x)
        y.block_until_ready()  # Wait for computation
        time_cpu = time.time() - start
    
    # GPU
    with jax.default_device(gpu_device):
        x = jax.random.normal(jax.random.PRNGKey(0), (n, n))
        
        start = time.time()
        y = jnp.dot(x, x)
        y.block_until_ready()
        time_gpu = time.time() - start
    
    print(f"\n🖥️  CPU time: {time_cpu:.4f} seconds")
    print(f"🎮 GPU time: {time_gpu:.4f} seconds")
    print(f"🚀 GPU speedup: {time_cpu/time_gpu:.1f}×")
else:
    print("\n💡 To use GPU:")
    print("   1. Install JAX with GPU support")
    print("   2. Install CUDA drivers")
    print("   3. BrainPy will automatically use GPU")

## Part 6: Performance Profiling

To optimize performance, you need to identify bottlenecks. Use profiling to find where time is spent.

**Profiling strategies:**
1. **Time individual operations**: Find slow components
2. **Use JAX profiler**: Detailed GPU/TPU profiling
3. **Monitor memory**: Detect memory leaks
4. **Check compilation**: Ensure JIT is working

In [None]:
# Simple profiling example
class ProfilingNetwork(brainstate.nn.Module):
    def __init__(self, n_neurons=5000):
        super().__init__()
        self.lif = bp.LIF(n_neurons, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
        self.proj = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_neurons, n_neurons, prob=0.01, weight=0.5*u.mS),
            syn=bp.Expon.desc(n_neurons, tau=5.*u.ms),
            out=bp.CUBA.desc(),
            post=self.lif
        )
    
    def update(self, inp):
        spk = self.lif.get_spike()
        self.proj(spk)
        self.lif(inp)
        return spk

# Profile simulation
net = ProfilingNetwork(n_neurons=5000)
brainstate.nn.init_all_states(net)

@brainstate.compile.jit
def run_step(net, inp):
    return net(inp)

# Warmup
inp = brainstate.random.rand(5000) * 2.0 * u.nA
_ = run_step(net, inp)

# Profile multiple steps
n_steps = 100
step_times = []

for _ in range(n_steps):
    inp = brainstate.random.rand(5000) * 2.0 * u.nA
    
    start = time.time()
    _ = run_step(net, inp)
    step_times.append(time.time() - start)

step_times = np.array(step_times) * 1000  # Convert to ms

# Plot timing distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time series
axes[0].plot(step_times, 'b-', linewidth=1, alpha=0.7)
axes[0].axhline(np.mean(step_times), color='r', linestyle='--', 
                label=f'Mean: {np.mean(step_times):.2f} ms')
axes[0].set_xlabel('Step', fontsize=12)
axes[0].set_ylabel('Time (ms)', fontsize=12)
axes[0].set_title('Step-by-Step Timing', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Histogram
axes[1].hist(step_times, bins=30, color='blue', alpha=0.7, edgecolor='black')
axes[1].axvline(np.mean(step_times), color='r', linestyle='--', linewidth=2,
               label=f'Mean: {np.mean(step_times):.2f} ms')
axes[1].set_xlabel('Time (ms)', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Timing Distribution', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"📊 Performance Statistics:")
print(f"   Mean time/step: {np.mean(step_times):.2f} ms")
print(f"   Std deviation: {np.std(step_times):.2f} ms")
print(f"   Min time: {np.min(step_times):.2f} ms")
print(f"   Max time: {np.max(step_times):.2f} ms")
print(f"   Throughput: {1000/np.mean(step_times):.1f} steps/second")

## Part 7: Optimization Checklist

Here's a comprehensive checklist for optimizing large-scale simulations.

### Before Optimization
1. **Profile first**: Identify actual bottlenecks
2. **Set target**: Define performance goals
3. **Baseline**: Measure current performance

### Code Optimizations
- ✅ Use JIT compilation (`@brainstate.compile.jit`)
- ✅ Use sparse connectivity (`EventFixedProb`)
- ✅ Use float32 instead of float64
- ✅ Batch multiple trials together
- ✅ Avoid Python loops (use `for_loop` or `scan`)
- ✅ Minimize state storage
- ✅ Use appropriate time steps (larger = faster)

### Hardware Optimizations
- ✅ Use GPU/TPU when available
- ✅ Increase batch size for better GPU utilization
- ✅ Monitor GPU memory usage
- ✅ Keep data on accelerator (avoid CPU-GPU transfers)

### Algorithm Optimizations
- ✅ Simplify neuron models if possible
- ✅ Use event-driven dynamics where appropriate
- ✅ Reduce synaptic computations (sparse updates)
- ✅ Cache frequently computed values

In [None]:
# Demonstrate optimization impact
def benchmark_configurations():
    """Benchmark different optimization strategies."""
    
    n_neurons = 2000
    n_steps = 100
    results = {}
    
    # 1. Baseline (no optimizations)
    print("Testing: Baseline (no optimizations)...")
    net1 = SimpleNetwork(n_neurons)
    brainstate.nn.init_all_states(net1)
    
    start = time.time()
    for _ in range(n_steps):
        inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA
        _ = net1(inp)
    results['Baseline'] = time.time() - start
    
    # 2. With JIT
    print("Testing: With JIT...")
    net2 = SimpleNetwork(n_neurons)
    brainstate.nn.init_all_states(net2)
    
    @brainstate.compile.jit
    def step_jit(net, inp):
        return net(inp)
    
    # Warmup
    inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA
    _ = step_jit(net2, inp)
    
    start = time.time()
    for _ in range(n_steps):
        inp = brainstate.random.rand(n_neurons) * 2.0 * u.nA
        _ = step_jit(net2, inp)
    results['JIT'] = time.time() - start
    
    # 3. With JIT + Batching
    print("Testing: JIT + Batching...")
    batch_size = 10
    net3 = SimpleNetwork(n_neurons)
    brainstate.nn.init_all_states(net3, batch_size=batch_size)
    
    # Warmup
    inp = brainstate.random.rand(batch_size, n_neurons) * 2.0 * u.nA
    _ = step_jit(net3, inp)
    
    start = time.time()
    for _ in range(n_steps):
        inp = brainstate.random.rand(batch_size, n_neurons) * 2.0 * u.nA
        _ = step_jit(net3, inp)
    results['JIT+Batch'] = time.time() - start
    
    return results

# Run benchmark
results = benchmark_configurations()

# Visualize results
fig, ax = plt.subplots(figsize=(10, 6))

configs = list(results.keys())
times = list(results.values())
speedups = [times[0] / t for t in times]

bars = ax.bar(configs, times, color=['red', 'orange', 'green'], alpha=0.7)

# Add speedup labels
for i, (bar, speedup) in enumerate(zip(bars, speedups)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
           f'{speedup:.1f}× faster\n{times[i]:.2f}s',
           ha='center', va='bottom', fontsize=11, fontweight='bold')

ax.set_ylabel('Time (seconds)', fontsize=12)
ax.set_title('Optimization Impact (2000 neurons, 100 steps)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n📊 Optimization Results:")
for config, t in results.items():
    speedup = times[0] / t
    print(f"   {config:15s}: {t:.3f}s  ({speedup:.1f}× faster)")

## Part 8: Complete Large-Scale Example

Let's put it all together with a fully optimized large-scale simulation.

In [None]:
# Optimized large-scale network
class OptimizedLargeNetwork(brainstate.nn.Module):
    """Fully optimized large-scale E-I network."""
    
    def __init__(self, n_exc=8000, n_inh=2000, p_conn=0.02):
        super().__init__()
        
        self.n_exc = n_exc
        self.n_inh = n_inh
        
        # LIF neurons (using default float32)
        self.E = bp.LIF(n_exc, V_rest=-65.*u.mV, V_th=-50.*u.mV, 
                       V_reset=-65.*u.mV, tau=15.*u.ms)
        self.I = bp.LIF(n_inh, V_rest=-65.*u.mV, V_th=-50.*u.mV,
                       V_reset=-65.*u.mV, tau=10.*u.ms)
        
        # Sparse connectivity
        self.E2E = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_exc, prob=p_conn, weight=0.05*u.mS),
            syn=bp.Expon.desc(n_exc, tau=5.*u.ms),
            out=bp.CUBA.desc(),
            post=self.E
        )
        
        self.E2I = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_inh, prob=p_conn, weight=0.05*u.mS),
            syn=bp.Expon.desc(n_inh, tau=5.*u.ms),
            out=bp.CUBA.desc(),
            post=self.I
        )
        
        self.I2E = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_exc, prob=p_conn, weight=0.4*u.mS),
            syn=bp.Expon.desc(n_exc, tau=10.*u.ms),
            out=bp.CUBA.desc(),
            post=self.E
        )
        
        self.I2I = bp.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_inh, prob=p_conn, weight=0.4*u.mS),
            syn=bp.Expon.desc(n_inh, tau=10.*u.ms),
            out=bp.CUBA.desc(),
            post=self.I
        )
    
    def update(self, inp_e, inp_i):
        # Get spikes
        spk_e = self.E.get_spike()
        spk_i = self.I.get_spike()
        
        # Update projections
        self.E2E(spk_e)
        self.E2I(spk_e)
        self.I2E(spk_i)
        self.I2I(spk_i)
        
        # Update neurons
        self.E(inp_e)
        self.I(inp_i)
        
        return spk_e, spk_i

# Create and simulate
print("Creating large-scale network...")
large_net = OptimizedLargeNetwork(n_exc=8000, n_inh=2000, p_conn=0.02)
brainstate.nn.init_all_states(large_net)

print("\n📊 Network Statistics:")
print(f"   Total neurons: {large_net.n_exc + large_net.n_inh:,}")
print(f"   Excitatory: {large_net.n_exc:,} (80%)")
print(f"   Inhibitory: {large_net.n_inh:,} (20%)")
print(f"   Connectivity: 2%")
print(f"   Estimated connections: {10000*10000*0.02:,.0f}")
print(f"   Estimated memory: ~50 MB")

# JIT-compiled simulation
@brainstate.compile.jit
def simulate_step(net, inp_e, inp_i):
    return net(inp_e, inp_i)

# Warmup
print("\nCompiling (this takes a moment)...")
inp_e = brainstate.random.rand(large_net.n_exc) * 1.0 * u.nA
inp_i = brainstate.random.rand(large_net.n_inh) * 1.0 * u.nA
_ = simulate_step(large_net, inp_e, inp_i)
print("✅ Compilation complete!")

# Run simulation
print("\nRunning simulation...")
n_steps = 500
spike_history_e = []
spike_history_i = []

start = time.time()
for i in range(n_steps):
    inp_e = brainstate.random.rand(large_net.n_exc) * 1.0 * u.nA
    inp_i = brainstate.random.rand(large_net.n_inh) * 1.0 * u.nA
    spk_e, spk_i = simulate_step(large_net, inp_e, inp_i)
    
    # Downsample recording (save memory)
    if i % 5 == 0:
        spike_history_e.append(spk_e)
        spike_history_i.append(spk_i)

sim_time = time.time() - start

print(f"\n⏱️  Simulation complete:")
print(f"   Real time: {sim_time:.2f} seconds")
print(f"   Simulated time: {n_steps * 0.1} ms")
print(f"   Speedup: {(n_steps * 0.1 / 1000) / sim_time:.1f}× real-time")
print(f"   Throughput: {n_steps / sim_time:.1f} steps/second")

# Visualize downsampled activity
spike_history_e = jnp.array(spike_history_e)
spike_history_i = jnp.array(spike_history_i)

fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

# Excitatory raster (subsample neurons for visibility)
n_show = 500
times_ms = np.arange(len(spike_history_e)) * 5 * 0.1  # Downsampled times

for neuron_idx in range(min(n_show, large_net.n_exc)):
    spike_times = times_ms[spike_history_e[:, neuron_idx] > 0]
    axes[0].scatter(spike_times, [neuron_idx] * len(spike_times),
                   s=0.5, c='blue', alpha=0.5)

axes[0].set_ylabel('Excitatory Neuron', fontsize=12)
axes[0].set_title(f'Large-Scale Network Activity ({large_net.n_exc + large_net.n_inh:,} neurons)', 
                 fontsize=14, fontweight='bold')
axes[0].set_ylim(0, n_show)

# Inhibitory raster
for neuron_idx in range(large_net.n_inh):
    spike_times = times_ms[spike_history_i[:, neuron_idx] > 0]
    axes[1].scatter(spike_times, [neuron_idx] * len(spike_times),
                   s=0.5, c='red', alpha=0.5)

axes[1].set_xlabel('Time (ms)', fontsize=12)
axes[1].set_ylabel('Inhibitory Neuron', fontsize=12)
axes[1].set_title('Inhibitory Population', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n✅ Successfully simulated 10,000 neuron network!")

## Summary

In this tutorial, you learned:

✅ **JIT compilation**
   - Use `@brainstate.compile.jit` for 10-100× speedup
   - Functions must be pure and have static shapes
   - Essential for large-scale simulations

✅ **Memory optimization**
   - Use float32 instead of float64 (2× savings)
   - Minimize state storage
   - Don't accumulate full histories

✅ **Sparse connectivity**
   - Use `EventFixedProb` for automatic sparse operations
   - 90-99% memory reduction for biological connectivity
   - Faster computation (skip zero connections)

✅ **Batching**
   - Run multiple trials simultaneously
   - Better hardware utilization
   - Faster parameter sweeps

✅ **GPU/TPU acceleration**
   - Automatic via JAX when available
   - 10-100× speedup for large networks
   - Keep data on device

✅ **Performance profiling**
   - Identify bottlenecks before optimizing
   - Monitor memory usage
   - Track throughput metrics

**Optimization workflow:**

```python
# 1. Create network with sparse connectivity
net = OptimizedNetwork(
    n_neurons=10000,
    connectivity=0.02  # Sparse!
)

# 2. Initialize with batching
brainstate.nn.init_all_states(net, batch_size=10)

# 3. JIT compile simulation loop
@brainstate.compile.jit
def simulate_step(net, inp):
    return net(inp)

# 4. Run on GPU (automatic if available)
for i in range(n_steps):
    inp = get_input()
    output = simulate_step(net, inp)
```

**Scale achieved:**
- ✅ 10,000 neurons: Easy on CPU
- ✅ 100,000 neurons: Needs GPU
- ✅ 1,000,000+ neurons: Multi-GPU or TPU

**Next steps:**
- Try your own large-scale models
- Experiment with different connectivity patterns
- Profile and optimize your specific use case
- Use specialized tutorials for specific applications
- Explore multi-GPU scaling (advanced)

**References:**
- JAX documentation: https://jax.readthedocs.io/
- BrainPy optimization guide: https://brainpy.readthedocs.io/
- Neuromorphic computing benchmarks
- Large-scale brain simulation papers (Spaun, Blue Brain Project)

## Exercises

Test your understanding:

### Exercise 1: JIT Compilation
Take a non-JIT network and apply JIT compilation. Measure the speedup. What happens if you violate JIT rules (e.g., use Python loops)?

### Exercise 2: Memory Analysis
Estimate memory requirements for a 100,000 neuron network with 1% connectivity. Will it fit in 16GB RAM?

### Exercise 3: Sparse vs Dense
Implement the same network with dense and sparse connectivity. Compare memory usage and runtime.

### Exercise 4: Batching Strategy
Run 100 independent trials. Compare: (a) sequential, (b) batched 10×10, (c) batched 100×1. Which is fastest?

### Exercise 5: Profiling
Profile a large network and identify the slowest operation. Optimize it and measure improvement.

**Bonus Challenge:** Scale up to the largest network your hardware can handle. How many neurons can you simulate in real-time?