# Tutorial 03c — MNIST Classification with SOEN (8 Inputs Per Neuron)

This notebook demonstrates training a **hardware-compatible** SOEN model on MNIST where **each neuron receives at most 8 inputs**.

---

## Hardware Constraint

| Connection | From → To | Fan-in per Neuron |
|------------|-----------|-------------------|
| J_0_to_1 | Input (112) → Hidden (128) | **8 inputs** per hidden neuron |
| J_1_to_1 | Hidden (128) → Hidden (128) | **8 inputs** per hidden neuron |
| J_1_to_2 | Hidden (128) → Output (10) | **8 inputs** per output neuron |

## Architecture

```
Input (112D)  →  Hidden (128D)  →  Output (10D)
              \      ↺          /
            8 inputs  8 recurrent  8 inputs
            per neuron  per neuron  per neuron
```

## Data Format

- **7 timesteps** × **112 features** per timestep
- Each timestep = 4 rows × 28 pixels = 112 features

## Setup

In [None]:
import os
os.environ["TQDM_DISABLE"] = "0"
os.environ["TQDM_MININTERVAL"] = "1"

import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import glob
import gzip
import urllib.request
import struct

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Create Fixed Fan-In Masks (8 Inputs Per Neuron)

Each neuron receives **exactly 8 inputs**. We create sparse connectivity masks that enforce this constraint.

In [None]:
def create_fixed_fanin_mask(from_nodes: int, to_nodes: int, fan_in: int = 8, 
                            allow_self_connections: bool = True, seed: int = 42) -> np.ndarray:
    """
    Create a connectivity mask where each destination neuron receives exactly `fan_in` inputs.
    
    Args:
        from_nodes: Number of source neurons
        to_nodes: Number of destination neurons  
        fan_in: Number of inputs per destination neuron (default: 8)
        allow_self_connections: Whether to allow neuron i to connect to itself
        seed: Random seed for reproducibility
        
    Returns:
        mask: Binary mask [to_nodes, from_nodes] where 1 indicates a connection
    """
    np.random.seed(seed)
    mask = np.zeros((to_nodes, from_nodes), dtype=np.float32)
    
    for i in range(to_nodes):
        # Get available source neurons
        if allow_self_connections or from_nodes != to_nodes:
            available = list(range(from_nodes))
        else:
            available = [j for j in range(from_nodes) if j != i]
        
        # Randomly select fan_in sources
        k = min(fan_in, len(available))
        selected = np.random.choice(available, size=k, replace=False)
        mask[i, selected] = 1.0
    
    return mask


def save_mask(mask: np.ndarray, filepath: str):
    """Save mask to .npz file."""
    path = Path(filepath)
    path.parent.mkdir(parents=True, exist_ok=True)
    np.savez(filepath, mask=mask)
    print(f"Saved mask to {filepath}")
    print(f"  Shape: {mask.shape}")
    print(f"  Fan-in per neuron: {mask.sum(axis=1).astype(int)}")
    print(f"  Total connections: {int(mask.sum())}")


# Create the masks directory
masks_dir = Path("training/test_models/masks")
masks_dir.mkdir(parents=True, exist_ok=True)

# Define network dimensions
INPUT_DIM = 112   # 4 rows × 28 pixels
HIDDEN_DIM = 128  # Hidden layer size
OUTPUT_DIM = 10   # 10 digit classes
FAN_IN = 8        # Maximum inputs per neuron

print("="*60)
print("CREATING FIXED FAN-IN MASKS (8 inputs per neuron)")
print("="*60)

# J_0_to_1: Input → Hidden (112 → 128)
print("\n[J_0_to_1] Input → Hidden")
mask_0_to_1 = create_fixed_fanin_mask(INPUT_DIM, HIDDEN_DIM, fan_in=FAN_IN, seed=42)
save_mask(mask_0_to_1, str(masks_dir / "J_0_to_1_fanin8.npz"))

# J_1_to_1: Hidden → Hidden (128 → 128, no self-connections)
print("\n[J_1_to_1] Hidden → Hidden (recurrent)")
mask_1_to_1 = create_fixed_fanin_mask(HIDDEN_DIM, HIDDEN_DIM, fan_in=FAN_IN, 
                                       allow_self_connections=False, seed=43)
save_mask(mask_1_to_1, str(masks_dir / "J_1_to_1_fanin8.npz"))

# J_1_to_2: Hidden → Output (128 → 10)
print("\n[J_1_to_2] Hidden → Output")
mask_1_to_2 = create_fixed_fanin_mask(HIDDEN_DIM, OUTPUT_DIM, fan_in=FAN_IN, seed=44)
save_mask(mask_1_to_2, str(masks_dir / "J_1_to_2_fanin8.npz"))

print("\n" + "="*60)
print("MASK SUMMARY")
print("="*60)
print(f"J_0_to_1: {INPUT_DIM} → {HIDDEN_DIM}, {FAN_IN} inputs/neuron, {int(mask_0_to_1.sum())} connections")
print(f"J_1_to_1: {HIDDEN_DIM} → {HIDDEN_DIM}, {FAN_IN} inputs/neuron, {int(mask_1_to_1.sum())} connections")
print(f"J_1_to_2: {HIDDEN_DIM} → {OUTPUT_DIM}, {FAN_IN} inputs/neuron, {int(mask_1_to_2.sum())} connections")
print(f"\nTotal connections: {int(mask_0_to_1.sum() + mask_1_to_1.sum() + mask_1_to_2.sum())}")
print(f"Dense equivalent: {INPUT_DIM*HIDDEN_DIM + HIDDEN_DIM*HIDDEN_DIM + HIDDEN_DIM*OUTPUT_DIM}")
sparsity = 1 - (mask_0_to_1.sum() + mask_1_to_1.sum() + mask_1_to_2.sum()) / \
           (INPUT_DIM*HIDDEN_DIM + HIDDEN_DIM*HIDDEN_DIM + HIDDEN_DIM*OUTPUT_DIM)
print(f"Sparsity: {sparsity:.1%}")

## 2. Visualize Connectivity Masks

In [None]:
def visualize_masks():
    """Visualize the sparse connectivity masks."""
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    fig.suptitle('Sparse Connectivity Masks (8 inputs per neuron)', fontsize=14, fontweight='bold')
    
    masks = [
        (mask_0_to_1, "J_0_to_1: Input→Hidden\n(112→128)"),
        (mask_1_to_1, "J_1_to_1: Hidden→Hidden\n(128→128, recurrent)"),
        (mask_1_to_2, "J_1_to_2: Hidden→Output\n(128→10)"),
    ]
    
    for ax, (mask, title) in zip(axes, masks):
        im = ax.imshow(mask, aspect='auto', cmap='binary')
        ax.set_title(title, fontsize=11)
        ax.set_xlabel('Source Neurons')
        ax.set_ylabel('Destination Neurons')
        
        # Add statistics
        fan_ins = mask.sum(axis=1)
        ax.text(0.02, 0.98, f"Fan-in: {int(fan_ins.min())}-{int(fan_ins.max())}\nConns: {int(mask.sum())}",
                transform=ax.transAxes, va='top', fontsize=9,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    # Show fan-in distribution
    fig, axes = plt.subplots(1, 3, figsize=(14, 3))
    fig.suptitle('Fan-in Distribution (should be exactly 8 for each neuron)', fontsize=12)
    
    for ax, (mask, title) in zip(axes, masks):
        fan_ins = mask.sum(axis=1)
        ax.bar(range(len(fan_ins)), fan_ins, color='steelblue', edgecolor='none', alpha=0.7)
        ax.axhline(y=8, color='red', linestyle='--', linewidth=2, label='Target (8)')
        ax.set_xlabel('Neuron Index')
        ax.set_ylabel('Number of Inputs')
        ax.set_title(title.split('\n')[0])
        ax.set_ylim(0, 10)
        ax.legend(loc='upper right')
    
    plt.tight_layout()
    plt.show()

visualize_masks()

## 3. Prepare MNIST Dataset (7×112 Format)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        url = base_url + filename
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
    return images

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

def reshape_to_7x112(images):
    n_samples = images.shape[0]
    reshaped = images.reshape(n_samples, 7, 4, 28)
    reshaped = reshaped.reshape(n_samples, 7, 112)
    return reshaped

def prepare_mnist_hdf5_7x112(output_path="training/datasets/mnist_seq7x112.hdf5", 
                             normalize=True, val_split=0.1):
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if output_path.exists():
        print(f"Dataset already exists at {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Train samples: {len(f['train']['labels'])}")
            print(f"  Val samples: {len(f['val']['labels'])}")
            print(f"  Test samples: {len(f['test']['labels'])}")
            print(f"  Data shape: {f['train']['data'].shape}")
        return output_path
    
    print("Downloading MNIST...")
    train_images_file = download_mnist_file("train-images-idx3-ubyte.gz")
    train_labels_file = download_mnist_file("train-labels-idx1-ubyte.gz")
    test_images_file = download_mnist_file("t10k-images-idx3-ubyte.gz")
    test_labels_file = download_mnist_file("t10k-labels-idx1-ubyte.gz")
    
    train_images = read_mnist_images(train_images_file).astype(np.float32)
    train_labels = read_mnist_labels(train_labels_file).astype(np.int64)
    test_images = read_mnist_images(test_images_file).astype(np.float32)
    test_labels = read_mnist_labels(test_labels_file).astype(np.int64)
    
    if normalize:
        train_images = train_images / 255.0
        test_images = test_images / 255.0
    
    print("\nReshaping to 7×112 format...")
    train_images = reshape_to_7x112(train_images)
    test_images = reshape_to_7x112(test_images)
    
    n_train = len(train_images)
    n_val = int(n_train * val_split)
    np.random.seed(42)
    indices = np.random.permutation(n_train)
    val_indices = indices[:n_val]
    train_indices = indices[n_val:]
    
    val_images = train_images[val_indices]
    val_labels = train_labels[val_indices]
    train_images = train_images[train_indices]
    train_labels = train_labels[train_indices]
    
    print(f"\nSaving to {output_path}...")
    with h5py.File(output_path, 'w') as f:
        train_grp = f.create_group('train')
        train_grp.create_dataset('data', data=train_images)
        train_grp.create_dataset('labels', data=train_labels)
        
        val_grp = f.create_group('val')
        val_grp.create_dataset('data', data=val_images)
        val_grp.create_dataset('labels', data=val_labels)
        
        test_grp = f.create_group('test')
        test_grp.create_dataset('data', data=test_images)
        test_grp.create_dataset('labels', data=test_labels)
        
        f.attrs['description'] = 'MNIST 7x112 format for 8-input neurons'
        f.attrs['num_classes'] = 10
        f.attrs['seq_len'] = 7
        f.attrs['feature_dim'] = 112
    
    print("Done!")
    return output_path

data_path = prepare_mnist_hdf5_7x112()

## 4. Illustrate the 8-Input Constraint

In [None]:
def illustrate_8input_constraint():
    """Visualize how each neuron receives exactly 8 inputs."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: Show which input features connect to a few example hidden neurons
    ax = axes[0]
    example_neurons = [0, 10, 50, 100, 127]  # Sample hidden neurons
    
    for i, neuron_idx in enumerate(example_neurons):
        connected_inputs = np.where(mask_0_to_1[neuron_idx] == 1)[0]
        y_pos = len(example_neurons) - i - 1
        ax.scatter(connected_inputs, [y_pos] * len(connected_inputs), 
                   s=50, alpha=0.7, label=f'Hidden neuron {neuron_idx}')
    
    ax.set_xlabel('Input Feature Index (0-111)', fontsize=11)
    ax.set_ylabel('Hidden Neuron')
    ax.set_yticks(range(len(example_neurons)))
    ax.set_yticklabels([f'H{n}' for n in reversed(example_neurons)])
    ax.set_xlim(-2, 114)
    ax.set_title('Input Connections to Example Hidden Neurons\n(Each receives exactly 8 inputs)', fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Right: Show the sparse pattern as a circuit diagram
    ax = axes[1]
    
    # Draw simplified network diagram
    input_y = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]  # 8 sample inputs
    hidden_y = 0.5  # Single hidden neuron
    output_y = 0.5
    
    # Input neurons
    for y in input_y:
        ax.scatter([0.1], [y], s=200, c='lightblue', edgecolors='black', zorder=3)
    ax.text(0.1, 1.0, 'Input\n(8 of 112)', ha='center', fontsize=10)
    
    # Hidden neuron
    ax.scatter([0.5], [hidden_y], s=400, c='lightgreen', edgecolors='black', zorder=3)
    ax.text(0.5, 0.2, 'Hidden\nNeuron\n(1 of 128)', ha='center', fontsize=10)
    
    # Output neuron
    ax.scatter([0.9], [output_y], s=300, c='lightyellow', edgecolors='black', zorder=3)
    ax.text(0.9, 0.2, 'Output\n(1 of 10)', ha='center', fontsize=10)
    
    # Draw connections from 8 inputs to hidden
    for y in input_y:
        ax.plot([0.1, 0.5], [y, hidden_y], 'b-', alpha=0.5, linewidth=1)
    
    # Draw connection from hidden to output
    ax.plot([0.5, 0.9], [hidden_y, output_y], 'g-', alpha=0.5, linewidth=2)
    
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.1)
    ax.axis('off')
    ax.set_title('8-Input Constraint: Each Neuron Receives Only 8 Inputs', fontsize=11)
    
    # Add annotation
    ax.annotate('Max 8 inputs', xy=(0.3, 0.6), fontsize=10, color='blue',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

illustrate_8input_constraint()

## 5. Load and Inspect the Model

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_7x112_8input.yaml")
model = build_model_from_yaml(model_path)

print("=" * 60)
print("MNIST SOEN MODEL (8 Inputs Per Neuron)")
print("=" * 60)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

print("\nLayer dimensions:")
for layer_id, dim in model.layer_nodes.items():
    print(f"  Layer {layer_id}: {dim} neurons")

print("\nConnections (sparse with 8 fan-in):")
for name, param in model.connections.items():
    # Check actual sparsity
    mask = model.connection_masks.get(name)
    if mask is not None:
        active = (mask > 0).sum().item()
        total = mask.numel()
        sparsity = 1 - active / total
        fan_in = mask.sum(dim=1).mean().item()
        print(f"  {name}: {param.shape}, active: {int(active)}/{total} ({sparsity:.1%} sparse), fan-in: {fan_in:.1f}")
    else:
        print(f"  {name}: {param.shape}")

# Test forward pass
print("\nTesting forward pass...")
x_test = torch.randn(2, 7, 112)
with torch.no_grad():
    output, states = model(x_test)
print(f"  Input: {x_test.shape}")
print(f"  Output: {output.shape}")
print("  Forward pass successful!")

## 6. Train the Model

In [None]:
import os
os.environ["SOEN_NO_PROGRESS_BAR"] = "1"

from soen_toolkit.training.trainers.experiment import run_from_config

print("="*60)
print("TRAINING SOEN MODEL (8 Inputs Per Neuron)")
print("="*60)
print("\nHardware constraints enforced:")
print("  • Each hidden neuron: 8 feedforward + 8 recurrent inputs")
print("  • Each output neuron: 8 inputs")
print("="*60)

run_from_config("training/training_configs/mnist_soen_7x112_8input.yaml", script_dir=Path.cwd())

## 7. Evaluate the Trained Model

In [None]:
def load_best_checkpoint():
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    # Prefer 8input checkpoints
    ckpts_8input = [c for c in all_ckpts if '8input' in c]
    if not ckpts_8input:
        ckpts_8input = all_ckpts
    
    if not ckpts_8input:
        print("No checkpoint found. Run training first.")
        return None, None
    
    latest_ckpt = max(ckpts_8input, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}")
    
    model = build_model_from_yaml(model_path)
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            clean_state_dict[k[6:]] = v
        else:
            clean_state_dict[k] = v
    
    model.load_state_dict(clean_state_dict, strict=False)
    model.eval()
    return model, latest_ckpt

trained_model, ckpt_path = load_best_checkpoint()

In [None]:
def evaluate_on_test_set(model, data_path, batch_size=128):
    if model is None:
        print("No model loaded.")
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    print(f"Evaluating on {len(test_labels)} test samples...")
    
    model.eval()
    device = next(model.parameters()).device
    
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), batch_size)):
            batch_data = test_data[i:i+batch_size]
            x = torch.tensor(batch_data, dtype=torch.float32).to(device)
            
            output, _ = model(x)
            
            if output.dim() == 3:
                pooled = output.max(dim=1)[0]
            else:
                pooled = output
            
            probs = torch.softmax(pooled, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    accuracy = (all_preds == test_labels).mean()
    
    print(f"\n{'='*50}")
    print(f"TEST SET RESULTS (8 Inputs Per Neuron)")
    print(f"{'='*50}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Correct: {(all_preds == test_labels).sum()}/{len(test_labels)}")
    
    return all_preds, all_probs, accuracy

if trained_model is not None:
    predictions, probabilities, test_accuracy = evaluate_on_test_set(trained_model, data_path)

## 8. Visualize Predictions

In [None]:
def visualize_predictions(model, data_path, n_samples=20):
    if model is None:
        print("No model loaded.")
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    np.random.seed(42)
    indices = np.random.choice(len(test_data), n_samples, replace=False)
    samples = test_data[indices]
    labels = test_labels[indices]
    
    # Reconstruct 28×28 for visualization
    samples_28x28 = samples.reshape(n_samples, 7, 4, 28).reshape(n_samples, 28, 28)
    
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        output, _ = model(x)
        if output.dim() == 3:
            pooled = output.max(dim=1)[0]
        else:
            pooled = output
        probs = torch.softmax(pooled, dim=1)
        preds = torch.argmax(probs, dim=1).numpy()
        confidence = probs.max(dim=1)[0].numpy()
    
    n_cols = 5
    n_rows = (n_samples + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 3*n_rows))
    axes = axes.flatten()
    fig.suptitle('MNIST Predictions (8-Input SOEN Model)', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        ax = axes[i]
        ax.imshow(samples_28x28[i], cmap='gray')
        is_correct = preds[i] == labels[i]
        color = 'green' if is_correct else 'red'
        symbol = '✓' if is_correct else '✗'
        ax.set_title(f"{symbol} Pred: {preds[i]} ({confidence[i]:.0%})\nTrue: {labels[i]}",
                     fontsize=9, color=color, fontweight='bold' if not is_correct else 'normal')
        ax.axis('off')
    
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    accuracy = (preds == labels).mean()
    print(f"\nSample accuracy: {accuracy:.1%} ({(preds == labels).sum()}/{n_samples})")

if trained_model is not None:
    visualize_predictions(trained_model, data_path)

## Summary

This notebook demonstrated training a **hardware-compatible** SOEN model where:

| Constraint | Value | Notes |
|------------|-------|-------|
| Max inputs per neuron | **8** | Enforced via sparse masks |
| Input shape | 7 × 112 | 7 timesteps, 112 features/timestep |
| Hidden layer | 128 neurons | Each receives 8 feedforward + 8 recurrent |
| Output layer | 10 neurons | Each receives 8 hidden inputs |

### Connectivity Summary

| Connection | Dense Size | Sparse Size | Sparsity |
|------------|------------|-------------|----------|
| J_0_to_1 (Input→Hidden) | 112×128 = 14,336 | 128×8 = 1,024 | 92.9% |
| J_1_to_1 (Hidden→Hidden) | 128×128 = 16,384 | 128×8 = 1,024 | 93.8% |
| J_1_to_2 (Hidden→Output) | 128×10 = 1,280 | 10×8 = 80 | 93.8% |
| **Total** | 32,000 | 2,128 | **93.4%** |

### Hardware Mapping

Each neuron in this model can be directly mapped to hardware neurons with 8 input ports.