# Tutorial 03 — MNIST Classification with SOEN

This notebook demonstrates training a SOEN (Superconducting Optoelectronic Network) model on the MNIST digit classification task.

## Key Concepts

1. **Sequential Processing**: MNIST images (28×28) are treated as 28-timestep sequences, where each timestep is a row of 28 pixels
2. **Temporal Dynamics**: The SingleDendrite layer integrates information over time, leveraging SOEN's natural temporal processing
3. **Learnable Readout**: The output connection is learnable (lesson from Tutorial 02)

## Architecture

```
Input (28D) → SingleDendrite (128D) → Output (10D)
     ↓              ↓ ↺                   ↓
  Row pixels    Recurrent SOEN      10 digit classes
```

## Setup

In [None]:
# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.nn as nn
import glob
import gzip
import urllib.request
import struct

# Use standard tqdm (not notebook version to avoid widget errors)
try:
    from tqdm import tqdm
except ImportError:
    # Simple fallback if tqdm not available
    def tqdm(iterable, **kwargs):
        return iterable

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Prepare MNIST Dataset

We'll download MNIST and convert it to HDF5 format compatible with the SOEN training pipeline.

**Data format**: Images are reshaped from (28, 28) to (28, 28) sequences where:
- Time dimension = 28 (rows)
- Feature dimension = 28 (columns/pixels per row)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    """Download a single MNIST file if not already present."""
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    
    filepath = data_dir / filename
    if not filepath.exists():
        url = base_url + filename
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
    return filepath

def read_mnist_images(filepath):
    """Read MNIST image file (idx3-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        # Read magic number and dimensions
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        # Read image data
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
    return images

def read_mnist_labels(filepath):
    """Read MNIST label file (idx1-ubyte format)."""
    with gzip.open(filepath, 'rb') as f:
        # Read magic number and count
        magic, num_labels = struct.unpack('>II', f.read(8))
        # Read labels
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

def prepare_mnist_hdf5(output_path="training/datasets/mnist_seq28.hdf5", 
                       normalize=True,
                       val_split=0.1):
    """Download MNIST and save as HDF5 for SOEN training.
    
    Args:
        output_path: Where to save the HDF5 file
        normalize: Whether to normalize pixel values to [0, 1]
        val_split: Fraction of training data to use for validation
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Check if already exists
    if output_path.exists():
        print(f"Dataset already exists at {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Train samples: {len(f['train']['labels'])}")
            print(f"  Val samples: {len(f['val']['labels'])}")
            print(f"  Test samples: {len(f['test']['labels'])}")
        return output_path
    
    print("Downloading MNIST (without torchvision)...")
    
    # Download MNIST files
    train_images_file = download_mnist_file("train-images-idx3-ubyte.gz")
    train_labels_file = download_mnist_file("train-labels-idx1-ubyte.gz")
    test_images_file = download_mnist_file("t10k-images-idx3-ubyte.gz")
    test_labels_file = download_mnist_file("t10k-labels-idx1-ubyte.gz")
    
    # Read the data
    train_images = read_mnist_images(train_images_file).astype(np.float32)
    train_labels = read_mnist_labels(train_labels_file).astype(np.int64)
    test_images = read_mnist_images(test_images_file).astype(np.float32)
    test_labels = read_mnist_labels(test_labels_file).astype(np.int64)
    
    # Normalize to [0, 1]
    if normalize:
        train_images = train_images / 255.0
        test_images = test_images / 255.0
    
    # Split training into train/val
    n_train = len(train_images)
    n_val = int(n_train * val_split)
    
    # Shuffle before splitting
    np.random.seed(42)
    indices = np.random.permutation(n_train)
    val_indices = indices[:n_val]
    train_indices = indices[n_val:]
    
    val_images = train_images[val_indices]
    val_labels = train_labels[val_indices]
    train_images = train_images[train_indices]
    train_labels = train_labels[train_indices]
    
    # Images are already (N, 28, 28) which is (N, T, D) for our sequence format
    print(f"Train shape: {train_images.shape} (N, T=28 timesteps, D=28 features)")
    print(f"Val shape: {val_images.shape}")
    print(f"Test shape: {test_images.shape}")
    
    # Save to HDF5
    print(f"\nSaving to {output_path}...")
    with h5py.File(output_path, 'w') as f:
        # Training set
        train_grp = f.create_group('train')
        train_grp.create_dataset('data', data=train_images, compression='gzip')
        train_grp.create_dataset('labels', data=train_labels)
        
        # Validation set
        val_grp = f.create_group('val')
        val_grp.create_dataset('data', data=val_images, compression='gzip')
        val_grp.create_dataset('labels', data=val_labels)
        
        # Test set
        test_grp = f.create_group('test')
        test_grp.create_dataset('data', data=test_images, compression='gzip')
        test_grp.create_dataset('labels', data=test_labels)
        
        # Metadata
        f.attrs['description'] = 'MNIST as sequences (28 timesteps x 28 features)'
        f.attrs['num_classes'] = 10
        f.attrs['seq_len'] = 28
        f.attrs['feature_dim'] = 28
    
    print("Done!")
    print(f"  Train: {len(train_labels)} samples")
    print(f"  Val: {len(val_labels)} samples")
    print(f"  Test: {len(test_labels)} samples")
    
    return output_path

# Prepare the dataset
data_path = prepare_mnist_hdf5()

## 2. Visualize the Dataset

Let's see how MNIST looks when treated as sequences.

In [None]:
def visualize_mnist_samples(data_path, n_samples=10):
    """Visualize MNIST samples and their sequential representation."""
    
    with h5py.File(data_path, 'r') as f:
        images = np.array(f['train']['data'][:n_samples])
        labels = np.array(f['train']['labels'][:n_samples])
    
    fig, axes = plt.subplots(2, n_samples, figsize=(2*n_samples, 5))
    fig.suptitle('MNIST Samples: Image View vs Sequential View', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        # Top row: Original image view
        axes[0, i].imshow(images[i], cmap='gray')
        axes[0, i].set_title(f'Label: {labels[i]}', fontsize=10)
        axes[0, i].axis('off')
        
        # Bottom row: Sequential view (heatmap of 28 timesteps x 28 features)
        im = axes[1, i].imshow(images[i], cmap='viridis', aspect='auto')
        axes[1, i].set_xlabel('Feature (pixel)', fontsize=8)
        if i == 0:
            axes[1, i].set_ylabel('Time (row)', fontsize=8)
    
    axes[0, 0].set_ylabel('Image View', fontsize=10)
    axes[1, 0].set_ylabel('Sequence View\n(T=28, D=28)', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Show class distribution
    with h5py.File(data_path, 'r') as f:
        all_labels = np.array(f['train']['labels'])
    
    fig, ax = plt.subplots(figsize=(10, 4))
    counts = np.bincount(all_labels)
    ax.bar(range(10), counts, color='steelblue', edgecolor='black')
    ax.set_xlabel('Digit Class', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title('Training Set Class Distribution', fontsize=14)
    ax.set_xticks(range(10))
    plt.tight_layout()
    plt.show()

visualize_mnist_samples(data_path)

## 3. Examine the Model Architecture

Let's look at the SOEN model we'll be training.

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

# Load and inspect the model
model_path = Path("training/test_models/model_specs/MNIST_SOENSpec.yaml")
model = build_model_from_yaml(model_path)

print("=" * 60)
print("MNIST SOEN MODEL ARCHITECTURE")
print("=" * 60)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

print("\nLayer dimensions:")
for layer_id, dim in model.layer_nodes.items():
    print(f"  Layer {layer_id}: {dim} neurons")

print("\nConnections:")
for name, param in model.connections.items():
    print(f"  {name}: {param.shape} (learnable: {param.requires_grad})")

# Test forward pass
print("\nTesting forward pass...")
x_test = torch.randn(2, 28, 28)  # [batch=2, seq_len=28, features=28]
with torch.no_grad():
    output, states = model(x_test)
print(f"  Input shape: {x_test.shape}")
print(f"  Output shape: {output.shape}")
print("  Forward pass successful!")

## 4. Train the Model

Now let's train the SOEN model on MNIST using the training configuration.

In [None]:
from soen_toolkit.training.trainers.experiment import run_from_config

# Run training
print("Starting MNIST SOEN training...")
print("This may take a while depending on your hardware.")
print("="*60)

run_from_config("training/training_configs/mnist_soen.yaml", script_dir=Path.cwd())

## 5. Evaluate the Trained Model

Let's load the best checkpoint and evaluate on the test set.

In [None]:
def load_best_checkpoint():
    """Find and load the best checkpoint from training."""
    
    # Find checkpoints
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    if not all_ckpts:
        print("No checkpoint found. Run training first.")
        return None, None
    
    # Get the most recent checkpoint
    latest_ckpt = max(all_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}")
    
    # Load model
    model_path = Path("training/test_models/model_specs/MNIST_SOENSpec.yaml")
    model = build_model_from_yaml(model_path)
    
    # Load weights
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    # Remove 'model.' prefix if present
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            clean_state_dict[k[6:]] = v
        else:
            clean_state_dict[k] = v
    
    model.load_state_dict(clean_state_dict, strict=False)
    model.eval()
    
    return model, latest_ckpt

trained_model, ckpt_path = load_best_checkpoint()

In [None]:
def evaluate_on_test_set(model, data_path, batch_size=128):
    """Evaluate model on the test set."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load test data
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    print(f"Evaluating on {len(test_labels)} test samples...")
    
    model.eval()
    device = next(model.parameters()).device
    
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), batch_size)):
            batch_data = test_data[i:i+batch_size]
            x = torch.tensor(batch_data, dtype=torch.float32).to(device)
            
            output, _ = model(x)
            
            # Apply max pooling over time
            if output.dim() == 3:
                pooled = output.max(dim=1)[0]
            else:
                pooled = output
            
            probs = torch.softmax(pooled, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            all_preds.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    
    # Calculate accuracy
    accuracy = (all_preds == test_labels).mean()
    
    print(f"\n{'='*50}")
    print(f"TEST SET RESULTS")
    print(f"{'='*50}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Correct: {(all_preds == test_labels).sum()}/{len(test_labels)}")
    
    return all_preds, all_probs, accuracy

if trained_model is not None:
    predictions, probabilities, test_accuracy = evaluate_on_test_set(trained_model, data_path)

## 6. Visualize Predictions

In [None]:
def visualize_predictions(model, data_path, n_samples=20):
    """Visualize model predictions on random test samples."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load test data
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    # Random sample
    np.random.seed(42)
    indices = np.random.choice(len(test_data), n_samples, replace=False)
    
    samples = test_data[indices]
    labels = test_labels[indices]
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        output, _ = model(x)
        
        if output.dim() == 3:
            pooled = output.max(dim=1)[0]
        else:
            pooled = output
        
        probs = torch.softmax(pooled, dim=1)
        preds = torch.argmax(probs, dim=1).numpy()
        confidence = probs.max(dim=1)[0].numpy()
    
    # Plot
    n_cols = 5
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 3*n_rows))
    axes = axes.flatten()
    
    fig.suptitle('MNIST Predictions (SOEN Model)', fontsize=14, fontweight='bold')
    
    for i in range(n_samples):
        ax = axes[i]
        ax.imshow(samples[i], cmap='gray')
        
        is_correct = preds[i] == labels[i]
        color = 'green' if is_correct else 'red'
        symbol = '✓' if is_correct else '✗'
        
        ax.set_title(
            f"{symbol} Pred: {preds[i]} ({confidence[i]:.0%})\nTrue: {labels[i]}",
            fontsize=9,
            color=color,
            fontweight='bold' if not is_correct else 'normal'
        )
        ax.axis('off')
    
    # Hide empty subplots
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Summary
    accuracy = (preds == labels).mean()
    print(f"\nSample accuracy: {accuracy:.1%} ({(preds == labels).sum()}/{n_samples})")

if trained_model is not None:
    visualize_predictions(trained_model, data_path)

## 7. Confusion Matrix

In [None]:
def plot_confusion_matrix(predictions, true_labels):
    """Plot confusion matrix for predictions."""
    
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    cm = confusion_matrix(true_labels, predictions)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=range(10), yticklabels=range(10))
    ax.set_xlabel('Predicted', fontsize=12)
    ax.set_ylabel('True', fontsize=12)
    ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Per-class accuracy
    print("\nPer-class accuracy:")
    for digit in range(10):
        mask = true_labels == digit
        class_acc = (predictions[mask] == digit).mean()
        print(f"  Digit {digit}: {class_acc:.1%}")

if trained_model is not None:
    with h5py.File(data_path, 'r') as f:
        test_labels = np.array(f['test']['labels'])
    plot_confusion_matrix(predictions, test_labels)

## 8. Analyze SOEN Dynamics

Let's visualize how the SingleDendrite layer processes the input over time.

In [None]:
def visualize_soen_dynamics(model, data_path, sample_idx=0):
    """Visualize the temporal dynamics of the SOEN hidden layer."""
    
    if model is None:
        print("No model loaded.")
        return
    
    # Load a sample
    with h5py.File(data_path, 'r') as f:
        sample = np.array(f['test']['data'][sample_idx:sample_idx+1])
        label = np.array(f['test']['labels'][sample_idx])
    
    model.eval()
    
    with torch.no_grad():
        x = torch.tensor(sample, dtype=torch.float32)
        output, all_states = model(x)
        
        # Get hidden layer states (Layer 1)
        hidden_states = all_states[1]  # [batch, seq_len+1, hidden_dim]
        hidden_states = hidden_states[0, 1:, :].numpy()  # Remove batch and initial state
        
        # Get output
        if output.dim() == 3:
            output_over_time = output[0].numpy()
        else:
            output_over_time = output[0].numpy()
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'SOEN Dynamics for Digit {label}', fontsize=14, fontweight='bold')
    
    # 1. Input image
    axes[0, 0].imshow(sample[0], cmap='gray')
    axes[0, 0].set_title('Input Image', fontsize=12)
    axes[0, 0].axis('off')
    
    # 2. Hidden layer activity over time
    im = axes[0, 1].imshow(hidden_states.T, aspect='auto', cmap='viridis')
    axes[0, 1].set_xlabel('Time (row)', fontsize=10)
    axes[0, 1].set_ylabel('Hidden Neuron', fontsize=10)
    axes[0, 1].set_title('Hidden Layer Activity (128 neurons)', fontsize=12)
    plt.colorbar(im, ax=axes[0, 1], label='Activation')
    
    # 3. Mean hidden activity over time
    mean_activity = hidden_states.mean(axis=1)
    std_activity = hidden_states.std(axis=1)
    time_steps = np.arange(len(mean_activity))
    axes[1, 0].plot(time_steps, mean_activity, 'b-', linewidth=2, label='Mean')
    axes[1, 0].fill_between(time_steps, 
                            mean_activity - std_activity,
                            mean_activity + std_activity,
                            alpha=0.3, label='±1 std')
    axes[1, 0].set_xlabel('Time (row)', fontsize=10)
    axes[1, 0].set_ylabel('Activation', fontsize=10)
    axes[1, 0].set_title('Mean Hidden Layer Activity', fontsize=12)
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Output class probabilities at final timestep
    if output_over_time.ndim == 2:
        final_output = output_over_time[-1]
    else:
        final_output = output_over_time
    probs = np.exp(final_output) / np.exp(final_output).sum()  # Softmax
    
    bars = axes[1, 1].bar(range(10), probs, color='steelblue', edgecolor='black')
    bars[label].set_color('green')  # Highlight true class
    axes[1, 1].set_xlabel('Digit Class', fontsize=10)
    axes[1, 1].set_ylabel('Probability', fontsize=10)
    axes[1, 1].set_title(f'Output Probabilities (True: {label})', fontsize=12)
    axes[1, 1].set_xticks(range(10))
    
    plt.tight_layout()
    plt.show()

if trained_model is not None:
    # Visualize dynamics for a few samples
    for idx in [0, 5, 10]:
        visualize_soen_dynamics(trained_model, data_path, sample_idx=idx)

## Summary

In this tutorial, we:

1. **Prepared MNIST** as a sequential dataset (28 timesteps × 28 features)
2. **Built a SOEN model** with:
   - Input layer (28D)
   - SingleDendrite hidden layer (128D) with recurrent connections
   - Linear output layer (10D) with **learnable** connections
3. **Trained the model** using the SOEN training pipeline
4. **Evaluated and visualized** the results

### Key Takeaways

- SOEN networks naturally process sequential data through temporal integration
- The SingleDendrite layer accumulates evidence over time (each row of the image)
- **Learnable output connections** are critical for effective training (from Tutorial 02)
- The recurrent connections allow the network to maintain context across timesteps

### Next Steps

- Try different hidden layer sizes (64, 256, 512)
- Experiment with multiple hidden layers
- Adjust the temporal parameters (gamma_plus, gamma_minus, dt)
- Compare with traditional RNN/LSTM baselines