# Tutorial 02-FF — Train a SOEN Model with Forward-Forward Learning

This tutorial demonstrates **Forward-Forward (FF) learning** integrated with **actual SOEN dynamics**.

Unlike the separate MLP-based FF implementation, this version:
- Uses **real SOEN SingleDendrite layers** with temporal processing
- Maintains **64-timestep sequential input** (not flattened)
- Computes **goodness from SOEN neuron states**
- Updates **J matrices** (connection weights) using local FF learning

---

## Architecture Comparison

| Aspect | SOEN Backprop | **SOEN + FF (This Tutorial)** |
|--------|---------------|-------------------------------|
| Layers | SingleDendrite | **SingleDendrite** (same!) |
| Input | 1D × 64 timesteps | **3D × 64 timesteps** (signal + label) |
| Hidden | 5 neurons | **5 neurons** (same!) |
| Recurrence | J_1_to_1 | **J_1_to_1** (same!) |
| Learning | Backprop | **Forward-Forward** (local) |
| Parameters | ~27 | **~45** (15 + 20 + 10) |

---

## How FF-SOEN Works

```
┌─────────────────────────────────────────────────────────────────┐
│                    FORWARD-FORWARD + SOEN                       │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Input at each timestep: [signal, label_0, label_1]             │
│                                                                 │
│  POSITIVE: signal + CORRECT label → SOEN dynamics → HIGH goodness│
│  NEGATIVE: signal + WRONG label   → SOEN dynamics → LOW goodness │
│                                                                 │
│  Goodness = mean(s²) from SingleDendrite neuron states          │
│  Local learning updates J matrices without backprop             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

---

## ML Task: Pulse Classification (Same as Tutorial 02)
- **Class 0**: Single pulse
- **Class 1**: Two pulses

## Step 1: Setup and Imports

We import the necessary libraries and set up the environment.

In [None]:
# ==============================================================================
# STEP 1: IMPORTS AND SETUP
# ==============================================================================
# Import SOEN toolkit and required libraries
# ==============================================================================

import sys
from pathlib import Path

# Add src directory to path
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import Dict, Tuple, List
from dataclasses import dataclass
from tqdm import tqdm

# Import SOEN model builder
from soen_toolkit.core.model_yaml import build_model_from_yaml

# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Using device: {DEVICE}")

## Step 2: Configuration

We define all hyperparameters and settings in a single configuration class.

### Hardware Mode Options

| Mode | Description | Accuracy | Hardware Ready |
|------|-------------|----------|----------------|
| `HARDWARE_MODE = False` | Uses autograd + Adam | Higher | No |
| `HARDWARE_MODE = True` | Uses Hebbian learning | Lower | Yes |

In [None]:
# ==============================================================================
# STEP 2: CONFIGURATION
# ==============================================================================
# Configuration for FF-SOEN training
#
# Key difference from standard FF:
#   - Uses actual SOEN model with temporal processing
#   - Label embedded in input at each timestep (3D input)
#   - Goodness computed from SOEN neuron states
# ==============================================================================

@dataclass
class FFSOENConfig:
    """Configuration for Forward-Forward training with SOEN."""
    
    # Model specification
    model_spec: str = "training/test_models/model_specs/3D_5D_2D_PulseNetSpec_FF.yaml"
    
    # Dataset
    data_path: str = "training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5"
    seq_len: int = 64
    signal_dim: int = 1       # Original signal dimension
    num_classes: int = 2      # Binary classification
    input_dim: int = 3        # signal (1) + label one-hot (2)
    
    # Architecture (from model spec)
    hidden_dim: int = 5       # SingleDendrite layer dimension
    
    # Training
    batch_size: int = 32
    num_epochs: int = 100
    learning_rate: float = 0.01   # For J matrix updates
    
    # Forward-Forward specific
    threshold: float = 0.5    # Goodness threshold (adjusted for SOEN states)
    goodness_type: str = "mean_squared"  # "mean_squared" or "mean_abs"


config = FFSOENConfig()

# Print configuration
print("="*60)
print("FF-SOEN CONFIGURATION")
print("="*60)
print(f"Model spec: {config.model_spec}")
print(f"Dataset: {config.data_path}")
print(f"")
print(f"Input: {config.input_dim}D × {config.seq_len} timesteps")
print(f"  - Signal: {config.signal_dim}D")
print(f"  - Label: {config.num_classes}D one-hot")
print(f"Hidden: {config.hidden_dim} SingleDendrite neurons")
print(f"Output: {config.num_classes} classes")
print(f"")
print(f"Batch size: {config.batch_size}")
print(f"Epochs: {config.num_epochs}")
print(f"Learning rate: {config.learning_rate}")
print(f"Threshold: {config.threshold}")
print("="*60)

## Step 3: Load and Visualize Dataset

We load the pulse classification dataset and visualize examples from each class.

### Dataset Structure
- **Shape**: `(N, T, D)` = `(samples, timesteps, features)`
- **Class 0**: Single pulse in the signal
- **Class 1**: Two pulses in the signal

In [None]:
# ==============================================================================
# STEP 3: LOAD AND VISUALIZE DATASET
# ==============================================================================
# This cell loads the HDF5 dataset and visualizes examples from each class.
#
# The dataset contains:
#   - Class 0: Signals with ONE pulse
#   - Class 1: Signals with TWO pulses
#
# Dataset format:
#   - data: shape (N, T, D) where N=samples, T=timesteps, D=features
#   - labels: shape (N,) with integer class labels
# ==============================================================================

def load_dataset(data_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Load train and validation data from HDF5 file.
    
    Args:
        data_path: Path to HDF5 dataset file
        
    Returns:
        train_data, train_labels, val_data, val_labels
    """
    with h5py.File(data_path, 'r') as f:
        # Load training data
        train_data = np.array(f['train']['data'])
        train_labels = np.array(f['train']['labels'])
        
        # Load validation data (or use part of train if not available)
        if 'val' in f:
            val_data = np.array(f['val']['data'])
            val_labels = np.array(f['val']['labels'])
        else:
            # Split train data
            split_idx = int(len(train_data) * 0.8)
            val_data = train_data[split_idx:]
            val_labels = train_labels[split_idx:]
            train_data = train_data[:split_idx]
            train_labels = train_labels[:split_idx]
    
    return train_data, train_labels, val_data, val_labels


def visualize_dataset(data: np.ndarray, labels: np.ndarray, n_examples: int = 4):
    """
    Visualize examples from each class.
    
    Args:
        data: Input data of shape (N, T, D)
        labels: Class labels of shape (N,)
        n_examples: Number of examples per class to show
    """
    print(f"Dataset shape: {data.shape} (N samples, T timesteps, D features)")
    print(f"Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")
    
    # Find examples of each class
    class_0_idx = np.where(labels == 0)[0][:n_examples]
    class_1_idx = np.where(labels == 1)[0][:n_examples]
    
    fig, axes = plt.subplots(2, n_examples, figsize=(3*n_examples, 5))
    fig.suptitle("Input Signals: One-Pulse (Class 0) vs Two-Pulse (Class 1)", 
                 fontsize=12, fontweight='bold')
    
    # Plot Class 0 (single pulse)
    for i, idx in enumerate(class_0_idx):
        axes[0, i].plot(data[idx, :, 0], 'b-', linewidth=1.5)
        axes[0, i].set_title(f"Sample {idx}", fontsize=10)
        axes[0, i].set_ylim(-0.1, 1.1)
        axes[0, i].grid(True, alpha=0.3)
        if i == 0:
            axes[0, i].set_ylabel("Class 0\n(One Pulse)", fontsize=10)
    
    # Plot Class 1 (two pulses)
    for i, idx in enumerate(class_1_idx):
        axes[1, i].plot(data[idx, :, 0], 'r-', linewidth=1.5)
        axes[1, i].set_title(f"Sample {idx}", fontsize=10)
        axes[1, i].set_ylim(-0.1, 1.1)
        axes[1, i].grid(True, alpha=0.3)
        if i == 0:
            axes[1, i].set_ylabel("Class 1\n(Two Pulses)", fontsize=10)
        axes[1, i].set_xlabel("Time step")
    
    plt.tight_layout()
    plt.show()


# Load dataset
train_data, train_labels, val_data, val_labels = load_dataset(config.data_path)

print(f"\nTrain set: {train_data.shape[0]} samples")
print(f"Val set: {val_data.shape[0]} samples")

# Visualize
visualize_dataset(train_data, train_labels)

## Step 4: Create FF-SOEN Dataset

For FF-SOEN, we embed the label at **each timestep**:

```
Original signal: [batch, 64, 1]  →  shape (B, T, 1)

With label embedding:
  POSITIVE: [batch, 64, 3]  →  [signal, correct_label_0, correct_label_1]
  NEGATIVE: [batch, 64, 3]  →  [signal, wrong_label_0, wrong_label_1]
```

The label is **repeated at every timestep** so the SOEN neurons receive consistent label information throughout the temporal processing.

In [None]:
# ==============================================================================
# STEP 4: CREATE FF-SOEN DATASET
# ==============================================================================
# Dataset that embeds labels temporally for SOEN processing.
#
# Key difference from standard FF:
#   - Labels are embedded at EACH TIMESTEP (not flattened)
#   - Output shape: [batch, seq_len, 3] where 3 = signal + one-hot label
#   - Preserves temporal structure for SOEN dynamics
# ==============================================================================

class PulseFFSOENDataset(Dataset):
    """
    Pulse classification dataset for FF-SOEN learning.
    
    Embeds labels temporally: at each timestep, the input is
    [signal_value, label_0, label_1] where label is one-hot.
    """
    
    def __init__(self, data: np.ndarray, labels: np.ndarray, num_classes: int = 2):
        """
        Args:
            data: Signal data of shape (N, T, D) where D=1
            labels: Class labels of shape (N,)
            num_classes: Number of classes (2 for pulse classification)
        """
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.num_classes = num_classes
        self.seq_len = data.shape[1]
        
        print(f"Created FF-SOEN Dataset:")
        print(f"  Original shape: {self.data.shape} (N, T, D)")
        print(f"  Output shape: (N, {self.seq_len}, {1 + num_classes})")
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def _embed_label_temporal(self, signal: torch.Tensor, label: int) -> torch.Tensor:
        """
        Embed label at each timestep.
        
        Args:
            signal: Signal of shape (T, 1)
            label: Integer class label
            
        Returns:
            Embedded tensor of shape (T, 3) = [signal, label_one_hot]
        """
        seq_len = signal.shape[0]
        
        # Create one-hot label and repeat for all timesteps
        one_hot = F.one_hot(torch.tensor(label), num_classes=self.num_classes).float()
        one_hot_repeated = one_hot.unsqueeze(0).expand(seq_len, -1)  # (T, 2)
        
        # Concatenate: [signal, label_0, label_1]
        embedded = torch.cat([signal, one_hot_repeated], dim=1)  # (T, 3)
        return embedded
    
    def _get_wrong_label(self, true_label: int) -> int:
        """Get a wrong label for negative example."""
        if self.num_classes == 2:
            return 1 - true_label
        else:
            wrong = torch.randint(0, self.num_classes - 1, (1,)).item()
            if wrong >= true_label:
                wrong += 1
            return wrong
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a sample with positive and negative versions.
        
        Returns:
            Dictionary with:
                - positive: signal with correct label, shape (T, 3)
                - negative: signal with wrong label, shape (T, 3)
                - label: true class label
                - signal: original signal (T, 1)
        """
        signal = self.data[idx]  # (T, 1)
        true_label = self.labels[idx].item()
        wrong_label = self._get_wrong_label(true_label)
        
        return {
            "positive": self._embed_label_temporal(signal, true_label),
            "negative": self._embed_label_temporal(signal, wrong_label),
            "label": self.labels[idx],
            "signal": signal,
        }


# Create datasets
print("Creating training dataset...")
train_dataset = PulseFFSOENDataset(train_data, train_labels, num_classes=config.num_classes)

print("\nCreating validation dataset...")
val_dataset = PulseFFSOENDataset(val_data, val_labels, num_classes=config.num_classes)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

# Verify a sample
print("\n" + "="*60)
print("SAMPLE VERIFICATION")
print("="*60)
sample = train_dataset[0]
print(f"Positive shape: {sample['positive'].shape}  (T, 3)")
print(f"Negative shape: {sample['negative'].shape}  (T, 3)")
print(f"True label: {sample['label'].item()}")
print(f"")
print(f"Timestep 0:")
print(f"  Signal value: {sample['positive'][0, 0].item():.4f}")
print(f"  Positive label part: {sample['positive'][0, 1:].tolist()}")
print(f"  Negative label part: {sample['negative'][0, 1:].tolist()}")

## Step 5: Load SOEN Model

We load the actual SOEN model from the YAML specification. This gives us:
- **SingleDendrite layers** with real superconductor dynamics
- **J matrices** (connection weights) that we'll train with FF
- **Temporal processing** with recurrence

In [None]:
# ==============================================================================
# STEP 5: LOAD SOEN MODEL
# ==============================================================================
# Load the actual SOEN model from YAML specification.
#
# The model has:
#   - Layer 0 (Input): dim=3 (signal + label)
#   - Layer 1 (SingleDendrite): dim=5 (hidden layer with SOEN dynamics)
#   - Layer 2 (Output): dim=2 (for readout)
#
# Connection matrices (J):
#   - J_0_to_1: [5, 3] = 15 weights
#   - J_1_to_1: [5, 5] - 5 diagonal = 20 weights (recurrent)
#   - J_1_to_2: [2, 5] = 10 weights
# ==============================================================================

# Build SOEN model from YAML
model_path = Path(config.model_spec)
model = build_model_from_yaml(model_path)
model = model.to(DEVICE)

# Enable state tracking for goodness computation
for layer in model.layers:
    if hasattr(layer, 'set_tracking_flags'):
        layer.set_tracking_flags(phi=True, g=True, s=True)

print("="*60)
print("SOEN MODEL LOADED")
print("="*60)
print(f"\nLayers:")
for i, layer in enumerate(model.layers):
    layer_type = type(layer).__name__
    if hasattr(layer, 'dim'):
        print(f"  Layer {i}: {layer_type}, dim={layer.dim}")
    else:
        print(f"  Layer {i}: {layer_type}")

print(f"\nConnections (J matrices):")
for key in model.connections:
    J = model.connections[key]
    print(f"  {key}: shape={list(J.shape)}, params={J.numel()}")

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {total_params}")
print("="*60)

## Step 6: Define FF-SOEN Training Functions

We define functions to:
1. **Compute goodness** from SOEN neuron states
2. **Run forward pass** through SOEN with label-embedded input
3. **Compute FF loss** (push positive goodness up, negative down)
4. **Update J matrices** based on local gradients

```
┌─────────────────────────────────────────────────────────────────┐
│  SOEN Forward Pass with FF Learning                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Input [B, T, 3] → Layer 0 → J_0_to_1 → Layer 1 (SingleDendrite)│
│                                  ↺ J_1_to_1 (recurrent)         │
│                                                                 │
│  Goodness = mean(s²) from Layer 1 states over all timesteps     │
│                                                                 │
│  Loss = softplus(θ - g_pos) + softplus(g_neg - θ)               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

In [None]:
# ==============================================================================
# STEP 6: FF-SOEN TRAINING FUNCTIONS
# ==============================================================================
# Functions for Forward-Forward learning with SOEN model.
#
# Key functions:
#   - compute_goodness: Extract goodness from SOEN layer states
#   - ff_loss: Forward-Forward loss function
#   - get_layer_states: Get neuron states from SOEN layers
# ==============================================================================

def compute_goodness(states: torch.Tensor, goodness_type: str = "mean_squared") -> torch.Tensor:
    """
    Compute goodness from SOEN neuron states.
    
    Args:
        states: Neuron states of shape (batch, seq_len, dim)
        goodness_type: "mean_squared" or "mean_abs"
        
    Returns:
        Goodness per sample, shape (batch,)
    """
    if goodness_type == "mean_squared":
        # Sum of squared activities across neurons and time
        return (states ** 2).mean(dim=(1, 2))
    else:  # mean_abs
        return states.abs().mean(dim=(1, 2))


def ff_loss(pos_goodness: torch.Tensor, neg_goodness: torch.Tensor, 
            threshold: float) -> torch.Tensor:
    """
    Compute Forward-Forward loss.
    
    Goal: Push positive goodness ABOVE threshold, negative BELOW.
    
    Args:
        pos_goodness: Goodness for positive data, shape (batch,)
        neg_goodness: Goodness for negative data, shape (batch,)
        threshold: Goodness threshold
        
    Returns:
        Scalar loss value
    """
    # Softplus loss: smooth approximation
    pos_loss = F.softplus(-(pos_goodness - threshold))  # Want pos > threshold
    neg_loss = F.softplus(neg_goodness - threshold)     # Want neg < threshold
    return (pos_loss + neg_loss).mean()


def forward_and_get_states(model, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """
    Run forward pass and extract layer states.
    
    Args:
        model: SOEN model
        x: Input tensor of shape (batch, seq_len, input_dim)
        
    Returns:
        output: Model output
        layer_states: List of state tensors for each layer
    """
    # Reset stateful components before forward pass
    if hasattr(model, 'reset_stateful_components'):
        model.reset_stateful_components()
    
    # Forward pass
    output, all_outputs = model(x)
    
    # Extract states from layers
    layer_states = []
    for i, layer in enumerate(model.layers):
        if hasattr(layer, 'get_state_history'):
            states = layer.get_state_history()  # (batch, seq_len+1, dim)
            if states is not None:
                # Remove initial state (first timestep)
                states = states[:, 1:, :]  # (batch, seq_len, dim)
                layer_states.append(states)
        elif i < len(all_outputs):
            # Fallback: use layer output as state
            layer_states.append(all_outputs[i])
    
    return output, layer_states


def predict_multipass(model, signals: torch.Tensor, num_classes: int) -> torch.Tensor:
    """
    Predict using multi-pass inference (one pass per class).
    
    For each class, embed that label and compute total goodness.
    Predict the class with highest goodness.
    
    Args:
        model: SOEN model
        signals: Original signals of shape (batch, seq_len, 1)
        num_classes: Number of classes
        
    Returns:
        Predicted class labels, shape (batch,)
    """
    batch_size = signals.shape[0]
    seq_len = signals.shape[1]
    device = signals.device
    
    all_goodness = torch.zeros(batch_size, num_classes, device=device)
    
    for label in range(num_classes):
        # Create one-hot label repeated at each timestep
        one_hot = F.one_hot(torch.tensor([label], device=device), num_classes).float()
        one_hot_repeated = one_hot.unsqueeze(1).expand(batch_size, seq_len, -1)
        
        # Embed label in input
        x = torch.cat([signals, one_hot_repeated], dim=2)  # (batch, seq_len, 3)
        
        # Forward pass
        _, layer_states = forward_and_get_states(model, x)
        
        # Compute goodness from hidden layer (layer 1)
        if len(layer_states) > 1:
            hidden_states = layer_states[1]  # SingleDendrite layer states
        else:
            hidden_states = layer_states[0]
        
        goodness = compute_goodness(hidden_states, config.goodness_type)
        all_goodness[:, label] = goodness
    
    return all_goodness.argmax(dim=1)


# Test the functions
print("Testing FF-SOEN functions...")
sample = train_dataset[0]
test_pos = sample["positive"].unsqueeze(0).to(DEVICE)  # (1, T, 3)
test_neg = sample["negative"].unsqueeze(0).to(DEVICE)

# Forward pass
output_pos, states_pos = forward_and_get_states(model, test_pos)
output_neg, states_neg = forward_and_get_states(model, test_neg)

print(f"Output shape: {output_pos.shape}")
print(f"Number of layer states: {len(states_pos)}")
for i, s in enumerate(states_pos):
    print(f"  Layer {i} states shape: {s.shape}")

# Compute goodness
if len(states_pos) > 1:
    g_pos = compute_goodness(states_pos[1], config.goodness_type)
    g_neg = compute_goodness(states_neg[1], config.goodness_type)
else:
    g_pos = compute_goodness(states_pos[0], config.goodness_type)
    g_neg = compute_goodness(states_neg[0], config.goodness_type)

print(f"\nGoodness (positive): {g_pos.item():.4f}")
print(f"Goodness (negative): {g_neg.item():.4f}")
print(f"FF Loss: {ff_loss(g_pos, g_neg, config.threshold).item():.4f}")

## Step 7: Training Loop

The FF-SOEN training loop:

1. **Positive pass**: Forward with correct label → compute positive goodness
2. **Negative pass**: Forward with wrong label → compute negative goodness  
3. **FF loss**: Push positive above threshold, negative below
4. **Update**: Gradient descent on J matrices (connection weights)

**Key difference from backprop**: We use the same loss for all layers (goodness-based),
not a global cross-entropy loss at the output.

In [None]:
# ==============================================================================
# STEP 7: FF-SOEN TRAINING LOOP
# ==============================================================================
# Train the SOEN model using Forward-Forward learning.
#
# Training flow:
#   For each batch:
#     1. Forward positive data → compute positive goodness
#     2. Forward negative data → compute negative goodness
#     3. Compute FF loss
#     4. Backprop through goodness computation (NOT through all layers)
#     5. Update J matrices
# ==============================================================================

def train_ff_soen(
    model,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: FFSOENConfig,
    device: torch.device,
) -> Dict:
    """
    Train SOEN model using Forward-Forward algorithm.
    
    Args:
        model: SOEN model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        config: Training configuration
        device: Device to train on
        
    Returns:
        Training history dictionary
    """
    model.to(device)
    model.train()
    
    # Optimizer for all trainable parameters (J matrices)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Training history
    history = {
        "train_loss": [],
        "val_acc": [],
        "pos_goodness": [],
        "neg_goodness": [],
    }
    
    print(f"\n{'='*60}")
    print("FF-SOEN TRAINING")
    print(f"{'='*60}")
    print(f"Epochs: {config.num_epochs}")
    print(f"Learning rate: {config.learning_rate}")
    print(f"Threshold: {config.threshold}")
    print(f"Goodness type: {config.goodness_type}")
    print(f"{'='*60}\n")
    
    for epoch in range(config.num_epochs):
        model.train()
        
        epoch_loss = 0.0
        epoch_pos_g = 0.0
        epoch_neg_g = 0.0
        num_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        for batch in pbar:
            pos_data = batch["positive"].to(device)  # (batch, seq_len, 3)
            neg_data = batch["negative"].to(device)
            
            optimizer.zero_grad()
            
            # ===== POSITIVE PASS =====
            _, states_pos = forward_and_get_states(model, pos_data)
            
            # Get hidden layer states (layer 1 = SingleDendrite)
            if len(states_pos) > 1:
                hidden_pos = states_pos[1]
            else:
                hidden_pos = states_pos[0]
            
            pos_goodness = compute_goodness(hidden_pos, config.goodness_type)
            
            # ===== NEGATIVE PASS =====
            _, states_neg = forward_and_get_states(model, neg_data)
            
            if len(states_neg) > 1:
                hidden_neg = states_neg[1]
            else:
                hidden_neg = states_neg[0]
            
            neg_goodness = compute_goodness(hidden_neg, config.goodness_type)
            
            # ===== FF LOSS =====
            loss = ff_loss(pos_goodness, neg_goodness, config.threshold)
            
            # ===== UPDATE =====
            loss.backward()
            
            # Clip gradients for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Track metrics
            epoch_loss += loss.item()
            epoch_pos_g += pos_goodness.mean().item()
            epoch_neg_g += neg_goodness.mean().item()
            num_batches += 1
            
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "g+": f"{pos_goodness.mean().item():.3f}",
                "g-": f"{neg_goodness.mean().item():.3f}",
            })
        
        # Epoch averages
        avg_loss = epoch_loss / num_batches
        avg_pos_g = epoch_pos_g / num_batches
        avg_neg_g = epoch_neg_g / num_batches
        
        history["train_loss"].append(avg_loss)
        history["pos_goodness"].append(avg_pos_g)
        history["neg_goodness"].append(avg_neg_g)
        
        # Evaluate every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == config.num_epochs - 1:
            val_acc = evaluate_ff_soen(model, val_loader, config.num_classes, device)
            history["val_acc"].append(val_acc)
            
            sep = avg_pos_g - avg_neg_g
            print(f"\nEpoch {epoch+1}: Loss={avg_loss:.4f}, "
                  f"g+={avg_pos_g:.3f}, g-={avg_neg_g:.3f}, sep={sep:+.3f}, "
                  f"Val Acc={val_acc:.1%}")
    
    return history


def evaluate_ff_soen(model, data_loader: DataLoader, num_classes: int, 
                     device: torch.device) -> float:
    """
    Evaluate FF-SOEN model accuracy using multi-pass inference.
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in data_loader:
            signals = batch["signal"].to(device)  # (batch, seq_len, 1)
            labels = batch["label"].to(device)
            
            predictions = predict_multipass(model, signals, num_classes)
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    return correct / total

In [None]:
# ==============================================================================
# TRAIN THE MODEL
# ==============================================================================
# Run FF-SOEN training.
#
# Watch for:
#   - Goodness separation (g+ should become > g-)
#   - Validation accuracy improvement
#   - Loss decreasing
# ==============================================================================

history = train_ff_soen(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=DEVICE,
)

## Step 8: Visualize Training Progress

We plot the training metrics to understand how learning progressed.

In [None]:
# ==============================================================================
# STEP 8: VISUALIZE TRAINING PROGRESS
# ==============================================================================
# Plot training metrics for FF-SOEN.
# ==============================================================================

def plot_ff_soen_history(history: Dict, config: FFSOENConfig):
    """Plot FF-SOEN training metrics."""
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle('FF-SOEN Training Progress', fontsize=14, fontweight='bold')
    
    # 1. Training loss
    axes[0].plot(history["train_loss"], 'b-', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('FF Loss')
    axes[0].set_title('Training Loss')
    axes[0].grid(True, alpha=0.3)
    
    # 2. Goodness separation
    epochs = range(len(history["pos_goodness"]))
    axes[1].plot(epochs, history["pos_goodness"], 'g-', linewidth=2, label='Positive')
    axes[1].plot(epochs, history["neg_goodness"], 'r-', linewidth=2, label='Negative')
    axes[1].axhline(y=config.threshold, color='black', linestyle='--', 
                    linewidth=1, label=f'Threshold ({config.threshold})')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Goodness')
    axes[1].set_title('Goodness Separation')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # 3. Validation accuracy
    if len(history["val_acc"]) > 0:
        eval_epochs = [0] + list(range(9, len(history["train_loss"]), 10))
        eval_epochs = eval_epochs[:len(history["val_acc"])]
        if len(eval_epochs) < len(history["val_acc"]):
            eval_epochs.append(len(history["train_loss"]) - 1)
        axes[2].plot(eval_epochs[:len(history["val_acc"])], history["val_acc"], 
                     'go-', linewidth=2, markersize=6)
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Accuracy')
        axes[2].set_title('Validation Accuracy')
        axes[2].set_ylim([0, 1.05])
        axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print(f"\n{'='*60}")
    print("TRAINING SUMMARY")
    print(f"{'='*60}")
    if len(history["val_acc"]) > 0:
        print(f"Final validation accuracy: {history['val_acc'][-1]:.1%}")
        print(f"Best validation accuracy: {max(history['val_acc']):.1%}")
    print(f"Final goodness separation: {history['pos_goodness'][-1] - history['neg_goodness'][-1]:+.4f}")
    print(f"  Positive: {history['pos_goodness'][-1]:.4f}")
    print(f"  Negative: {history['neg_goodness'][-1]:.4f}")


plot_ff_soen_history(history, config)

## Step 9: Visualize Predictions

Let's see how the model performs on individual samples.

In [None]:
# ==============================================================================
# STEP 9: VISUALIZE PREDICTIONS
# ==============================================================================
# Show model predictions on individual samples.
# ==============================================================================

def visualize_ff_soen_predictions(model, dataset, num_classes: int, 
                                   n_samples: int = 8, device=DEVICE):
    """Visualize FF-SOEN model predictions."""
    
    model.eval()
    
    indices = np.random.choice(len(dataset), min(n_samples, len(dataset)), replace=False)
    
    n_cols = min(4, n_samples)
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5*n_cols, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    fig.suptitle('FF-SOEN Predictions (Multi-Pass Inference)', fontsize=14, fontweight='bold')
    
    class_names = ["One Pulse", "Two Pulses"]
    correct_count = 0
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            sample = dataset[idx]
            signal = sample["signal"].unsqueeze(0).to(device)  # (1, T, 1)
            label = sample["label"].item()
            
            prediction = predict_multipass(model, signal, num_classes).item()
            
            is_correct = prediction == label
            correct_count += is_correct
            
            row, col = i // n_cols, i % n_cols
            ax = axes[row, col]
            
            ax.plot(sample["signal"][:, 0].numpy(), 'b-', linewidth=1.5)
            ax.set_ylim(-0.1, 1.1)
            ax.grid(True, alpha=0.3)
            
            color = 'green' if is_correct else 'red'
            symbol = '✓' if is_correct else '✗'
            
            ax.set_title(
                f"{symbol} Pred: {class_names[prediction]}\nTrue: {class_names[label]}",
                fontsize=9, color=color,
                fontweight='bold' if not is_correct else 'normal'
            )
            
            if col == 0:
                ax.set_ylabel("Signal")
            if row == n_rows - 1:
                ax.set_xlabel("Time step")
    
    for i in range(n_samples, n_rows * n_cols):
        row, col = i // n_cols, i % n_cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nSample accuracy: {correct_count}/{n_samples} ({correct_count/n_samples:.1%})")


visualize_ff_soen_predictions(model, val_dataset, config.num_classes, n_samples=8)

## Step 10: Analyze Goodness Distributions

Let's see how well the layers separate positive and negative goodness.

In [None]:
# ==============================================================================
# STEP 10: ANALYZE GOODNESS AND SOEN STATES
# ==============================================================================
# Analyze how the SOEN neuron states differ for positive vs negative data.
# ==============================================================================

def analyze_soen_states(model, data_loader, config, device=DEVICE):
    """Analyze SOEN neuron states and goodness distributions."""
    
    model.eval()
    
    all_pos_goodness = []
    all_neg_goodness = []
    all_pos_states = []
    all_neg_states = []
    
    with torch.no_grad():
        for batch in data_loader:
            pos_data = batch["positive"].to(device)
            neg_data = batch["negative"].to(device)
            
            # Positive pass
            _, states_pos = forward_and_get_states(model, pos_data)
            if len(states_pos) > 1:
                hidden_pos = states_pos[1]
            else:
                hidden_pos = states_pos[0]
            
            g_pos = compute_goodness(hidden_pos, config.goodness_type)
            all_pos_goodness.extend(g_pos.cpu().numpy())
            all_pos_states.append(hidden_pos.cpu())
            
            # Negative pass
            _, states_neg = forward_and_get_states(model, neg_data)
            if len(states_neg) > 1:
                hidden_neg = states_neg[1]
            else:
                hidden_neg = states_neg[0]
            
            g_neg = compute_goodness(hidden_neg, config.goodness_type)
            all_neg_goodness.extend(g_neg.cpu().numpy())
            all_neg_states.append(hidden_neg.cpu())
    
    all_pos_goodness = np.array(all_pos_goodness)
    all_neg_goodness = np.array(all_neg_goodness)
    all_pos_states = torch.cat(all_pos_states, dim=0)  # (N, T, D)
    all_neg_states = torch.cat(all_neg_states, dim=0)
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle('SOEN State Analysis', fontsize=14, fontweight='bold')
    
    # 1. Goodness distributions
    axes[0].hist(all_pos_goodness, bins=50, alpha=0.6, color='green', 
                 label=f'Positive (μ={all_pos_goodness.mean():.3f})', density=True)
    axes[0].hist(all_neg_goodness, bins=50, alpha=0.6, color='red', 
                 label=f'Negative (μ={all_neg_goodness.mean():.3f})', density=True)
    axes[0].axvline(x=config.threshold, color='black', linestyle='--', 
                    linewidth=2, label=f'Threshold ({config.threshold})')
    axes[0].set_xlabel('Goodness')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Goodness Distributions')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 2. Mean neuron activity over time
    mean_pos = all_pos_states.mean(dim=(0, 2)).numpy()  # (T,)
    mean_neg = all_neg_states.mean(dim=(0, 2)).numpy()
    axes[1].plot(mean_pos, 'g-', linewidth=2, label='Positive')
    axes[1].plot(mean_neg, 'r-', linewidth=2, label='Negative')
    axes[1].set_xlabel('Timestep')
    axes[1].set_ylabel('Mean Neuron Activity')
    axes[1].set_title('SOEN State Dynamics')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # 3. Per-neuron activity comparison
    pos_per_neuron = all_pos_states.mean(dim=(0, 1)).numpy()  # (D,)
    neg_per_neuron = all_neg_states.mean(dim=(0, 1)).numpy()
    x = np.arange(len(pos_per_neuron))
    width = 0.35
    axes[2].bar(x - width/2, pos_per_neuron, width, color='green', alpha=0.7, label='Positive')
    axes[2].bar(x + width/2, neg_per_neuron, width, color='red', alpha=0.7, label='Negative')
    axes[2].set_xlabel('Neuron Index')
    axes[2].set_ylabel('Mean Activity')
    axes[2].set_title('Per-Neuron Activity')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    separation = all_pos_goodness.mean() - all_neg_goodness.mean()
    print(f"\nGoodness Statistics:")
    print(f"  Positive: mean={all_pos_goodness.mean():.4f}, std={all_pos_goodness.std():.4f}")
    print(f"  Negative: mean={all_neg_goodness.mean():.4f}, std={all_neg_goodness.std():.4f}")
    print(f"  Separation: {separation:+.4f}")


analyze_soen_states(model, val_loader, config)

## Summary

### What We Achieved

This tutorial demonstrates **true FF-SOEN integration**:

1. **Real SOEN dynamics**: Uses actual SingleDendrite layers with superconductor physics
2. **Temporal processing**: 64-timestep sequences (not flattened like standard FF)
3. **Label embedding**: One-hot labels at each timestep `[signal, label_0, label_1]`
4. **Goodness from SOEN states**: `mean(s²)` computed from neuron state trajectories
5. **J matrix updates**: Connection weights trained with FF local learning

### Architecture Comparison

| Component | SOEN Backprop | FF-SOEN (This Tutorial) |
|-----------|---------------|-------------------------|
| **Layers** | SingleDendrite | SingleDendrite (same!) |
| **Input dim** | 1 | 3 (signal + label) |
| **Processing** | Temporal (64 steps) | Temporal (64 steps) |
| **Recurrence** | Yes (J_1_to_1) | Yes (J_1_to_1) |
| **Learning** | Backprop + CE loss | FF + goodness loss |
| **J_0_to_1** | 5 params | 15 params |
| **J_1_to_1** | 20 params | 20 params |
| **J_1_to_2** | 2 params | 10 params |

### Key Insight

Forward-Forward learning is a **learning algorithm**, not an architecture. By integrating FF with SOEN:
- We preserve the **temporal dynamics** and **recurrence** of SOEN
- We replace **backpropagation** with **local goodness-based learning**
- The model remains **hardware-compatible** (no weight transport needed)

### Parameter Comparison

| Connection | Original (1D input) | FF-SOEN (3D input) |
|------------|---------------------|---------------------|
| J_0_to_1 | 1×5 = 5 | 3×5 = **15** |
| J_1_to_1 | 5×5-5 = 20 | 5×5-5 = **20** |
| J_1_to_2 | one-to-one ≈ 2 | all-to-all = **10** |
| **Total** | ~27 | ~**45** |

The extra parameters come from:
- Expanded input (signal + label)
- Full output connection (needed for proper gradient flow)

### Next Steps

- Compare accuracy with backprop version (Tutorial 02)
- Try different goodness functions (`mean_abs` for hardware)
- Experiment with threshold values
- Apply to MNIST (Tutorial 04)