# Tutorial 04 â€” Forward-Forward Learning for MNIST with SOEN

This notebook demonstrates the **Forward-Forward (FF) algorithm** (Hinton, 2022) for training SOEN networks on MNIST, as an alternative to backpropagation.

---

## ðŸ”Š NOISE CONFIGURATION: ENABLED (Default)

> **This tutorial runs with NOISE INJECTION (documented defaults).**
>
> | Parameter | Default | Description |
> |-----------|---------|-------------|
> | `phi` | **0.01** | Noise on input flux |
> | `s` | **0.005** | Noise on state |
> | `relative` | **False** | Absolute scaling |
>
> **To toggle noise on/off:** Use the `NOISE_ENABLED` variable below.

---

## Why Forward-Forward for SOEN?

| Aspect | Backpropagation | Forward-Forward |
|--------|-----------------|------------------|
| **Computation** | Requires backward pass | Forward only |
| **Memory** | Store all activations | No caching needed |
| **Weight transport** | Symmetric weights required | Local weights only |
| **Hardware fit** | Needs external computation | Matches SOEN physics |
| **Noise tolerance** | Gradients compound noise | Local learning is robust |

## The Forward-Forward Algorithm

Instead of propagating errors backward, each layer learns locally:

1. **Positive pass**: Real data with correct label â†’ maximize "goodness"
2. **Negative pass**: Real data with wrong label â†’ minimize "goodness"
3. **Goodness**: Sum of squared activations (or mean squared state)
4. **Threshold**: Each layer pushes positive goodness above Î¸, negative below Î¸

```
Positive data: Image + correct_label_embedding â†’ High goodness âœ“
Negative data: Image + wrong_label_embedding   â†’ Low goodness  âœ—
```

## Setup

In [None]:
# Disable tqdm notebook widgets BEFORE any imports
import os
os.environ["TQDM_DISABLE"] = "0"  # Don't disable, but force text mode
os.environ["TQDM_MININTERVAL"] = "1"

# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import gzip
import urllib.request
import struct
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass

# Use standard tqdm (not notebook version to avoid widget errors)
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

# Set torch precision
torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

In [None]:
# ==============================================================================
# NOISE CONFIGURATION TOGGLE
# ==============================================================================
# Set NOISE_ENABLED = False to run with ideal conditions (no noise)
# Set NOISE_ENABLED = True for noise injection (default)

NOISE_ENABLED = True  # Toggle this to enable/disable noise

# Default noise parameters (documented defaults)
NOISE_DEFAULTS = {
    "phi": 0.01,           # Noise on input flux
    "s": 0.005,            # Noise on state
    "g": 0.0,              # Source function noise
    "bias_current": 0.0,   # Bias current noise
    "j": 0.0,              # Connection weight noise
    "relative": False,     # Absolute scaling
}

def set_model_noise(model, enabled=True, noise_values=None):
    """
    Toggle noise injection on/off for a SOEN model.
    
    Args:
        model: SOENModelCore instance
        enabled: If True, apply noise; if False, set all noise to 0
        noise_values: Dict of noise parameters (uses NOISE_DEFAULTS if None)
    
    Returns:
        model: The modified model (for chaining)
    """
    from soen_toolkit.core.configs import NoiseConfig
    
    if noise_values is None:
        noise_values = NOISE_DEFAULTS
    
    # Update layer noise configurations
    for cfg in model.layers_config:
        if enabled:
            cfg.noise = NoiseConfig(
                phi=noise_values.get("phi", 0.01),
                s=noise_values.get("s", 0.005),
                g=noise_values.get("g", 0.0),
                bias_current=noise_values.get("bias_current", 0.0),
                j=noise_values.get("j", 0.0),
                relative=noise_values.get("relative", False),
                extras=getattr(cfg.noise, "extras", {}),
            )
        else:
            cfg.noise = NoiseConfig(
                phi=0.0, s=0.0, g=0.0, bias_current=0.0, j=0.0,
                relative=False,
                extras=getattr(cfg.noise, "extras", {}),
            )
    
    # Update connection noise configurations
    for conn_cfg in model.connections_config:
        if enabled:
            conn_cfg.noise = NoiseConfig(
                phi=0.0, g=0.0, s=0.0, bias_current=0.0,
                j=noise_values.get("j", 0.0),
                relative=noise_values.get("relative", False),
                extras={},
            )
        else:
            conn_cfg.noise = NoiseConfig(
                phi=0.0, g=0.0, s=0.0, bias_current=0.0, j=0.0,
                relative=False, extras={},
            )
    
    status = "ENABLED" if enabled else "DISABLED"
    print(f"âœ“ Noise injection {status}")
    if enabled:
        print(f"  phi={noise_values['phi']}, s={noise_values['s']}, "
              f"relative={noise_values['relative']}")
    
    return model

print(f"Noise injection: {'ENABLED' if NOISE_ENABLED else 'DISABLED'}")
if NOISE_ENABLED:
    print(f"  Default values: phi={NOISE_DEFAULTS['phi']}, s={NOISE_DEFAULTS['s']}, "
          f"relative={NOISE_DEFAULTS['relative']}")

## 1. Forward-Forward Configuration

In [None]:
@dataclass
class FFConfig:
    """Configuration for Forward-Forward training."""
    
    # Architecture
    input_dim: int = 28 * 28 + 10  # Flattened MNIST + one-hot label
    hidden_dims: List[int] = None  # Will default to [500, 500]
    num_classes: int = 10
    
    # Training
    batch_size: int = 128
    num_epochs: int = 60
    learning_rate: float = 0.03
    
    # Forward-Forward specific
    threshold: float = 2.0  # Goodness threshold
    negative_label_method: str = "random"  # "random" or "next"
    
    # SOEN specific
    use_soen_layers: bool = True  # Use SOEN dynamics vs standard layers
    dt: float = 100.0  # SOEN timestep
    num_timesteps: int = 28  # Process as sequence (one row at a time)
    
    def __post_init__(self):
        if self.hidden_dims is None:
            self.hidden_dims = [500, 500]


# Default configuration
config = FFConfig()
print("Forward-Forward Configuration:")
print(f"  Input dim: {config.input_dim}")
print(f"  Hidden dims: {config.hidden_dims}")
print(f"  Threshold: {config.threshold}")
print(f"  Use SOEN layers: {config.use_soen_layers}")

## 2. Prepare MNIST Dataset with Label Embedding

For Forward-Forward, we embed the label directly into the input:
- **Positive data**: Image pixels + one-hot encoding of the **correct** label
- **Negative data**: Image pixels + one-hot encoding of a **wrong** label

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    """Download a single MNIST file if not already present."""
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    
    filepath = data_dir / filename
    if not filepath.exists():
        url = base_url + filename
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
    return filepath


def read_mnist_images(filepath):
    """Read MNIST image file (idx3-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
    return images


def read_mnist_labels(filepath):
    """Read MNIST label file (idx1-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels


def load_mnist():
    """Load MNIST dataset."""
    # Download files
    train_images_file = download_mnist_file("train-images-idx3-ubyte.gz")
    train_labels_file = download_mnist_file("train-labels-idx1-ubyte.gz")
    test_images_file = download_mnist_file("t10k-images-idx3-ubyte.gz")
    test_labels_file = download_mnist_file("t10k-labels-idx1-ubyte.gz")
    
    # Read data
    train_images = read_mnist_images(train_images_file).astype(np.float32) / 255.0
    train_labels = read_mnist_labels(train_labels_file)
    test_images = read_mnist_images(test_images_file).astype(np.float32) / 255.0
    test_labels = read_mnist_labels(test_labels_file)
    
    return train_images, train_labels, test_images, test_labels


# Load data
train_images, train_labels, test_images, test_labels = load_mnist()
print(f"Train: {train_images.shape}, Test: {test_images.shape}")

In [None]:
class MNISTForwardForwardDataset(Dataset):
    """MNIST dataset prepared for Forward-Forward learning.
    
    Each sample returns both positive and negative versions:
    - Positive: image + correct label one-hot
    - Negative: image + wrong label one-hot
    """
    
    def __init__(self, images, labels, num_classes=10, negative_method="random"):
        """
        Args:
            images: (N, 28, 28) image array
            labels: (N,) label array
            num_classes: Number of classes (10 for MNIST)
            negative_method: "random" for random wrong label, "next" for (label+1)%10
        """
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.num_classes = num_classes
        self.negative_method = negative_method
        
    def __len__(self):
        return len(self.labels)
    
    def _embed_label(self, image, label):
        """Embed label into image.
        
        We replace the first 10 pixels of the flattened image with the one-hot label.
        This is Hinton's approach from the original paper.
        """
        flat_image = image.flatten()  # (784,)
        one_hot = F.one_hot(label, num_classes=self.num_classes).float()  # (10,)
        
        # Replace first 10 pixels with label
        embedded = flat_image.clone()
        embedded[:self.num_classes] = one_hot
        
        return embedded
    
    def _get_negative_label(self, true_label):
        """Generate a wrong label for negative data."""
        if self.negative_method == "next":
            return (true_label + 1) % self.num_classes
        else:  # random
            wrong_label = torch.randint(0, self.num_classes - 1, (1,)).item()
            if wrong_label >= true_label:
                wrong_label += 1
            return wrong_label
    
    def __getitem__(self, idx):
        image = self.images[idx]
        true_label = self.labels[idx]
        
        # Positive: correct label
        positive_data = self._embed_label(image, true_label)
        
        # Negative: wrong label
        wrong_label = self._get_negative_label(true_label.item())
        negative_data = self._embed_label(image, torch.tensor(wrong_label))
        
        return {
            "positive": positive_data,
            "negative": negative_data,
            "label": true_label,
            "image": image,
        }


# Create datasets
train_dataset = MNISTForwardForwardDataset(
    train_images, train_labels, 
    negative_method=config.negative_label_method
)
test_dataset = MNISTForwardForwardDataset(
    test_images, test_labels,
    negative_method=config.negative_label_method
)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# Verify
sample = train_dataset[0]
print(f"Positive data shape: {sample['positive'].shape}")
print(f"Negative data shape: {sample['negative'].shape}")
print(f"True label: {sample['label']}")
print(f"First 10 pixels of positive (should be one-hot): {sample['positive'][:10]}")

## 3. Visualize Positive vs Negative Data

In [None]:
def visualize_ff_data(dataset, n_samples=5):
    """Visualize positive and negative data pairs."""
    
    fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))
    fig.suptitle('Forward-Forward Data: Positive vs Negative', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        sample = dataset[i]
        
        # Original image
        axes[0, i].imshow(sample['image'], cmap='gray')
        axes[0, i].set_title(f"True label: {sample['label'].item()}")
        axes[0, i].axis('off')
        
        # Positive embedding (first 10 pixels show one-hot)
        pos_reshaped = sample['positive'].reshape(28, 28)
        axes[1, i].imshow(pos_reshaped, cmap='gray')
        embedded_label = sample['positive'][:10].argmax().item()
        axes[1, i].set_title(f"Positive (label={embedded_label})")
        axes[1, i].axis('off')
        
        # Negative embedding
        neg_reshaped = sample['negative'].reshape(28, 28)
        axes[2, i].imshow(neg_reshaped, cmap='gray')
        wrong_label = sample['negative'][:10].argmax().item()
        axes[2, i].set_title(f"Negative (label={wrong_label})", color='red')
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Positive\n(correct label)', fontsize=12)
    axes[2, 0].set_ylabel('Negative\n(wrong label)', fontsize=12)
    
    plt.tight_layout()
    plt.show()


visualize_ff_data(train_dataset)

## 4. Forward-Forward Layer

Each layer in Forward-Forward has its own local objective:
- **Goodness** = sum of squared activations
- **Goal**: goodness > threshold for positive data, goodness < threshold for negative data

We'll create both a standard FF layer and a SOEN-based FF layer.

In [None]:
class FFLayer(nn.Module):
    """Standard Forward-Forward layer with local learning objective."""
    
    def __init__(self, in_features: int, out_features: int, threshold: float = 2.0):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.threshold = threshold
        self.layer_norm = nn.LayerNorm(out_features)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with layer normalization."""
        x = self.linear(x)
        x = self.relu(x)
        x = self.layer_norm(x)
        return x
    
    def compute_goodness(self, x: torch.Tensor) -> torch.Tensor:
        """Compute goodness (sum of squared activations)."""
        return (x ** 2).mean(dim=1)  # Mean over features, shape: (batch,)
    
    def ff_loss(self, pos_goodness: torch.Tensor, neg_goodness: torch.Tensor) -> torch.Tensor:
        """Forward-Forward loss: push positive above threshold, negative below.
        
        Loss = log(1 + exp(-(pos_goodness - threshold))) + log(1 + exp(neg_goodness - threshold))
        """
        pos_loss = torch.log(1 + torch.exp(-(pos_goodness - self.threshold)))
        neg_loss = torch.log(1 + torch.exp(neg_goodness - self.threshold))
        return (pos_loss + neg_loss).mean()


# Test
test_layer = FFLayer(784, 500)
test_input = torch.randn(32, 784)
test_output = test_layer(test_input)
test_goodness = test_layer.compute_goodness(test_output)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"Goodness shape: {test_goodness.shape}")
print(f"Mean goodness: {test_goodness.mean().item():.4f}")

## 5. SOEN Forward-Forward Layer

This layer uses SOEN SingleDendrite dynamics for the forward pass, making it compatible with neuromorphic hardware.

In [None]:
from soen_toolkit.core import (
    ConnectionConfig,
    LayerConfig,
    SimulationConfig,
    SOENModelCore,
)


class SOENFFLayer(nn.Module):
    """SOEN-based Forward-Forward layer using SingleDendrite dynamics.
    
    The input is processed as a sequence (28 timesteps for MNIST rows),
    and goodness is computed from the final states.
    """
    
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        threshold: float = 2.0,
        dt: float = 100.0,
        num_timesteps: int = 28,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold
        self.num_timesteps = num_timesteps
        
        # Build SOEN model
        sim_cfg = SimulationConfig(
            dt=dt,
            input_type="state",
            track_phi=False,
            track_power=False,
        )
        
        # Input layer
        layer0 = LayerConfig(
            layer_id=0,
            layer_type="Input",
            params={"dim": in_features // num_timesteps},  # Features per timestep
        )
        
        # SOEN hidden layer with SingleDendrite dynamics
        layer1 = LayerConfig(
            layer_id=1,
            layer_type="SingleDendrite",
            params={
                "dim": out_features,
                "solver": "FE",
                "source_func": "Heaviside_fit_state_dep",
                "phi_offset": 0.02,
                "bias_current": {"distribution": "uniform", "params": {"min": 1.8, "max": 2.1}},
                "gamma_plus": {"distribution": "constant", "params": {"value": 0.001}},
                "gamma_minus": {"distribution": "constant", "params": {"value": 0.0001}},
            },
        )
        
        # Connection
        conn = ConnectionConfig(
            from_layer=0,
            to_layer=1,
            connection_type="dense",
            params={"init": "xavier_uniform"},
            learnable=True,
        )
        
        self.soen_model = SOENModelCore(
            sim_config=sim_cfg,
            layers_config=[layer0, layer1],
            connections_config=[conn],
        )
        
        # Apply noise configuration
        if NOISE_ENABLED:
            set_model_noise(self.soen_model, enabled=True, noise_values=NOISE_DEFAULTS)
        
        # Output projection (optional, for matching dimensions)
        self.layer_norm = nn.LayerNorm(out_features)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Process input through SOEN dynamics.
        
        Args:
            x: Input tensor of shape (batch, features) where features = 784
        
        Returns:
            Output tensor of shape (batch, out_features)
        """
        batch_size = x.shape[0]
        
        # Reshape to sequence: (batch, timesteps, features_per_step)
        # For MNIST: (batch, 28, 28) treating each row as a timestep
        x_seq = x.view(batch_size, self.num_timesteps, -1)
        
        # Process through SOEN
        output, all_states = self.soen_model(x_seq)
        
        # Use final state (or max-pooled state)
        # output shape: (batch, timesteps, out_features)
        if output.dim() == 3:
            # Max pool over time to get most salient features
            output = output.max(dim=1)[0]  # (batch, out_features)
        
        # Normalize
        output = self.layer_norm(output)
        
        return output
    
    def compute_goodness(self, x: torch.Tensor) -> torch.Tensor:
        """Compute goodness (mean squared activation)."""
        return (x ** 2).mean(dim=1)
    
    def ff_loss(self, pos_goodness: torch.Tensor, neg_goodness: torch.Tensor) -> torch.Tensor:
        """Forward-Forward loss."""
        pos_loss = torch.log(1 + torch.exp(-(pos_goodness - self.threshold)))
        neg_loss = torch.log(1 + torch.exp(neg_goodness - self.threshold))
        return (pos_loss + neg_loss).mean()


# Test SOEN FF layer
print("Testing SOEN FF Layer...")
soen_layer = SOENFFLayer(784, 256, threshold=2.0)
test_input = torch.randn(4, 784)
test_output = soen_layer(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"Goodness: {soen_layer.compute_goodness(test_output).mean().item():.4f}")

## 6. Forward-Forward Network

Stack multiple FF layers, each trained with its own local objective.

In [None]:
class ForwardForwardNet(nn.Module):
    """Multi-layer Forward-Forward network.
    
    Each layer is trained independently with its own optimizer.
    For inference, we evaluate goodness across all possible labels.
    """
    
    def __init__(
        self, 
        input_dim: int = 784,
        hidden_dims: List[int] = [500, 500],
        num_classes: int = 10,
        threshold: float = 2.0,
        use_soen: bool = False,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.use_soen = use_soen
        
        # Build layers
        self.layers = nn.ModuleList()
        
        dims = [input_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            if use_soen and i == 0:  # Only first layer uses SOEN for now
                layer = SOENFFLayer(dims[i], dims[i+1], threshold=threshold)
            else:
                layer = FFLayer(dims[i], dims[i+1], threshold=threshold)
            self.layers.append(layer)
        
        print(f"Created FF Network with {len(self.layers)} layers:")
        for i, layer in enumerate(self.layers):
            layer_type = "SOEN" if isinstance(layer, SOENFFLayer) else "Standard"
            print(f"  Layer {i}: {dims[i]} -> {dims[i+1]} ({layer_type})")
    
    def forward_one_layer(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
        """Forward pass through layers up to and including layer_idx."""
        for i in range(layer_idx + 1):
            x = self.layers[i](x)
        return x
    
    def compute_layer_goodness(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
        """Compute goodness at a specific layer."""
        output = self.forward_one_layer(x, layer_idx)
        return self.layers[layer_idx].compute_goodness(output)
    
    def predict(self, images: torch.Tensor) -> torch.Tensor:
        """Predict labels by finding which label embedding maximizes total goodness.
        
        For each image, we try all 10 possible labels and pick the one
        that results in highest summed goodness across all layers.
        """
        batch_size = images.shape[0]
        device = images.device
        
        # Flatten images
        flat_images = images.view(batch_size, -1)  # (batch, 784)
        
        # Store goodness for each label
        all_goodness = torch.zeros(batch_size, self.num_classes, device=device)
        
        for label in range(self.num_classes):
            # Create one-hot label
            one_hot = F.one_hot(torch.tensor([label], device=device), self.num_classes)
            one_hot = one_hot.float().expand(batch_size, -1)  # (batch, 10)
            
            # Embed label into image (replace first 10 pixels)
            embedded = flat_images.clone()
            embedded[:, :self.num_classes] = one_hot
            
            # Compute total goodness across all layers
            total_goodness = torch.zeros(batch_size, device=device)
            x = embedded
            for layer in self.layers:
                x = layer(x)
                total_goodness += layer.compute_goodness(x)
            
            all_goodness[:, label] = total_goodness
        
        # Return predicted labels (highest goodness)
        return all_goodness.argmax(dim=1)
    
    def get_layer_optimizer(self, layer_idx: int, lr: float = 0.03):
        """Get optimizer for a specific layer."""
        return torch.optim.Adam(self.layers[layer_idx].parameters(), lr=lr)


# Create network
ff_net = ForwardForwardNet(
    input_dim=784,
    hidden_dims=config.hidden_dims,
    num_classes=config.num_classes,
    threshold=config.threshold,
    use_soen=config.use_soen_layers,
).to(DEVICE)

## 7. Training Loop

The key difference from backprop: **each layer is trained independently** using its own forward passes and local loss function.

In [None]:
def train_forward_forward(
    model: ForwardForwardNet,
    train_loader: DataLoader,
    test_loader: DataLoader,
    num_epochs: int = 60,
    lr: float = 0.03,
    device: torch.device = DEVICE,
):
    """Train using Forward-Forward algorithm.
    
    Key insight: Each layer is trained independently!
    No backward pass through the entire network.
    """
    model.to(device)
    
    # Create separate optimizer for each layer
    optimizers = [model.get_layer_optimizer(i, lr) for i in range(len(model.layers))]
    
    # Training history
    history = {
        "train_loss": [],
        "test_accuracy": [],
        "layer_losses": [[] for _ in range(len(model.layers))],
        "layer_pos_goodness": [[] for _ in range(len(model.layers))],
        "layer_neg_goodness": [[] for _ in range(len(model.layers))],
    }
    
    print(f"\nStarting Forward-Forward Training")
    print(f"="*60)
    print(f"Epochs: {num_epochs}, LR: {lr}, Threshold: {model.layers[0].threshold}")
    print(f"Layers: {len(model.layers)}, Device: {device}")
    print(f"="*60 + "\n")
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = [0.0 for _ in range(len(model.layers))]
        epoch_pos_goodness = [0.0 for _ in range(len(model.layers))]
        epoch_neg_goodness = [0.0 for _ in range(len(model.layers))]
        num_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in pbar:
            pos_data = batch["positive"].to(device)
            neg_data = batch["negative"].to(device)
            
            # Train each layer independently (this is the key FF insight!)
            pos_input = pos_data
            neg_input = neg_data
            
            batch_loss = 0.0
            for layer_idx, (layer, optimizer) in enumerate(zip(model.layers, optimizers)):
                optimizer.zero_grad()
                
                # Forward pass for this layer
                pos_output = layer(pos_input)
                neg_output = layer(neg_input)
                
                # Compute goodness
                pos_goodness = layer.compute_goodness(pos_output)
                neg_goodness = layer.compute_goodness(neg_output)
                
                # Local loss for this layer
                loss = layer.ff_loss(pos_goodness, neg_goodness)
                
                # Backward (only for this layer's parameters!)
                loss.backward()
                optimizer.step()
                
                # Track metrics
                epoch_losses[layer_idx] += loss.item()
                epoch_pos_goodness[layer_idx] += pos_goodness.mean().item()
                epoch_neg_goodness[layer_idx] += neg_goodness.mean().item()
                batch_loss += loss.item()
                
                # Detach for next layer (no gradient flow between layers!)
                pos_input = pos_output.detach()
                neg_input = neg_output.detach()
            
            num_batches += 1
            pbar.set_postfix({"loss": f"{batch_loss/len(model.layers):.4f}"})
        
        # Average metrics
        for i in range(len(model.layers)):
            history["layer_losses"][i].append(epoch_losses[i] / num_batches)
            history["layer_pos_goodness"][i].append(epoch_pos_goodness[i] / num_batches)
            history["layer_neg_goodness"][i].append(epoch_neg_goodness[i] / num_batches)
        
        avg_loss = sum(epoch_losses) / (num_batches * len(model.layers))
        history["train_loss"].append(avg_loss)
        
        # Evaluate on test set
        if (epoch + 1) % 5 == 0 or epoch == 0:
            test_acc = evaluate_ff(model, test_loader, device)
            history["test_accuracy"].append(test_acc)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Test Acc={test_acc:.2%}")
            
            # Print layer-wise goodness
            for i in range(len(model.layers)):
                pos_g = history["layer_pos_goodness"][i][-1]
                neg_g = history["layer_neg_goodness"][i][-1]
                print(f"  Layer {i}: pos_goodness={pos_g:.3f}, neg_goodness={neg_g:.3f}")
    
    return history


def evaluate_ff(model: ForwardForwardNet, test_loader: DataLoader, device: torch.device):
    """Evaluate Forward-Forward model."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            
            predictions = model.predict(images)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    return correct / total

In [None]:
# Train the model
history = train_forward_forward(
    model=ff_net,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=config.num_epochs,
    lr=config.learning_rate,
    device=DEVICE,
)

## 8. Visualize Training Progress

In [None]:
def plot_ff_training_history(history: Dict):
    """Plot Forward-Forward training metrics."""
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Forward-Forward Training Progress', fontsize=14, fontweight='bold')
    
    # 1. Overall loss
    axes[0, 0].plot(history["train_loss"], 'b-', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Test accuracy
    epochs_with_acc = list(range(0, len(history["train_loss"]), 5))
    if 0 not in epochs_with_acc:
        epochs_with_acc = [0] + epochs_with_acc
    axes[0, 1].plot(epochs_with_acc[:len(history["test_accuracy"])], 
                    history["test_accuracy"], 'g-o', linewidth=2, markersize=6)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Test Accuracy')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim([0, 1])
    
    # 3. Layer-wise goodness (positive)
    colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(history["layer_pos_goodness"])))
    for i, (pos_g, neg_g) in enumerate(zip(history["layer_pos_goodness"], 
                                           history["layer_neg_goodness"])):
        axes[1, 0].plot(pos_g, color=colors[i], linestyle='-', 
                        linewidth=2, label=f'Layer {i} (pos)')
        axes[1, 0].plot(neg_g, color=colors[i], linestyle='--', 
                        linewidth=2, alpha=0.7, label=f'Layer {i} (neg)')
    
    # Add threshold line
    threshold = config.threshold
    axes[1, 0].axhline(y=threshold, color='red', linestyle=':', linewidth=2, label=f'Threshold ({threshold})')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Goodness')
    axes[1, 0].set_title('Layer-wise Goodness (solid=positive, dashed=negative)')
    axes[1, 0].legend(loc='best', fontsize=8)
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Layer-wise loss
    for i, layer_loss in enumerate(history["layer_losses"]):
        axes[1, 1].plot(layer_loss, color=colors[i], linewidth=2, label=f'Layer {i}')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_title('Layer-wise FF Loss')
    axes[1, 1].legend(loc='best')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final summary
    print(f"\n{'='*50}")
    print("TRAINING SUMMARY")
    print(f"{'='*50}")
    print(f"Final test accuracy: {history['test_accuracy'][-1]:.2%}")
    print(f"Best test accuracy: {max(history['test_accuracy']):.2%}")
    print(f"\nFinal layer-wise goodness:")
    for i in range(len(history['layer_pos_goodness'])):
        pos = history['layer_pos_goodness'][i][-1]
        neg = history['layer_neg_goodness'][i][-1]
        separation = pos - neg
        print(f"  Layer {i}: pos={pos:.3f}, neg={neg:.3f}, separation={separation:.3f}")


plot_ff_training_history(history)

## 9. Visualize Predictions

In [None]:
def visualize_ff_predictions(model: ForwardForwardNet, test_dataset, n_samples=20):
    """Visualize Forward-Forward predictions with goodness values."""
    
    model.eval()
    
    # Random samples
    indices = np.random.choice(len(test_dataset), n_samples, replace=False)
    
    n_cols = 5
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3.5*n_rows))
    axes = axes.flatten()
    
    fig.suptitle('Forward-Forward Predictions', fontsize=14, fontweight='bold')
    
    correct = 0
    for i, idx in enumerate(indices):
        sample = test_dataset[idx]
        image = sample['image'].unsqueeze(0).to(DEVICE)
        true_label = sample['label'].item()
        
        # Get prediction
        with torch.no_grad():
            pred = model.predict(image).item()
        
        is_correct = pred == true_label
        correct += is_correct
        
        # Plot
        ax = axes[i]
        ax.imshow(sample['image'], cmap='gray')
        
        color = 'green' if is_correct else 'red'
        symbol = 'âœ“' if is_correct else 'âœ—'
        ax.set_title(f"{symbol} Pred: {pred}\nTrue: {true_label}", 
                     color=color, fontsize=10, fontweight='bold' if not is_correct else 'normal')
        ax.axis('off')
    
    # Hide empty subplots
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nSample accuracy: {correct}/{n_samples} ({correct/n_samples:.1%})")


visualize_ff_predictions(ff_net, test_dataset)

## 10. Analyze Goodness Separation

In [None]:
def analyze_goodness_separation(model: ForwardForwardNet, test_loader: DataLoader):
    """Analyze how well positive and negative goodness are separated."""
    
    model.eval()
    
    all_pos_goodness = [[] for _ in range(len(model.layers))]
    all_neg_goodness = [[] for _ in range(len(model.layers))]
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Analyzing goodness"):
            pos_data = batch["positive"].to(DEVICE)
            neg_data = batch["negative"].to(DEVICE)
            
            pos_input = pos_data
            neg_input = neg_data
            
            for layer_idx, layer in enumerate(model.layers):
                pos_output = layer(pos_input)
                neg_output = layer(neg_input)
                
                pos_g = layer.compute_goodness(pos_output)
                neg_g = layer.compute_goodness(neg_output)
                
                all_pos_goodness[layer_idx].extend(pos_g.cpu().numpy())
                all_neg_goodness[layer_idx].extend(neg_g.cpu().numpy())
                
                pos_input = pos_output
                neg_input = neg_output
    
    # Plot distributions
    fig, axes = plt.subplots(1, len(model.layers), figsize=(6*len(model.layers), 5))
    if len(model.layers) == 1:
        axes = [axes]
    
    fig.suptitle('Goodness Distributions by Layer', fontsize=14, fontweight='bold')
    
    threshold = config.threshold
    
    for i, ax in enumerate(axes):
        pos_g = np.array(all_pos_goodness[i])
        neg_g = np.array(all_neg_goodness[i])
        
        # Histograms
        ax.hist(pos_g, bins=50, alpha=0.6, color='green', label=f'Positive (Î¼={pos_g.mean():.2f})', density=True)
        ax.hist(neg_g, bins=50, alpha=0.6, color='red', label=f'Negative (Î¼={neg_g.mean():.2f})', density=True)
        
        # Threshold
        ax.axvline(x=threshold, color='black', linestyle='--', linewidth=2, label=f'Threshold ({threshold})')
        
        ax.set_xlabel('Goodness', fontsize=12)
        ax.set_ylabel('Density', fontsize=12)
        ax.set_title(f'Layer {i}', fontsize=12)
        ax.legend(loc='best')
        ax.grid(True, alpha=0.3)
        
        # Print separation stats
        separation = pos_g.mean() - neg_g.mean()
        overlap = np.sum((pos_g < threshold) | (neg_g > threshold)) / (len(pos_g) + len(neg_g))
        print(f"Layer {i}: separation={separation:.3f}, ~overlap={overlap:.1%}")
    
    plt.tight_layout()
    plt.show()


analyze_goodness_separation(ff_net, test_loader)

## 11. Compare with Backpropagation (Optional)

For reference, let's see how a similar architecture performs with standard backprop.

In [None]:
class BackpropNet(nn.Module):
    """Standard backprop network for comparison."""
    
    def __init__(self, input_dim=784, hidden_dims=[500, 500], num_classes=10):
        super().__init__()
        
        layers = []
        dims = [input_dim] + hidden_dims + [num_classes]
        
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:  # No activation after last layer
                layers.append(nn.ReLU())
                layers.append(nn.LayerNorm(dims[i+1]))
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        return self.net(x)


def train_backprop(model, train_loader, test_loader, num_epochs=20, lr=0.001):
    """Train with standard backpropagation."""
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    history = {"train_loss": [], "test_accuracy": []}
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Backprop Epoch {epoch+1}"):
            images = batch["image"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()  # This is the key difference from FF!
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        history["train_loss"].append(avg_loss)
        
        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in test_loader:
                images = batch["image"].to(DEVICE)
                labels = batch["label"].to(DEVICE)
                outputs = model(images)
                _, predicted = outputs.max(1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        
        acc = correct / total
        history["test_accuracy"].append(acc)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Test Acc={acc:.2%}")
    
    return history


# Uncomment to train backprop comparison
# bp_net = BackpropNet(hidden_dims=config.hidden_dims)
# bp_history = train_backprop(bp_net, train_loader, test_loader, num_epochs=20)

## Summary

### What We Learned

1. **Forward-Forward Algorithm**:
   - Each layer learns **independently** using local goodness
   - No backward pass required â†’ natural fit for neuromorphic hardware
   - Uses positive/negative contrastive data with embedded labels

2. **Key Differences from Backprop**:
   | Aspect | Backprop | Forward-Forward |
   |--------|----------|------------------|
   | Gradient flow | Global (through all layers) | Local (per layer) |
   | Memory | Store all activations | No storage needed |
   | Data | Standard (image, label) | Contrastive (pos/neg pairs) |
   | Hardware | Requires external compute | On-chip learning possible |

3. **SOEN Advantages**:
   - FF aligns with SOEN's forward-only physics
   - Local learning is noise-tolerant
   - Potential for on-chip training without external GPUs

### Expected Results

- Forward-Forward typically achieves ~98% on MNIST (vs ~99% for backprop)
- The slight accuracy gap is offset by hardware advantages
- Goodness separation should show clear positive/negative clusters

### Next Steps

- Try different threshold values
- Experiment with more SOEN layers
- Implement symmetric negative data generation
- Test on more complex datasets (Fashion-MNIST, CIFAR-10)