# Tutorial 03b â€” MNIST Classification with SOEN (7Ã—112 Format)

This notebook demonstrates training a SOEN (Superconducting Optoelectronic Network) model on the MNIST digit classification task using a **7Ã—112 input format** designed for neurons with **8 maximum dendrites**.

---

## ðŸ”Š NOISE CONFIGURATION: ENABLED (Default)

> **This tutorial runs with NOISE INJECTION by default (documented values).**
>
> | Parameter | Default | Description |
> |-----------|---------|-------------|
> | `phi` | **0.01** | Noise on input flux |
> | `s` | **0.005** | Noise on state |
> | `relative` | **false** | Absolute scaling |
>
> **To toggle noise on/off**, use the `NOISE_ENABLED` setting in the Configuration cell below.

---

## Key Differences from Standard Tutorial 03

| Aspect | Standard (28Ã—28) | This Version (7Ã—112) |
|--------|------------------|----------------------|
| **Input Shape** | 28 timesteps Ã— 28 features | 7 timesteps Ã— 112 features |
| **Row Grouping** | 1 row per timestep | 4 rows per timestep |
| **Dendrite Compatibility** | N/A | 112 / 8 = 14 inputs per dendrite |
| **Temporal Steps** | 28 | 7 |

## Architecture

```
Input (112D) â†’ SingleDendrite (128D) â†’ Output (10D)
     â†“               â†“ â†º                    â†“
  4 rowsÃ—28px    Recurrent SOEN       10 digit classes
  (7 timesteps)  (8 dendrite ready)
```

## Why 7Ã—112?

- **8 Dendrite Compatibility**: 112 features / 8 dendrites = 14 inputs per dendrite
- **Shorter Sequences**: 7 timesteps vs 28 timesteps (4Ã— faster processing)
- **Spatial Context**: Each timestep sees 4 rows at once (more spatial context)

## Setup

In [None]:
# Disable tqdm notebook widgets BEFORE any imports
import os
os.environ["TQDM_DISABLE"] = "0"  # Don't disable, but force text mode
os.environ["TQDM_MININTERVAL"] = "1"

# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import glob
import gzip
import urllib.request
import struct

# Use standard tqdm (not notebook version to avoid widget errors)
try:
    from tqdm import tqdm
except ImportError:
    # Simple fallback if tqdm not available
    def tqdm(iterable, **kwargs):
        return iterable

# Set torch precision for H100 tensor cores
torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================================================
# NOISE CONFIGURATION TOGGLE
# ============================================================================
# Set NOISE_ENABLED to control noise injection:
#   True  = Noise enabled (default) - adds stochastic noise per timestep
#   False = Ideal conditions - no noise, deterministic behavior
# ============================================================================

NOISE_ENABLED = True  # Toggle this to enable/disable noise

# Default noise values from documentation (used when NOISE_ENABLED = True)
NOISE_DEFAULTS = {
    "phi": 0.01,           # Noise on input flux
    "s": 0.005,            # Noise on state
    "g": 0.0,              # Source function noise
    "bias_current": 0.0,   # Bias current noise
    "j": 0.0,              # Connection weight noise
    "relative": False,     # Absolute scaling
}

def set_model_noise(model, enabled=True, noise_values=None):
    """
    Toggle noise injection on/off for a SOEN model.
    
    Args:
        model: SOENModelCore instance
        enabled: True to enable noise, False to disable
        noise_values: Optional dict of noise values (uses NOISE_DEFAULTS if None)
    
    Returns:
        model: The modified model (in-place modification)
    """
    from soen_toolkit.core.configs import NoiseConfig
    
    if noise_values is None:
        noise_values = NOISE_DEFAULTS
    
    for cfg in model.layers_config:
        if enabled:
            # Apply noise values
            cfg.noise = NoiseConfig(
                phi=noise_values.get("phi", 0.01),
                s=noise_values.get("s", 0.005),
                g=noise_values.get("g", 0.0),
                bias_current=noise_values.get("bias_current", 0.0),
                j=noise_values.get("j", 0.0),
                relative=noise_values.get("relative", False),
                extras=getattr(cfg.noise, "extras", {}),
            )
        else:
            # Disable noise (all zeros)
            cfg.noise = NoiseConfig(
                phi=0.0, s=0.0, g=0.0, bias_current=0.0, j=0.0,
                relative=False,
                extras=getattr(cfg.noise, "extras", {}),
            )
    
    # Also update connection noise
    for conn_cfg in model.connections_config:
        if enabled:
            conn_cfg.noise = NoiseConfig(
                phi=0.0, g=0.0, s=0.0, bias_current=0.0,
                j=noise_values.get("j", 0.0),
                relative=noise_values.get("relative", False),
                extras={},
            )
        else:
            conn_cfg.noise = NoiseConfig(
                phi=0.0, g=0.0, s=0.0, bias_current=0.0, j=0.0,
                relative=False, extras={},
            )
    
    status = "ENABLED" if enabled else "DISABLED"
    print(f"âœ“ Noise injection {status}")
    if enabled:
        print(f"  phi={noise_values['phi']}, s={noise_values['s']}")
    
    return model

print(f"Noise configuration: {'ENABLED' if NOISE_ENABLED else 'DISABLED'}")

## 1. Prepare MNIST Dataset (7Ã—112 Format)

We'll download MNIST and convert it to HDF5 format with the **7Ã—112 reshaping**:

**Data format**: Images are reshaped from (28, 28) to (7, 112) where:
- Time dimension = 7 (groups of 4 rows)
- Feature dimension = 112 (4 rows Ã— 28 pixels per row)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    """Download a single MNIST file if not already present."""
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    
    filepath = data_dir / filename
    if not filepath.exists():
        url = base_url + filename
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
    return filepath

def read_mnist_images(filepath):
    """Read MNIST image file (idx3-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        # Read magic number and dimensions
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        # Read image data
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
    return images

def read_mnist_labels(filepath):
    """Read MNIST label file (idx1-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        # Read magic number and count
        magic, num_labels = struct.unpack('>II', f.read(8))
        # Read labels
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

def reshape_to_7x112(images):
    """
    Reshape MNIST images from (N, 28, 28) to (N, 7, 112).
    
    Groups 4 consecutive rows into each timestep:
    - Timestep 0: rows 0-3 (4 Ã— 28 = 112 features)
    - Timestep 1: rows 4-7 (4 Ã— 28 = 112 features)
    - ...
    - Timestep 6: rows 24-27 (4 Ã— 28 = 112 features)
    
    This format is designed for 8 dendrites: 112 / 8 = 14 inputs per dendrite
    """
    n_samples = images.shape[0]
    # Reshape: (N, 28, 28) -> (N, 7, 4, 28) -> (N, 7, 112)
    reshaped = images.reshape(n_samples, 7, 4, 28)
    reshaped = reshaped.reshape(n_samples, 7, 112)
    return reshaped

def prepare_mnist_hdf5_7x112(output_path="training/datasets/mnist_seq7x112.hdf5", 
                             normalize=True,
                             val_split=0.1):
    """
    Download MNIST and save as HDF5 for SOEN training with 7Ã—112 format.
    
    Args:
        output_path: Where to save the HDF5 file
        normalize: Whether to normalize pixel values to [0, 1]
        val_split: Fraction of training data to use for validation
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Check if already exists
    if output_path.exists():
        print(f"Dataset already exists at {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Train samples: {len(f['train']['labels'])}")
            print(f"  Val samples: {len(f['val']['labels'])}")
            print(f"  Test samples: {len(f['test']['labels'])}")
            print(f"  Data shape: {f['train']['data'].shape}")
        return output_path
    
    print("Downloading MNIST (without torchvision)...")
    
    # Download MNIST files
    train_images_file = download_mnist_file("train-images-idx3-ubyte.gz")
    train_labels_file = download_mnist_file("train-labels-idx1-ubyte.gz")
    test_images_file = download_mnist_file("t10k-images-idx3-ubyte.gz")
    test_labels_file = download_mnist_file("t10k-labels-idx1-ubyte.gz")
    
    # Read the data
    train_images = read_mnist_images(train_images_file).astype(np.float32)
    train_labels = read_mnist_labels(train_labels_file).astype(np.int64)
    test_images = read_mnist_images(test_images_file).astype(np.float32)
    test_labels = read_mnist_labels(test_labels_file).astype(np.int64)
    
    # Normalize to [0, 1]
    if normalize:
        train_images = train_images / 255.0
        test_images = test_images / 255.0
    
    # Reshape to 7Ã—112 format
    print("\nReshaping images from (28, 28) to (7, 112)...")
    print("  - 7 timesteps (groups of 4 rows)")
    print("  - 112 features per timestep (4 rows Ã— 28 pixels)")
    print("  - 8 dendrite compatible: 112 / 8 = 14 inputs per dendrite")
    
    train_images = reshape_to_7x112(train_images)
    test_images = reshape_to_7x112(test_images)
    
    # Split training into train/val
    n_train = len(train_images)
    n_val = int(n_train * val_split)
    
    # Shuffle before splitting
    np.random.seed(42)
    indices = np.random.permutation(n_train)
    val_indices = indices[:n_val]
    train_indices = indices[n_val:]
    
    val_images = train_images[val_indices]
    val_labels = train_labels[val_indices]
    train_images = train_images[train_indices]
    train_labels = train_labels[train_indices]
    
    print(f"\nFinal shapes:")
    print(f"  Train: {train_images.shape} (N, T=7 timesteps, D=112 features)")
    print(f"  Val: {val_images.shape}")
    print(f"  Test: {test_images.shape}")
    
    # Save to HDF5 (no compression for speed - file is only ~170MB)
    print(f"\nSaving to {output_path}...")
    with h5py.File(output_path, 'w') as f:
        # Training set
        train_grp = f.create_group('train')
        train_grp.create_dataset('data', data=train_images)
        train_grp.create_dataset('labels', data=train_labels)
        
        # Validation set
        val_grp = f.create_group('val')
        val_grp.create_dataset('data', data=val_images)
        val_grp.create_dataset('labels', data=val_labels)
        
        # Test set
        test_grp = f.create_group('test')
        test_grp.create_dataset('data', data=test_images)
        test_grp.create_dataset('labels', data=test_labels)
        
        # Metadata
        f.attrs['description'] = 'MNIST as sequences (7 timesteps x 112 features) for 8 dendrite neurons'
        f.attrs['num_classes'] = 10
        f.attrs['seq_len'] = 7
        f.attrs['feature_dim'] = 112
        f.attrs['rows_per_timestep'] = 4
        f.attrs['dendrite_inputs'] = 14  # 112 / 8 dendrites
    
    print("Done!")
    print(f"  Train: {len(train_labels)} samples")
    print(f"  Val: {len(val_labels)} samples")
    print(f"  Test: {len(test_labels)} samples")
    
    return output_path

# Prepare the dataset
data_path = prepare_mnist_hdf5_7x112()

## 2. Visualize the 7Ã—112 Dataset

Let's see how MNIST looks when reshaped to 7 timesteps Ã— 112 features.

In [None]:
def visualize_mnist_7x112_samples(data_path, n_samples=5):
    """
    Visualize MNIST samples in both original 28Ã—28 and reshaped 7Ã—112 formats.
    """
    with h5py.File(data_path, 'r') as f:
        images_7x112 = np.array(f['train']['data'][:n_samples])
        labels = np.array(f['train']['labels'][:n_samples])
    
    # Reconstruct 28Ã—28 from 7Ã—112 for comparison
    images_28x28 = images_7x112.reshape(n_samples, 7, 4, 28).reshape(n_samples, 28, 28)
    
    fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))
    fig.suptitle('MNIST: Original vs 7Ã—112 Reshaping', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        # Top row: Original 28Ã—28 image
        axes[0, i].imshow(images_28x28[i], cmap='gray')
        axes[0, i].set_title(f'Label: {labels[i]}', fontsize=10)
        axes[0, i].axis('off')
        
        # Middle row: 7Ã—112 sequential view (heatmap)
        im = axes[1, i].imshow(images_7x112[i], cmap='viridis', aspect='auto')
        axes[1, i].set_xlabel('Feature (0-111)', fontsize=8)
        if i == 0:
            axes[1, i].set_ylabel('Timestep (0-6)', fontsize=8)
        axes[1, i].set_yticks(range(7))
        
        # Bottom row: Show 7Ã—112 as 7 groups of 4 rows
        img_grouped = images_7x112[i].reshape(7, 4, 28)
        # Stack horizontally to show all 7 timesteps
        stacked = np.hstack([img_grouped[t] for t in range(7)])
        axes[2, i].imshow(stacked, cmap='gray', aspect='auto')
        axes[2, i].set_xlabel('Timestep blocks (0-6)', fontsize=8)
        if i == 0:
            axes[2, i].set_ylabel('4 rows per block', fontsize=8)
        # Add vertical lines to separate timesteps
        for t in range(1, 7):
            axes[2, i].axvline(x=t*28-0.5, color='red', linewidth=0.5, alpha=0.5)
    
    axes[0, 0].set_ylabel('Original\n28Ã—28', fontsize=10)
    axes[1, 0].set_ylabel('7Ã—112\nSequence', fontsize=10)
    axes[2, 0].set_ylabel('4 rows per\ntimestep', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\n" + "="*60)
    print("7Ã—112 Format Summary")
    print("="*60)
    print(f"â€¢ Original shape: 28Ã—28 = 784 pixels")
    print(f"â€¢ Reshaped: 7 timesteps Ã— 112 features = 784 pixels")
    print(f"â€¢ Each timestep: 4 rows Ã— 28 pixels = 112 features")
    print(f"â€¢ 8 dendrite compatibility: 112 / 8 = 14 inputs per dendrite")
    print("="*60)

visualize_mnist_7x112_samples(data_path)

In [None]:
def visualize_dendrite_mapping(n_dendrites=8, features_per_timestep=112):
    """
    Visualize how the 112 input features map to 8 dendrites.
    """
    inputs_per_dendrite = features_per_timestep // n_dendrites
    
    fig, ax = plt.subplots(figsize=(14, 4))
    
    # Create a color map for dendrites
    colors = plt.cm.tab10(np.linspace(0, 1, n_dendrites))
    
    # Draw the feature mapping
    for d in range(n_dendrites):
        start = d * inputs_per_dendrite
        end = start + inputs_per_dendrite
        ax.barh(0, inputs_per_dendrite, left=start, height=0.8, 
                color=colors[d], edgecolor='black', linewidth=1)
        ax.text(start + inputs_per_dendrite/2, 0, f'D{d}\n({start}-{end-1})',
                ha='center', va='center', fontsize=9, fontweight='bold')
    
    ax.set_xlim(0, 112)
    ax.set_ylim(-0.5, 0.5)
    ax.set_xlabel('Feature Index (0-111)', fontsize=12)
    ax.set_title(f'8 Dendrite Mapping: 112 features â†’ 14 inputs per dendrite', 
                 fontsize=14, fontweight='bold')
    ax.set_yticks([])
    
    # Add pixel mapping annotation
    ax.text(56, -0.35, 'Each timestep: 4 rows Ã— 28 pixels = 112 features', 
            ha='center', fontsize=10, style='italic')
    
    plt.tight_layout()
    plt.show()
    
    # Show detailed mapping
    print("\nDendrite â†’ Pixel Mapping:")
    print("-" * 50)
    for d in range(n_dendrites):
        start = d * inputs_per_dendrite
        end = start + inputs_per_dendrite
        row_start = start // 28
        col_start = start % 28
        row_end = (end - 1) // 28
        col_end = (end - 1) % 28
        print(f"  Dendrite {d}: features [{start:3d}-{end-1:3d}] â†’ "
              f"row {row_start}, col {col_start:2d} to row {row_end}, col {col_end:2d}")

visualize_dendrite_mapping()

## 3. Examine the Model Architecture

Let's look at the SOEN model configured for 7Ã—112 input.

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

# Load and inspect the model
model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_7x112.yaml")
model = build_model_from_yaml(model_path)

print("=" * 60)
print("MNIST SOEN MODEL ARCHITECTURE (7Ã—112 Format)")
print("=" * 60)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

print("\nLayer dimensions:")
for layer_id, dim in model.layer_nodes.items():
    if layer_id == 0:
        print(f"  Layer {layer_id}: {dim} neurons (112 = 4 rows Ã— 28 pixels)")
    else:
        print(f"  Layer {layer_id}: {dim} neurons")

print("\nConnections:")
for name, param in model.connections.items():
    print(f"  {name}: {param.shape} (learnable: {param.requires_grad})")

# Test forward pass
print("\nTesting forward pass...")
x_test = torch.randn(2, 7, 112)  # [batch=2, seq_len=7, features=112]
with torch.no_grad():
    output, states = model(x_test)
print(f"  Input shape: {x_test.shape} (batch, 7 timesteps, 112 features)")
print(f"  Output shape: {output.shape}")
print("  Forward pass successful!")

print("\n" + "=" * 60)
print("8 DENDRITE COMPATIBILITY")
print("=" * 60)
print(f"  Input features: 112")
print(f"  Number of dendrites: 8")
print(f"  Inputs per dendrite: 112 / 8 = 14")
print("=" * 60)

## 4. Train the Model

Now let's train the SOEN model on MNIST using the 7Ã—112 training configuration.

In [None]:
# Disable progress bar to avoid widget errors in notebooks
import os
os.environ["SOEN_NO_PROGRESS_BAR"] = "1"

from soen_toolkit.training.trainers.experiment import run_from_config

# Run training
print("Starting MNIST SOEN training (7Ã—112 format)...")
print("This may take a while depending on your hardware.")
print("="*60)
print("Input format: 7 timesteps Ã— 112 features")
print("8 dendrite compatible: 112 / 8 = 14 inputs per dendrite")
print("="*60)

run_from_config("training/training_configs/mnist_soen_7x112.yaml", script_dir=Path.cwd())

## 5. Evaluate the Trained Model

Let's load the best checkpoint and evaluate on the test set.

In [None]:
def load_best_checkpoint_7x112():
    """Find and load the best checkpoint from training."""
    
    # Find checkpoints
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    # Filter for 7x112 checkpoints
    ckpts_7x112 = [c for c in all_ckpts if '7x112' in c or 'MNIST_SOEN_7x112' in c]
    
    if not ckpts_7x112:
        # Fall back to all checkpoints if no specific 7x112 found
        ckpts_7x112 = all_ckpts
    
    if not ckpts_7x112:
        print("No checkpoint found. Run training first.")
        return None, None
    
    # Get the most recent checkpoint
    latest_ckpt = max(ckpts_7x112, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}")
    
    # Load model
    model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_7x112.yaml")
    model = build_model_from_yaml(model_path)
    
    # Load weights
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    # Remove 'model.' prefix if present
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            clean_state_dict[k[6:]] = v
        else:
            clean_state_dict[k] = v
    
    model.load_state_dict(clean_state_dict, strict=False)
    model.eval()
    
    return model, latest_ckpt

trained_model, ckpt_path = load_best_checkpoint_7x112()

In [None]:
def evaluate_on_test_set(model, data_path, batch_size=128):
    """Evaluate model on the test set."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load test data
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    print(f"Evaluating on {len(test_labels)} test samples...")
    print(f"Input shape: {test_data.shape} (N, 7 timesteps, 112 features)")
    
    model.eval()
    device = next(model.parameters()).device
    
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), batch_size)):
            batch_data = test_data[i:i+batch_size]
            x = torch.tensor(batch_data, dtype=torch.float32).to(device)
            
            output, _ = model(x)
            
            # Apply max pooling over time
            if output.dim() == 3:
                pooled = output.max(dim=1)[0]
            else:
                pooled = output
            
            probs = torch.softmax(pooled, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    
    # Calculate accuracy
    accuracy = (all_preds == test_labels).mean()
    
    print(f"\n{'='*50}")
    print(f"TEST SET RESULTS (7Ã—112 Format)")
    print(f"{'='*50}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Correct: {(all_preds == test_labels).sum()}/{len(test_labels)}")
    
    return all_preds, all_probs, accuracy

if trained_model is not None:
    predictions, probabilities, test_accuracy = evaluate_on_test_set(trained_model, data_path)

## 6. Visualize Predictions

In [None]:
def visualize_predictions_7x112(model, data_path, n_samples=20):
    """Visualize model predictions on random test samples."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load test data
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    # Random sample
    np.random.seed(42)
    indices = np.random.choice(len(test_data), n_samples, replace=False)
    
    samples = test_data[indices]  # Shape: (n_samples, 7, 112)
    labels = test_labels[indices]
    
    # Reconstruct 28Ã—28 for visualization
    samples_28x28 = samples.reshape(n_samples, 7, 4, 28).reshape(n_samples, 28, 28)
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        output, _ = model(x)
        
        if output.dim() == 3:
            pooled = output.max(dim=1)[0]
        else:
            pooled = output
        
        probs = torch.softmax(pooled, dim=1)
        preds = torch.argmax(probs, dim=1).numpy()
        confidence = probs.max(dim=1)[0].numpy()
    
    # Plot
    n_cols = 5
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 3*n_rows))
    axes = axes.flatten()
    
    fig.suptitle('MNIST Predictions (SOEN 7Ã—112 Model)', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        ax = axes[i]
        ax.imshow(samples_28x28[i], cmap='gray')
        
        is_correct = preds[i] == labels[i]
        color = 'green' if is_correct else 'red'
        symbol = 'âœ“' if is_correct else 'âœ—'
        
        ax.set_title(
            f"{symbol} Pred: {preds[i]} ({confidence[i]:.0%})\nTrue: {labels[i]}",
            fontsize=9,
            color=color,
            fontweight='bold' if not is_correct else 'normal'
        )
        ax.axis('off')
    
    # Hide empty subplots
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Summary
    accuracy = (preds == labels).mean()
    print(f"\nSample accuracy: {accuracy:.1%} ({(preds == labels).sum()}/{n_samples})")

if trained_model is not None:
    visualize_predictions_7x112(trained_model, data_path)

## 7. Confusion Matrix

In [None]:
def plot_confusion_matrix(predictions, true_labels):
    """Plot confusion matrix for predictions."""
    
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    cm = confusion_matrix(true_labels, predictions)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=range(10), yticklabels=range(10))
    ax.set_xlabel('Predicted', fontsize=12)
    ax.set_ylabel('True', fontsize=12)
    ax.set_title('Confusion Matrix (7Ã—112 Format)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Per-class accuracy
    print("\nPer-class accuracy:")
    for digit in range(10):
        mask = true_labels == digit
        class_acc = (predictions[mask] == digit).mean()
        print(f"  Digit {digit}: {class_acc:.1%}")

if trained_model is not None:
    with h5py.File(data_path, 'r') as f:
        test_labels = np.array(f['test']['labels'])
    plot_confusion_matrix(predictions, test_labels)

## 8. Analyze SOEN Dynamics (7 Timesteps)

Let's visualize how the SingleDendrite layer processes the input over the 7 timesteps.

In [None]:
def visualize_soen_dynamics_7x112(model, data_path, sample_idx=0):
    """Visualize the temporal dynamics of the SOEN hidden layer with 7Ã—112 input."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load a sample
    with h5py.File(data_path, 'r') as f:
        sample = np.array(f['test']['data'][sample_idx:sample_idx+1])  # (1, 7, 112)
        label = np.array(f['test']['labels'][sample_idx])
    
    # Reconstruct 28Ã—28 for visualization
    sample_28x28 = sample.reshape(1, 7, 4, 28).reshape(28, 28)
    
    model.eval()
    
    with torch.no_grad():
        x = torch.tensor(sample, dtype=torch.float32)
        output, all_states = model(x)
        
        # Get hidden layer states (Layer 1)
        hidden_states = all_states[1]  # [batch, seq_len+1, hidden_dim]
        hidden_states = hidden_states[0, 1:, :].numpy()  # Remove batch and initial state
        
        # Get output
        if output.dim() == 3:
            output_over_time = output[0].numpy()
        else:
            output_over_time = output[0].numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'SOEN Dynamics for Digit {label} (7Ã—112 Input)', fontsize=14, fontweight='bold')
    
    # 1. Input image with timestep boundaries
    ax = axes[0, 0]
    ax.imshow(sample_28x28, cmap='gray')
    # Draw horizontal lines to show 7 timestep boundaries
    for t in range(1, 7):
        ax.axhline(y=t*4-0.5, color='red', linewidth=1, alpha=0.7)
    ax.set_title('Input Image (red lines = timestep boundaries)', fontsize=12)
    ax.set_ylabel('Rows (4 per timestep)')
    
    # 2. Hidden layer activity over time (only 7 timesteps)
    im = axes[0, 1].imshow(hidden_states.T, aspect='auto', cmap='viridis')
    axes[0, 1].set_xlabel('Timestep (0-6)', fontsize=10)
    axes[0, 1].set_ylabel('Hidden Neuron', fontsize=10)
    axes[0, 1].set_title('Hidden Layer Activity (128 neurons Ã— 7 timesteps)', fontsize=12)
    axes[0, 1].set_xticks(range(7))
    plt.colorbar(im, ax=axes[0, 1], label='Activation')
    
    # 3. Mean hidden activity over time
    mean_activity = hidden_states.mean(axis=1)
    std_activity = hidden_states.std(axis=1)
    time_steps = np.arange(len(mean_activity))
    axes[1, 0].plot(time_steps, mean_activity, 'b-', linewidth=2, marker='o', label='Mean')
    axes[1, 0].fill_between(time_steps, 
                            mean_activity - std_activity,
                            mean_activity + std_activity,
                            alpha=0.3, label='Â±1 std')
    axes[1, 0].set_xlabel('Timestep', fontsize=10)
    axes[1, 0].set_ylabel('Activation', fontsize=10)
    axes[1, 0].set_title('Mean Hidden Layer Activity (7 timesteps)', fontsize=12)
    axes[1, 0].set_xticks(range(7))
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Output class probabilities at final timestep
    if output_over_time.ndim == 2:
        final_output = output_over_time[-1]
    else:
        final_output = output_over_time
    probs = np.exp(final_output) / np.exp(final_output).sum()  # Softmax
    
    bars = axes[1, 1].bar(range(10), probs, color='steelblue', edgecolor='black')
    bars[label].set_color('green')  # Highlight true class
    axes[1, 1].set_xlabel('Digit Class', fontsize=10)
    axes[1, 1].set_ylabel('Probability', fontsize=10)
    axes[1, 1].set_title(f'Output Probabilities (True: {label})', fontsize=12)
    axes[1, 1].set_xticks(range(10))
    
    plt.tight_layout()
    plt.show()

if trained_model is not None:
    # Visualize dynamics for a few samples
    for idx in [0, 5, 10]:
        visualize_soen_dynamics_7x112(trained_model, data_path, sample_idx=idx)

## Summary

In this tutorial, we:

1. **Prepared MNIST** with 7Ã—112 reshaping:
   - 7 timesteps (groups of 4 rows)
   - 112 features per timestep (4 rows Ã— 28 pixels)
   - Compatible with 8 dendrite neurons (112 / 8 = 14 inputs per dendrite)

2. **Built a SOEN model** with:
   - Input layer (112D)
   - SingleDendrite hidden layer (128D) with recurrent connections
   - Linear output layer (10D) with **learnable** connections

3. **Trained and evaluated** the model on MNIST

### Key Differences from Standard 28Ã—28 Format

| Aspect | 28Ã—28 | 7Ã—112 |
|--------|-------|-------|
| Timesteps | 28 | 7 |
| Features/timestep | 28 | 112 |
| Rows per timestep | 1 | 4 |
| Dendrite compatible | N/A | 8 dendrites (14 inputs each) |
| Processing speed | Slower | 4Ã— faster |
| Spatial context | Single row | 4 rows at once |

### Benefits of 7Ã—112 Format

- **8 Dendrite Compatibility**: Perfect fit for 8-dendrite neuron models
- **Shorter Sequences**: 7 vs 28 timesteps means faster processing
- **Richer Spatial Context**: Each timestep sees 4 rows at once

### Next Steps

- Compare accuracy between 28Ã—28 and 7Ã—112 formats
- Experiment with different dendrite configurations
- Try other reshaping strategies (e.g., 4Ã—196 for 7 dendrites)