# Tutorial 03 — MNIST with Equilibrium Propagation (EP)

**Equilibrium Propagation** as a remedy for vanishing gradients in long sequences.

## Why EP?

| Problem with BPTT | EP Solution |
|-------------------|-------------|
| Gradients through 112 steps → vanish | No backprop through time |
| Gradient ≈ (W)^112 → exponential decay | State differences encode gradient |
| Non-local, expensive | Local Hebbian-like updates |

## Architecture (unchanged)

```
Input (7) → Hidden (28) → Output (10)
               ↺ recurrent
```

## EP Algorithm

```
FREE PHASE:     Process input → Settle to equilibrium → Record states
NUDGED PHASE:   Push output toward target → Re-settle → Record states  
WEIGHT UPDATE:  ΔW ∝ (state_nudged - state_free) × pre_activation
```

In [None]:
import os
import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import gzip
import urllib.request
import struct
from tqdm import tqdm

torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")

## 1. Prepare Data (112×7)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def load_mnist_112x7():
    """Load MNIST as (N, 112, 7) sequences."""
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    # Reshape to 112 timesteps × 7 features
    train_img = train_img.reshape(-1, 784).reshape(-1, 112, 7)
    test_img = test_img.reshape(-1, 784).reshape(-1, 112, 7)
    
    # Train/val split
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = 6000
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    print(f"Train: {train_img.shape}, Val: {val_img.shape}, Test: {test_img.shape}")
    return (train_img, train_lbl), (val_img, val_lbl), (test_img, test_lbl)

(train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = load_mnist_112x7()

## 2. SOEN-EP Model

SingleDendrite dynamics:
$$\frac{ds}{dt} = \gamma^+ g(\phi) - \gamma^- s$$

where $\phi = W_{in} x + W_{rec} s$ and $g$ is the source function.

In [None]:
class SOEN_EP_Model(nn.Module):
    """
    SOEN model with Equilibrium Propagation training.
    
    Architecture: 7 → 28 → 10
    
    Key difference from standard training:
    - Two-phase forward: free phase + nudged phase
    - Weight updates from state differences (no BPTT)
    """
    
    def __init__(self, input_dim=7, hidden_dim=28, output_dim=10, dt=0.1,
                 gamma_plus=0.1, gamma_minus=0.01):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dt = dt
        self.gamma_plus = gamma_plus
        self.gamma_minus = gamma_minus
        
        # Weight matrices (same structure as baseline)
        self.W_i2h = nn.Parameter(torch.empty(hidden_dim, input_dim))  # 28 × 7
        self.W_h2h = nn.Parameter(torch.empty(hidden_dim, hidden_dim))  # 28 × 28
        self.W_h2o = nn.Parameter(torch.empty(output_dim, hidden_dim))  # 10 × 28
        
        # Bias
        self.bias_h = nn.Parameter(torch.zeros(hidden_dim))
        self.bias_o = nn.Parameter(torch.zeros(output_dim))
        
        self._init_weights()
        
    def _init_weights(self):
        nn.init.uniform_(self.W_i2h, -0.3, 0.3)
        nn.init.normal_(self.W_h2h, 0, 0.1)
        nn.init.normal_(self.W_h2o, 0, 0.2)
        # Remove self-connections in recurrent
        with torch.no_grad():
            self.W_h2h.fill_diagonal_(0)
    
    def source_function(self, phi):
        """Approximate Heaviside/sigmoid source function."""
        # Smooth approximation of SOEN source function
        return torch.sigmoid(5 * phi)
    
    def hidden_dynamics(self, s, phi):
        """
        SingleDendrite dynamics: ds/dt = γ⁺ g(φ) - γ⁻ s
        """
        g = self.source_function(phi)
        dsdt = self.gamma_plus * g - self.gamma_minus * s
        return s + self.dt * dsdt
    
    def compute_phi(self, x, s, feedback=None):
        """
        Compute total flux: φ = W_i2h @ x + W_h2h @ s + bias + feedback
        """
        phi = F.linear(x, self.W_i2h, self.bias_h) + F.linear(s, self.W_h2h)
        if feedback is not None:
            phi = phi + feedback
        return phi
    
    def compute_output(self, s):
        """Output layer: linear readout."""
        return F.linear(s, self.W_h2o, self.bias_o)
    
    def free_phase(self, x, settle_steps=20):
        """
        FREE PHASE: Process input sequence and settle to equilibrium.
        
        Args:
            x: Input sequence (batch, 112, 7)
            settle_steps: Additional steps after input to reach equilibrium
            
        Returns:
            s_free: Hidden state at equilibrium
            o_free: Output at equilibrium
            s_history: Hidden states during processing (for analysis)
        """
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        
        # Initialize hidden state
        s = torch.zeros(batch_size, self.hidden_dim, device=x.device)
        s_history = []
        
        # Process all 112 timesteps
        for t in range(seq_len):
            phi = self.compute_phi(x[:, t], s)
            s = self.hidden_dynamics(s, phi)
            if t % 10 == 0:  # Store every 10th for analysis
                s_history.append(s.detach().clone())
        
        # Additional settling (no new input)
        x_zero = torch.zeros(batch_size, self.input_dim, device=x.device)
        for _ in range(settle_steps):
            phi = self.compute_phi(x_zero, s)
            s = self.hidden_dynamics(s, phi)
        
        s_free = s
        o_free = self.compute_output(s_free)
        
        return s_free, o_free, s_history
    
    def nudged_phase(self, x, s_free, o_free, target, beta=0.5, settle_steps=20):
        """
        NUDGED PHASE: Push output toward target and re-settle.
        
        The key insight: instead of backprop, we let the network
        "feel" the error through symmetric feedback connections.
        
        Args:
            x: Input sequence (for reference)
            s_free: Hidden state from free phase
            o_free: Output from free phase
            target: One-hot target (batch, 10)
            beta: Nudging strength
            settle_steps: Steps to re-settle
            
        Returns:
            s_nudged: Hidden state after nudging
            o_nudged: Output after nudging
        """
        batch_size = x.shape[0]
        
        # Start from free phase state
        s = s_free.clone()
        
        # Zero input during settling
        x_zero = torch.zeros(batch_size, self.input_dim, device=x.device)
        
        # Settle with nudge feedback
        for _ in range(settle_steps):
            # Current output
            o = self.compute_output(s)
            
            # Nudge: push output toward target
            # Error signal: β × (target - output)
            output_error = beta * (target - o)
            
            # Feedback to hidden layer via transpose of W_h2o
            # This is the "symmetric feedback" that makes EP work
            feedback = F.linear(output_error, self.W_h2o.t())
            
            # Update hidden state with feedback
            phi = self.compute_phi(x_zero, s, feedback=feedback)
            s = self.hidden_dynamics(s, phi)
        
        s_nudged = s
        o_nudged = self.compute_output(s_nudged)
        
        return s_nudged, o_nudged
    
    def forward(self, x, target=None, beta=0.5, settle_steps=20):
        """
        Full EP forward pass.
        
        If target is None: only free phase (inference)
        If target provided: both phases (training)
        """
        s_free, o_free, s_history = self.free_phase(x, settle_steps)
        
        if target is None:
            return o_free, {'s_free': s_free, 's_history': s_history}
        
        s_nudged, o_nudged = self.nudged_phase(x, s_free, o_free, target, beta, settle_steps)
        
        return o_free, {
            's_free': s_free,
            'o_free': o_free,
            's_nudged': s_nudged,
            'o_nudged': o_nudged,
            's_history': s_history,
            'x_last': x[:, -1]  # Last input for weight update
        }

# Test model
model = SOEN_EP_Model().to(device)
print(f"Model created: 7 → 28 → 10")
print(f"Parameters: {sum(p.numel() for p in model.parameters())}")

## 3. EP Trainer

**Contrastive Hebbian Learning Rule:**

$$\Delta W = \frac{1}{\beta} \left( s_{\text{nudged}} \otimes \text{pre}_{\text{nudged}} - s_{\text{free}} \otimes \text{pre}_{\text{free}} \right)$$

This is **local** — each synapse only needs its pre/post neuron states!

In [None]:
class EPTrainer:
    """
    Equilibrium Propagation trainer.
    
    Key insight: Weight updates are computed from the DIFFERENCE
    between free and nudged equilibrium states. No backprop through time!
    """
    
    def __init__(self, model, lr=0.01, beta=0.5, settle_steps=20,
                 weight_decay=1e-4):
        self.model = model
        self.lr = lr
        self.beta = beta
        self.settle_steps = settle_steps
        self.weight_decay = weight_decay
        
    def compute_ep_gradients(self, states):
        """
        Compute weight updates using contrastive Hebbian rule.
        
        ΔW ∝ (1/β) × (post_nudged @ pre_nudged.T - post_free @ pre_free.T)
        
        This approximates the true gradient without backprop!
        """
        s_free = states['s_free']      # (batch, 28)
        s_nudged = states['s_nudged']  # (batch, 28)
        o_free = states['o_free']      # (batch, 10)
        o_nudged = states['o_nudged']  # (batch, 10)
        x_last = states['x_last']      # (batch, 7)
        
        batch_size = s_free.shape[0]
        scale = 1.0 / (self.beta * batch_size)
        
        # W_h2o gradient: output × hidden.T
        # ΔW_h2o = (o_nudged @ s_nudged.T - o_free @ s_free.T) / β
        grad_W_h2o = scale * (o_nudged.T @ s_nudged - o_free.T @ s_free)
        
        # W_h2h gradient: hidden × hidden.T (recurrent)
        # ΔW_h2h = (s_nudged @ s_nudged.T - s_free @ s_free.T) / β
        grad_W_h2h = scale * (s_nudged.T @ s_nudged - s_free.T @ s_free)
        # Zero diagonal (no self-connections)
        grad_W_h2h.fill_diagonal_(0)
        
        # W_i2h gradient: hidden × input.T
        # Using last input as representative (could also average)
        # ΔW_i2h = (s_nudged @ x.T - s_free @ x.T) / β
        grad_W_i2h = scale * (s_nudged.T @ x_last - s_free.T @ x_last)
        
        # Bias gradients
        grad_bias_h = scale * (s_nudged.sum(0) - s_free.sum(0))
        grad_bias_o = scale * (o_nudged.sum(0) - o_free.sum(0))
        
        return {
            'W_h2o': grad_W_h2o,
            'W_h2h': grad_W_h2h,
            'W_i2h': grad_W_i2h,
            'bias_h': grad_bias_h,
            'bias_o': grad_bias_o
        }
    
    def update_weights(self, grads):
        """Apply weight updates with optional weight decay."""
        with torch.no_grad():
            # Update weights
            self.model.W_h2o += self.lr * grads['W_h2o']
            self.model.W_h2h += self.lr * grads['W_h2h']
            self.model.W_i2h += self.lr * grads['W_i2h']
            self.model.bias_h += self.lr * grads['bias_h']
            self.model.bias_o += self.lr * grads['bias_o']
            
            # Weight decay
            if self.weight_decay > 0:
                self.model.W_h2o *= (1 - self.lr * self.weight_decay)
                self.model.W_h2h *= (1 - self.lr * self.weight_decay)
                self.model.W_i2h *= (1 - self.lr * self.weight_decay)
            
            # Enforce no self-connections
            self.model.W_h2h.fill_diagonal_(0)
            
            # Clamp weights (stability)
            self.model.W_h2o.clamp_(-1.0, 1.0)
            self.model.W_h2h.clamp_(-0.5, 0.5)
            self.model.W_i2h.clamp_(-1.0, 1.0)
    
    def train_step(self, x, labels):
        """
        Single training step with EP.
        
        1. Free phase: process input, settle
        2. Nudged phase: push toward target, re-settle
        3. Compute EP gradients from state differences
        4. Update weights
        """
        # One-hot encode targets
        target = F.one_hot(labels, num_classes=10).float()
        
        # Forward pass (both phases)
        output, states = self.model(x, target, self.beta, self.settle_steps)
        
        # Compute EP gradients (no backprop!)
        grads = self.compute_ep_gradients(states)
        
        # Update weights
        self.update_weights(grads)
        
        # Compute loss for monitoring (not used for gradients)
        loss = F.cross_entropy(output, labels)
        pred = output.argmax(dim=1)
        acc = (pred == labels).float().mean()
        
        return loss.item(), acc.item()
    
    @torch.no_grad()
    def evaluate(self, dataloader):
        """Evaluate on dataset (free phase only)."""
        self.model.eval()
        total_loss = 0
        total_acc = 0
        total_samples = 0
        
        for x, labels in dataloader:
            x, labels = x.to(device), labels.to(device)
            output, _ = self.model(x)  # Free phase only
            
            loss = F.cross_entropy(output, labels)
            pred = output.argmax(dim=1)
            acc = (pred == labels).float().sum()
            
            total_loss += loss.item() * len(x)
            total_acc += acc.item()
            total_samples += len(x)
        
        self.model.train()
        return total_loss / total_samples, total_acc / total_samples

print("EPTrainer ready")

## 4. Train with EP

In [None]:
def train_ep(model, train_data, train_labels, val_data, val_labels,
             epochs=30, batch_size=64, lr=0.05, beta=0.5, settle_steps=20):
    """
    Train model using Equilibrium Propagation.
    """
    # Create dataloaders
    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.float32),
        torch.tensor(train_labels, dtype=torch.long)
    )
    val_dataset = TensorDataset(
        torch.tensor(val_data, dtype=torch.float32),
        torch.tensor(val_labels, dtype=torch.long)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
    
    # Create trainer
    trainer = EPTrainer(model, lr=lr, beta=beta, settle_steps=settle_steps)
    
    # Training history
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    
    print("="*60)
    print("EQUILIBRIUM PROPAGATION TRAINING")
    print("="*60)
    print(f"Beta (nudge strength): {beta}")
    print(f"Settle steps: {settle_steps}")
    print(f"Learning rate: {lr}")
    print(f"Batch size: {batch_size}")
    print("="*60)
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        epoch_acc = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, labels in pbar:
            x, labels = x.to(device), labels.to(device)
            
            loss, acc = trainer.train_step(x, labels)
            
            epoch_loss += loss
            epoch_acc += acc
            n_batches += 1
            
            pbar.set_postfix({'loss': f'{loss:.4f}', 'acc': f'{acc:.3f}'})
        
        # Epoch averages
        train_loss = epoch_loss / n_batches
        train_acc = epoch_acc / n_batches
        
        # Validation
        val_loss, val_acc = trainer.evaluate(val_loader)
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Track best
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f} {'*' if val_acc == best_val_acc else ''}")
    
    # Restore best model
    model.load_state_dict(best_state)
    print(f"\nBest validation accuracy: {best_val_acc:.4f}")
    
    return history

# Initialize fresh model
model = SOEN_EP_Model(
    input_dim=7,
    hidden_dim=28,
    output_dim=10,
    dt=0.1,
    gamma_plus=0.1,
    gamma_minus=0.01
).to(device)

# Train!
history = train_ep(
    model,
    train_data, train_labels,
    val_data, val_labels,
    epochs=30,
    batch_size=64,
    lr=0.05,        # Higher LR for EP
    beta=0.5,       # Nudge strength
    settle_steps=20  # Steps to equilibrium
)

## 5. Visualize Training

In [None]:
def plot_training_history(history):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss (EP Training)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy (EP Training)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

## 6. Evaluate on Test Set

In [None]:
@torch.no_grad()
def evaluate_test(model, test_data, test_labels):
    model.eval()
    
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.float32),
        torch.tensor(test_labels, dtype=torch.long)
    )
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    for x, labels in tqdm(test_loader, desc="Testing"):
        x = x.to(device)
        output, _ = model(x)
        preds = output.argmax(dim=1).cpu()
        all_preds.append(preds)
        all_labels.append(labels)
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    accuracy = (all_preds == all_labels).float().mean().item()
    
    print(f"\n{'='*60}")
    print(f"TEST ACCURACY (Equilibrium Propagation): {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    return accuracy, all_preds.numpy(), all_labels.numpy()

test_acc, test_preds, test_true = evaluate_test(model, test_data, test_labels)

## 7. Visualize Predictions

In [None]:
def visualize_predictions_ep(model, test_data, test_labels, n=20):
    np.random.seed(42)
    idx = np.random.choice(len(test_data), n, replace=False)
    samples = test_data[idx]
    labels = test_labels[idx]
    images = samples.reshape(n, 784).reshape(n, 28, 28)
    
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32).to(device)
        output, _ = model(x)
        probs = F.softmax(output, dim=1)
        preds = probs.argmax(dim=1).cpu().numpy()
        conf = probs.max(dim=1)[0].cpu().numpy()
    
    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.flatten()
    fig.suptitle('Predictions (Equilibrium Propagation)', fontsize=14, fontweight='bold')
    
    for i in range(n):
        axes[i].imshow(images[i], cmap='gray')
        ok = preds[i] == labels[i]
        axes[i].set_title(f"{'✓' if ok else '✗'} {preds[i]} ({conf[i]:.0%})\nTrue: {labels[i]}",
                          color='green' if ok else 'red', fontsize=9)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions_ep(model, test_data, test_labels)

## 8. Analyze EP Dynamics

In [None]:
def analyze_ep_dynamics(model, sample_data, sample_label, beta=0.5, settle_steps=20):
    """
    Visualize how states evolve during free and nudged phases.
    """
    model.eval()
    x = torch.tensor(sample_data, dtype=torch.float32).unsqueeze(0).to(device)
    target = F.one_hot(torch.tensor([sample_label]), num_classes=10).float().to(device)
    
    # Track states during free phase
    s = torch.zeros(1, model.hidden_dim, device=device)
    free_states = [s.clone()]
    
    # Process input
    for t in range(112):
        phi = model.compute_phi(x[:, t], s)
        s = model.hidden_dynamics(s, phi)
        if t % 5 == 0:
            free_states.append(s.clone())
    
    # Settle
    x_zero = torch.zeros(1, model.input_dim, device=device)
    for _ in range(settle_steps):
        phi = model.compute_phi(x_zero, s)
        s = model.hidden_dynamics(s, phi)
        free_states.append(s.clone())
    
    s_free = s.clone()
    o_free = model.compute_output(s_free)
    
    # Nudged phase
    nudged_states = [s.clone()]
    for _ in range(settle_steps):
        o = model.compute_output(s)
        output_error = beta * (target - o)
        feedback = F.linear(output_error, model.W_h2o.t())
        phi = model.compute_phi(x_zero, s, feedback=feedback)
        s = model.hidden_dynamics(s, phi)
        nudged_states.append(s.clone())
    
    s_nudged = s
    o_nudged = model.compute_output(s_nudged)
    
    # Plot
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # State evolution during free phase
    free_arr = torch.stack(free_states).squeeze().cpu().numpy()
    axes[0, 0].imshow(free_arr.T, aspect='auto', cmap='viridis')
    axes[0, 0].axvline(x=22, color='r', linestyle='--', label='Input ends')
    axes[0, 0].set_xlabel('Time step')
    axes[0, 0].set_ylabel('Hidden neuron')
    axes[0, 0].set_title('Free Phase: Hidden State Evolution')
    axes[0, 0].legend()
    
    # Nudged phase
    nudged_arr = torch.stack(nudged_states).squeeze().cpu().numpy()
    axes[0, 1].imshow(nudged_arr.T, aspect='auto', cmap='viridis')
    axes[0, 1].set_xlabel('Settle step')
    axes[0, 1].set_ylabel('Hidden neuron')
    axes[0, 1].set_title('Nudged Phase: Hidden State Evolution')
    
    # State difference (this is what drives learning!)
    diff = (s_nudged - s_free).squeeze().cpu().numpy()
    axes[1, 0].bar(range(len(diff)), diff)
    axes[1, 0].set_xlabel('Hidden neuron')
    axes[1, 0].set_ylabel('Δs = s_nudged - s_free')
    axes[1, 0].set_title('State Difference (drives weight updates)')
    axes[1, 0].axhline(y=0, color='k', linestyle='-', linewidth=0.5)
    
    # Output comparison
    x_pos = np.arange(10)
    width = 0.25
    axes[1, 1].bar(x_pos - width, o_free.squeeze().cpu().numpy(), width, label='Free', alpha=0.7)
    axes[1, 1].bar(x_pos, o_nudged.squeeze().cpu().numpy(), width, label='Nudged', alpha=0.7)
    axes[1, 1].bar(x_pos + width, target.squeeze().cpu().numpy(), width, label='Target', alpha=0.7)
    axes[1, 1].set_xlabel('Class')
    axes[1, 1].set_ylabel('Activation')
    axes[1, 1].set_title(f'Output: Free vs Nudged vs Target (True: {sample_label})')
    axes[1, 1].legend()
    axes[1, 1].set_xticks(x_pos)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nFree phase prediction: {o_free.argmax().item()}")
    print(f"Nudged phase prediction: {o_nudged.argmax().item()}")
    print(f"True label: {sample_label}")

# Analyze a sample
analyze_ep_dynamics(model, test_data[0], test_labels[0])

## Summary

| Aspect | BPTT | Equilibrium Propagation |
|--------|------|-------------------------|
| Gradient computation | Backprop through 112 steps | State differences |
| Vanishing gradients | Yes (exponential decay) | **Avoided** |
| Memory requirement | O(T × hidden) | O(hidden) |
| Hardware compatibility | Requires external compute | **Local, Hebbian-like** |
| Training phases | 1 (forward + backward) | 2 (free + nudged) |

### Key Insight

EP computes gradients from the **difference** between free and nudged equilibrium states.
This is fundamentally different from BPTT and avoids the vanishing gradient problem
because we don't chain multiply through 112 timesteps.

### Tuning Tips

- **β (beta)**: Controls nudge strength. Too small → noisy gradients, too large → biased
- **settle_steps**: More steps → better equilibrium, but slower
- **Learning rate**: EP often needs higher LR than backprop