# Tutorial 03 — MNIST with Parallel Input (One Row per Neuron)

**Different approach**: Instead of sequential row-by-row input, each neuron receives a different row simultaneously.

## Architecture

```
Image (28×28)
     ↓
28 Hidden Neurons (each gets one row = 28 pixels)
     ↓ ↺ recurrent
10 Output Neurons
```

## Timeline

```
t=1  t=2  t=3  t=4  t=5  │  t=6  t=7  t=8  t=9  t=10
─────────────────────────┼─────────────────────────────
    INPUT PHASE          │       SETTLE PHASE
  (same input ×5)        │      (no input)
                         │         ↑
                         │     Take output here
```

## Hyperparameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `n_input_steps` | 5 | Times to present the same input |
| `n_settle_steps` | 5 | Times to run without input |
| `output_step` | 6 | Which timestep to read output from |

In [None]:
import os
import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import gzip
import urllib.request
import struct
from tqdm import tqdm

torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")

## 1. Hyperparameters

In [None]:
# ============================================================
# KEY HYPERPARAMETERS
# ============================================================

N_INPUT_STEPS = 5      # Present same input for this many timesteps
N_SETTLE_STEPS = 5     # Run without input for this many timesteps
OUTPUT_STEP = 6        # Which timestep to read output (1-indexed)
                       # 6 = first timestep after input phase ends

# Network dimensions
HIDDEN_DIM = 28        # One neuron per row
INPUT_DIM = 28         # Each neuron gets 28 pixels (one row)
OUTPUT_DIM = 10        # 10 digit classes

# SOEN dynamics
DT = 0.1
GAMMA_PLUS = 0.1       # Integration rate
GAMMA_MINUS = 0.01     # Leak rate

# Training
BATCH_SIZE = 128
EPOCHS = 30
LR = 0.01

print(f"Timeline: {N_INPUT_STEPS} input steps + {N_SETTLE_STEPS} settle steps = {N_INPUT_STEPS + N_SETTLE_STEPS} total")
print(f"Output taken at step {OUTPUT_STEP}")
print(f"Architecture: {INPUT_DIM} (per neuron) → {HIDDEN_DIM} hidden → {OUTPUT_DIM} output")

## 2. Load MNIST

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def load_mnist():
    """Load MNIST as (N, 28, 28) - each row goes to a different neuron."""
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    # Train/val split
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = 6000
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    print(f"Train: {train_img.shape}, Val: {val_img.shape}, Test: {test_img.shape}")
    return (train_img, train_lbl), (val_img, val_lbl), (test_img, test_lbl)

(train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = load_mnist()

## 3. Visualize the Parallel Input Concept

In [None]:
def visualize_parallel_input(image, label):
    """Show how each neuron receives a different row."""
    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    
    # Original image
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title(f'Original Image (Label: {label})', fontsize=12)
    axes[0].set_xlabel('28 columns')
    axes[0].set_ylabel('28 rows')
    
    # Row assignment
    row_colors = np.arange(28).reshape(28, 1) * np.ones((1, 28))
    axes[1].imshow(row_colors, cmap='tab20', aspect='auto')
    for i in range(0, 28, 4):
        axes[1].axhline(y=i-0.5, color='white', linewidth=0.5)
        axes[1].text(29, i, f'Neuron {i}', va='center', fontsize=8)
    axes[1].set_title('Row → Neuron Assignment', fontsize=12)
    axes[1].set_xlabel('28 pixels (input to that neuron)')
    axes[1].set_ylabel('Neuron index')
    
    # Timeline
    ax = axes[2]
    total_steps = N_INPUT_STEPS + N_SETTLE_STEPS
    
    # Input phase
    ax.barh(0, N_INPUT_STEPS, left=0, height=0.5, color='blue', alpha=0.7, label='Input phase')
    ax.text(N_INPUT_STEPS/2, 0, f'Same input ×{N_INPUT_STEPS}', ha='center', va='center', color='white', fontweight='bold')
    
    # Settle phase
    ax.barh(0, N_SETTLE_STEPS, left=N_INPUT_STEPS, height=0.5, color='orange', alpha=0.7, label='Settle phase')
    ax.text(N_INPUT_STEPS + N_SETTLE_STEPS/2, 0, f'No input ×{N_SETTLE_STEPS}', ha='center', va='center', color='white', fontweight='bold')
    
    # Output marker
    ax.axvline(x=OUTPUT_STEP - 0.5, color='red', linewidth=2, linestyle='--', label=f'Output (step {OUTPUT_STEP})')
    ax.scatter([OUTPUT_STEP - 0.5], [0], color='red', s=100, zorder=5, marker='v')
    
    ax.set_xlim(-0.5, total_steps + 0.5)
    ax.set_ylim(-1, 1)
    ax.set_xlabel('Timestep')
    ax.set_title('Timeline', fontsize=12)
    ax.legend(loc='upper right')
    ax.set_xticks(range(total_steps + 1))
    ax.set_xticklabels([str(i+1) for i in range(total_steps + 1)])
    ax.set_yticks([])
    
    plt.tight_layout()
    plt.show()
    
    print("\nKey insight:")
    print(f"  • Each of 28 neurons receives a different row (28 pixels)")
    print(f"  • Same input presented {N_INPUT_STEPS} times for integration")
    print(f"  • Network settles for {N_SETTLE_STEPS} steps without input")
    print(f"  • Output read at step {OUTPUT_STEP} (first settle step)")

visualize_parallel_input(train_data[0], train_labels[0])

## 4. Parallel SOEN Model

In [None]:
class ParallelSOEN(nn.Module):
    """
    SOEN model with parallel input: each neuron receives a different row.
    
    Architecture:
        - 28 hidden neurons
        - Each neuron i receives row i of the image (28 pixels)
        - Recurrent connections between hidden neurons
        - 28 → 10 output layer
    
    Input connection is DIAGONAL: W_i2h[i, :] only receives from row i.
    This is the key difference from the standard all-to-all connection.
    """
    
    def __init__(self, hidden_dim=28, input_dim=28, output_dim=10,
                 n_input_steps=5, n_settle_steps=5, output_step=6,
                 dt=0.1, gamma_plus=0.1, gamma_minus=0.01):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.n_input_steps = n_input_steps
        self.n_settle_steps = n_settle_steps
        self.output_step = output_step
        self.dt = dt
        self.gamma_plus = gamma_plus
        self.gamma_minus = gamma_minus
        
        # Input weights: each neuron gets its own row
        # W_i2h[neuron_i, pixel_j] = weight from pixel j of row i to neuron i
        # This is 28 separate (28,) weight vectors, implemented as (28, 28)
        self.W_i2h = nn.Parameter(torch.empty(hidden_dim, input_dim))
        
        # Recurrent weights: all-to-all between hidden neurons
        self.W_h2h = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        
        # Output weights: 28 → 10
        self.W_h2o = nn.Parameter(torch.empty(output_dim, hidden_dim))
        
        # Biases
        self.bias_h = nn.Parameter(torch.zeros(hidden_dim))
        self.bias_o = nn.Parameter(torch.zeros(output_dim))
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.uniform_(self.W_i2h, -0.2, 0.2)
        nn.init.normal_(self.W_h2h, 0, 0.1)
        nn.init.normal_(self.W_h2o, 0, 0.2)
        # No self-connections in recurrent
        with torch.no_grad():
            self.W_h2h.fill_diagonal_(0)
    
    def source_function(self, phi):
        """Smooth approximation of Heaviside source function."""
        return torch.sigmoid(5 * phi)
    
    def step(self, s, x=None):
        """
        Single timestep update.
        
        Args:
            s: Hidden state (batch, 28)
            x: Input image (batch, 28, 28) or None for settle phase
        
        Returns:
            s_new: Updated hidden state
        """
        batch_size = s.shape[0]
        
        # Compute input contribution
        # Each neuron i receives row i of the image
        # input_to_neuron[b, i] = sum_j(W_i2h[i, j] * x[b, i, j])
        if x is not None:
            # Element-wise multiply then sum along pixel dimension
            # x: (batch, 28, 28) - 28 rows, 28 pixels per row
            # W_i2h: (28, 28) - 28 neurons, 28 input weights per neuron
            # For neuron i, we want: sum_j(W_i2h[i,j] * x[b,i,j])
            input_contrib = (x * self.W_i2h.unsqueeze(0)).sum(dim=2)  # (batch, 28)
        else:
            input_contrib = 0
        
        # Recurrent contribution
        recurrent_contrib = F.linear(s, self.W_h2h)  # (batch, 28)
        
        # Total flux
        phi = input_contrib + recurrent_contrib + self.bias_h
        
        # SingleDendrite dynamics: ds/dt = γ⁺g(φ) - γ⁻s
        g = self.source_function(phi)
        dsdt = self.gamma_plus * g - self.gamma_minus * s
        s_new = s + self.dt * dsdt
        
        return s_new
    
    def forward(self, x):
        """
        Forward pass with parallel input.
        
        Args:
            x: Images (batch, 28, 28)
        
        Returns:
            output: Class logits (batch, 10)
            states: Dictionary with intermediate states
        """
        batch_size = x.shape[0]
        
        # Initialize hidden state
        s = torch.zeros(batch_size, self.hidden_dim, device=x.device)
        
        # Store states for analysis
        all_states = [s.clone()]
        all_outputs = []
        
        # INPUT PHASE: Present same input for n_input_steps
        for t in range(self.n_input_steps):
            s = self.step(s, x)
            all_states.append(s.clone())
            all_outputs.append(F.linear(s, self.W_h2o, self.bias_o))
        
        # SETTLE PHASE: Run without input for n_settle_steps
        for t in range(self.n_settle_steps):
            s = self.step(s, x=None)
            all_states.append(s.clone())
            all_outputs.append(F.linear(s, self.W_h2o, self.bias_o))
        
        # Get output at specified step (1-indexed)
        output = all_outputs[self.output_step - 1]
        
        return output, {
            'all_states': all_states,
            'all_outputs': all_outputs,
            'final_state': s
        }

# Create model
model = ParallelSOEN(
    hidden_dim=HIDDEN_DIM,
    input_dim=INPUT_DIM,
    output_dim=OUTPUT_DIM,
    n_input_steps=N_INPUT_STEPS,
    n_settle_steps=N_SETTLE_STEPS,
    output_step=OUTPUT_STEP,
    dt=DT,
    gamma_plus=GAMMA_PLUS,
    gamma_minus=GAMMA_MINUS
).to(device)

print(f"Model: {HIDDEN_DIM} hidden neurons, each receiving {INPUT_DIM} pixels")
print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
print(f"W_i2h: {model.W_i2h.shape} (each row is one neuron's input weights)")
print(f"W_h2h: {model.W_h2h.shape} (recurrent)")
print(f"W_h2o: {model.W_h2o.shape} (output)")

## 5. Training

In [None]:
def train_model(model, train_data, train_labels, val_data, val_labels,
                epochs=30, batch_size=128, lr=0.01):
    """
    Train the parallel SOEN model.
    """
    # Create dataloaders
    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.float32),
        torch.tensor(train_labels, dtype=torch.long)
    )
    val_dataset = TensorDataset(
        torch.tensor(val_data, dtype=torch.float32),
        torch.tensor(val_labels, dtype=torch.long)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # Training history
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    best_state = None
    
    print("="*60)
    print("PARALLEL SOEN TRAINING")
    print("="*60)
    print(f"Each neuron receives a different row of the image")
    print(f"Input steps: {model.n_input_steps}, Settle steps: {model.n_settle_steps}")
    print(f"Output at step: {model.output_step}")
    print("="*60)
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, labels in pbar:
            x, labels = x.to(device), labels.to(device)
            
            optimizer.zero_grad()
            output, _ = model(x)
            loss = F.cross_entropy(output, labels)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            
            # Enforce no self-connections
            with torch.no_grad():
                model.W_h2h.fill_diagonal_(0)
            
            # Metrics
            pred = output.argmax(dim=1)
            epoch_correct += (pred == labels).sum().item()
            epoch_total += len(labels)
            epoch_loss += loss.item() * len(labels)
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{epoch_correct/epoch_total:.3f}'})
        
        scheduler.step()
        
        # Epoch metrics
        train_loss = epoch_loss / epoch_total
        train_acc = epoch_correct / epoch_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for x, labels in val_loader:
                x, labels = x.to(device), labels.to(device)
                output, _ = model(x)
                loss = F.cross_entropy(output, labels)
                
                val_loss += loss.item() * len(labels)
                val_correct += (output.argmax(dim=1) == labels).sum().item()
                val_total += len(labels)
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Track best
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f} {'*' if val_acc == best_val_acc else ''}")
    
    # Restore best
    if best_state:
        model.load_state_dict(best_state)
    print(f"\nBest validation accuracy: {best_val_acc:.4f}")
    
    return history

history = train_model(model, train_data, train_labels, val_data, val_labels,
                      epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR)

## 6. Visualize Training

In [None]:
def plot_training(history):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss (Parallel SOEN)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy (Parallel SOEN)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training(history)

## 7. Evaluate

In [None]:
@torch.no_grad()
def evaluate(model, test_data, test_labels):
    model.eval()
    
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.float32),
        torch.tensor(test_labels, dtype=torch.long)
    )
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    for x, labels in tqdm(test_loader, desc="Testing"):
        x = x.to(device)
        output, _ = model(x)
        preds = output.argmax(dim=1).cpu()
        all_preds.append(preds)
        all_labels.append(labels)
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    accuracy = (all_preds == all_labels).float().mean().item()
    
    print(f"\n{'='*60}")
    print(f"TEST ACCURACY (Parallel SOEN): {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    return accuracy

test_acc = evaluate(model, test_data, test_labels)

## 8. Visualize Dynamics

In [None]:
def visualize_dynamics(model, image, label):
    """Visualize how states evolve during input and settle phases."""
    model.eval()
    x = torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output, states = model(x)
    
    # Stack states: (n_steps+1, 28)
    all_states = torch.stack(states['all_states']).squeeze().cpu().numpy()
    all_outputs = torch.stack(states['all_outputs']).squeeze().cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original image
    axes[0, 0].imshow(image, cmap='gray')
    axes[0, 0].set_title(f'Input Image (Label: {label})')
    axes[0, 0].set_xlabel('Pixel')
    axes[0, 0].set_ylabel('Row (= Neuron)')
    
    # Hidden state evolution
    im = axes[0, 1].imshow(all_states.T, aspect='auto', cmap='viridis')
    axes[0, 1].axvline(x=model.n_input_steps + 0.5, color='red', linestyle='--', label='Input ends')
    axes[0, 1].axvline(x=model.output_step, color='green', linestyle='--', label=f'Output step {model.output_step}')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('Neuron')
    axes[0, 1].set_title('Hidden State Evolution')
    axes[0, 1].legend()
    plt.colorbar(im, ax=axes[0, 1])
    
    # Output evolution
    for i in range(10):
        axes[1, 0].plot(all_outputs[:, i], label=f'{i}', alpha=0.7)
    axes[1, 0].axvline(x=model.n_input_steps - 0.5, color='red', linestyle='--')
    axes[1, 0].axvline(x=model.output_step - 1, color='green', linestyle='--')
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Logit')
    axes[1, 0].set_title('Output Logits Over Time')
    axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Final prediction
    pred = output.argmax(dim=1).item()
    probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
    colors = ['green' if i == label else 'blue' for i in range(10)]
    colors[pred] = 'red' if pred != label else 'green'
    axes[1, 1].bar(range(10), probs, color=colors)
    axes[1, 1].set_xlabel('Class')
    axes[1, 1].set_ylabel('Probability')
    axes[1, 1].set_title(f'Prediction: {pred} (True: {label}) {"✓" if pred == label else "✗"}')
    axes[1, 1].set_xticks(range(10))
    
    plt.tight_layout()
    plt.show()

# Visualize a few samples
for i in [0, 1, 2]:
    visualize_dynamics(model, test_data[i], test_labels[i])

## 9. Hyperparameter Exploration

In [None]:
def explore_output_steps(model, test_data, test_labels, max_samples=1000):
    """
    Compare accuracy at different output steps.
    """
    model.eval()
    
    # Sample subset for speed
    x = torch.tensor(test_data[:max_samples], dtype=torch.float32).to(device)
    labels = torch.tensor(test_labels[:max_samples], dtype=torch.long)
    
    with torch.no_grad():
        _, states = model(x)
    
    all_outputs = states['all_outputs']  # List of (batch, 10)
    
    results = []
    for step, output in enumerate(all_outputs, 1):
        pred = output.argmax(dim=1).cpu()
        acc = (pred == labels).float().mean().item()
        phase = 'input' if step <= model.n_input_steps else 'settle'
        results.append((step, acc, phase))
        print(f"Step {step:2d} ({phase:6s}): {acc:.4f}")
    
    # Plot
    steps = [r[0] for r in results]
    accs = [r[1] for r in results]
    colors = ['blue' if r[2] == 'input' else 'orange' for r in results]
    
    plt.figure(figsize=(10, 5))
    plt.bar(steps, accs, color=colors)
    plt.axvline(x=model.n_input_steps + 0.5, color='red', linestyle='--', label='Input phase ends')
    plt.xlabel('Output Step')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs Output Step')
    plt.legend()
    plt.xticks(steps)
    plt.grid(True, alpha=0.3, axis='y')
    plt.show()
    
    best_step = steps[np.argmax(accs)]
    print(f"\nBest output step: {best_step} (accuracy: {max(accs):.4f})")

explore_output_steps(model, test_data, test_labels)

## Summary

| Aspect | Sequential (28 steps) | Parallel (5+5 steps) |
|--------|----------------------|----------------------|
| Input format | 28 timesteps, 1 row each | 1 image, 28 parallel neurons |
| Total timesteps | 28 | 10 (5 input + 5 settle) |
| Each neuron sees | All rows over time | Only its assigned row |
| Gradient path | Through 28 steps | Through 10 steps |

### Key Insights

1. **Parallel input reduces temporal depth**: 10 steps vs 28 steps
2. **Each neuron specializes**: Neuron i becomes expert at row i
3. **Recurrence integrates across rows**: Hidden→Hidden connections combine information
4. **Settle phase matters**: Network needs time to integrate after input