# Tutorial 03 — MNIST Classification (112 timesteps × 7 features)

**Hardware-compatible SOEN model** for MNIST digit classification.

## Architecture

| Layer | Dimension | Inputs per Neuron |
|-------|-----------|-------------------|
| Input | 7 | - |
| Hidden | 28 | **7** (< 8 max) |
| Output | 10 | 28 |

## Data Format

```
Original: 28 × 28 = 784 pixels
Reshaped: 112 timesteps × 7 features = 784 pixels

Each timestep: 7 pixels scanned sequentially through the image
```

In [None]:
import os
os.environ["TQDM_DISABLE"] = "0"
os.environ["TQDM_MININTERVAL"] = "1"

import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import glob
import gzip
import urllib.request
import struct

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Prepare MNIST Dataset (112 × 7 Format)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        url = base_url + filename
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
    return images

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

def reshape_to_112x7(images):
    """
    Reshape MNIST from (N, 28, 28) to (N, 112, 7).
    
    - 112 timesteps
    - 7 features per timestep
    - Scans image row by row, 7 pixels at a time
    
    28 × 28 = 784 pixels
    112 × 7 = 784 pixels
    """
    n_samples = images.shape[0]
    # Flatten to 784, then reshape to (112, 7)
    flat = images.reshape(n_samples, -1)  # (N, 784)
    reshaped = flat.reshape(n_samples, 112, 7)  # (N, 112, 7)
    return reshaped

def prepare_mnist_hdf5_112x7(output_path="training/datasets/mnist_seq112x7.hdf5", 
                              normalize=True, val_split=0.1):
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if output_path.exists():
        print(f"Dataset already exists at {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Train: {f['train']['data'].shape}")
            print(f"  Val: {f['val']['data'].shape}")
            print(f"  Test: {f['test']['data'].shape}")
        return output_path
    
    print("Downloading MNIST...")
    train_images = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32)
    train_labels = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_images = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32)
    test_labels = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    if normalize:
        train_images = train_images / 255.0
        test_images = test_images / 255.0
    
    print("\nReshaping to 112 timesteps × 7 features...")
    print("  Hardware compatible: each neuron receives 7 inputs (< 8 max)")
    train_images = reshape_to_112x7(train_images)
    test_images = reshape_to_112x7(test_images)
    
    # Split train/val
    n_train = len(train_images)
    n_val = int(n_train * val_split)
    np.random.seed(42)
    indices = np.random.permutation(n_train)
    
    val_images = train_images[indices[:n_val]]
    val_labels = train_labels[indices[:n_val]]
    train_images = train_images[indices[n_val:]]
    train_labels = train_labels[indices[n_val:]]
    
    print(f"\nFinal shapes:")
    print(f"  Train: {train_images.shape} (N, T=112, D=7)")
    print(f"  Val: {val_images.shape}")
    print(f"  Test: {test_images.shape}")
    
    print(f"\nSaving to {output_path}...")
    with h5py.File(output_path, 'w') as f:
        f.create_group('train')
        f['train'].create_dataset('data', data=train_images)
        f['train'].create_dataset('labels', data=train_labels)
        
        f.create_group('val')
        f['val'].create_dataset('data', data=val_images)
        f['val'].create_dataset('labels', data=val_labels)
        
        f.create_group('test')
        f['test'].create_dataset('data', data=test_images)
        f['test'].create_dataset('labels', data=test_labels)
        
        f.attrs['seq_len'] = 112
        f.attrs['feature_dim'] = 7
        f.attrs['num_classes'] = 10
    
    print("Done!")
    return output_path

data_path = prepare_mnist_hdf5_112x7()

## 2. Visualize the 112×7 Format

In [None]:
def visualize_112x7_format(data_path, n_samples=5):
    with h5py.File(data_path, 'r') as f:
        images_112x7 = np.array(f['train']['data'][:n_samples])
        labels = np.array(f['train']['labels'][:n_samples])
    
    # Reconstruct 28×28
    images_28x28 = images_112x7.reshape(n_samples, 784).reshape(n_samples, 28, 28)
    
    fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 8))
    fig.suptitle('MNIST: 112 timesteps × 7 features', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        # Row 1: Original 28×28
        axes[0, i].imshow(images_28x28[i], cmap='gray')
        axes[0, i].set_title(f'Label: {labels[i]}')
        axes[0, i].axis('off')
        
        # Row 2: 112×7 as heatmap
        axes[1, i].imshow(images_112x7[i], cmap='viridis', aspect='auto')
        axes[1, i].set_xlabel('Feature (0-6)')
        if i == 0:
            axes[1, i].set_ylabel('Timestep (0-111)')
        
        # Row 3: Show scanning pattern
        scan_vis = np.zeros((28, 28))
        for t in range(0, 112, 16):  # Show every 16th timestep
            start_pixel = t * 7
            for p in range(7):
                pixel_idx = start_pixel + p
                if pixel_idx < 784:
                    row, col = pixel_idx // 28, pixel_idx % 28
                    scan_vis[row, col] = t / 112
        axes[2, i].imshow(scan_vis, cmap='plasma', aspect='equal')
        axes[2, i].set_title('Scan order')
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Original\n28×28', fontsize=10)
    axes[1, 0].set_ylabel('Sequence\n112×7', fontsize=10)
    axes[2, 0].set_ylabel('Scan\nPattern', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*60)
    print("112×7 Format Summary")
    print("="*60)
    print(f"• 112 timesteps, 7 features each")
    print(f"• Total: 112 × 7 = 784 pixels (matches 28×28)")
    print(f"• Each hidden neuron receives 7 inputs (< 8 max)")
    print("="*60)

visualize_112x7_format(data_path)

## 3. Load and Inspect Model

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_112x7.yaml")
model = build_model_from_yaml(model_path)

print("=" * 60)
print("MNIST SOEN MODEL (112×7 Format)")
print("=" * 60)

print("\nArchitecture:")
for layer_id, dim in model.layer_nodes.items():
    print(f"  Layer {layer_id}: {dim} neurons")

print("\nConnections:")
for name, param in model.connections.items():
    print(f"  {name}: {param.shape}")

print("\nHardware compatibility:")
print(f"  Input → Hidden: 7 inputs per neuron ✓ (< 8)")
print(f"  Hidden → Hidden: 28 inputs per neuron (recurrent)")
print(f"  Hidden → Output: 28 inputs per neuron")

# Test forward pass
print("\nTesting forward pass...")
x_test = torch.randn(2, 112, 7)  # batch=2, 112 timesteps, 7 features
with torch.no_grad():
    output, states = model(x_test)
print(f"  Input: {x_test.shape}")
print(f"  Output: {output.shape}")
print("  Success!")

## 4. Train the Model

In [None]:
os.environ["SOEN_NO_PROGRESS_BAR"] = "1"

from soen_toolkit.training.trainers.experiment import run_from_config

print("="*60)
print("TRAINING: 112 timesteps × 7 features")
print("="*60)
print("Each hidden neuron receives 7 inputs (hardware compatible)")
print("="*60)

run_from_config("training/training_configs/mnist_soen_112x7.yaml", script_dir=Path.cwd())

## 5. Evaluate

In [None]:
def load_best_checkpoint():
    patterns = ["training/temp/**/checkpoints/**/*.ckpt", "training/temp/**/*.ckpt"]
    all_ckpts = []
    for p in patterns:
        all_ckpts.extend(glob.glob(p, recursive=True))
    
    ckpts_112x7 = [c for c in all_ckpts if '112x7' in c]
    if not ckpts_112x7:
        ckpts_112x7 = all_ckpts
    
    if not ckpts_112x7:
        print("No checkpoint found.")
        return None
    
    latest = max(ckpts_112x7, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading: {latest}")
    
    model = build_model_from_yaml(model_path)
    ckpt = torch.load(latest, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    clean = {k[6:] if k.startswith('model.') else k: v for k, v in state_dict.items()}
    model.load_state_dict(clean, strict=False)
    model.eval()
    return model

trained_model = load_best_checkpoint()

In [None]:
def evaluate(model, data_path):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    model.eval()
    all_preds = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), 128)):
            x = torch.tensor(test_data[i:i+128], dtype=torch.float32)
            output, _ = model(x)
            if output.dim() == 3:
                output = output.mean(dim=1)  # Mean pooling
            preds = output.argmax(dim=1).numpy()
            all_preds.append(preds)
    
    all_preds = np.concatenate(all_preds)
    accuracy = (all_preds == test_labels).mean()
    
    print(f"\nTest Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    return accuracy

if trained_model:
    evaluate(trained_model, data_path)

## 6. Visualize Predictions

In [None]:
def visualize_predictions(model, data_path, n_samples=20):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    np.random.seed(42)
    idx = np.random.choice(len(test_data), n_samples, replace=False)
    samples = test_data[idx]
    labels = test_labels[idx]
    
    # Reconstruct 28×28
    images = samples.reshape(n_samples, 784).reshape(n_samples, 28, 28)
    
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        output, _ = model(x)
        if output.dim() == 3:
            output = output.mean(dim=1)
        probs = torch.softmax(output, dim=1)
        preds = probs.argmax(dim=1).numpy()
        conf = probs.max(dim=1)[0].numpy()
    
    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.flatten()
    fig.suptitle('Predictions (112×7 SOEN)', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        axes[i].imshow(images[i], cmap='gray')
        correct = preds[i] == labels[i]
        color = 'green' if correct else 'red'
        axes[i].set_title(f"{'✓' if correct else '✗'} {preds[i]} ({conf[i]:.0%})\nTrue: {labels[i]}",
                          color=color, fontsize=9)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

if trained_model:
    visualize_predictions(trained_model, data_path)

## Summary

| Aspect | Value |
|--------|-------|
| Input shape | (112, 7) |
| Timesteps | 112 |
| Features/timestep | 7 |
| Hidden neurons | 28 |
| Inputs per hidden neuron | **7** (< 8 max) |

This format is **hardware compatible** because each hidden neuron receives only 7 inputs from the input layer.