# Tutorial 5: Training Spiking Neural Networks

**Duration:** ~45 minutes | **Prerequisites:** Basic Tutorials 1-4

## Learning Objectives

By the end of this tutorial, you will:

- ✅ Understand surrogate gradient methods for training SNNs
- ✅ Implement backpropagation through time (BPTT) for SNNs
- ✅ Use appropriate loss functions for spike-based learning
- ✅ Configure optimizers and learning rates
- ✅ Train an SNN classifier on real datasets
- ✅ Evaluate and visualize training progress

## Overview

Training spiking neural networks is challenging because spike generation is a discrete, non-differentiable operation. In this tutorial, we'll learn how to overcome this using **surrogate gradient methods**, which allow us to train SNNs using standard gradient-based optimization.

**Key Concepts:**
- **The gradient problem**: Spike generation has zero gradient almost everywhere
- **Surrogate gradients**: Use smooth approximations during backpropagation
- **BPTT for SNNs**: Unroll network dynamics through time
- **Rate-based losses**: Train on spike rates or membrane potentials
- **Temporal credit assignment**: Learn when to spike

Let's start by understanding why training SNNs is difficult and how we can solve it!

In [None]:
import brainpy as bp
import brainstate
import brainunit as u
import braintools
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
brainstate.random.seed(42)

# Configure environment
brainstate.environ.set(dt=1.0 * u.ms)

## Part 1: The Gradient Problem

Let's visualize why training SNNs is challenging. The spike generation function is a Heaviside step function:

$$
S(V) = \begin{cases}
1 & \text{if } V \geq V_{th} \\
0 & \text{if } V < V_{th}
\end{cases}
$$

The gradient of this function is:

$$
\frac{dS}{dV} = \begin{cases}
\infty & \text{at } V = V_{th} \\
0 & \text{everywhere else}
\end{cases}
$$

This makes gradient-based learning impossible! Let's see this visually.

In [None]:
# Heaviside step function
def heaviside(x, threshold=0.0):
    return (x >= threshold).astype(float)

# Voltage values
V = np.linspace(-2, 2, 1000)
V_th = 0.0

# Spike function and its "gradient"
spikes = heaviside(V, V_th)
# Numerical gradient (will be mostly zeros)
grad_spike = np.gradient(spikes, V)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Spike function
axes[0].plot(V, spikes, 'b-', linewidth=2)
axes[0].axvline(V_th, color='r', linestyle='--', label='Threshold')
axes[0].set_xlabel('Membrane Potential (V)', fontsize=12)
axes[0].set_ylabel('Spike Output', fontsize=12)
axes[0].set_title('Spike Generation Function', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Gradient (problematic!)
axes[1].plot(V, grad_spike, 'r-', linewidth=2)
axes[1].axvline(V_th, color='r', linestyle='--', label='Threshold')
axes[1].set_xlabel('Membrane Potential (V)', fontsize=12)
axes[1].set_ylabel('Gradient dS/dV', fontsize=12)
axes[1].set_title('Gradient (Problematic!)', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(-0.1, 0.6)

plt.tight_layout()
plt.show()

print("❌ Problem: Gradient is zero almost everywhere!")
print("   This prevents gradient descent from working.")

## Part 2: Surrogate Gradient Solution

The solution is **surrogate gradients**: Use the true spike function in the forward pass, but use a smooth approximation during backpropagation.

**Common surrogate gradient functions:**

1. **Sigmoid**: $\sigma'(\beta(V - V_{th}))$
2. **ReLU**: $\max(0, 1 - |V - V_{th}|)$
3. **SuperSpike**: $\frac{1}{(1 + |\beta(V - V_{th})|)^2}$

BrainPy provides these in `braintools.surrogate`. Let's visualize them!

In [None]:
# Create surrogate gradient functions
sigmoid_surrogate = braintools.surrogate.sigmoid(alpha=4.0)
relu_surrogate = braintools.surrogate.relu_grad(alpha=1.0)
superspike_surrogate = braintools.surrogate.slayer_grad(alpha=4.0)

# Voltage range
V_range = np.linspace(-2, 2, 1000)
V_th = 0.0

# Compute surrogate gradients
grad_sigmoid = sigmoid_surrogate(V_range - V_th)
grad_relu = relu_surrogate(V_range - V_th)
grad_superspike = superspike_surrogate(V_range - V_th)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Sigmoid surrogate
axes[0].plot(V_range, grad_sigmoid, 'g-', linewidth=2, label='Sigmoid surrogate')
axes[0].axvline(V_th, color='r', linestyle='--', alpha=0.5)
axes[0].set_xlabel('V - V_th', fontsize=12)
axes[0].set_ylabel('Surrogate Gradient', fontsize=12)
axes[0].set_title('Sigmoid Surrogate', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# ReLU surrogate
axes[1].plot(V_range, grad_relu, 'b-', linewidth=2, label='ReLU surrogate')
axes[1].axvline(V_th, color='r', linestyle='--', alpha=0.5)
axes[1].set_xlabel('V - V_th', fontsize=12)
axes[1].set_ylabel('Surrogate Gradient', fontsize=12)
axes[1].set_title('ReLU Surrogate', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# SuperSpike surrogate
axes[2].plot(V_range, grad_superspike, 'm-', linewidth=2, label='SuperSpike surrogate')
axes[2].axvline(V_th, color='r', linestyle='--', alpha=0.5)
axes[2].set_xlabel('V - V_th', fontsize=12)
axes[2].set_ylabel('Surrogate Gradient', fontsize=12)
axes[2].set_title('SuperSpike Surrogate', fontsize=14, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✅ Solution: Smooth surrogate gradients enable learning!")
print("   Forward pass: Use real spikes")
print("   Backward pass: Use smooth gradient approximation")

## Part 3: Creating a Trainable SNN

Now let's create an SNN classifier. We'll build a simple network:

**Architecture:**
- Input layer: 784 neurons (28×28 image)
- Hidden layer: 128 LIF neurons
- Output layer: 10 LIF neurons (digits 0-9)

**Key for training:**
- Use LIF neurons with surrogate gradient spike functions
- Use `bp.Readout` to convert spikes to logits

In [None]:
class TrainableSNN(brainstate.nn.Module):
    """Simple feedforward SNN for classification."""
    
    def __init__(self, n_input=784, n_hidden=128, n_output=10):
        super().__init__()
        
        # Input to hidden projection
        self.fc1 = brainstate.nn.Linear(n_input, n_hidden, w_init=brainstate.init.KaimingNormal())
        
        # Hidden LIF neurons with surrogate gradient
        self.lif1 = bp.LIF(
            n_hidden,
            V_rest=-65.0 * u.mV,
            V_th=-50.0 * u.mV,
            V_reset=-65.0 * u.mV,
            tau=10.0 * u.ms,
            spike_fun=braintools.surrogate.ReluGrad()  # Surrogate gradient!
        )
        
        # Hidden to output projection
        self.fc2 = brainstate.nn.Linear(n_hidden, n_output, w_init=brainstate.init.KaimingNormal())
        
        # Output LIF neurons with surrogate gradient
        self.lif2 = bp.LIF(
            n_output,
            V_rest=-65.0 * u.mV,
            V_th=-50.0 * u.mV,
            V_reset=-65.0 * u.mV,
            tau=10.0 * u.ms,
            spike_fun=braintools.surrogate.ReluGrad()  # Surrogate gradient!
        )
        
        # Readout layer to convert spikes to logits
        self.readout = bp.Readout(n_output, n_output)
    
    def update(self, x):
        """Forward pass for one time step.
        
        Args:
            x: Input current (batch_size, n_input) with physical units
        
        Returns:
            logits: Output logits (batch_size, n_output)
        """
        # Input to hidden
        current1 = self.fc1(x)
        self.lif1(current1)
        hidden_spikes = self.lif1.get_spike()
        
        # Hidden to output
        current2 = self.fc2(hidden_spikes)
        self.lif2(current2)
        output_spikes = self.lif2.get_spike()
        
        # Convert spikes to logits
        logits = self.readout(output_spikes)
        
        return logits

# Create network
net = TrainableSNN(n_input=784, n_hidden=128, n_output=10)
brainstate.nn.init_all_states(net, batch_size=32)

print("✅ Created trainable SNN with surrogate gradients")
print(f"   Input: 784 neurons")
print(f"   Hidden: 128 LIF neurons")
print(f"   Output: 10 LIF neurons")
print(f"   Total parameters: {sum(p.size for p in net.states(brainstate.ParamState).values())}")

## Part 4: Loss Functions for SNNs

For classification, we typically use **cross-entropy loss** on the output logits. The logits are computed by integrating spikes over time.

**Loss computation:**
1. Run the network for `T` time steps
2. Accumulate output logits over time
3. Compute cross-entropy loss: $L = -\sum_i y_i \log(\text{softmax}(\text{logits}_i))$

Let's implement the training step!

In [None]:
def loss_fn(network, inputs, labels, n_steps=25):
    """Compute loss for SNN classification.
    
    Args:
        network: SNN model
        inputs: Input data (batch_size, n_features)
        labels: True labels (batch_size,)
        n_steps: Number of simulation time steps
    
    Returns:
        loss: Cross-entropy loss
    """
    # Reset network state
    brainstate.nn.init_all_states(network)
    
    # Add physical units to input (convert to current)
    inputs_with_units = inputs * u.nA
    
    # Simulate for n_steps and accumulate output
    def run_step(i):
        return network(inputs_with_units)
    
    # Run simulation and accumulate logits
    logits_sum = brainstate.transform.for_loop(run_step, jnp.arange(n_steps))
    logits_sum = jnp.sum(logits_sum, axis=0)  # Sum over time
    
    # Compute cross-entropy loss
    loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
        logits_sum, labels
    ).mean()
    
    return loss

def accuracy_fn(network, inputs, labels, n_steps=25):
    """Compute accuracy for SNN classification."""
    # Reset network state
    brainstate.nn.init_all_states(network)
    
    # Add physical units
    inputs_with_units = inputs * u.nA
    
    # Simulate and accumulate logits
    def run_step(i):
        return network(inputs_with_units)
    
    logits_sum = brainstate.transform.for_loop(run_step, jnp.arange(n_steps))
    logits_sum = jnp.sum(logits_sum, axis=0)
    
    # Compute accuracy
    predictions = jnp.argmax(logits_sum, axis=1)
    accuracy = jnp.mean(predictions == labels)
    
    return accuracy

print("✅ Defined loss and accuracy functions")
print("   Loss: Cross-entropy on accumulated logits")
print("   Accuracy: Argmax of accumulated logits")

## Part 5: Optimizers and Training Loop

Now we'll set up the optimizer and training loop. BrainPy uses `braintools.optim` which provides standard optimizers like Adam, SGD, etc.

**Training loop:**
1. Get batch of data
2. Compute gradients using `brainstate.transform.grad()`
3. Update parameters using optimizer
4. Track loss and accuracy

We'll use synthetic data for this demo (in practice, you'd use MNIST).

In [None]:
# Create synthetic dataset (in practice, use real data like MNIST)
def create_synthetic_data(n_samples=1000, n_features=784, n_classes=10):
    """Create synthetic classification data."""
    X = np.random.randn(n_samples, n_features).astype(np.float32) * 0.5
    y = np.random.randint(0, n_classes, size=n_samples)
    return X, y

# Generate data
X_train, y_train = create_synthetic_data(n_samples=1000)
X_test, y_test = create_synthetic_data(n_samples=200)

print("✅ Created synthetic dataset")
print(f"   Training: {X_train.shape[0]} samples")
print(f"   Test: {X_test.shape[0]} samples")
print(f"   Features: {X_train.shape[1]}")
print(f"   Classes: {len(np.unique(y_train))}")

In [None]:
# Reset network and create optimizer
net = TrainableSNN(n_input=784, n_hidden=128, n_output=10)
brainstate.nn.init_all_states(net, batch_size=32)

# Create Adam optimizer
optimizer = braintools.optim.Adam(learning_rate=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))

print("✅ Created optimizer")
print(f"   Type: Adam")
print(f"   Learning rate: 1e-3")
print(f"   Trainable parameters: {len(net.states(brainstate.ParamState))}")

In [None]:
# Training loop
n_epochs = 5
batch_size = 32
n_steps = 25  # Simulation steps per sample

train_losses = []
train_accs = []
test_accs = []

print("🚀 Starting training...\n")

for epoch in range(n_epochs):
    # Shuffle training data
    indices = np.random.permutation(len(X_train))
    X_shuffled = X_train[indices]
    y_shuffled = y_train[indices]
    
    epoch_losses = []
    epoch_accs = []
    
    # Mini-batch training
    n_batches = len(X_train) // batch_size
    for i in range(n_batches):
        # Get batch
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        X_batch = X_shuffled[start_idx:end_idx]
        y_batch = y_shuffled[start_idx:end_idx]
        
        # Compute gradients
        grads, loss = brainstate.transform.grad(
            loss_fn,
            net.states(brainstate.ParamState),
            return_value=True
        )(net, X_batch, y_batch, n_steps)
        
        # Update parameters
        optimizer.update(grads)
        
        # Track metrics
        epoch_losses.append(float(loss))
        
        # Compute accuracy every 10 batches
        if i % 10 == 0:
            acc = accuracy_fn(net, X_batch, y_batch, n_steps)
            epoch_accs.append(float(acc))
    
    # Epoch statistics
    avg_loss = np.mean(epoch_losses)
    avg_train_acc = np.mean(epoch_accs) if epoch_accs else 0.0
    
    # Test accuracy
    test_acc = float(accuracy_fn(net, X_test, y_test, n_steps))
    
    train_losses.append(avg_loss)
    train_accs.append(avg_train_acc)
    test_accs.append(test_acc)
    
    print(f"Epoch {epoch+1}/{n_epochs}:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Train Acc: {avg_train_acc:.2%}")
    print(f"  Test Acc: {test_acc:.2%}\n")

print("✅ Training complete!")

## Part 6: Visualizing Training Progress

Let's visualize how the loss and accuracy evolved during training.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = np.arange(1, n_epochs + 1)

# Plot loss
axes[0].plot(epochs_range, train_losses, 'b-o', linewidth=2, markersize=8, label='Training Loss')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(epochs_range, train_accs, 'g-o', linewidth=2, markersize=8, label='Train Accuracy')
axes[1].plot(epochs_range, test_accs, 'r-s', linewidth=2, markersize=8, label='Test Accuracy')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Classification Accuracy', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0, 1)

plt.tight_layout()
plt.show()

print(f"📊 Final Results:")
print(f"   Final train accuracy: {train_accs[-1]:.2%}")
print(f"   Final test accuracy: {test_accs[-1]:.2%}")

## Part 7: Understanding BPTT for SNNs

Let's visualize what happens during backpropagation through time (BPTT). The network processes input over multiple time steps, and gradients flow backward through time.

**BPTT process:**
1. **Forward pass**: Simulate network for T steps, accumulate outputs
2. **Backward pass**: Compute gradients backward through all T steps
3. **Surrogate gradients**: Used at spike generation points

Let's examine the gradient flow!

In [None]:
# Analyze gradient magnitudes during training
def analyze_gradients(network, inputs, labels, n_steps=25):
    """Compute and analyze gradient magnitudes."""
    grads = brainstate.transform.grad(
        loss_fn,
        network.states(brainstate.ParamState)
    )(network, inputs, labels, n_steps)
    
    # Compute gradient norms for each layer
    grad_norms = {}
    for name, grad in grads.items():
        grad_norm = float(jnp.linalg.norm(grad.value.flatten()))
        grad_norms[name] = grad_norm
    
    return grad_norms

# Analyze gradients on a batch
sample_X = X_train[:32]
sample_y = y_train[:32]
grad_norms = analyze_gradients(net, sample_X, sample_y)

# Plot gradient magnitudes
fig, ax = plt.subplots(figsize=(10, 6))

layer_names = list(grad_norms.keys())
grad_values = list(grad_norms.values())

colors = ['blue' if 'fc1' in name else 'green' if 'fc2' in name else 'red' for name in layer_names]

bars = ax.bar(range(len(layer_names)), grad_values, color=colors, alpha=0.7)
ax.set_xticks(range(len(layer_names)))
ax.set_xticklabels(layer_names, rotation=45, ha='right')
ax.set_ylabel('Gradient Norm', fontsize=12)
ax.set_title('Gradient Magnitudes Across Layers', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='blue', alpha=0.7, label='Input Layer'),
    Patch(facecolor='green', alpha=0.7, label='Hidden Layer'),
    Patch(facecolor='red', alpha=0.7, label='Readout Layer')
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

print("📊 Gradient Analysis:")
for name, norm in grad_norms.items():
    print(f"   {name}: {norm:.6f}")
print("\n✅ Surrogate gradients enable backpropagation through spike generation!")

## Part 8: Real-World Example - MNIST Classification

Now let's see how to train on real data. Here's the complete workflow for MNIST (or Fashion-MNIST):

**Steps:**
1. Load and preprocess MNIST data
2. Convert images to rate-coded spike trains (or use pixel intensities as currents)
3. Train SNN classifier
4. Evaluate on test set

Below is a template you can use with real MNIST data.

In [None]:
# Template for MNIST training (requires torchvision or tensorflow)

def load_mnist_data():
    """Load and preprocess MNIST data.
    
    In practice, use:
    from torchvision import datasets, transforms
    
    train_dataset = datasets.MNIST(
        './data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    )
    """
    pass

def train_on_mnist():
    """Complete MNIST training workflow."""
    
    # 1. Load data
    # X_train, y_train, X_test, y_test = load_mnist_data()
    
    # 2. Create network
    net = TrainableSNN(n_input=784, n_hidden=256, n_output=10)
    brainstate.nn.init_all_states(net, batch_size=128)
    
    # 3. Create optimizer
    optimizer = braintools.optim.Adam(learning_rate=1e-3)
    optimizer.register_trainable_weights(net.states(brainstate.ParamState))
    
    # 4. Training loop (epochs, batches, gradient updates)
    # for epoch in range(n_epochs):
    #     for batch in data_loader:
    #         grads, loss = compute_gradients(...)
    #         optimizer.update(grads)
    
    # 5. Evaluation
    # test_acc = evaluate(net, X_test, y_test)
    
    return net

print("📝 MNIST Training Template:")
print("""\n1. Load MNIST: Use torchvision.datasets.MNIST or tensorflow.keras.datasets.mnist
2. Preprocess: Flatten images (28×28 → 784), normalize to [0,1]
3. Convert to currents: Multiply by scaling factor (e.g., 5 nA)
4. Train: Use same loss_fn and training loop as above
5. Expected accuracy: 95-98% on MNIST with proper hyperparameters

Key hyperparameters to tune:
- Learning rate: Try 1e-3, 5e-4, 1e-4
- Hidden size: Try 128, 256, 512
- Simulation steps: Try 25, 50, 100
- Batch size: Try 32, 64, 128
""")

## Part 9: Advanced Training Techniques

Here are some advanced techniques to improve SNN training:

### 1. Learning Rate Scheduling

Reduce learning rate during training for better convergence.

In [None]:
# Example: Exponential decay learning rate schedule
def create_lr_schedule(initial_lr=1e-3, decay_rate=0.95, decay_steps=1000):
    """Create exponential decay learning rate schedule."""
    def lr_schedule(step):
        return initial_lr * (decay_rate ** (step / decay_steps))
    return lr_schedule

# Usage:
# lr_schedule = create_lr_schedule()
# optimizer = braintools.optim.Adam(learning_rate=lr_schedule)

print("✅ Learning rate scheduling helps with convergence")

### 2. Gradient Clipping

Prevent gradient explosion by clipping large gradients.

In [None]:
def clip_gradients(grads, max_norm=1.0):
    """Clip gradients by global norm."""
    # Compute global norm
    global_norm = jnp.sqrt(
        sum(jnp.sum(g.value ** 2) for g in grads.values())
    )
    
    # Clip if necessary
    clip_coef = max_norm / (global_norm + 1e-6)
    clip_coef = jnp.minimum(1.0, clip_coef)
    
    # Apply clipping
    clipped_grads = {}
    for name, grad in grads.items():
        clipped_grads[name] = brainstate.ParamState(grad.value * clip_coef)
    
    return clipped_grads

# Usage in training loop:
# grads = compute_gradients(...)
# grads = clip_gradients(grads, max_norm=1.0)
# optimizer.update(grads)

print("✅ Gradient clipping prevents training instabilities")

### 3. Regularization

Add L2 regularization to prevent overfitting.

In [None]:
def loss_with_regularization(network, inputs, labels, n_steps=25, l2_weight=1e-4):
    """Loss function with L2 regularization."""
    # Standard loss
    ce_loss = loss_fn(network, inputs, labels, n_steps)
    
    # L2 regularization
    l2_loss = 0.0
    for param in network.states(brainstate.ParamState).values():
        l2_loss += jnp.sum(param.value ** 2)
    
    total_loss = ce_loss + l2_weight * l2_loss
    return total_loss

print("✅ L2 regularization improves generalization")

## Summary

In this tutorial, you learned:

✅ **The gradient problem**: Spike generation is non-differentiable

✅ **Surrogate gradients**: Use smooth approximations during backprop
   - Forward: Real spikes
   - Backward: Smooth surrogate

✅ **SNN architecture**: Create trainable networks with LIF neurons

✅ **Loss functions**: Cross-entropy on accumulated spike outputs

✅ **Training loop**: BPTT with gradient descent
   ```python
   grads, loss = brainstate.transform.grad(loss_fn, params)(net, X, y)
   optimizer.update(grads)
   ```

✅ **Advanced techniques**: LR scheduling, gradient clipping, regularization

**Key code pattern:**
```python
# 1. Create network with surrogate gradients
lif = bp.LIF(..., spike_fun=braintools.surrogate.ReluGrad())

# 2. Define loss over time
def loss_fn(net, X, y, n_steps):
    logits = simulate_for_n_steps(net, X, n_steps)
    return cross_entropy(logits, y)

# 3. Compute gradients and update
grads = brainstate.transform.grad(loss_fn, params)(...)
optimizer.update(grads)
```

**Next steps:**
- Try training on real MNIST/Fashion-MNIST
- Experiment with different surrogate functions
- Tune hyperparameters (learning rate, hidden size, simulation steps)
- Add recurrent connections for temporal tasks
- See Tutorial 6 for incorporating synaptic plasticity

**References:**
- Neftci et al. (2019): "Surrogate Gradient Learning in Spiking Neural Networks"
- Zenke & Ganguli (2018): "SuperSpike: Supervised learning in multilayer spiking neural networks"
- Shrestha & Orchard (2018): "SLAYER: Spike Layer Error Reassignment in Time"
- Wu et al. (2018): "Spatio-Temporal Backpropagation for Training High-Performance Spiking Neural Networks"

## Exercises

Test your understanding:

### Exercise 1: Surrogate Function Comparison
Compare training with different surrogate gradient functions (Sigmoid, ReLU, SuperSpike). Which works best?

### Exercise 2: Simulation Steps
How does the number of simulation steps (n_steps) affect accuracy and training time? Plot the trade-off.

### Exercise 3: Network Architecture
Add a second hidden layer. Does deeper architecture improve performance?

### Exercise 4: Learning Rate Tuning
Implement learning rate scheduling and compare convergence with fixed learning rate.

### Exercise 5: Real MNIST
Load real MNIST data and train a classifier. Aim for >95% test accuracy!

**Bonus Challenge:** Implement online learning where the network is trained on streaming data one sample at a time (no batches).