# Tutorial 03 — MNIST with Sliding Window Input (8 pixels at a time)

**Redundant input with sliding windows**: Each neuron receives only 8 pixels at a time,
with windows sliding across rows then columns.

## Input Scheme

```
SWEEP (40 steps):
├── Row Phase (20 steps): Slide 8-pixel window across each row
│   Step 0:  [0:8]   → pixels 0-7
│   Step 1:  [1:9]   → pixels 1-8
│   ...     
│   Step 19: [19:27] → pixels 19-26
│
└── Column Phase (20 steps): Slide 8-pixel window down each column
    Step 20: row [0:8]   → rows 0-7
    Step 21: row [1:9]   → rows 1-8
    ...     
    Step 39: row [19:27] → rows 19-26

Repeat sweeps until 100 input steps → 2.5 sweeps
```

## Hardware Constraint: 8 inputs per neuron

In [None]:
import os
import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import gzip
import urllib.request
import struct
from tqdm import tqdm

torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")

## 1. Hyperparameters

In [None]:
# ============================================================
# KEY HYPERPARAMETERS
# ============================================================

# Sliding window parameters
WINDOW_SIZE = 8         # Each neuron receives 8 pixels (hardware constraint)
N_ROW_STEPS = 20        # Sliding window steps across rows
N_COL_STEPS = 20        # Sliding window steps down columns
STEPS_PER_SWEEP = N_ROW_STEPS + N_COL_STEPS  # 40

# Timing
N_INPUT_STEPS = 100     # Total input presentation steps
N_SETTLE_STEPS = 1      # Steps without input after
OUTPUT_STEP = 101       # Output at first settle step (after 100 input steps)

# Network
HIDDEN_DIM = 28         # One neuron per row/column
INPUT_DIM = WINDOW_SIZE # 8 inputs per neuron
OUTPUT_DIM = 10         # 10 digit classes

# SOEN dynamics
DT = 0.1
GAMMA_PLUS = 0.1
GAMMA_MINUS = 0.01

# Training
BATCH_SIZE = 128
EPOCHS = 30
LR = 0.005

n_sweeps = N_INPUT_STEPS / STEPS_PER_SWEEP
print(f"Sliding window: {WINDOW_SIZE} pixels")
print(f"Sweep: {N_ROW_STEPS} row steps + {N_COL_STEPS} col steps = {STEPS_PER_SWEEP} steps")
print(f"Total: {N_INPUT_STEPS} input steps = {n_sweeps:.1f} sweeps")
print(f"Output at step: {OUTPUT_STEP}")

## 2. Load MNIST

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def load_mnist():
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = 6000
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    print(f"Train: {train_img.shape}, Val: {val_img.shape}, Test: {test_img.shape}")
    return (train_img, train_lbl), (val_img, val_lbl), (test_img, test_lbl)

(train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = load_mnist()

## 3. Visualize Sliding Window Scheme

In [None]:
def visualize_sliding_window(image, label):
    """Visualize the sliding window input scheme."""
    fig = plt.figure(figsize=(16, 10))
    
    # Original image
    ax1 = fig.add_subplot(2, 3, 1)
    ax1.imshow(image, cmap='gray')
    ax1.set_title(f'Original Image (Label: {label})')
    
    # Row phase visualization (show a few steps)
    ax2 = fig.add_subplot(2, 3, 2)
    row_vis = np.zeros((28, 28))
    for step in [0, 5, 10, 15, 19]:
        row_vis[:, step:step+8] += 0.2
    ax2.imshow(row_vis, cmap='Blues', vmin=0, vmax=1)
    ax2.set_title('Row Phase: Window slides →')
    ax2.set_xlabel('Column (window slides this way)')
    ax2.set_ylabel('Row (each neuron = one row)')
    
    # Column phase visualization
    ax3 = fig.add_subplot(2, 3, 3)
    col_vis = np.zeros((28, 28))
    for step in [0, 5, 10, 15, 19]:
        col_vis[step:step+8, :] += 0.2
    ax3.imshow(col_vis, cmap='Oranges', vmin=0, vmax=1)
    ax3.set_title('Column Phase: Window slides ↓')
    ax3.set_xlabel('Column (each neuron = one column)')
    ax3.set_ylabel('Row (window slides this way)')
    
    # Timeline
    ax4 = fig.add_subplot(2, 1, 2)
    
    # Draw sweeps
    colors = ['blue', 'orange']
    for sweep in range(3):  # Show ~2.5 sweeps
        start = sweep * STEPS_PER_SWEEP
        if start >= N_INPUT_STEPS:
            break
        
        # Row phase
        row_end = min(start + N_ROW_STEPS, N_INPUT_STEPS)
        if row_end > start:
            ax4.barh(0, row_end - start, left=start, height=0.4, 
                     color='blue', alpha=0.7, edgecolor='black')
            if row_end - start > 5:
                ax4.text((start + row_end)/2, 0, 'Row', ha='center', va='center', 
                         color='white', fontweight='bold', fontsize=8)
        
        # Column phase
        col_start = start + N_ROW_STEPS
        col_end = min(col_start + N_COL_STEPS, N_INPUT_STEPS)
        if col_end > col_start:
            ax4.barh(0, col_end - col_start, left=col_start, height=0.4,
                     color='orange', alpha=0.7, edgecolor='black')
            if col_end - col_start > 5:
                ax4.text((col_start + col_end)/2, 0, 'Col', ha='center', va='center',
                         color='white', fontweight='bold', fontsize=8)
    
    # Settle phase
    ax4.barh(0, N_SETTLE_STEPS, left=N_INPUT_STEPS, height=0.4,
             color='green', alpha=0.7, edgecolor='black')
    
    # Output marker
    ax4.axvline(x=OUTPUT_STEP - 0.5, color='red', linewidth=2, linestyle='--')
    ax4.scatter([OUTPUT_STEP - 0.5], [0], color='red', s=100, zorder=5, marker='v')
    ax4.text(OUTPUT_STEP, 0.3, f'Output\n(step {OUTPUT_STEP})', ha='center', fontsize=9, color='red')
    
    ax4.set_xlim(-1, N_INPUT_STEPS + N_SETTLE_STEPS + 2)
    ax4.set_ylim(-0.5, 0.5)
    ax4.set_xlabel('Timestep')
    ax4.set_title(f'Timeline: {N_INPUT_STEPS} input steps ({N_INPUT_STEPS/STEPS_PER_SWEEP:.1f} sweeps) + {N_SETTLE_STEPS} settle')
    ax4.set_yticks([])
    
    # Legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='blue', alpha=0.7, label='Row phase (20 steps)'),
        Patch(facecolor='orange', alpha=0.7, label='Col phase (20 steps)'),
        Patch(facecolor='green', alpha=0.7, label='Settle'),
    ]
    ax4.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.show()
    
    print("\nInput scheme summary:")
    print(f"  • Each neuron receives {WINDOW_SIZE} pixels at a time (hardware constraint)")
    print(f"  • Row phase: slide window across 28 columns in {N_ROW_STEPS} steps")
    print(f"  • Col phase: slide window down 28 rows in {N_COL_STEPS} steps")
    print(f"  • One sweep = {STEPS_PER_SWEEP} steps")
    print(f"  • Total: {N_INPUT_STEPS} steps = {N_INPUT_STEPS/STEPS_PER_SWEEP:.1f} sweeps")

visualize_sliding_window(train_data[0], train_labels[0])

## 4. Sliding Window SOEN Model

In [None]:
class SlidingWindowSOEN(nn.Module):
    """
    SOEN model with sliding window input.
    
    Each neuron receives 8 pixels at a time (hardware constraint).
    Window slides across rows, then columns, then repeats.
    """
    
    def __init__(self, hidden_dim=28, window_size=8, output_dim=10,
                 n_row_steps=20, n_col_steps=20,
                 n_input_steps=100, n_settle_steps=1, output_step=101,
                 dt=0.1, gamma_plus=0.1, gamma_minus=0.01):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.window_size = window_size
        self.output_dim = output_dim
        self.n_row_steps = n_row_steps
        self.n_col_steps = n_col_steps
        self.steps_per_sweep = n_row_steps + n_col_steps
        self.n_input_steps = n_input_steps
        self.n_settle_steps = n_settle_steps
        self.output_step = output_step
        self.dt = dt
        self.gamma_plus = gamma_plus
        self.gamma_minus = gamma_minus
        
        # Input weights: each neuron has weights for 8-pixel window
        self.W_i2h = nn.Parameter(torch.empty(hidden_dim, window_size))  # (28, 8)
        
        # Recurrent weights
        self.W_h2h = nn.Parameter(torch.empty(hidden_dim, hidden_dim))  # (28, 28)
        
        # Output weights
        self.W_h2o = nn.Parameter(torch.empty(output_dim, hidden_dim))  # (10, 28)
        
        # Biases
        self.bias_h = nn.Parameter(torch.zeros(hidden_dim))
        self.bias_o = nn.Parameter(torch.zeros(output_dim))
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.uniform_(self.W_i2h, -0.2, 0.2)
        nn.init.normal_(self.W_h2h, 0, 0.1)
        nn.init.normal_(self.W_h2o, 0, 0.2)
        with torch.no_grad():
            self.W_h2h.fill_diagonal_(0)
    
    def source_function(self, phi):
        return torch.sigmoid(5 * phi)
    
    def get_window_input(self, images, step):
        """
        Extract 8-pixel window for each neuron at given step.
        
        Args:
            images: (batch, 28, 28)
            step: Current timestep (0-indexed)
        
        Returns:
            window: (batch, 28, 8) - 28 neurons, 8 inputs each
        """
        batch_size = images.shape[0]
        step_in_sweep = step % self.steps_per_sweep
        
        if step_in_sweep < self.n_row_steps:
            # ROW PHASE: Each neuron i receives window from row i
            # Window slides across columns
            window_start = step_in_sweep
            window_end = window_start + self.window_size
            
            # Handle edge case: clamp to valid range
            if window_end > 28:
                window_end = 28
                window_start = window_end - self.window_size
            
            # Extract: images[:, row, window_start:window_end]
            # For all rows at once: images[:, :, window_start:window_end]
            window = images[:, :, window_start:window_end]  # (batch, 28, 8)
            
        else:
            # COLUMN PHASE: Each neuron i receives window from column i
            # Window slides down rows
            col_step = step_in_sweep - self.n_row_steps
            window_start = col_step
            window_end = window_start + self.window_size
            
            if window_end > 28:
                window_end = 28
                window_start = window_end - self.window_size
            
            # Extract: images[:, window_start:window_end, col]
            # For all columns: images[:, window_start:window_end, :]
            # Then transpose to get (batch, 28 cols, 8 rows)
            window = images[:, window_start:window_end, :].transpose(1, 2)  # (batch, 28, 8)
        
        return window
    
    def step(self, s, window_input=None):
        """
        Single timestep update.
        
        Args:
            s: Hidden state (batch, 28)
            window_input: (batch, 28, 8) or None for settle phase
        """
        # Input contribution: each neuron applies its weights to its 8-pixel window
        if window_input is not None:
            # window_input: (batch, 28, 8)
            # W_i2h: (28, 8)
            # For each neuron i: sum_j(W_i2h[i,j] * window_input[b,i,j])
            input_contrib = (window_input * self.W_i2h.unsqueeze(0)).sum(dim=2)  # (batch, 28)
        else:
            input_contrib = 0
        
        # Recurrent contribution
        recurrent_contrib = F.linear(s, self.W_h2h)  # (batch, 28)
        
        # Total flux
        phi = input_contrib + recurrent_contrib + self.bias_h
        
        # SingleDendrite dynamics
        g = self.source_function(phi)
        dsdt = self.gamma_plus * g - self.gamma_minus * s
        s_new = s + self.dt * dsdt
        
        return s_new
    
    def forward(self, images):
        """
        Forward pass with sliding window input.
        
        Args:
            images: (batch, 28, 28)
        
        Returns:
            output: (batch, 10)
            states: Dict with intermediate states
        """
        batch_size = images.shape[0]
        s = torch.zeros(batch_size, self.hidden_dim, device=images.device)
        
        all_states = [s.clone()]
        all_outputs = []
        
        # INPUT PHASE: n_input_steps with sliding window
        for t in range(self.n_input_steps):
            window = self.get_window_input(images, t)
            s = self.step(s, window)
            all_states.append(s.clone())
            all_outputs.append(F.linear(s, self.W_h2o, self.bias_o))
        
        # SETTLE PHASE: no input
        for t in range(self.n_settle_steps):
            s = self.step(s, window_input=None)
            all_states.append(s.clone())
            all_outputs.append(F.linear(s, self.W_h2o, self.bias_o))
        
        # Get output at specified step (1-indexed)
        output_idx = min(self.output_step - 1, len(all_outputs) - 1)
        output = all_outputs[output_idx]
        
        return output, {
            'all_states': all_states,
            'all_outputs': all_outputs,
            'final_state': s
        }

# Create model
model = SlidingWindowSOEN(
    hidden_dim=HIDDEN_DIM,
    window_size=WINDOW_SIZE,
    output_dim=OUTPUT_DIM,
    n_row_steps=N_ROW_STEPS,
    n_col_steps=N_COL_STEPS,
    n_input_steps=N_INPUT_STEPS,
    n_settle_steps=N_SETTLE_STEPS,
    output_step=OUTPUT_STEP,
    dt=DT,
    gamma_plus=GAMMA_PLUS,
    gamma_minus=GAMMA_MINUS
).to(device)

print(f"Model created")
print(f"  W_i2h: {model.W_i2h.shape} (each neuron has {WINDOW_SIZE} input weights)")
print(f"  W_h2h: {model.W_h2h.shape} (recurrent)")
print(f"  W_h2o: {model.W_h2o.shape} (output)")
print(f"  Parameters: {sum(p.numel() for p in model.parameters())}")

## 5. Verify Window Extraction

In [None]:
def verify_window_extraction(model, image):
    """Verify that window extraction is working correctly."""
    x = torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device)
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    fig.suptitle('Window Extraction Verification', fontsize=14)
    
    # Row phase examples
    for i, step in enumerate([0, 5, 10, 15, 19]):
        window = model.get_window_input(x, step).squeeze().cpu().numpy()
        axes[0, i].imshow(window, cmap='viridis', aspect='auto')
        axes[0, i].set_title(f'Row step {step}')
        axes[0, i].set_xlabel('Pixel in window')
        if i == 0:
            axes[0, i].set_ylabel('Neuron (row)')
    
    # Column phase examples
    for i, step in enumerate([20, 25, 30, 35, 39]):
        window = model.get_window_input(x, step).squeeze().cpu().numpy()
        axes[1, i].imshow(window, cmap='viridis', aspect='auto')
        axes[1, i].set_title(f'Col step {step}')
        axes[1, i].set_xlabel('Pixel in window')
        if i == 0:
            axes[1, i].set_ylabel('Neuron (col)')
    
    plt.tight_layout()
    plt.show()

verify_window_extraction(model, train_data[0])

## 6. Training

In [None]:
def train_model(model, train_data, train_labels, val_data, val_labels,
                epochs=30, batch_size=128, lr=0.005):
    """
    Train the sliding window SOEN model.
    """
    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.float32),
        torch.tensor(train_labels, dtype=torch.long)
    )
    val_dataset = TensorDataset(
        torch.tensor(val_data, dtype=torch.float32),
        torch.tensor(val_labels, dtype=torch.long)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    best_state = None
    
    print("="*60)
    print("SLIDING WINDOW SOEN TRAINING")
    print("="*60)
    print(f"Window size: {model.window_size} pixels (hardware constraint)")
    print(f"Sweep: {model.n_row_steps} row + {model.n_col_steps} col = {model.steps_per_sweep} steps")
    print(f"Total: {model.n_input_steps} input + {model.n_settle_steps} settle steps")
    print(f"Output at step: {model.output_step}")
    print("="*60)
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, labels in pbar:
            x, labels = x.to(device), labels.to(device)
            
            optimizer.zero_grad()
            output, _ = model(x)
            loss = F.cross_entropy(output, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            with torch.no_grad():
                model.W_h2h.fill_diagonal_(0)
            
            pred = output.argmax(dim=1)
            epoch_correct += (pred == labels).sum().item()
            epoch_total += len(labels)
            epoch_loss += loss.item() * len(labels)
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{epoch_correct/epoch_total:.3f}'})
        
        scheduler.step()
        
        train_loss = epoch_loss / epoch_total
        train_acc = epoch_correct / epoch_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for x, labels in val_loader:
                x, labels = x.to(device), labels.to(device)
                output, _ = model(x)
                loss = F.cross_entropy(output, labels)
                val_loss += loss.item() * len(labels)
                val_correct += (output.argmax(dim=1) == labels).sum().item()
                val_total += len(labels)
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f} {'*' if val_acc == best_val_acc else ''}")
    
    if best_state:
        model.load_state_dict(best_state)
    print(f"\nBest validation accuracy: {best_val_acc:.4f}")
    
    return history

history = train_model(model, train_data, train_labels, val_data, val_labels,
                      epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR)

## 7. Visualize Training

In [None]:
def plot_training(history):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss (Sliding Window SOEN)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy (Sliding Window SOEN)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training(history)

## 8. Evaluate

In [None]:
@torch.no_grad()
def evaluate(model, test_data, test_labels):
    model.eval()
    
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.float32),
        torch.tensor(test_labels, dtype=torch.long)
    )
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    for x, labels in tqdm(test_loader, desc="Testing"):
        x = x.to(device)
        output, _ = model(x)
        all_preds.append(output.argmax(dim=1).cpu())
        all_labels.append(labels)
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    accuracy = (all_preds == all_labels).float().mean().item()
    
    print(f"\n{'='*60}")
    print(f"TEST ACCURACY (Sliding Window SOEN): {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    return accuracy

test_acc = evaluate(model, test_data, test_labels)

## 9. Visualize Dynamics

In [None]:
def visualize_dynamics(model, image, label):
    model.eval()
    x = torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output, states = model(x)
    
    all_states = torch.stack(states['all_states']).squeeze().cpu().numpy()
    all_outputs = torch.stack(states['all_outputs']).squeeze().cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original image
    axes[0, 0].imshow(image, cmap='gray')
    axes[0, 0].set_title(f'Input Image (Label: {label})')
    
    # Hidden state evolution
    im = axes[0, 1].imshow(all_states.T, aspect='auto', cmap='viridis')
    # Mark sweep boundaries
    for sweep in range(3):
        pos = sweep * model.steps_per_sweep
        if pos < model.n_input_steps:
            axes[0, 1].axvline(x=pos + 0.5, color='white', linestyle=':', alpha=0.5)
        pos = sweep * model.steps_per_sweep + model.n_row_steps
        if pos < model.n_input_steps:
            axes[0, 1].axvline(x=pos + 0.5, color='red', linestyle='--', alpha=0.5)
    axes[0, 1].axvline(x=model.n_input_steps + 0.5, color='green', linestyle='--', label='Settle')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('Neuron')
    axes[0, 1].set_title('Hidden State Evolution')
    plt.colorbar(im, ax=axes[0, 1])
    
    # Output evolution
    for i in range(10):
        axes[1, 0].plot(all_outputs[:, i], label=f'{i}', alpha=0.7)
    axes[1, 0].axvline(x=model.n_input_steps - 0.5, color='green', linestyle='--')
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Logit')
    axes[1, 0].set_title('Output Logits Over Time')
    axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    # Final prediction
    pred = output.argmax(dim=1).item()
    probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
    colors = ['green' if i == label else 'blue' for i in range(10)]
    colors[pred] = 'red' if pred != label else 'green'
    axes[1, 1].bar(range(10), probs, color=colors)
    axes[1, 1].set_xlabel('Class')
    axes[1, 1].set_ylabel('Probability')
    axes[1, 1].set_title(f'Prediction: {pred} (True: {label}) {"✓" if pred == label else "✗"}')
    axes[1, 1].set_xticks(range(10))
    
    plt.tight_layout()
    plt.show()

for i in range(3):
    visualize_dynamics(model, test_data[i], test_labels[i])

## 10. Explore Output Steps

In [None]:
def explore_output_steps(model, test_data, test_labels, max_samples=1000):
    model.eval()
    
    x = torch.tensor(test_data[:max_samples], dtype=torch.float32).to(device)
    labels = torch.tensor(test_labels[:max_samples], dtype=torch.long)
    
    with torch.no_grad():
        _, states = model(x)
    
    all_outputs = states['all_outputs']
    
    # Sample steps to show
    sample_steps = list(range(0, len(all_outputs), 10)) + [len(all_outputs) - 1]
    sample_steps = sorted(set(sample_steps))
    
    results = []
    for step in sample_steps:
        output = all_outputs[step]
        pred = output.argmax(dim=1).cpu()
        acc = (pred == labels).float().mean().item()
        
        if step < model.n_input_steps:
            step_in_sweep = step % model.steps_per_sweep
            if step_in_sweep < model.n_row_steps:
                phase = 'row'
            else:
                phase = 'col'
        else:
            phase = 'settle'
        
        results.append((step + 1, acc, phase))  # 1-indexed
        print(f"Step {step+1:3d} ({phase:6s}): {acc:.4f}")
    
    # Plot
    steps = [r[0] for r in results]
    accs = [r[1] for r in results]
    colors = {'row': 'blue', 'col': 'orange', 'settle': 'green'}
    bar_colors = [colors[r[2]] for r in results]
    
    plt.figure(figsize=(12, 5))
    plt.bar(range(len(steps)), accs, color=bar_colors)
    plt.xticks(range(len(steps)), steps)
    plt.xlabel('Output Step')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs Output Step')
    
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='blue', label='Row phase'),
        Patch(facecolor='orange', label='Col phase'),
        Patch(facecolor='green', label='Settle'),
    ]
    plt.legend(handles=legend_elements)
    plt.grid(True, alpha=0.3, axis='y')
    plt.show()
    
    best_idx = np.argmax(accs)
    print(f"\nBest step: {steps[best_idx]} (accuracy: {accs[best_idx]:.4f})")

explore_output_steps(model, test_data, test_labels)

## Summary

| Aspect | Value |
|--------|-------|
| **Hardware constraint** | 8 inputs per neuron |
| Window size | 8 pixels |
| Row steps | 20 |
| Column steps | 20 |
| Steps per sweep | 40 |
| Total input steps | 100 (2.5 sweeps) |
| Settle steps | 1 |

### Key Features

1. **Hardware compatible**: Each neuron only receives 8 inputs at a time
2. **Redundant coverage**: Multiple sweeps see the same data
3. **Row + Column**: Both orientations captured
4. **Sliding window**: Overlapping windows for smooth transitions