# Equilibrium Propagation Tutorial

This notebook implements **Equilibrium Propagation**, a learning algorithm for energy-based models proposed by Scellier & Bengio (2017).

## Key Concepts:
- **Energy-based Models**: Networks that define an energy function over states
- **Free Phase**: Network relaxes to equilibrium without target clamping
- **Weakly Clamped Phase**: Network relaxes with weak target signal
- **Learning Rule**: Parameter updates based on energy differences between phases

**Paper**: [Equilibrium Propagation: Bridging the Gap between Energy-Based Models and Backpropagation](https://arxiv.org/abs/1602.05179)

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = 'cuda'
else:
    device = 'cpu'
print(f"Using device: {device}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

In [None]:
class EquilibriumPropagationNetwork:
    """Modern PyTorch implementation of Equilibrium Propagation
    
    Based on the paper "Equilibrium Propagation: Bridging the Gap between 
    Energy-Based Models and Backpropagation" by Scellier & Bengio (2017)
    """
    
    def __init__(self, hidden_sizes=[500], device='cpu'):
        self.device = device
        self.hidden_sizes = hidden_sizes
        
        # Network architecture: 784 (MNIST) -> hidden_sizes -> 10
        layer_sizes = [784] + hidden_sizes + [10]
        
        # Initialize weights and biases
        self.weights = []
        self.biases = []
        
        for i in range(len(layer_sizes) - 1):
            # Glorot/Xavier initialization
            fan_in, fan_out = layer_sizes[i], layer_sizes[i+1]
            bound = np.sqrt(6.0 / (fan_in + fan_out))
            W = torch.empty(fan_in, fan_out).uniform_(-bound, bound).to(device)
            W.requires_grad_(True)
            self.weights.append(W)
            
            # Initialize biases to zero
            b = torch.zeros(layer_sizes[i+1]).to(device)
            b.requires_grad_(True)
            self.biases.append(b)
    
    def rho(self, s):
        """Hard sigmoid activation function: rho(s) = max(0, min(1, s))"""
        return torch.clamp(s, 0.0, 1.0)
    
    def energy(self, layers):
        """Compute energy function E for the current state"""
        # Squared norm term: sum(rho(layer)^2) / 2
        squared_norm = sum([(self.rho(layer) ** 2).sum(dim=1) for layer in layers]) / 2.0
        
        # Linear terms: -sum(rho(layer) * bias) - note: skip input layer for biases
        linear_terms = -sum([torch.sum(self.rho(layer) * bias, dim=1) 
                           for layer, bias in zip(layers[1:], self.biases)])
        
        # Quadratic terms: -sum(rho(pre) @ W @ rho(post))
        quadratic_terms = -sum([torch.sum(self.rho(pre) @ W * self.rho(post), dim=1)
                               for pre, W, post in zip(layers[:-1], self.weights, layers[1:])])
        
        return squared_norm + linear_terms + quadratic_terms
    
    def cost(self, output_layer, target):
        """Compute cost function C (mean squared error)"""
        target_one_hot = F.one_hot(target, num_classes=10).float()
        return ((output_layer - target_one_hot) ** 2).sum(dim=1)
    
    def total_energy(self, layers, target, beta):
        """Compute total energy F = E + beta * C"""
        return self.energy(layers) + beta * self.cost(layers[-1], target)
    
    def free_phase(self, x, n_iterations=20, epsilon=0.5):
        """Run free phase dynamics to find equilibrium"""
        batch_size = x.shape[0]
        
        # Initialize layers
        layers = [x]  # Input layer (clamped)
        for size in self.hidden_sizes + [10]:
            layer = torch.zeros(batch_size, size, requires_grad=True, device=self.device)
            layers.append(layer)
        
        # Run dynamics
        for _ in range(n_iterations):
            # Compute energy gradient
            energy_val = self.energy(layers).sum()
            grads = torch.autograd.grad(energy_val, layers[1:], create_graph=True)
            
            # Update layers (except input which is clamped)
            new_layers = [layers[0]]  # Keep input layer
            for i, grad in enumerate(grads):
                new_layer = self.rho(layers[i+1] - epsilon * grad)
                new_layer.requires_grad_(True)
                new_layers.append(new_layer)
            layers = new_layers
        
        return layers
    
    def weakly_clamped_phase(self, x, target, n_iterations=4, epsilon=0.5, beta=0.5):
        """Run weakly clamped phase dynamics"""
        batch_size = x.shape[0]
        
        # Initialize layers
        layers = [x]  # Input layer (clamped)
        for size in self.hidden_sizes + [10]:
            layer = torch.zeros(batch_size, size, requires_grad=True, device=self.device)
            layers.append(layer)
        
        # Run dynamics
        for _ in range(n_iterations):
            # Compute total energy gradient
            total_energy_val = self.total_energy(layers, target, beta).sum()
            grads = torch.autograd.grad(total_energy_val, layers[1:], create_graph=True)
            
            # Update layers (except input which is clamped)
            new_layers = [layers[0]]  # Keep input layer
            for i, grad in enumerate(grads):
                new_layer = self.rho(layers[i+1] - epsilon * grad)
                new_layer.requires_grad_(True)
                new_layers.append(new_layer)
            layers = new_layers
        
        return layers
    
    def compute_gradients(self, layers_free, layers_clamped, beta):
        """Compute parameter gradients using equilibrium propagation"""
        # Energy at free equilibrium
        energy_free = self.energy(layers_free).mean()
        
        # Energy at weakly clamped equilibrium  
        energy_clamped = self.energy(layers_clamped).mean()
        
        # Gradient of energy difference w.r.t. parameters
        energy_diff = (energy_clamped - energy_free) / beta
        
        weight_grads = torch.autograd.grad(energy_diff, self.weights, retain_graph=True)
        bias_grads = torch.autograd.grad(energy_diff, self.biases)
        
        return weight_grads, bias_grads
    
    def update_parameters(self, weight_grads, bias_grads, alphas):
        """Update network parameters"""
        with torch.no_grad():
            for i, (W, grad) in enumerate(zip(self.weights, weight_grads)):
                W -= alphas[i] * grad
            
            for i, (b, grad) in enumerate(zip(self.biases, bias_grads)):
                b -= alphas[i] * grad
    
    def predict(self, x):
        """Make predictions using free phase"""
        layers = self.free_phase(x)
        return torch.argmax(layers[-1], dim=1)
    
    def measure(self, x, target):
        """Measure energy, cost, and error rate"""
        layers = self.free_phase(x)
        
        energy_val = self.energy(layers).mean()
        cost_val = self.cost(layers[-1], target).mean()
        
        predictions = torch.argmax(layers[-1], dim=1)
        error_rate = (predictions != target).float().mean()
        
        return energy_val.item(), cost_val.item(), error_rate.item()

In [None]:
def load_mnist_data(batch_size=20):
    """Load MNIST dataset"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))  # Flatten 28x28 to 784
    ])
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# Load data
print("Loading MNIST dataset...")
train_loader, test_loader = load_mnist_data(batch_size=20)
print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Quick demo with a single batch
print("Creating network...")
net = EquilibriumPropagationNetwork(hidden_sizes=[100], device=device)  # Smaller for demo

# Get a single batch
data_iter = iter(train_loader)
data, target = next(data_iter)
data, target = data.to(device), target.to(device)

print(f"Batch shape: {data.shape}")
print(f"Target shape: {target.shape}")

# Test free phase
print("\nRunning free phase...")
start_time = time.time()
layers_free = net.free_phase(data, n_iterations=10, epsilon=0.5)
free_time = time.time() - start_time
print(f"Free phase took {free_time:.2f} seconds")

# Test weakly clamped phase
print("\nRunning weakly clamped phase...")
start_time = time.time()
layers_clamped = net.weakly_clamped_phase(data, target, n_iterations=4, epsilon=0.5, beta=0.5)
clamped_time = time.time() - start_time
print(f"Weakly clamped phase took {clamped_time:.2f} seconds")

# Measure performance
energy, cost, error = net.measure(data, target)
print(f"\nInitial performance:")
print(f"Energy: {energy:.3f}")
print(f"Cost: {cost:.3f}")
print(f"Error rate: {error*100:.1f}%")

In [None]:
def train_network(hidden_sizes=[100], n_epochs=3, batch_size=20, 
                 n_it_neg=10, n_it_pos=4, epsilon=0.5, beta=0.5, 
                 alphas=[0.1, 0.05], device='cpu', max_batches=50):
    """Train equilibrium propagation network (simplified for demo)"""
    
    print(f"Architecture: 784-{'-'.join(map(str, hidden_sizes))}-10")
    print(f"Epochs: {n_epochs}, Batch size: {batch_size}")
    print(f"Free phase iterations: {n_it_neg}, Clamped phase iterations: {n_it_pos}")
    print(f"Learning rate: {epsilon}, Beta: {beta}")
    print(f"Alphas: {alphas}")
    print(f"Max batches per epoch: {max_batches}\n")
    
    # Initialize network
    net = EquilibriumPropagationNetwork(hidden_sizes, device)
    
    # Load data
    train_loader, test_loader = load_mnist_data(batch_size)
    
    # Training curves
    training_errors = []
    validation_errors = []
    
    start_time = time.time()
    
    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs}")
        
        # Training phase
        train_error_sum = 0.0
        num_train_batches = 0
        
        pbar = tqdm(enumerate(train_loader), total=min(max_batches, len(train_loader)))
        for batch_idx, (data, target) in pbar:
            if batch_idx >= max_batches:
                break
                
            data, target = data.to(device), target.to(device)
            
            # Free phase
            layers_free = net.free_phase(data, n_it_neg, epsilon)
            
            # Measure at free equilibrium
            _, _, error = net.measure(data, target)
            train_error_sum += error
            num_train_batches += 1
            
            # Update progress bar
            avg_error = train_error_sum / num_train_batches * 100
            pbar.set_description(f"Training Error: {avg_error:.1f}%")
            
            # Weakly clamped phase
            sign = 2 * np.random.randint(0, 2) - 1  # Random sign
            beta_signed = sign * beta
            
            layers_clamped = net.weakly_clamped_phase(data, target, n_it_pos, epsilon, beta_signed)
            
            # Compute and apply gradients
            weight_grads, bias_grads = net.compute_gradients(layers_free, layers_clamped, beta_signed)
            net.update_parameters(weight_grads, bias_grads, alphas)
        
        avg_train_error = train_error_sum / num_train_batches * 100
        training_errors.append(avg_train_error)
        
        # Validation phase (subset)
        val_error_sum = 0.0
        num_val_batches = 0
        
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                if batch_idx >= 20:  # Only test on first 20 batches
                    break
                    
                data, target = data.to(device), target.to(device)
                
                _, _, error = net.measure(data, target)
                val_error_sum += error
                num_val_batches += 1
        
        avg_val_error = val_error_sum / num_val_batches * 100
        validation_errors.append(avg_val_error)
        
        duration = (time.time() - start_time) / 60.0
        print(f"  Training error: {avg_train_error:.2f}%")
        print(f"  Validation error: {avg_val_error:.2f}%")
        print(f"  Duration: {duration:.1f} min\n")
    
    return net, training_errors, validation_errors

In [None]:
# Train the network (reduced parameters for demo)
print("Starting training...")

net, train_errors, val_errors = train_network(
    hidden_sizes=[100],  # Smaller network
    n_epochs=3,          # Fewer epochs
    batch_size=20,
    n_it_neg=10,         # Fewer iterations
    n_it_pos=4,
    epsilon=0.5,
    beta=0.5,
    alphas=[0.1, 0.05],
    device=device,
    max_batches=50       # Limit batches per epoch
)

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_errors, 'b-', label='Training Error', marker='o')
plt.plot(val_errors, 'r-', label='Validation Error', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Error (%)')
plt.title('Equilibrium Propagation Training Curves')
plt.legend()
plt.grid(True)

# Test on a few examples
plt.subplot(1, 2, 2)
test_iter = iter(test_loader)
test_data, test_target = next(test_iter)
test_data, test_target = test_data.to(device), test_target.to(device)

# Get predictions
predictions = net.predict(test_data[:8])  # First 8 examples

# Show some examples
for i in range(8):
    plt.subplot(2, 4, i+1)
    img = test_data[i].cpu().reshape(28, 28)
    plt.imshow(img, cmap='gray')
    plt.title(f'True: {test_target[i].item()}, Pred: {predictions[i].item()}')
    plt.axis('off')

plt.tight_layout()
plt.show()

print("\nTraining completed!")
print(f"Final training error: {train_errors[-1]:.2f}%")
print(f"Final validation error: {val_errors[-1]:.2f}%")

## How Equilibrium Propagation Works

### 1. Energy Function
The network defines an energy function:
```
E(s) = (1/2) Σ ρ(s_i)² - Σ b_i ρ(s_i) - Σ W_ij ρ(s_i) ρ(s_j)
```
where `ρ` is the hard sigmoid activation function.

### 2. Free Phase
- Network relaxes to minimize energy: `ds/dt = -∂E/∂s`
- Input layer is clamped to data
- Hidden and output layers evolve to equilibrium

### 3. Weakly Clamped Phase  
- Total energy becomes: `F(s) = E(s) + β C(s)`
- `C(s)` is cost function (e.g., MSE with targets)
- `β` is small perturbation parameter
- Network finds new equilibrium under this modified energy

### 4. Learning Rule
Parameter updates are proportional to:
```
Δθ ∝ (E_clamped - E_free) / β
```

This approximates the gradient of the cost function!

### Key Insights
- **Local learning**: Updates depend only on local equilibrium states
- **Biologically plausible**: No need for backpropagation
- **Energy-based**: Naturally handles deep networks
- **Equivalence to backprop**: Under certain conditions, EP ≈ backpropagation

In [None]:
# Compare with a simple baseline (random predictions)
print("Comparison with random baseline:")
random_accuracy = 10.0  # 10% for 10-class classification
ep_accuracy = 100 - val_errors[-1]

print(f"Random baseline accuracy: {random_accuracy:.1f}%")
print(f"Equilibrium Propagation accuracy: {ep_accuracy:.1f}%")
print(f"Improvement: {ep_accuracy - random_accuracy:.1f} percentage points")

# Show energy evolution during training
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(train_errors, 'b-', linewidth=2)
plt.axhline(y=random_accuracy, color='r', linestyle='--', label='Random baseline')
plt.xlabel('Epoch')
plt.ylabel('Error (%)')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
# Show final learned weights (first layer)
W1 = net.weights[0].detach().cpu().numpy()
plt.imshow(W1[:100, :100], cmap='RdBu', aspect='auto')
plt.title('Learned Weight Matrix (subset)')
plt.xlabel('Hidden units')
plt.ylabel('Input pixels')
plt.colorbar()

plt.tight_layout()
plt.show()