# Forward-Forward MNIST Classification with Cross-Neuron Recurrence (W_hh)

**UPGRADE**: Adds hidden-to-hidden connections (W_hh) to enable cross-neuron temporal communication.

## Key Difference from Original Temporal Version

| Version | Dynamics | Description |
|---------|----------|-------------|
| Original | `s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t])` | Self-recurrence only (decay) |
| **This** | `s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t] + W_hh×s[t-1])` | Cross-neuron recurrence |

## Why W_hh Matters for Sequence Classification

**Without W_hh (original)**:
- Each neuron only remembers its own past (leaky integration)
- Neurons cannot share information through time
- Pattern at t=0 cannot influence neuron j at t=27 unless j saw the pattern

**With W_hh (this version)**:
- Neurons can communicate temporal patterns to each other
- Early patterns can be "passed" to specialized neurons for later processing
- Enables true sequence memory, not just signal averaging

## Hardware Compatibility (SOEN)

From Shainline's 2021 paper "Optoelectronic Intelligence":
- Cross-neuron connections ARE possible via optical waveguide routing
- Photons from neuron i can be routed to synaptic input of neuron j
- Constraint: Topology must be fixed at fabrication time
- This notebook assumes W_hh topology is pre-determined (all-to-all or structured)

## Expected Improvement

Previous best (without W_hh): ~30.5% test accuracy  
Target with W_hh: Significant improvement through temporal pattern sharing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import gzip
import urllib.request

torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print("\n" + "="*70)
print("FORWARD-FORWARD WITH CROSS-NEURON RECURRENCE (W_hh)")
print("="*70)

## 1. Load MNIST Dataset

In [None]:
# Direct MNIST download without torchvision
def download_mnist(data_dir='./data/mnist'):
    """Download MNIST dataset without torchvision."""
    os.makedirs(data_dir, exist_ok=True)
    
    base_url = 'https://ossci-datasets.s3.amazonaws.com/mnist/'
    files = {
        'train_images': 'train-images-idx3-ubyte.gz',
        'train_labels': 'train-labels-idx1-ubyte.gz',
        'test_images': 't10k-images-idx3-ubyte.gz',
        'test_labels': 't10k-labels-idx1-ubyte.gz',
    }
    
    paths = {}
    for key, filename in files.items():
        filepath = os.path.join(data_dir, filename)
        if not os.path.exists(filepath):
            print(f"Downloading {filename}...")
            urllib.request.urlretrieve(base_url + filename, filepath)
        paths[key] = filepath
    
    return paths


def load_mnist_images(filepath):
    """Load MNIST images from gzipped IDX file - keep as 28x28."""
    with gzip.open(filepath, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        n_images = int.from_bytes(f.read(4), 'big')
        n_rows = int.from_bytes(f.read(4), 'big')
        n_cols = int.from_bytes(f.read(4), 'big')
        data = np.frombuffer(f.read(), dtype=np.uint8)
        return data.reshape(n_images, n_rows, n_cols).astype(np.float32) / 255.0


def load_mnist_labels(filepath):
    """Load MNIST labels from gzipped IDX file."""
    with gzip.open(filepath, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        n_labels = int.from_bytes(f.read(4), 'big')
        return np.frombuffer(f.read(), dtype=np.uint8)


# Download and load MNIST
paths = download_mnist()
X_train_full = torch.from_numpy(load_mnist_images(paths['train_images']))
y_train_full = torch.from_numpy(load_mnist_labels(paths['train_labels'])).long()
X_test_full = torch.from_numpy(load_mnist_images(paths['test_images']))
y_test_full = torch.from_numpy(load_mnist_labels(paths['test_labels'])).long()

print(f"Full dataset: Train={X_train_full.shape}, Test={X_test_full.shape}")

# Use training data subset
N_TRAIN = 20000
N_TEST = 2000

torch.manual_seed(42)
train_idx = torch.randperm(len(X_train_full))[:N_TRAIN]
test_idx = torch.randperm(len(X_test_full))[:N_TEST]

X_train = X_train_full[train_idx]
y_train = y_train_full[train_idx]
X_test = X_test_full[test_idx]
y_test = y_test_full[test_idx]

# Scale to SOEN operating range [0.025, 0.275]
X_train = X_train * 0.25 + 0.025
X_test = X_test * 0.25 + 0.025

print(f"\nUsing subset:")
print(f"  Training set: {X_train.shape} (N × rows × cols)")
print(f"  Test set: {X_test.shape}")
print(f"  X range: [{X_train.min():.3f}, {X_train.max():.3f}]")

## 2. Constants and Label Embedding

In [None]:
N_CLASSES = 10
N_ROWS = 28      # Number of timesteps (rows in image)
N_COLS = 28      # Pixels per row
LABEL_SCALE = 0.25

# Input dimension: 28 pixels + 10 label = 38 per timestep
INPUT_DIM_PER_ROW = N_COLS + N_CLASSES


def embed_label_temporal(X, y, n_classes=N_CLASSES, label_scale=LABEL_SCALE):
    """
    Embed one-hot label into each row of the temporal sequence.
    
    Args:
        X: [N, 28, 28] images (N samples, 28 rows, 28 cols)
        y: [N] class labels (0-9)
    
    Returns:
        X_embedded: [N, 28, 38] - each row has 28 pixels + 10 label dims
    """
    N = X.shape[0]
    device = X.device
    
    # Create one-hot labels [N, 10]
    one_hot = torch.zeros(N, n_classes, device=device)
    one_hot.scatter_(1, y.unsqueeze(1), label_scale)
    
    # Expand to [N, 28, 10] - same label at each timestep
    one_hot_expanded = one_hot.unsqueeze(1).expand(-1, N_ROWS, -1)
    
    # Concatenate: [N, 28, 28] + [N, 28, 10] = [N, 28, 38]
    return torch.cat([X, one_hot_expanded], dim=2)


def create_positive_negative_pairs_temporal(X, y, n_classes=N_CLASSES, label_scale=LABEL_SCALE):
    """
    Create positive and negative temporal sequences for Forward-Forward.
    """
    N = X.shape[0]
    device = X.device
    
    # Positive: correct labels
    X_pos = embed_label_temporal(X, y, n_classes, label_scale)
    
    # Negative: random wrong labels
    y_wrong = (y + torch.randint(1, n_classes, (N,), device=device)) % n_classes
    X_neg = embed_label_temporal(X, y_wrong, n_classes, label_scale)
    
    return X_pos, X_neg


print(f"Input dimension per timestep: {INPUT_DIM_PER_ROW} ({N_COLS} pixels + {N_CLASSES} label)")
print(f"Number of timesteps: {N_ROWS}")

## 3. SOEN Layer with Cross-Neuron Recurrence (W_hh)

This is the key upgrade: adding W_hh connections to enable neurons to influence each other through time.

### Dynamics

**Original (self-recurrence only)**:
```
pre_activation[t] = W_ih @ x[t]
s[t] = α × s[t-1] + (1-α) × g(pre_activation[t])
```

**Upgraded (cross-neuron recurrence)**:
```
pre_activation[t] = W_ih @ x[t] + W_hh @ s[t-1]  # <-- W_hh added!
s[t] = α × s[t-1] + (1-α) × g(pre_activation[t])
```

The W_hh term allows neuron i's current state to depend on neuron j's previous state.

In [None]:
class SOENRecurrentLayer(nn.Module):
    """
    SOEN-inspired recurrent layer with cross-neuron connections (W_hh).
    
    Dynamics:
        pre[t] = W_ih @ x[t] + W_hh @ s[t-1] + bias
        s[t] = alpha * s[t-1] + (1-alpha) * activation(pre[t])
    
    Where:
        - W_ih: input-to-hidden weights [hidden_dim, input_dim]
        - W_hh: hidden-to-hidden weights [hidden_dim, hidden_dim] (NEW!)
        - alpha: retention factor (from gamma_minus)
        - activation: SOEN-like squashing function
    
    Hardware notes:
        - W_hh requires optical routing from neuron outputs back to synapses
        - This is achievable in SOEN via waveguide design (Shainline 2021)
        - Topology fixed at fabrication; weights are synaptic efficacies
    """
    
    def __init__(self, input_dim, hidden_dim, alpha=0.95, activation='tanh',
                 w_hh_scale=0.1, sparse_hh=False, sparsity=0.5):
        """
        Args:
            input_dim: Input features per timestep (38 for MNIST row + label)
            hidden_dim: Number of hidden neurons (<=24 for constraint)
            alpha: Retention factor (1 - dt*gamma_minus), default 0.95
            activation: 'tanh' or 'relu' or 'soen' (Heaviside approximation)
            w_hh_scale: Initialization scale for W_hh (smaller = more stable)
            sparse_hh: If True, use sparse W_hh (more hardware-realistic)
            sparsity: Fraction of W_hh connections to keep (if sparse_hh=True)
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.alpha = alpha
        self.activation_type = activation
        self.sparse_hh = sparse_hh
        
        # Input-to-hidden weights (standard)
        self.W_ih = nn.Linear(input_dim, hidden_dim, bias=True)
        nn.init.xavier_uniform_(self.W_ih.weight)
        nn.init.zeros_(self.W_ih.bias)
        
        # Hidden-to-hidden weights (THE KEY UPGRADE!)
        self.W_hh = nn.Linear(hidden_dim, hidden_dim, bias=False)
        # Initialize smaller to prevent exploding activations
        nn.init.xavier_uniform_(self.W_hh.weight)
        self.W_hh.weight.data *= w_hh_scale
        
        # Optional: sparse mask for W_hh (hardware constraint)
        if sparse_hh:
            mask = (torch.rand(hidden_dim, hidden_dim) < sparsity).float()
            # Ensure some connectivity (at least 1 connection per neuron)
            for i in range(hidden_dim):
                if mask[i].sum() == 0:
                    mask[i, torch.randint(0, hidden_dim, (1,))] = 1.0
            self.register_buffer('hh_mask', mask)
        else:
            self.register_buffer('hh_mask', torch.ones(hidden_dim, hidden_dim))
    
    def soen_activation(self, x):
        """SOEN-like activation: approximates Heaviside with smooth transition."""
        # Soft Heaviside: sigmoid stretched to match SOEN operating range
        return torch.sigmoid(10.0 * (x - 0.1))  # Threshold around 0.1
    
    def forward(self, x):
        """
        Forward pass through temporal sequence.
        
        Args:
            x: [batch, seq_len, input_dim] - temporal input sequence
        
        Returns:
            states: [batch, seq_len, hidden_dim] - hidden states at all timesteps
        """
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        # Initialize hidden state
        h = torch.zeros(batch_size, self.hidden_dim, device=device)
        
        # Collect states for all timesteps
        states = []
        
        # Apply sparse mask to W_hh
        W_hh_effective = self.W_hh.weight * self.hh_mask
        
        for t in range(seq_len):
            # Input contribution
            input_term = self.W_ih(x[:, t, :])  # [batch, hidden_dim]
            
            # Recurrent contribution (THE KEY UPGRADE!)
            recurrent_term = F.linear(h, W_hh_effective)  # [batch, hidden_dim]
            
            # Combined pre-activation
            pre_activation = input_term + recurrent_term
            
            # Apply activation
            if self.activation_type == 'tanh':
                activated = torch.tanh(pre_activation)
            elif self.activation_type == 'relu':
                activated = F.relu(pre_activation)
            elif self.activation_type == 'soen':
                activated = self.soen_activation(pre_activation)
            else:
                activated = torch.tanh(pre_activation)
            
            # Leaky integration (SOEN temporal dynamics)
            h = self.alpha * h + (1 - self.alpha) * activated
            
            states.append(h)
        
        # Stack: [batch, seq_len, hidden_dim]
        return torch.stack(states, dim=1)


class TemporalFFNetwork(nn.Module):
    """
    Forward-Forward network for temporal sequence classification.
    
    Uses SOEN recurrent layers with W_hh for cross-neuron communication.
    """
    
    def __init__(self, input_dim, hidden_dims, alpha=0.95, activation='tanh',
                 w_hh_scale=0.1, sparse_hh=False, sparsity=0.5):
        """
        Args:
            input_dim: Input features per timestep
            hidden_dims: List of hidden dimensions for each layer
            alpha: Retention factor for temporal dynamics
            activation: Activation function ('tanh', 'relu', 'soen')
            w_hh_scale: Scale for W_hh initialization
            sparse_hh: Use sparse W_hh connections
            sparsity: Fraction of W_hh connections to keep
        """
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layer = SOENRecurrentLayer(
                input_dim=prev_dim,
                hidden_dim=hidden_dim,
                alpha=alpha,
                activation=activation,
                w_hh_scale=w_hh_scale,
                sparse_hh=sparse_hh,
                sparsity=sparsity,
            )
            self.layers.append(layer)
            prev_dim = hidden_dim
    
    def forward(self, x):
        """
        Forward pass returning states from all layers.
        
        Args:
            x: [batch, seq_len, input_dim]
        
        Returns:
            all_states: List of [batch, seq_len, hidden_dim] for each layer
        """
        all_states = []
        h = x
        
        for layer in self.layers:
            h = layer(h)  # [batch, seq_len, hidden_dim]
            all_states.append(h)
        
        return all_states


# Test the network
print("Testing TemporalFFNetwork with W_hh...")
test_net = TemporalFFNetwork(
    input_dim=INPUT_DIM_PER_ROW,
    hidden_dims=[24],
    alpha=0.95,
    w_hh_scale=0.1
)

n_params = sum(p.numel() for p in test_net.parameters() if p.requires_grad)
n_params_wih = 38 * 24 + 24  # W_ih + bias
n_params_whh = 24 * 24       # W_hh (no bias)

print(f"\nArchitecture: {INPUT_DIM_PER_ROW} → [24] → goodness")
print(f"Total parameters: {n_params}")
print(f"  W_ih parameters: {n_params_wih} ({INPUT_DIM_PER_ROW} × 24 + 24 bias)")
print(f"  W_hh parameters: {n_params_whh} (24 × 24) ← NEW!")
print(f"\nW_hh enables cross-neuron temporal communication!")

## 4. Forward-Forward Loss Functions

In [None]:
def compute_goodness(activations):
    """
    Compute goodness as mean of squared activations.
    Hardware-compatible: measures mean power in the layer.
    
    Args:
        activations: [batch, hidden_dim]
    
    Returns:
        goodness: [batch] - mean squared activation per sample
    """
    return (activations ** 2).mean(dim=1)


def forward_forward_loss(goodness_pos, goodness_neg, margin=0.01):
    """
    Contrastive Forward-Forward loss.
    
    Push G_pos to be greater than G_neg by at least margin.
    """
    return F.softplus(margin - (goodness_pos - goodness_neg)).mean()

## 5. Training with W_hh

In [None]:
def evaluate_ff_temporal(model, X, y, batch_size=100, goodness_mode='all'):
    """
    Evaluate temporal Forward-Forward model.
    
    For each sample, test all 10 label hypotheses and pick highest goodness.
    """
    model.eval()
    N = X.shape[0]
    all_predictions = []
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            X_batch = X[start:end].to(device)  # [B, 28, 28]
            B = X_batch.shape[0]
            
            # Repeat each sample N_CLASSES times
            X_repeated = X_batch.unsqueeze(1).expand(-1, N_CLASSES, -1, -1)
            X_repeated = X_repeated.reshape(B * N_CLASSES, N_ROWS, N_COLS)
            
            y_hypotheses = torch.arange(N_CLASSES, device=device)
            y_hypotheses = y_hypotheses.unsqueeze(0).expand(B, -1).reshape(B * N_CLASSES)
            
            X_embedded = embed_label_temporal(X_repeated, y_hypotheses)
            
            # Forward pass
            layer_states = model(X_embedded)
            
            # Compute total goodness
            total_goodness = torch.zeros(B * N_CLASSES, device=device)
            for states in layer_states:
                if goodness_mode == 'final':
                    act = states[:, -1, :]  # Final timestep only
                    total_goodness += compute_goodness(act)
                else:  # 'all'
                    for t in range(states.shape[1]):
                        act = states[:, t, :]
                        total_goodness += compute_goodness(act)
            
            # Reshape and get predictions
            goodness_matrix = total_goodness.reshape(B, N_CLASSES)
            predictions = goodness_matrix.argmax(dim=1)
            all_predictions.append(predictions.cpu())
    
    all_predictions = torch.cat(all_predictions)
    accuracy = (all_predictions == y).float().mean().item()
    model.train()
    return accuracy


def train_forward_forward_whh(model, X_train, y_train, X_test, y_test,
                               n_epochs=100, lr=0.01, margin=0.01,
                               batch_size=64, eval_subset=1000, verbose=True,
                               weight_decay=1e-4, lr_decay=0.98,
                               local_in_time=True, goodness_mode='all',
                               gradient_compensation=True):
    """
    Train temporal Forward-Forward with W_hh.
    
    Args:
        model: TemporalFFNetwork with W_hh
        local_in_time: Compute loss at each timestep (hardware-compatible)
        goodness_mode: 'final' or 'all' timesteps
        gradient_compensation: Weight early timesteps more heavily
    """
    device = next(model.parameters()).device
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay)
    
    history = {
        'loss': [],
        'train_acc': [],
        'test_acc': [],
        'goodness_pos': [],
        'goodness_neg': [],
        'lr': [],
    }
    
    N = X_train.shape[0]
    n_batches = (N + batch_size - 1) // batch_size
    
    # Evaluation subset
    eval_idx = torch.randperm(N)[:min(eval_subset, N)]
    X_train_eval = X_train[eval_idx]
    y_train_eval = y_train[eval_idx]
    
    best_test_acc = 0
    
    # Gradient compensation weights (for vanishing gradient mitigation)
    alpha = model.layers[0].alpha
    n_timesteps = N_ROWS
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        epoch_g_pos = []
        epoch_g_neg = []
        
        perm = torch.randperm(N)
        X_shuffled = X_train[perm]
        y_shuffled = y_train[perm]
        
        for batch_idx in range(n_batches):
            start = batch_idx * batch_size
            end = min(start + batch_size, N)
            
            X_batch = X_shuffled[start:end].to(device)
            y_batch = y_shuffled[start:end].to(device)
            
            X_pos, X_neg = create_positive_negative_pairs_temporal(X_batch, y_batch)
            
            optimizer.zero_grad()
            
            # Forward pass
            states_pos = model(X_pos)  # List of [batch, seq_len, hidden_dim]
            states_neg = model(X_neg)
            
            total_loss = 0
            batch_g_pos_list = []
            batch_g_neg_list = []
            
            for layer_states_pos, layer_states_neg in zip(states_pos, states_neg):
                seq_len = layer_states_pos.shape[1]
                
                if local_in_time:
                    # Loss at each timestep
                    for t in range(seq_len):
                        act_pos = layer_states_pos[:, t, :]
                        act_neg = layer_states_neg[:, t, :]
                        
                        g_pos = compute_goodness(act_pos)
                        g_neg = compute_goodness(act_neg)
                        
                        batch_g_pos_list.append(g_pos.mean().item())
                        batch_g_neg_list.append(g_neg.mean().item())
                        
                        timestep_loss = forward_forward_loss(g_pos, g_neg, margin)
                        
                        # Gradient compensation for vanishing gradients
                        if gradient_compensation:
                            weight = (1.0 / alpha) ** (seq_len - 1 - t)
                            # Normalize
                            normalizer = sum((1.0/alpha)**(seq_len-1-i) for i in range(seq_len))
                            weight = weight / normalizer * seq_len
                        else:
                            weight = 1.0
                        
                        total_loss = total_loss + weight * timestep_loss
                
                elif goodness_mode == 'final':
                    # Final timestep only
                    act_pos = layer_states_pos[:, -1, :]
                    act_neg = layer_states_neg[:, -1, :]
                    
                    g_pos = compute_goodness(act_pos)
                    g_neg = compute_goodness(act_neg)
                    
                    batch_g_pos_list.append(g_pos.mean().item())
                    batch_g_neg_list.append(g_neg.mean().item())
                    
                    total_loss = total_loss + forward_forward_loss(g_pos, g_neg, margin)
                
                else:  # 'all' without local_in_time
                    for t in range(seq_len):
                        act_pos = layer_states_pos[:, t, :]
                        act_neg = layer_states_neg[:, t, :]
                        
                        g_pos = compute_goodness(act_pos)
                        g_neg = compute_goodness(act_neg)
                        
                        batch_g_pos_list.append(g_pos.mean().item())
                        batch_g_neg_list.append(g_neg.mean().item())
                        
                        total_loss = total_loss + forward_forward_loss(g_pos, g_neg, margin)
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += total_loss.item()
            epoch_g_pos.append(np.mean(batch_g_pos_list) if batch_g_pos_list else 0)
            epoch_g_neg.append(np.mean(batch_g_neg_list) if batch_g_neg_list else 0)
            
            if verbose and batch_idx % 50 == 0:
                print(f"\rEpoch {epoch+1}/{n_epochs} | Batch {batch_idx+1}/{n_batches} | "
                      f"Loss: {total_loss.item():.4f}", end="")
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Evaluate
        eval_mode = 'all' if local_in_time else goodness_mode
        train_acc = evaluate_ff_temporal(model, X_train_eval, y_train_eval, goodness_mode=eval_mode)
        test_acc = evaluate_ff_temporal(model, X_test, y_test, goodness_mode=eval_mode)
        
        if test_acc > best_test_acc:
            best_test_acc = test_acc
        
        history['loss'].append(epoch_loss / n_batches)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        history['goodness_pos'].append(np.mean(epoch_g_pos))
        history['goodness_neg'].append(np.mean(epoch_g_neg))
        history['lr'].append(current_lr)
        
        if verbose:
            sep = np.mean(epoch_g_pos) - np.mean(epoch_g_neg)
            print(f"\rEpoch {epoch+1}/{n_epochs} | Loss: {epoch_loss/n_batches:.4f} | "
                  f"Train: {train_acc:.4f} | Test: {test_acc:.4f} | "
                  f"Best: {best_test_acc:.4f} | Sep: {sep:.4f}    ")
    
    return history

## 6. Train Model with W_hh

In [None]:
# Hyperparameters
HIDDEN_DIMS = [24]  # Best under 26 constraint
ALPHA = 0.95        # Retention factor (1 - dt*gamma_minus)
W_HH_SCALE = 0.1    # W_hh initialization scale
ACTIVATION = 'tanh' # Activation function
MARGIN = 0.01
N_EPOCHS = 100
LR = 0.01
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-4
LR_DECAY = 0.98

# Training mode
LOCAL_IN_TIME = True
GRADIENT_COMPENSATION = True

print("="*80)
print("TRAINING FORWARD-FORWARD WITH CROSS-NEURON RECURRENCE (W_hh)")
print("="*80)
print(f"\nArchitecture: {INPUT_DIM_PER_ROW} → {HIDDEN_DIMS} → goodness")
print(f"Timesteps: {N_ROWS} (one per row)")
print(f"Alpha (retention): {ALPHA}")
print(f"W_hh scale: {W_HH_SCALE}")
print(f"\nKEY UPGRADE:")
print(f"  Original: s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t])")
print(f"  This:     s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t] + W_hh×s[t-1])")
print(f"  → W_hh enables cross-neuron temporal communication!")
print(f"\nTraining: {N_TRAIN} samples, Testing: {N_TEST} samples")
print("="*80)

# Build model
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

model = TemporalFFNetwork(
    input_dim=INPUT_DIM_PER_ROW,
    hidden_dims=HIDDEN_DIMS,
    alpha=ALPHA,
    activation=ACTIVATION,
    w_hh_scale=W_HH_SCALE,
    sparse_hh=False,
).to(device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {n_params}")
print(f"  W_ih: {INPUT_DIM_PER_ROW * HIDDEN_DIMS[0] + HIDDEN_DIMS[0]}")
print(f"  W_hh: {HIDDEN_DIMS[0] * HIDDEN_DIMS[0]} ← NEW!")

# Train
history = train_forward_forward_whh(
    model, X_train, y_train, X_test, y_test,
    n_epochs=N_EPOCHS, lr=LR, margin=MARGIN,
    batch_size=BATCH_SIZE, verbose=True,
    weight_decay=WEIGHT_DECAY, lr_decay=LR_DECAY,
    local_in_time=LOCAL_IN_TIME,
    gradient_compensation=GRADIENT_COMPENSATION,
)

print("="*80)
print(f"Final train accuracy: {history['train_acc'][-1]:.4f}")
print(f"Final test accuracy: {history['test_acc'][-1]:.4f}")
print(f"Best test accuracy: {max(history['test_acc']):.4f}")
print(f"\nComparison:")
print(f"  Previous best (without W_hh): ~30.5%")
print(f"  This version (with W_hh): {max(history['test_acc'])*100:.1f}%")
print(f"  Random baseline: 10%")

## 7. Training Curves

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Loss
ax1 = axes[0, 0]
ax1.plot(history['loss'], color='steelblue', lw=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Contrastive Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

# Goodness
ax2 = axes[0, 1]
ax2.plot(history['goodness_pos'], label='Positive (G+)', color='green', lw=2)
ax2.plot(history['goodness_neg'], label='Negative (G-)', color='red', lw=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Mean Goodness')
ax2.set_title('Goodness Values')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning rate
ax3 = axes[0, 2]
ax3.plot(history['lr'], color='orange', lw=2)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_title('Learning Rate Decay')
ax3.grid(True, alpha=0.3)

# Accuracy
ax4 = axes[1, 0]
ax4.plot(history['train_acc'], label='Train', color='coral', lw=2)
ax4.plot(history['test_acc'], label='Test', color='steelblue', lw=2)
ax4.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Random (10%)')
ax4.axhline(y=0.305, color='purple', linestyle=':', alpha=0.7, label='Previous best (30.5%)')
best_epoch = np.argmax(history['test_acc'])
ax4.axvline(x=best_epoch, color='green', linestyle=':', alpha=0.7)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accuracy')
ax4.set_title(f'Classification Accuracy (Best: {max(history["test_acc"]):.2%})')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_ylim(0, 1.0)

# Separation
ax5 = axes[1, 1]
separation = [p - n for p, n in zip(history['goodness_pos'], history['goodness_neg'])]
ax5.plot(separation, color='purple', lw=2)
ax5.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax5.set_xlabel('Epoch')
ax5.set_ylabel('G+ - G-')
ax5.set_title('Goodness Separation')
ax5.grid(True, alpha=0.3)

# Train vs Test gap
ax6 = axes[1, 2]
gap = [t - v for t, v in zip(history['train_acc'], history['test_acc'])]
ax6.plot(gap, color='brown', lw=2)
ax6.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax6.set_xlabel('Epoch')
ax6.set_ylabel('Train - Test')
ax6.set_title('Generalization Gap')
ax6.grid(True, alpha=0.3)

plt.suptitle(f'Forward-Forward with W_hh ({sum(HIDDEN_DIMS)} neurons)', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Visualize W_hh Learned Patterns

In [None]:
# Visualize the learned W_hh matrix
W_hh = model.layers[0].W_hh.weight.data.cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# W_hh heatmap
ax1 = axes[0]
im1 = ax1.imshow(W_hh, cmap='RdBu_r', aspect='auto', vmin=-np.abs(W_hh).max(), vmax=np.abs(W_hh).max())
ax1.set_xlabel('From neuron')
ax1.set_ylabel('To neuron')
ax1.set_title('W_hh (Hidden-to-Hidden Weights)')
plt.colorbar(im1, ax=ax1)

# W_hh distribution
ax2 = axes[1]
ax2.hist(W_hh.flatten(), bins=50, color='steelblue', edgecolor='black', alpha=0.7)
ax2.axvline(x=0, color='red', linestyle='--')
ax2.set_xlabel('Weight value')
ax2.set_ylabel('Count')
ax2.set_title(f'W_hh Distribution (mean={W_hh.mean():.4f}, std={W_hh.std():.4f})')

# Top connections
ax3 = axes[2]
# Find strongest connections
flat_idx = np.argsort(np.abs(W_hh).flatten())[-20:][::-1]
top_connections = [(i // W_hh.shape[1], i % W_hh.shape[1], W_hh.flatten()[i]) for i in flat_idx]
labels = [f'{i}→{j}' for i, j, _ in top_connections]
values = [v for _, _, v in top_connections]
colors = ['green' if v > 0 else 'red' for v in values]
ax3.barh(range(len(values)), values, color=colors, edgecolor='black')
ax3.set_yticks(range(len(labels)))
ax3.set_yticklabels(labels)
ax3.axvline(x=0, color='black', linestyle='-')
ax3.set_xlabel('Weight')
ax3.set_title('Strongest W_hh Connections')

plt.suptitle('Cross-Neuron Recurrence Analysis', fontsize=12)
plt.tight_layout()
plt.show()

# Statistics
print(f"\nW_hh Statistics:")
print(f"  Shape: {W_hh.shape}")
print(f"  Mean: {W_hh.mean():.6f}")
print(f"  Std: {W_hh.std():.6f}")
print(f"  Max: {W_hh.max():.6f}")
print(f"  Min: {W_hh.min():.6f}")
print(f"  Sparsity (|w|<0.01): {(np.abs(W_hh) < 0.01).mean()*100:.1f}%")

## 9. Compare W_hh Scales and Configurations

In [None]:
# Compare different W_hh configurations
configs = [
    {'name': 'No W_hh (baseline)', 'w_hh_scale': 0.0},
    {'name': 'W_hh scale=0.05', 'w_hh_scale': 0.05},
    {'name': 'W_hh scale=0.1', 'w_hh_scale': 0.1},
    {'name': 'W_hh scale=0.2', 'w_hh_scale': 0.2},
    {'name': 'W_hh scale=0.3', 'w_hh_scale': 0.3},
]

comparison_results = []

print("Comparing W_hh configurations...")
print("="*80)

for config in configs:
    torch.manual_seed(42)
    
    model = TemporalFFNetwork(
        input_dim=INPUT_DIM_PER_ROW,
        hidden_dims=[24],
        alpha=0.95,
        activation='tanh',
        w_hh_scale=config['w_hh_scale'],
    ).to(device)
    
    # Zero out W_hh for baseline
    if config['w_hh_scale'] == 0.0:
        with torch.no_grad():
            model.layers[0].W_hh.weight.zero_()
            model.layers[0].W_hh.weight.requires_grad = False
    
    history = train_forward_forward_whh(
        model, X_train, y_train, X_test, y_test,
        n_epochs=50, lr=0.01, margin=0.01,
        batch_size=64, verbose=False,
        weight_decay=1e-4, lr_decay=0.98,
        local_in_time=True,
        gradient_compensation=True,
    )
    
    best_test = max(history['test_acc'])
    comparison_results.append({
        'config': config['name'],
        'w_hh_scale': config['w_hh_scale'],
        'train_acc': history['train_acc'][-1],
        'test_acc': history['test_acc'][-1],
        'best_test': best_test,
    })
    
    print(f"{config['name']:25s} | Final: {history['test_acc'][-1]:.4f} | Best: {best_test:.4f}")

print("="*80)

# Find best configuration
best_config = max(comparison_results, key=lambda x: x['best_test'])
baseline = [r for r in comparison_results if r['w_hh_scale'] == 0.0][0]

print(f"\nBest configuration: {best_config['config']} with {best_config['best_test']:.2%}")
print(f"Improvement over baseline: {(best_config['best_test'] - baseline['best_test'])*100:.1f}%")

## 10. Try Different Alpha Values

In [None]:
# Compare different alpha (retention) values
alpha_configs = [
    {'alpha': 0.90, 'name': 'α=0.90 (fast decay)'},
    {'alpha': 0.95, 'name': 'α=0.95 (moderate)'},
    {'alpha': 0.97, 'name': 'α=0.97 (slow decay)'},
    {'alpha': 0.99, 'name': 'α=0.99 (very slow)'},
]

alpha_results = []

print("Comparing alpha (retention factor) values with W_hh...")
print("="*80)

for config in alpha_configs:
    torch.manual_seed(42)
    
    model = TemporalFFNetwork(
        input_dim=INPUT_DIM_PER_ROW,
        hidden_dims=[24],
        alpha=config['alpha'],
        activation='tanh',
        w_hh_scale=0.1,
    ).to(device)
    
    history = train_forward_forward_whh(
        model, X_train, y_train, X_test, y_test,
        n_epochs=50, lr=0.01, margin=0.01,
        batch_size=64, verbose=False,
        weight_decay=1e-4, lr_decay=0.98,
        local_in_time=True,
        gradient_compensation=True,
    )
    
    best_test = max(history['test_acc'])
    alpha_results.append({
        'config': config['name'],
        'alpha': config['alpha'],
        'best_test': best_test,
        'history': history,
    })
    
    print(f"{config['name']:25s} | Best: {best_test:.4f}")

print("="*80)

# Plot comparison
fig, ax = plt.subplots(figsize=(10, 5))
for result in alpha_results:
    ax.plot(result['history']['test_acc'], label=result['config'], lw=2)
ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Random')
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Accuracy')
ax.set_title('Effect of Alpha (Retention Factor) with W_hh')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 11. Sparse W_hh (Hardware-Realistic)

In [None]:
# Try sparse W_hh (more hardware-realistic)
sparsity_configs = [
    {'sparse_hh': False, 'sparsity': 1.0, 'name': 'Dense W_hh (100%)'},
    {'sparse_hh': True, 'sparsity': 0.5, 'name': 'Sparse W_hh (50%)'},
    {'sparse_hh': True, 'sparsity': 0.3, 'name': 'Sparse W_hh (30%)'},
    {'sparse_hh': True, 'sparsity': 0.1, 'name': 'Sparse W_hh (10%)'},
]

sparsity_results = []

print("Comparing sparse vs dense W_hh (hardware constraints)...")
print("="*80)

for config in sparsity_configs:
    torch.manual_seed(42)
    
    model = TemporalFFNetwork(
        input_dim=INPUT_DIM_PER_ROW,
        hidden_dims=[24],
        alpha=0.95,
        activation='tanh',
        w_hh_scale=0.1,
        sparse_hh=config['sparse_hh'],
        sparsity=config['sparsity'],
    ).to(device)
    
    history = train_forward_forward_whh(
        model, X_train, y_train, X_test, y_test,
        n_epochs=50, lr=0.01, margin=0.01,
        batch_size=64, verbose=False,
        weight_decay=1e-4, lr_decay=0.98,
        local_in_time=True,
        gradient_compensation=True,
    )
    
    best_test = max(history['test_acc'])
    sparsity_results.append({
        'config': config['name'],
        'sparsity': config['sparsity'],
        'best_test': best_test,
    })
    
    print(f"{config['name']:25s} | Best: {best_test:.4f}")

print("="*80)
print("\nNote: Sparse W_hh is more hardware-realistic because optical routing")
print("      is expensive. 10-30% connectivity may be practical.")

## 12. Conclusions

In [None]:
print("="*70)
print("CONCLUSIONS: FORWARD-FORWARD WITH CROSS-NEURON RECURRENCE (W_hh)")
print("="*70)

print(f"\n1. KEY UPGRADE:")
print(f"   Original: s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t])")
print(f"   This:     s[t] = α×s[t-1] + (1-α)×g(W_ih×x[t] + W_hh×s[t-1])")
print(f"   → W_hh allows neurons to share temporal patterns!")

print(f"\n2. ARCHITECTURE:")
print(f"   Input: {INPUT_DIM_PER_ROW} ({N_COLS} pixels + {N_CLASSES} label)")
print(f"   Timesteps: {N_ROWS}")
print(f"   Hidden: {sum(HIDDEN_DIMS)} neurons")
print(f"   W_ih: {INPUT_DIM_PER_ROW} × {HIDDEN_DIMS[0]} = {INPUT_DIM_PER_ROW * HIDDEN_DIMS[0]} weights")
print(f"   W_hh: {HIDDEN_DIMS[0]} × {HIDDEN_DIMS[0]} = {HIDDEN_DIMS[0]**2} weights (NEW!)")

print(f"\n3. PERFORMANCE:")
print(f"   Previous best (without W_hh): ~30.5%")
print(f"   This version (with W_hh):    {max(history['test_acc'])*100:.1f}%")
print(f"   Random baseline: 10%")

print(f"\n4. HARDWARE COMPATIBILITY:")
print(f"   ✓ W_hh implemented via optical waveguide routing")
print(f"   ✓ Topology fixed at fabrication (Shainline 2021)")
print(f"   ✓ Sparse W_hh reduces routing complexity")
print(f"   ✓ Local-in-time learning (no BPTT hardware needed)")

print(f"\n5. WHY W_hh HELPS:")
print(f"   Without W_hh: Each neuron can only remember its own past")
print(f"   With W_hh:    Neurons can share temporal information")
print(f"   Example: Neuron 1 detects top rows → tells Neuron 2 for bottom rows")

print(f"\n6. LIMITATIONS:")
print(f"   - Still learning with BPTT (for gradient flow)")
print(f"   - Hardware would need different learning rule")
print(f"   - Fixed topology may limit flexibility")

print("\n" + "="*70)