# Tutorial 03 — MNIST Sliding Window + Hierarchical (14→14→10)

**Two-layer hierarchical architecture** with 14 neurons per layer.

## Architecture

```
Input (8 pixels)
       ↓
Hidden1 (14 neurons) ←── Each neuron handles 2 adjacent rows
       ↓ ↺ recurrent
Hidden2 (14 neurons) ←── Processes Hidden1 output
       ↓ ↺ recurrent
Output (10 neurons)
```

## Row Assignment

- Neuron 0: rows 0-1
- Neuron 1: rows 2-3
- ...
- Neuron 13: rows 26-27

Each Hidden1 neuron receives 2×8 = 16 inputs (2 rows × 8 window pixels).

In [None]:
import os
import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import gzip
import urllib.request
import struct
from tqdm import tqdm

torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")

## 1. Hyperparameters

In [None]:
# ============================================================
# KEY HYPERPARAMETERS
# ============================================================

# Sliding window
WINDOW_SIZE = 8
N_ROW_STEPS = 20
N_COL_STEPS = 20
STEPS_PER_SWEEP = N_ROW_STEPS + N_COL_STEPS  # 40

# Timing
N_INPUT_STEPS = 100
N_SETTLE_STEPS = 1
OUTPUT_STEP = 101

# HIERARCHICAL: Two layers with 14 neurons each
HIDDEN1_DIM = 14  # First hidden layer
HIDDEN2_DIM = 14  # Second hidden layer
ROWS_PER_NEURON = 2  # 28 rows / 14 neurons = 2 rows each

OUTPUT_DIM = 10

# SOEN dynamics
DT = 0.1
GAMMA_PLUS = 0.1
GAMMA_MINUS = 0.01

# Training
BATCH_SIZE = 128
EPOCHS = 30
LR = 0.005

print(f"Hierarchical: {HIDDEN1_DIM} → {HIDDEN2_DIM} → {OUTPUT_DIM}")
print(f"Each Hidden1 neuron: {ROWS_PER_NEURON} rows × {WINDOW_SIZE} pixels = {ROWS_PER_NEURON * WINDOW_SIZE} inputs")

## 2. Load MNIST

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def load_mnist():
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = 6000
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    print(f"Train: {train_img.shape}, Val: {val_img.shape}, Test: {test_img.shape}")
    return (train_img, train_lbl), (val_img, val_lbl), (test_img, test_lbl)

(train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = load_mnist()

## 3. Visualize Hierarchical Architecture

In [None]:
def visualize_architecture():
    """Visualize the hierarchical architecture."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Left: Row assignment
    ax = axes[0]
    row_colors = np.zeros((28, 28))
    for neuron in range(14):
        row_start = neuron * 2
        row_end = row_start + 2
        row_colors[row_start:row_end, :] = neuron
    
    ax.imshow(row_colors, cmap='tab20', aspect='auto')
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')
    ax.set_title('Row → Hidden1 Neuron Assignment')
    
    # Add neuron labels
    for neuron in range(14):
        row = neuron * 2 + 0.5
        ax.text(29, row, f'N{neuron}', va='center', fontsize=8)
    
    # Right: Architecture diagram
    ax = axes[1]
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    
    # Layers
    layers = [
        {'name': 'Input', 'x': 1, 'y': 5, 'size': '8 px', 'color': 'lightblue'},
        {'name': 'Hidden1', 'x': 4, 'y': 5, 'size': '14', 'color': 'lightgreen'},
        {'name': 'Hidden2', 'x': 7, 'y': 5, 'size': '14', 'color': 'lightyellow'},
        {'name': 'Output', 'x': 10, 'y': 5, 'size': '10', 'color': 'lightcoral'},
    ]
    
    for layer in layers:
        circle = plt.Circle((layer['x'], layer['y']), 0.8, 
                            color=layer['color'], ec='black', linewidth=2)
        ax.add_patch(circle)
        ax.text(layer['x'], layer['y'], layer['size'], ha='center', va='center', fontsize=12, fontweight='bold')
        ax.text(layer['x'], layer['y'] - 1.3, layer['name'], ha='center', fontsize=10)
    
    # Arrows
    for i in range(len(layers) - 1):
        ax.annotate('', xy=(layers[i+1]['x'] - 0.9, layers[i+1]['y']),
                    xytext=(layers[i]['x'] + 0.9, layers[i]['y']),
                    arrowprops=dict(arrowstyle='->', color='black', lw=2))
    
    # Recurrent arrows
    for layer in layers[1:3]:  # Hidden1 and Hidden2
        ax.annotate('', xy=(layer['x'] - 0.3, layer['y'] + 0.9),
                    xytext=(layer['x'] + 0.3, layer['y'] + 0.9),
                    arrowprops=dict(arrowstyle='->', color='blue', lw=1.5,
                                    connectionstyle='arc3,rad=-0.5'))
    
    ax.text(4, 6.5, 'recurrent', fontsize=8, color='blue', ha='center')
    ax.text(7, 6.5, 'recurrent', fontsize=8, color='blue', ha='center')
    
    ax.set_xlim(-0.5, 11.5)
    ax.set_ylim(2, 8)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('Hierarchical Architecture: 14 → 14 → 10')
    
    plt.tight_layout()
    plt.show()
    
    print("\nArchitecture summary:")
    print(f"  Input: {WINDOW_SIZE} pixels (sliding window)")
    print(f"  Hidden1: {HIDDEN1_DIM} neurons (each handles {ROWS_PER_NEURON} rows)")
    print(f"  Hidden2: {HIDDEN2_DIM} neurons (processes Hidden1 output)")
    print(f"  Output: {OUTPUT_DIM} neurons")
    print(f"\nConnections:")
    print(f"  Input → Hidden1: each neuron gets {ROWS_PER_NEURON}×{WINDOW_SIZE}={ROWS_PER_NEURON*WINDOW_SIZE} inputs")
    print(f"  Hidden1 → Hidden2: 14 → 14 (dense)")
    print(f"  Hidden2 → Output: 14 → 10 (dense)")

visualize_architecture()

## 4. Hierarchical Sliding Window SOEN

In [None]:
class HierarchicalSlidingSOEN(nn.Module):
    """
    Two-layer hierarchical SOEN with sliding window input.
    
    Architecture: Input → Hidden1 (14) → Hidden2 (14) → Output (10)
    
    Each Hidden1 neuron handles 2 adjacent rows of the image.
    """
    
    def __init__(self, hidden1_dim=14, hidden2_dim=14, window_size=8, output_dim=10,
                 rows_per_neuron=2, n_row_steps=20, n_col_steps=20,
                 n_input_steps=100, n_settle_steps=1, output_step=101,
                 dt=0.1, gamma_plus=0.1, gamma_minus=0.01):
        super().__init__()
        
        self.hidden1_dim = hidden1_dim
        self.hidden2_dim = hidden2_dim
        self.window_size = window_size
        self.output_dim = output_dim
        self.rows_per_neuron = rows_per_neuron
        self.input_per_neuron = rows_per_neuron * window_size  # 2 * 8 = 16
        self.n_row_steps = n_row_steps
        self.n_col_steps = n_col_steps
        self.steps_per_sweep = n_row_steps + n_col_steps
        self.n_input_steps = n_input_steps
        self.n_settle_steps = n_settle_steps
        self.output_step = output_step
        self.dt = dt
        self.gamma_plus = gamma_plus
        self.gamma_minus = gamma_minus
        
        # Layer 1 weights: Input → Hidden1
        # Each neuron gets rows_per_neuron × window_size inputs
        self.W_i2h1 = nn.Parameter(torch.empty(hidden1_dim, self.input_per_neuron))  # (14, 16)
        self.W_h1h1 = nn.Parameter(torch.empty(hidden1_dim, hidden1_dim))  # (14, 14) recurrent
        self.bias_h1 = nn.Parameter(torch.zeros(hidden1_dim))
        
        # Layer 2 weights: Hidden1 → Hidden2
        self.W_h1h2 = nn.Parameter(torch.empty(hidden2_dim, hidden1_dim))  # (14, 14)
        self.W_h2h2 = nn.Parameter(torch.empty(hidden2_dim, hidden2_dim))  # (14, 14) recurrent
        self.bias_h2 = nn.Parameter(torch.zeros(hidden2_dim))
        
        # Output weights: Hidden2 → Output
        self.W_h2o = nn.Parameter(torch.empty(output_dim, hidden2_dim))  # (10, 14)
        self.bias_o = nn.Parameter(torch.zeros(output_dim))
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.uniform_(self.W_i2h1, -0.2, 0.2)
        nn.init.normal_(self.W_h1h1, 0, 0.1)
        nn.init.normal_(self.W_h1h2, 0, 0.15)
        nn.init.normal_(self.W_h2h2, 0, 0.1)
        nn.init.normal_(self.W_h2o, 0, 0.2)
        with torch.no_grad():
            self.W_h1h1.fill_diagonal_(0)
            self.W_h2h2.fill_diagonal_(0)
    
    def source_function(self, phi):
        return torch.sigmoid(5 * phi)
    
    def get_hierarchical_input(self, images, step):
        """
        Extract input for hierarchical layer.
        
        Each of 14 neurons gets 2 adjacent rows × 8 pixel window = 16 inputs.
        
        Returns: (batch, 14, 16)
        """
        batch_size = images.shape[0]
        step_in_sweep = step % self.steps_per_sweep
        
        if step_in_sweep < self.n_row_steps:
            # ROW PHASE
            window_start = step_in_sweep
            window_end = min(window_start + self.window_size, 28)
            window_start = window_end - self.window_size
            
            # Get window for all rows: (batch, 28, 8)
            all_windows = images[:, :, window_start:window_end]
            
            # Group into pairs of rows for each neuron
            # Reshape: (batch, 14, 2, 8) then flatten last two dims → (batch, 14, 16)
            grouped = all_windows.reshape(batch_size, self.hidden1_dim, self.rows_per_neuron, self.window_size)
            hierarchical_input = grouped.reshape(batch_size, self.hidden1_dim, -1)
            
        else:
            # COLUMN PHASE
            col_step = step_in_sweep - self.n_row_steps
            window_start = col_step
            window_end = min(window_start + self.window_size, 28)
            window_start = window_end - self.window_size
            
            # Get window for all columns: (batch, 8, 28) then transpose → (batch, 28, 8)
            all_windows = images[:, window_start:window_end, :].transpose(1, 2)
            
            # Group into pairs
            grouped = all_windows.reshape(batch_size, self.hidden1_dim, self.rows_per_neuron, self.window_size)
            hierarchical_input = grouped.reshape(batch_size, self.hidden1_dim, -1)
        
        return hierarchical_input  # (batch, 14, 16)
    
    def step(self, s1, s2, h_input=None):
        """
        Single timestep update for both hidden layers.
        
        Args:
            s1: Hidden1 state (batch, 14)
            s2: Hidden2 state (batch, 14)
            h_input: Hierarchical input (batch, 14, 16) or None
        
        Returns:
            s1_new, s2_new
        """
        # LAYER 1 UPDATE
        if h_input is not None:
            # Each neuron applies its weights to its 16 inputs
            # h_input: (batch, 14, 16), W_i2h1: (14, 16)
            input_contrib1 = (h_input * self.W_i2h1.unsqueeze(0)).sum(dim=2)  # (batch, 14)
        else:
            input_contrib1 = 0
        
        recurrent1 = F.linear(s1, self.W_h1h1)
        phi1 = input_contrib1 + recurrent1 + self.bias_h1
        g1 = self.source_function(phi1)
        ds1 = self.gamma_plus * g1 - self.gamma_minus * s1
        s1_new = s1 + self.dt * ds1
        
        # LAYER 2 UPDATE (receives from layer 1)
        forward12 = F.linear(s1_new, self.W_h1h2)  # (batch, 14)
        recurrent2 = F.linear(s2, self.W_h2h2)
        phi2 = forward12 + recurrent2 + self.bias_h2
        g2 = self.source_function(phi2)
        ds2 = self.gamma_plus * g2 - self.gamma_minus * s2
        s2_new = s2 + self.dt * ds2
        
        return s1_new, s2_new
    
    def forward(self, images, return_all=False):
        """
        Forward pass.
        """
        batch_size = images.shape[0]
        
        # Initialize states
        s1 = torch.zeros(batch_size, self.hidden1_dim, device=images.device)
        s2 = torch.zeros(batch_size, self.hidden2_dim, device=images.device)
        
        all_outputs = []
        all_s1 = [] if return_all else None
        all_s2 = [] if return_all else None
        
        # INPUT PHASE
        for t in range(self.n_input_steps):
            h_input = self.get_hierarchical_input(images, t)
            s1, s2 = self.step(s1, s2, h_input)
            output = F.linear(s2, self.W_h2o, self.bias_o)
            all_outputs.append(output)
            
            if return_all:
                all_s1.append(s1.clone())
                all_s2.append(s2.clone())
        
        # SETTLE PHASE
        for t in range(self.n_settle_steps):
            s1, s2 = self.step(s1, s2, h_input=None)
            output = F.linear(s2, self.W_h2o, self.bias_o)
            all_outputs.append(output)
            
            if return_all:
                all_s1.append(s1.clone())
                all_s2.append(s2.clone())
        
        # Get output at specified step
        output_idx = min(self.output_step - 1, len(all_outputs) - 1)
        final_output = all_outputs[output_idx]
        
        return final_output, {
            'all_outputs': all_outputs,
            'all_s1': all_s1,
            'all_s2': all_s2,
            'final_s1': s1,
            'final_s2': s2
        }

# Create model
model = HierarchicalSlidingSOEN(
    hidden1_dim=HIDDEN1_DIM,
    hidden2_dim=HIDDEN2_DIM,
    window_size=WINDOW_SIZE,
    output_dim=OUTPUT_DIM,
    rows_per_neuron=ROWS_PER_NEURON,
    n_row_steps=N_ROW_STEPS,
    n_col_steps=N_COL_STEPS,
    n_input_steps=N_INPUT_STEPS,
    n_settle_steps=N_SETTLE_STEPS,
    output_step=OUTPUT_STEP,
    dt=DT,
    gamma_plus=GAMMA_PLUS,
    gamma_minus=GAMMA_MINUS
).to(device)

print(f"Hierarchical Model Created")
print(f"  W_i2h1: {model.W_i2h1.shape} (input → hidden1)")
print(f"  W_h1h1: {model.W_h1h1.shape} (hidden1 recurrent)")
print(f"  W_h1h2: {model.W_h1h2.shape} (hidden1 → hidden2)")
print(f"  W_h2h2: {model.W_h2h2.shape} (hidden2 recurrent)")
print(f"  W_h2o:  {model.W_h2o.shape} (hidden2 → output)")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters())}")

## 5. Training

In [None]:
def train_model(model, train_data, train_labels, val_data, val_labels,
                epochs=30, batch_size=128, lr=0.005):
    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.float32),
        torch.tensor(train_labels, dtype=torch.long)
    )
    val_dataset = TensorDataset(
        torch.tensor(val_data, dtype=torch.float32),
        torch.tensor(val_labels, dtype=torch.long)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    best_state = None
    
    print("="*60)
    print("HIERARCHICAL SLIDING WINDOW TRAINING")
    print("="*60)
    print(f"Architecture: {model.hidden1_dim} → {model.hidden2_dim} → {model.output_dim}")
    print(f"Each Hidden1 neuron: {model.rows_per_neuron} rows × {model.window_size} pixels")
    print("="*60)
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, labels in pbar:
            x, labels = x.to(device), labels.to(device)
            
            optimizer.zero_grad()
            output, _ = model(x)
            loss = F.cross_entropy(output, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            with torch.no_grad():
                model.W_h1h1.fill_diagonal_(0)
                model.W_h2h2.fill_diagonal_(0)
            
            pred = output.argmax(dim=1)
            epoch_correct += (pred == labels).sum().item()
            epoch_total += len(labels)
            epoch_loss += loss.item() * len(labels)
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{epoch_correct/epoch_total:.3f}'})
        
        scheduler.step()
        
        train_loss = epoch_loss / epoch_total
        train_acc = epoch_correct / epoch_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for x, labels in val_loader:
                x, labels = x.to(device), labels.to(device)
                output, _ = model(x)
                loss = F.cross_entropy(output, labels)
                val_loss += loss.item() * len(labels)
                val_correct += (output.argmax(dim=1) == labels).sum().item()
                val_total += len(labels)
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f} {'*' if val_acc == best_val_acc else ''}")
    
    if best_state:
        model.load_state_dict(best_state)
    print(f"\nBest validation accuracy: {best_val_acc:.4f}")
    
    return history

history = train_model(model, train_data, train_labels, val_data, val_labels,
                      epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR)

## 6. Visualize Training

In [None]:
def plot_training(history):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss (Hierarchical 14→14)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy (Hierarchical 14→14)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training(history)

## 7. Evaluate

In [None]:
@torch.no_grad()
def evaluate(model, test_data, test_labels):
    model.eval()
    
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.float32),
        torch.tensor(test_labels, dtype=torch.long)
    )
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    for x, labels in tqdm(test_loader, desc="Testing"):
        x = x.to(device)
        output, _ = model(x)
        all_preds.append(output.argmax(dim=1).cpu())
        all_labels.append(labels)
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    accuracy = (all_preds == all_labels).float().mean().item()
    
    print(f"\n{'='*60}")
    print(f"TEST ACCURACY (Hierarchical 14→14): {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    return accuracy

test_acc = evaluate(model, test_data, test_labels)

## 8. Visualize Layer Activations

In [None]:
def visualize_layers(model, image, label):
    """Visualize activations in both hidden layers."""
    model.eval()
    x = torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output, states = model(x, return_all=True)
    
    all_s1 = torch.stack(states['all_s1']).squeeze().cpu().numpy()
    all_s2 = torch.stack(states['all_s2']).squeeze().cpu().numpy()
    all_outputs = torch.stack(states['all_outputs']).squeeze().cpu().numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original image
    axes[0, 0].imshow(image, cmap='gray')
    axes[0, 0].set_title(f'Input (Label: {label})')
    
    # Hidden1 activations
    im1 = axes[0, 1].imshow(all_s1.T, aspect='auto', cmap='viridis')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('Hidden1 Neuron')
    axes[0, 1].set_title('Hidden Layer 1 (14 neurons)')
    plt.colorbar(im1, ax=axes[0, 1])
    
    # Hidden2 activations
    im2 = axes[1, 0].imshow(all_s2.T, aspect='auto', cmap='viridis')
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Hidden2 Neuron')
    axes[1, 0].set_title('Hidden Layer 2 (14 neurons)')
    plt.colorbar(im2, ax=axes[1, 0])
    
    # Output evolution
    for i in range(10):
        axes[1, 1].plot(all_outputs[:, i], label=f'{i}', alpha=0.7)
    axes[1, 1].axvline(x=model.n_input_steps - 0.5, color='green', linestyle='--')
    axes[1, 1].set_xlabel('Timestep')
    axes[1, 1].set_ylabel('Logit')
    axes[1, 1].set_title('Output Evolution')
    axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    plt.suptitle(f'Hierarchical Network Activations (True: {label}, Pred: {output.argmax().item()})',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

for i in range(3):
    visualize_layers(model, test_data[i], test_labels[i])

## Summary

| Aspect | Single Layer (28) | Hierarchical (14→14) |
|--------|------------------|----------------------|
| Hidden neurons | 28 | 14 + 14 = 28 |
| Layers | 1 | 2 |
| Input per neuron | 8 | 16 (2 rows × 8) |
| Feature hierarchy | None | Low → High level |

### Why Hierarchical Helps

1. **Feature composition**: Layer 2 learns combinations of Layer 1 features
2. **Grouped input**: Each neuron sees 2 adjacent rows (local receptive field)
3. **More processing depth**: Gradients flow through 2 layers of dynamics