# Forward-Forward MNIST Classification (Row-by-Row Temporal)

10-class MNIST classification using Forward-Forward algorithm with **temporal row scanning**.

## Architecture

```
Input: 28 pixels (one row) + 10 (one-hot label) = 38 per timestep
Timesteps: 28 (one per row, scanning top to bottom)
Hidden: ≤24 SingleDendrite neurons
Output: Goodness (sum of squared activations at final timestep)
```

## Key Innovation: Temporal Image Scanning

Instead of feeding all 784 pixels at once:
- Feed image **row by row** (28 pixels per timestep)
- Network accumulates information over 28 timesteps
- Final state captures full image representation
- Temporal dynamics become meaningful!

## Inference (10 forward passes)

```
For each class c ∈ {0, 1, ..., 9}:
    For t in 0..27:
        X_t = [row_t_pixels, one_hot(c)]  # 38 dims
    goodness_c = compute_goodness(final_state)
Predict: argmax(goodness_0, ..., goodness_9)
```

## Hardware Compatibility

- Goodness = mean(I²) = power measurement
- Label embedding = optical input modulation
- Temporal scanning = natural sequential processing
- No backward pass needed for inference

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import gzip
import urllib.request

from soen_toolkit.core import (
    ConnectionConfig,
    LayerConfig,
    SimulationConfig,
    SOENModelCore,
)

torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

## 1. Load MNIST Dataset

In [None]:
# Direct MNIST download without torchvision
def download_mnist(data_dir='./data/mnist'):
    """Download MNIST dataset without torchvision."""
    os.makedirs(data_dir, exist_ok=True)
    
    base_url = 'https://ossci-datasets.s3.amazonaws.com/mnist/'
    files = {
        'train_images': 'train-images-idx3-ubyte.gz',
        'train_labels': 'train-labels-idx1-ubyte.gz',
        'test_images': 't10k-images-idx3-ubyte.gz',
        'test_labels': 't10k-labels-idx1-ubyte.gz',
    }
    
    paths = {}
    for key, filename in files.items():
        filepath = os.path.join(data_dir, filename)
        if not os.path.exists(filepath):
            print(f"Downloading {filename}...")
            urllib.request.urlretrieve(base_url + filename, filepath)
        paths[key] = filepath
    
    return paths


def load_mnist_images(filepath):
    """Load MNIST images from gzipped IDX file - keep as 28x28."""
    with gzip.open(filepath, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        n_images = int.from_bytes(f.read(4), 'big')
        n_rows = int.from_bytes(f.read(4), 'big')
        n_cols = int.from_bytes(f.read(4), 'big')
        data = np.frombuffer(f.read(), dtype=np.uint8)
        # Keep as [N, 28, 28] for row-by-row processing
        return data.reshape(n_images, n_rows, n_cols).astype(np.float32) / 255.0


def load_mnist_labels(filepath):
    """Load MNIST labels from gzipped IDX file."""
    with gzip.open(filepath, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        n_labels = int.from_bytes(f.read(4), 'big')
        return np.frombuffer(f.read(), dtype=np.uint8)


# Download and load MNIST
paths = download_mnist()
X_train_full = torch.from_numpy(load_mnist_images(paths['train_images']))
y_train_full = torch.from_numpy(load_mnist_labels(paths['train_labels'])).long()
X_test_full = torch.from_numpy(load_mnist_images(paths['test_images']))
y_test_full = torch.from_numpy(load_mnist_labels(paths['test_labels'])).long()

print(f"Full dataset: Train={X_train_full.shape}, Test={X_test_full.shape}")
print(f"Image shape: {X_train_full.shape[1:]} (28 rows × 28 cols)")

# Use training data subset
N_TRAIN = 20000
N_TEST = 2000

torch.manual_seed(42)
train_idx = torch.randperm(len(X_train_full))[:N_TRAIN]
test_idx = torch.randperm(len(X_test_full))[:N_TEST]

X_train = X_train_full[train_idx]
y_train = y_train_full[train_idx]
X_test = X_test_full[test_idx]
y_test = y_test_full[test_idx]

# Scale to SOEN operating range [0.025, 0.275]
X_train = X_train * 0.25 + 0.025
X_test = X_test * 0.25 + 0.025

print(f"\nUsing subset:")
print(f"  Training set: {X_train.shape} (N × rows × cols)")
print(f"  Test set: {X_test.shape}")
print(f"  X range: [{X_train.min():.3f}, {X_train.max():.3f}]")
print(f"  Class distribution (train): {torch.bincount(y_train)}")

In [None]:
# Visualize samples and row-by-row scanning concept
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    img = X_train[i].numpy()
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Label: {y_train[i].item()}')
    ax.axis('off')
plt.suptitle('MNIST Samples (will be scanned row-by-row: 28 timesteps)')
plt.tight_layout()
plt.show()

# Show row-by-row concept
fig, axes = plt.subplots(1, 4, figsize=(14, 3))
sample_img = X_train[0].numpy()

axes[0].imshow(sample_img, cmap='gray')
axes[0].set_title('Full Image')
axes[0].axis('off')

for i, t in enumerate([0, 13, 27]):
    axes[i+1].bar(range(28), sample_img[t], color='steelblue')
    axes[i+1].set_ylim(0, 0.3)
    axes[i+1].set_xlabel('Pixel')
    axes[i+1].set_ylabel('Value')
    axes[i+1].set_title(f't={t} (row {t})')

plt.suptitle('Row-by-Row Temporal Input: Each row = one timestep', fontsize=12)
plt.tight_layout()
plt.show()

## 2. Forward-Forward Functions

In [None]:
N_CLASSES = 10
N_ROWS = 28      # Number of timesteps (rows in image)
N_COLS = 28      # Pixels per row
LABEL_SCALE = 0.25

# Input dimension: 28 pixels + 10 label = 38 per timestep
INPUT_DIM_PER_ROW = N_COLS + N_CLASSES

def embed_label_temporal(X, y, n_classes=N_CLASSES, label_scale=LABEL_SCALE):
    """
    Embed one-hot label into each row of the temporal sequence.
    
    Args:
        X: [N, 28, 28] images (N samples, 28 rows, 28 cols)
        y: [N] class labels (0-9)
    
    Returns:
        X_embedded: [N, 28, 38] - each row has 28 pixels + 10 label dims
    """
    N = X.shape[0]
    
    # Create one-hot labels [N, 10]
    one_hot = torch.zeros(N, n_classes)
    one_hot.scatter_(1, y.unsqueeze(1), label_scale)
    
    # Expand to [N, 28, 10] - same label at each timestep
    one_hot_expanded = one_hot.unsqueeze(1).expand(-1, N_ROWS, -1)
    
    # Concatenate: [N, 28, 28] + [N, 28, 10] = [N, 28, 38]
    return torch.cat([X, one_hot_expanded], dim=2)


def create_positive_negative_pairs_temporal(X, y, n_classes=N_CLASSES, label_scale=LABEL_SCALE):
    """
    Create positive and negative temporal sequences for Forward-Forward.
    
    Positive: image rows with correct label at each timestep
    Negative: image rows with random wrong label at each timestep
    """
    N = X.shape[0]
    
    # Positive: correct labels
    X_pos = embed_label_temporal(X, y, n_classes, label_scale)
    
    # Negative: random wrong labels
    y_wrong = (y + torch.randint(1, n_classes, (N,))) % n_classes
    X_neg = embed_label_temporal(X, y_wrong, n_classes, label_scale)
    
    return X_pos, X_neg


# Test embedding
X_pos, X_neg = create_positive_negative_pairs_temporal(X_train[:5], y_train[:5])
print(f"Input shape: {X_train[:5].shape} (N × rows × cols)")
print(f"Embedded shape: {X_pos.shape} (N × timesteps × features)")
print(f"Features per timestep: {N_COLS} pixels + {N_CLASSES} label = {INPUT_DIM_PER_ROW}")
print(f"Timesteps: {N_ROWS} (one per row)")

In [None]:
def compute_goodness(activations):
    """
    Compute goodness as mean of squared activations.
    Hardware-compatible: measures mean power in the layer.
    """
    return (activations ** 2).mean(dim=1)


def forward_forward_loss(goodness_pos, goodness_neg, margin=0.01):
    """
    Contrastive Forward-Forward loss.
    
    Push G_pos to be greater than G_neg by at least margin.
    Small margin (0.01) works well when separation is ~0.1.
    """
    return F.softplus(margin - (goodness_pos - goodness_neg)).mean()

## 3. Build SOEN Model for MNIST

In [None]:
def build_ff_temporal_model(hidden_dims, input_dim=INPUT_DIM_PER_ROW, dt=1.0, gamma_minus=1e-6):
    """
    Build a SOEN model for Forward-Forward with temporal row scanning.
    
    Args:
        hidden_dims: List of hidden layer dimensions (e.g., [24] or [12, 12])
        input_dim: 28 pixels + 10 label = 38 per timestep
        gamma_minus: Decay rate. Use 1e-6 for pure accumulation (recommended)
                     Higher values cause information loss over 28 timesteps.
    
    Key difference from flat input:
    - Much smaller input dimension (38 vs 794)
    - Temporal dynamics accumulate information over 28 timesteps
    - Network "scans" image row by row
    
    IMPORTANT: gamma_minus=0.05 causes 77% information loss over 28 steps!
    With gamma_minus=1e-6, the network acts as a pure accumulator.
    """
    sim_cfg = SimulationConfig(
        dt=dt,
        input_type="state",
        track_phi=False,
        track_power=False,
    )
    
    layers = []
    connections = []
    
    # Input layer - 38 dims per timestep
    layers.append(LayerConfig(
        layer_id=0,
        layer_type="Input",
        params={"dim": input_dim},
    ))
    
    # Hidden layers
    for i, hidden_dim in enumerate(hidden_dims):
        layer_id = i + 1
        
        layers.append(LayerConfig(
            layer_id=layer_id,
            layer_type="SingleDendrite",
            params={
                "dim": hidden_dim,
                "solver": "FE",
                "source_func": "Heaviside_fit_state_dep",
                "phi_offset": 0.02,
                "bias_current": 1.98,
                "gamma_plus": 1.0,
                "gamma_minus": gamma_minus,  # 1e-6 = pure accumulator (no forgetting)
                "learnable_params": {
                    "phi_offset": False,
                    "bias_current": False,
                    "gamma_plus": False,
                    "gamma_minus": False,
                },
            },
        ))
        
        connections.append(ConnectionConfig(
            from_layer=layer_id - 1,
            to_layer=layer_id,
            connection_type="all_to_all",
            learnable=True,
            params={"init": "xavier_uniform"},
        ))
    
    model = SOENModelCore(
        sim_config=sim_cfg,
        layers_config=layers,
        connections_config=connections,
    )
    
    return model


# Test model with pure accumulation (gamma_minus=1e-6)
HIDDEN_DIMS = [24]
test_model = build_ff_temporal_model(HIDDEN_DIMS, gamma_minus=1e-6)
n_params = sum(p.numel() for p in test_model.parameters() if p.requires_grad)
print(f"Model architecture: {INPUT_DIM_PER_ROW} → {HIDDEN_DIMS} → goodness")
print(f"Input: {INPUT_DIM_PER_ROW} = {N_COLS} pixels + {N_CLASSES} label (per timestep)")
print(f"Timesteps: {N_ROWS} (rows)")
print(f"Parameters: {n_params}")
print(f"gamma_minus=1e-6 (pure accumulator - NO information loss over 28 steps)")
print(f"\nWhy pure accumulator?")
print(f"  With gamma_minus=0.05: decay^28 = 0.95^28 = {0.95**28:.2%} retention")
print(f"  With gamma_minus=1e-6: decay^28 ≈ 100% retention")

## 4. Training Functions

In [None]:
def evaluate_ff_temporal_fast(model, X, y, batch_size=100, goodness_mode='final'):
    """
    Fast evaluation for temporal row-by-row processing.
    
    Args:
        goodness_mode: 'final' = only last timestep, 'all' = sum all timesteps
    """
    model.eval()
    N = X.shape[0]
    all_predictions = []
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            X_batch = X[start:end]  # [B, 28, 28]
            B = X_batch.shape[0]
            
            # Repeat each sample N_CLASSES times
            X_repeated = X_batch.unsqueeze(1).expand(-1, N_CLASSES, -1, -1).reshape(B * N_CLASSES, N_ROWS, N_COLS)
            y_hypotheses = torch.arange(N_CLASSES).unsqueeze(0).expand(B, -1).reshape(B * N_CLASSES)
            
            X_embedded = embed_label_temporal(X_repeated, y_hypotheses)
            _, layer_states = model(X_embedded)
            
            # Compute goodness based on mode
            total_goodness = torch.zeros(B * N_CLASSES)
            for layer_idx in range(1, len(model.layers)):
                if goodness_mode == 'final':
                    act = layer_states[layer_idx][:, -1, :]  # Final timestep only
                    total_goodness += compute_goodness(act)
                elif goodness_mode == 'all':
                    # Sum goodness over all timesteps (better gradient signal)
                    for t in range(layer_states[layer_idx].shape[1]):
                        act = layer_states[layer_idx][:, t, :]
                        total_goodness += compute_goodness(act)
            
            goodness_matrix = total_goodness.reshape(B, N_CLASSES)
            predictions = goodness_matrix.argmax(dim=1)
            all_predictions.append(predictions)
    
    all_predictions = torch.cat(all_predictions)
    accuracy = (all_predictions == y).float().mean().item()
    model.train()
    return accuracy


def train_forward_forward_temporal(model, X_train, y_train, X_test, y_test,
                                    n_epochs=100, lr=0.01, margin=0.01,
                                    batch_size=64, eval_subset=1000, verbose=True,
                                    weight_decay=1e-4, lr_decay=0.98,
                                    goodness_mode='all'):
    """
    Train SOEN model with Forward-Forward using temporal row scanning.
    
    Args:
        goodness_mode: 'final' = only last timestep (weak gradient for early rows)
                       'all' = sum all timesteps (better gradient signal)
    
    For multi-layer networks, we accumulate losses from all layers and do
    a single backward pass to avoid gradient computation issues.
    """
    model.train()
    
    # Single optimizer for all parameters (avoids gradient computation issues)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay)
    
    # Track which layers are hidden
    hidden_layer_indices = [i for i, l in enumerate(model.layers) if l.layer_type != 'Input']
    
    history = {
        'loss': [],
        'train_acc': [],
        'test_acc': [],
        'goodness_pos': [],
        'goodness_neg': [],
        'lr': [],
    }
    
    N = X_train.shape[0]
    n_batches = (N + batch_size - 1) // batch_size
    
    # Subset for fast evaluation
    eval_idx = torch.randperm(N)[:min(eval_subset, N)]
    X_train_eval = X_train[eval_idx]
    y_train_eval = y_train[eval_idx]
    
    best_test_acc = 0
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        epoch_g_pos = []
        epoch_g_neg = []
        
        # Shuffle data
        perm = torch.randperm(N)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        for batch_idx in range(n_batches):
            start = batch_idx * batch_size
            end = min(start + batch_size, N)
            
            X_batch = X_shuffled[start:end]  # [B, 28, 28]
            y_batch = y_shuffled[start:end]
            
            # Create pos/neg pairs with temporal embedding
            X_pos, X_neg = create_positive_negative_pairs_temporal(X_batch, y_batch)
            
            optimizer.zero_grad()
            
            # Forward pass through all 28 timesteps
            _, layer_states_pos = model(X_pos)
            _, layer_states_neg = model(X_neg)
            
            # Accumulate loss from all hidden layers
            total_loss = 0
            batch_g_pos_list = []
            batch_g_neg_list = []
            
            for layer_idx in hidden_layer_indices:
                if goodness_mode == 'final':
                    # Only final timestep (weak gradient for early rows)
                    act_pos = layer_states_pos[layer_idx][:, -1, :]
                    act_neg = layer_states_neg[layer_idx][:, -1, :]
                    g_pos = compute_goodness(act_pos)
                    g_neg = compute_goodness(act_neg)
                elif goodness_mode == 'all':
                    # Sum goodness over ALL timesteps (better gradient signal)
                    n_timesteps = layer_states_pos[layer_idx].shape[1]
                    g_pos = torch.zeros(X_batch.shape[0])
                    g_neg = torch.zeros(X_batch.shape[0])
                    for t in range(n_timesteps):
                        act_pos = layer_states_pos[layer_idx][:, t, :]
                        act_neg = layer_states_neg[layer_idx][:, t, :]
                        g_pos = g_pos + compute_goodness(act_pos)
                        g_neg = g_neg + compute_goodness(act_neg)
                
                batch_g_pos_list.append(g_pos.mean().item())
                batch_g_neg_list.append(g_neg.mean().item())
                
                layer_loss = forward_forward_loss(g_pos, g_neg, margin)
                total_loss = total_loss + layer_loss
            
            # Single backward pass for accumulated loss
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            batch_loss = total_loss.item()
            batch_g_pos = np.mean(batch_g_pos_list)
            batch_g_neg = np.mean(batch_g_neg_list)
            
            epoch_loss += batch_loss
            epoch_g_pos.append(batch_g_pos)
            epoch_g_neg.append(batch_g_neg)
            
            if verbose and batch_idx % 50 == 0:
                print(f"\rEpoch {epoch+1}/{n_epochs} | Batch {batch_idx+1}/{n_batches} | "
                      f"Loss: {batch_loss:.4f} | G+: {batch_g_pos:.4f} | G-: {batch_g_neg:.4f}", end="")
        
        # Step LR scheduler
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Evaluate
        train_acc = evaluate_ff_temporal_fast(model, X_train_eval, y_train_eval, goodness_mode=goodness_mode)
        test_acc = evaluate_ff_temporal_fast(model, X_test, y_test, goodness_mode=goodness_mode)
        
        if test_acc > best_test_acc:
            best_test_acc = test_acc
        
        history['loss'].append(epoch_loss / n_batches)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        history['goodness_pos'].append(np.mean(epoch_g_pos))
        history['goodness_neg'].append(np.mean(epoch_g_neg))
        history['lr'].append(current_lr)
        
        if verbose:
            sep = np.mean(epoch_g_pos) - np.mean(epoch_g_neg)
            print(f"\rEpoch {epoch+1}/{n_epochs} | Loss: {epoch_loss/n_batches:.4f} | "
                  f"Train: {train_acc:.4f} | Test: {test_acc:.4f} | "
                  f"Best: {best_test_acc:.4f} | Sep: {sep:.4f}    ")
    
    return history

## 5. Train the Model

In [None]:
# Build temporal model with 24 hidden neurons
HIDDEN_DIMS = [24]
MARGIN = 0.01
N_EPOCHS = 100
LR = 0.01
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-4
LR_DECAY = 0.98
GAMMA_MINUS = 1e-6  # Pure accumulator - critical for temporal processing!
GOODNESS_MODE = 'all'  # Sum all timesteps for better gradient signal

print(f"Training Forward-Forward MNIST with TEMPORAL ROW SCANNING...")
print(f"Architecture: {INPUT_DIM_PER_ROW} → {HIDDEN_DIMS} → goodness")
print(f"Input per timestep: {N_COLS} pixels + {N_CLASSES} label = {INPUT_DIM_PER_ROW}")
print(f"Timesteps: {N_ROWS} (one per row)")
print(f"Total neurons: {sum(HIDDEN_DIMS)} (constraint: <26)")
print(f"Margin: {MARGIN}, Initial LR: {LR}")
print(f"Training samples: {N_TRAIN}, Test samples: {N_TEST}")
print(f"\nKEY SETTINGS:")
print(f"  gamma_minus = {GAMMA_MINUS} (pure accumulator, no information loss)")
print(f"  goodness_mode = '{GOODNESS_MODE}' (all timesteps → better gradient)")
print("=" * 80)

torch.manual_seed(42)
model = build_ff_temporal_model(HIDDEN_DIMS, gamma_minus=GAMMA_MINUS)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {n_params} (much smaller than flat: {INPUT_DIM_PER_ROW}*{HIDDEN_DIMS[0]} vs 794*24)")

history = train_forward_forward_temporal(
    model, X_train, y_train, X_test, y_test,
    n_epochs=N_EPOCHS, lr=LR, margin=MARGIN,
    batch_size=BATCH_SIZE, verbose=True,
    weight_decay=WEIGHT_DECAY, lr_decay=LR_DECAY,
    goodness_mode=GOODNESS_MODE
)

print("=" * 80)
print(f"Final train accuracy: {history['train_acc'][-1]:.4f}")
print(f"Final test accuracy: {history['test_acc'][-1]:.4f}")
print(f"Best test accuracy: {max(history['test_acc']):.4f}")
print(f"Random baseline: 10%")

## 6. Training Curves

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Loss
ax1 = axes[0, 0]
ax1.plot(history['loss'], color='steelblue', lw=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Contrastive Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

# Goodness
ax2 = axes[0, 1]
ax2.plot(history['goodness_pos'], label='Positive (G+)', color='green', lw=2)
ax2.plot(history['goodness_neg'], label='Negative (G-)', color='red', lw=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Mean Goodness')
ax2.set_title('Goodness Values')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning rate
ax3 = axes[0, 2]
ax3.plot(history['lr'], color='orange', lw=2)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_title('Learning Rate Decay')
ax3.grid(True, alpha=0.3)

# Accuracy
ax4 = axes[1, 0]
ax4.plot(history['train_acc'], label='Train', color='coral', lw=2)
ax4.plot(history['test_acc'], label='Test', color='steelblue', lw=2)
ax4.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Random (10%)')
best_epoch = np.argmax(history['test_acc'])
ax4.axvline(x=best_epoch, color='green', linestyle=':', alpha=0.7, label=f'Best ({max(history["test_acc"]):.2%})')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accuracy')
ax4.set_title('Classification Accuracy')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_ylim(0, 1.0)

# Separation
ax5 = axes[1, 1]
separation = [p - n for p, n in zip(history['goodness_pos'], history['goodness_neg'])]
ax5.plot(separation, color='purple', lw=2)
ax5.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax5.set_xlabel('Epoch')
ax5.set_ylabel('G+ - G-')
ax5.set_title('Goodness Separation')
ax5.grid(True, alpha=0.3)

# Train vs Test gap
ax6 = axes[1, 2]
gap = [t - v for t, v in zip(history['train_acc'], history['test_acc'])]
ax6.plot(gap, color='brown', lw=2)
ax6.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax6.set_xlabel('Epoch')
ax6.set_ylabel('Train - Test')
ax6.set_title('Generalization Gap')
ax6.grid(True, alpha=0.3)

plt.suptitle(f'Forward-Forward MNIST ({sum(HIDDEN_DIMS)} neurons, 10 classes)', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Confusion Matrix

In [None]:
def get_predictions_temporal(model, X, batch_size=100, goodness_mode='all'):
    """Get predictions and goodness for temporal model."""
    model.eval()
    N = X.shape[0]
    all_predictions = []
    all_goodness = []
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            X_batch = X[start:end]  # [B, 28, 28]
            B = X_batch.shape[0]
            
            X_repeated = X_batch.unsqueeze(1).expand(-1, N_CLASSES, -1, -1).reshape(B * N_CLASSES, N_ROWS, N_COLS)
            y_hypotheses = torch.arange(N_CLASSES).unsqueeze(0).expand(B, -1).reshape(B * N_CLASSES)
            
            X_embedded = embed_label_temporal(X_repeated, y_hypotheses)
            _, layer_states = model(X_embedded)
            
            total_goodness = torch.zeros(B * N_CLASSES)
            for layer_idx in range(1, len(model.layers)):
                if goodness_mode == 'final':
                    act = layer_states[layer_idx][:, -1, :]
                    total_goodness += compute_goodness(act)
                elif goodness_mode == 'all':
                    for t in range(layer_states[layer_idx].shape[1]):
                        act = layer_states[layer_idx][:, t, :]
                        total_goodness += compute_goodness(act)
            
            goodness_matrix = total_goodness.reshape(B, N_CLASSES)
            all_goodness.append(goodness_matrix)
            all_predictions.append(goodness_matrix.argmax(dim=1))
    
    return torch.cat(all_predictions), torch.cat(all_goodness)


def compute_confusion_matrix(y_true, y_pred, n_classes=N_CLASSES):
    """Compute confusion matrix."""
    cm = np.zeros((n_classes, n_classes), dtype=np.int32)
    for true, pred in zip(y_true, y_pred):
        cm[true, pred] += 1
    return cm


# Get test predictions
test_preds, test_goodness = get_predictions_temporal(model, X_test, goodness_mode=GOODNESS_MODE)

# Confusion matrix
cm = compute_confusion_matrix(y_test.numpy(), test_preds.numpy())

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(cm, cmap='Blues')
ax.set_xticks(range(N_CLASSES))
ax.set_yticks(range(N_CLASSES))
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'Confusion Matrix - Temporal Model (Test Acc: {history["test_acc"][-1]:.2%})')

for i in range(N_CLASSES):
    for j in range(N_CLASSES):
        text = ax.text(j, i, cm[i, j], ha='center', va='center', fontsize=10,
                       color='white' if cm[i, j] > cm.max()/2 else 'black')

plt.colorbar(im)
plt.tight_layout()
plt.show()

# Per-class accuracy
print("\nPer-class accuracy:")
for digit in range(N_CLASSES):
    mask = y_test == digit
    if mask.sum() > 0:
        digit_acc = (test_preds[mask] == digit).float().mean().item()
        print(f"  Digit {digit}: {digit_acc:.2%}")

## 8. Visualize Predictions

In [None]:
# Show some predictions with temporal scanning visualization
n_show = 20
fig, axes = plt.subplots(4, 5, figsize=(15, 12))

for i, ax in enumerate(axes.flat):
    if i >= n_show:
        break
    
    img = X_test[i].numpy()  # [28, 28]
    true_label = y_test[i].item()
    pred_label = test_preds[i].item()
    
    ax.imshow(img, cmap='gray')
    color = 'green' if pred_label == true_label else 'red'
    ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
    ax.axis('off')

plt.suptitle('Temporal Forward-Forward Predictions (28 timesteps per image)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Show goodness distribution and temporal dynamics
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i, ax in enumerate(axes.flat):
    goodness_vals = test_goodness[i].numpy()
    true_label = y_test[i].item()
    pred_label = test_preds[i].item()
    
    colors = ['green' if d == true_label else 'lightgray' for d in range(N_CLASSES)]
    colors[pred_label] = 'red' if pred_label != true_label else 'green'
    
    ax.bar(range(N_CLASSES), goodness_vals, color=colors)
    ax.set_xticks(range(N_CLASSES))
    ax.set_xlabel('Digit')
    ax.set_ylabel('Goodness')
    status = '✓' if pred_label == true_label else '✗'
    ax.set_title(f'True: {true_label}, Pred: {pred_label} {status}')

plt.suptitle('Goodness Distribution (accumulated over 28 timesteps)', fontsize=14)
plt.tight_layout()
plt.show()

# Visualize temporal activation dynamics for one sample
print("\nVisualizing temporal dynamics for sample 0...")
sample_idx = 0
sample_img = X_test[sample_idx:sample_idx+1]  # [1, 28, 28]
true_label = y_test[sample_idx].item()

# Get activations over time for correct label
X_embedded = embed_label_temporal(sample_img, torch.tensor([true_label]))
model.eval()
with torch.no_grad():
    _, layer_states = model(X_embedded)
    hidden_states = layer_states[1]  # [1, 28, 24] - hidden layer over time

# Plot activation evolution
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Left: image with scan line indication
ax1 = axes[0]
ax1.imshow(sample_img[0].numpy(), cmap='gray')
ax1.set_title(f'Sample (label={true_label})')
ax1.set_xlabel('Column (pixels per timestep)')
ax1.set_ylabel('Row (timestep)')

# Right: hidden neuron activations over time
ax2 = axes[1]
activations = hidden_states[0].numpy()  # [28, 24]
im = ax2.imshow(activations.T, aspect='auto', cmap='viridis')
ax2.set_xlabel('Timestep (row)')
ax2.set_ylabel('Neuron')
ax2.set_title('Hidden Layer Activation Over Time')
plt.colorbar(im, ax=ax2, label='Activation')

plt.suptitle('Temporal Dynamics: Network "Scans" Image Row by Row', fontsize=12)
plt.tight_layout()
plt.show()

## 9. Compare with Different Hidden Sizes

In [None]:
# Compare different temporal architectures (all <26 neurons)
hidden_configs = [
    [8],       # 8 neurons
    [12],      # 12 neurons  
    [16],      # 16 neurons
    [20],      # 20 neurons
    [24],      # 24 neurons
    [12, 12],  # Two-layer: 24 total
]

comparison_results = []

print("Comparing TEMPORAL architectures (all <26 neurons)...")
print(f"Input: {INPUT_DIM_PER_ROW} features × {N_ROWS} timesteps")
print(f"gamma_minus = {GAMMA_MINUS} (pure accumulator)")
print(f"goodness_mode = '{GOODNESS_MODE}'")
print("=" * 80)

for hidden_dims in hidden_configs:
    torch.manual_seed(42)
    model = build_ff_temporal_model(hidden_dims, gamma_minus=GAMMA_MINUS)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_neurons = sum(hidden_dims)
    
    history = train_forward_forward_temporal(
        model, X_train, y_train, X_test, y_test,
        n_epochs=50, lr=0.01, margin=0.01,
        batch_size=64, verbose=False,
        weight_decay=1e-4, lr_decay=0.98,
        goodness_mode=GOODNESS_MODE
    )
    
    best_test = max(history['test_acc'])
    comparison_results.append({
        'hidden_dims': str(hidden_dims),
        'total_neurons': total_neurons,
        'n_params': n_params,
        'train_acc': history['train_acc'][-1],
        'test_acc': history['test_acc'][-1],
        'best_test': best_test,
    })
    
    print(f"Hidden={str(hidden_dims):12s} | Neurons={total_neurons:2d} | Params={n_params:5d} | "
          f"Final: {history['test_acc'][-1]:.4f} | Best: {best_test:.4f}")

print("=" * 80)

# Find best architecture
best_result = max(comparison_results, key=lambda x: x['best_test'])
print(f"\nBest temporal architecture: {best_result['hidden_dims']} with {best_result['best_test']:.2%} test accuracy")
print(f"Note: Parameters much smaller due to {INPUT_DIM_PER_ROW} input vs 794 (flat)")

## 10. Conclusions

In [None]:
print("=" * 70)
print("CONCLUSIONS: TEMPORAL FORWARD-FORWARD MNIST")
print("=" * 70)

print(f"\n1. ARCHITECTURE (Temporal Row Scanning):")
print(f"   Input per timestep: {N_COLS} pixels + {N_CLASSES} label = {INPUT_DIM_PER_ROW}")
print(f"   Timesteps: {N_ROWS} (one per image row)")
print(f"   Hidden: {sum(HIDDEN_DIMS)} SingleDendrite neurons")
print(f"   Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

print(f"\n2. CRITICAL PARAMETERS:")
print(f"   gamma_minus = {GAMMA_MINUS} (pure accumulator)")
print(f"   goodness_mode = '{GOODNESS_MODE}' (all timesteps)")
print(f"")
print(f"   WHY THESE MATTER:")
print(f"   - gamma_minus=0.05 causes 77% info loss (0.95^28 = {0.95**28:.1%})")
print(f"   - gamma_minus=1e-6 gives ~100% retention")
print(f"   - goodness_mode='all' provides gradient to ALL rows")
print(f"   - goodness_mode='final' only trains last few rows")

print(f"\n3. PARAMETER EFFICIENCY:")
print(f"   Flat input:     794 dims × 1 timestep  = 794 input features")
print(f"   Temporal input: 38 dims × 28 timesteps = 1064 total features")
print(f"   But parameters: {INPUT_DIM_PER_ROW}×{sum(HIDDEN_DIMS)} = {INPUT_DIM_PER_ROW*sum(HIDDEN_DIMS)} vs 794×24 = 19056")
print(f"   → {19056 // (INPUT_DIM_PER_ROW*sum(HIDDEN_DIMS))}× fewer parameters!")

print(f"\n4. PERFORMANCE:")
print(f"   Train accuracy: {history['train_acc'][-1]:.2%}")
print(f"   Test accuracy:  {history['test_acc'][-1]:.2%}")
print(f"   Best test:      {max(history['test_acc']):.2%}")
print(f"   Random baseline: 10%")

print(f"\n5. HARDWARE ADVANTAGES:")
print(f"   ✓ Smaller input fanout ({INPUT_DIM_PER_ROW} vs 794)")
print(f"   ✓ Temporal dynamics meaningful (scanning)")
print(f"   ✓ Sequential processing = natural for hardware")
print(f"   ✓ Memory through state accumulation")
print(f"   ✓ Only {sum(HIDDEN_DIMS)} physical neurons needed!")

print(f"\n6. INFERENCE PIPELINE:")
print(f"   For each digit hypothesis (10 total):")
print(f"     - Present image row-by-row (28 timesteps)")
print(f"     - Each timestep: 28 pixels + 10-dim label = 38 inputs")
print(f"     - Measure goodness (power) at each timestep, sum all")
print(f"   Total: 10 × 28 = 280 timesteps per classification")

print("\n" + "=" * 70)