# Tutorial 03 — MNIST with Hierarchical SOEN (7 → 8 → 8 → 10)

**Hierarchical architecture** to build up pattern representations over 112 timesteps.

## Architecture

```
Input (7) → Hidden1 (8) → Hidden2 (8) → Output (10)
            7 inputs      8 inputs      8 inputs
            per neuron    per neuron    per neuron
```

**All neurons receive ≤ 8 inputs** (hardware compatible)

## Why Hierarchical?

- **Layer 1**: Extracts local features from 7 input pixels
- **Layer 2**: Combines features from Layer 1
- **Output**: Classifies based on combined features

This creates a feature hierarchy that can better integrate patterns across 112 timesteps.

In [None]:
import os
os.environ["TQDM_DISABLE"] = "0"
os.environ["TQDM_MININTERVAL"] = "1"

import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import glob
import gzip
import urllib.request
import struct

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Prepare Dataset (112×7)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def prepare_mnist_112x7(output_path="training/datasets/mnist_seq112x7.hdf5"):
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if output_path.exists():
        print(f"Dataset exists: {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Shape: {f['train']['data'].shape}")
        return output_path
    
    print("Preparing MNIST (112×7 format)...")
    
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    # Reshape to 112×7
    train_img = train_img.reshape(-1, 784).reshape(-1, 112, 7)
    test_img = test_img.reshape(-1, 784).reshape(-1, 112, 7)
    
    # Train/val split
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = int(len(train_img) * 0.1)
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    print(f"  Train: {train_img.shape}")
    print(f"  Val: {val_img.shape}")
    print(f"  Test: {test_img.shape}")
    
    with h5py.File(output_path, 'w') as f:
        for name, data, labels in [('train', train_img, train_lbl), 
                                    ('val', val_img, val_lbl), 
                                    ('test', test_img, test_lbl)]:
            g = f.create_group(name)
            g.create_dataset('data', data=data)
            g.create_dataset('labels', data=labels)
    
    print("Done!")
    return output_path

data_path = prepare_mnist_112x7()

## 2. Visualize Hierarchical Architecture

In [None]:
def visualize_architecture():
    fig, ax = plt.subplots(figsize=(14, 6))
    
    # Layer positions
    layers = [
        (0.1, 7, 'Input\n(7)', 'lightblue'),
        (0.35, 8, 'Hidden 1\n(8)', 'lightgreen'),
        (0.6, 8, 'Hidden 2\n(8)', 'lightgreen'),
        (0.85, 10, 'Output\n(10)', 'lightyellow'),
    ]
    
    for x, n, label, color in layers:
        y_positions = np.linspace(0.1, 0.9, n)
        for y in y_positions:
            circle = plt.Circle((x, y), 0.025, color=color, ec='black', linewidth=1)
            ax.add_patch(circle)
        ax.text(x, -0.05, label, ha='center', fontsize=11, fontweight='bold')
    
    # Draw connections
    connections = [
        (0.1, 7, 0.35, 8, '7 inputs'),
        (0.35, 8, 0.6, 8, '8 inputs'),
        (0.6, 8, 0.85, 10, '8 inputs'),
    ]
    
    for x1, n1, x2, n2, label in connections:
        y1 = np.linspace(0.1, 0.9, n1)
        y2 = np.linspace(0.1, 0.9, n2)
        # Draw a few representative connections
        for i in range(min(3, n2)):
            for j in range(min(3, n1)):
                ax.plot([x1+0.025, x2-0.025], [y1[j*n1//3], y2[i*n2//3]], 
                        'gray', alpha=0.2, linewidth=0.5)
        ax.text((x1+x2)/2, 0.95, label, ha='center', fontsize=9, color='blue')
    
    ax.set_xlim(0, 1)
    ax.set_ylim(-0.15, 1.05)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('Hierarchical Architecture: 7 → 8 → 8 → 10\n(All neurons receive ≤ 8 inputs)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("\nHardware Compatibility Check:")
    print("="*50)
    print(f"  Layer 1: 8 neurons × 7 inputs = {8*7} weights  ✓ (7 ≤ 8)")
    print(f"  Layer 2: 8 neurons × 8 inputs = {8*8} weights  ✓ (8 ≤ 8)")
    print(f"  Output: 10 neurons × 8 inputs = {10*8} weights  ✓ (8 ≤ 8)")
    print(f"  Total: {8*7 + 8*8 + 10*8} learnable weights")

visualize_architecture()

## 3. Load Model

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_112x7_hierarchical.yaml")
model = build_model_from_yaml(model_path)

print("="*60)
print("HIERARCHICAL SOEN MODEL")
print("="*60)

print("\nLayers:")
for lid, dim in model.layer_nodes.items():
    print(f"  Layer {lid}: {dim} neurons")

print("\nConnections:")
for name, param in model.connections.items():
    shape = param.shape
    fan_in = shape[1] if len(shape) > 1 else 1
    print(f"  {name}: {list(shape)} ({fan_in} inputs per neuron)")

print("\nForward pass test...")
x = torch.randn(2, 112, 7)
with torch.no_grad():
    y, states = model(x)
print(f"  Input: {x.shape}")
print(f"  Output: {y.shape}")
print("  Success!")

## 4. Train

In [None]:
os.environ["SOEN_NO_PROGRESS_BAR"] = "1"

from soen_toolkit.training.trainers.experiment import run_from_config

print("="*60)
print("TRAINING HIERARCHICAL MODEL (7 → 8 → 8 → 10)")
print("="*60)

run_from_config("training/training_configs/mnist_soen_112x7_hierarchical.yaml", script_dir=Path.cwd())

## 5. Evaluate

In [None]:
def load_checkpoint():
    patterns = ["training/temp/**/checkpoints/**/*.ckpt", "training/temp/**/*.ckpt"]
    ckpts = []
    for p in patterns:
        ckpts.extend(glob.glob(p, recursive=True))
    
    hier_ckpts = [c for c in ckpts if 'hierarchical' in c.lower() or '7_8_8_10' in c]
    if not hier_ckpts:
        hier_ckpts = ckpts
    
    if not hier_ckpts:
        print("No checkpoint found.")
        return None
    
    latest = max(hier_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading: {latest}")
    
    model = build_model_from_yaml(model_path)
    ckpt = torch.load(latest, map_location='cpu')
    state = ckpt.get('state_dict', ckpt)
    clean = {k[6:] if k.startswith('model.') else k: v for k, v in state.items()}
    model.load_state_dict(clean, strict=False)
    model.eval()
    return model

trained_model = load_checkpoint()

In [None]:
def evaluate(model, data_path):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    model.eval()
    all_preds = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), 128)):
            x = torch.tensor(test_data[i:i+128], dtype=torch.float32)
            y, _ = model(x)
            if y.dim() == 3:
                y = y.mean(dim=1)
            all_preds.append(y.argmax(dim=1).numpy())
    
    preds = np.concatenate(all_preds)
    acc = (preds == test_labels).mean()
    
    print(f"\nTest Accuracy: {acc:.4f} ({acc*100:.2f}%)")
    return acc

if trained_model:
    evaluate(trained_model, data_path)

## 6. Visualize Predictions

In [None]:
def visualize_predictions(model, data_path, n=20):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        data = np.array(f['test']['data'])
        labels = np.array(f['test']['labels'])
    
    np.random.seed(42)
    idx = np.random.choice(len(data), n, replace=False)
    samples, true = data[idx], labels[idx]
    images = samples.reshape(n, 784).reshape(n, 28, 28)
    
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        y, _ = model(x)
        if y.dim() == 3:
            y = y.mean(dim=1)
        probs = torch.softmax(y, dim=1)
        preds = probs.argmax(dim=1).numpy()
        conf = probs.max(dim=1)[0].numpy()
    
    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.flatten()
    fig.suptitle('Predictions (Hierarchical 7→8→8→10)', fontsize=14, fontweight='bold')
    
    for i in range(n):
        axes[i].imshow(images[i], cmap='gray')
        ok = preds[i] == true[i]
        axes[i].set_title(f"{'✓' if ok else '✗'} {preds[i]} ({conf[i]:.0%})\nTrue: {true[i]}",
                          color='green' if ok else 'red', fontsize=9)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Sample accuracy: {(preds == true).mean():.1%}")

if trained_model:
    visualize_predictions(trained_model, data_path)

## Summary

| Layer | Neurons | Inputs/Neuron | Weights |
|-------|---------|---------------|--------|
| Input | 7 | - | - |
| Hidden 1 | 8 | 7 | 56 |
| Hidden 2 | 8 | 8 | 64 |
| Output | 10 | 8 | 80 |
| **Total** | - | - | **200** |

All connections satisfy the **≤ 8 inputs per neuron** hardware constraint.