# Tutorial 03 — MNIST with Gated SOEN (MultiplierNOCC)

**Gated architecture** using MultiplierNOCC for selective memory.

## Architecture

```
Input (7) → Hidden (28, MultiplierNOCC) → Output (10)
                    ↺ recurrent
```

## Why MultiplierNOCC Gating?

MultiplierNOCC (No Collection Coils) has:
- **Dual SQUID states** (s1, s2): Can selectively amplify or suppress signals
- **Aggregated output** (m): Gated combination of states
- **Learnable phi_y**: Secondary input that acts as a gate control

This allows the network to:
1. **Hold important patterns** across 112 timesteps
2. **Suppress noise** or irrelevant inputs
3. **Selectively integrate** temporal information

## Hardware Compatibility

| Layer | Neurons | Inputs/Neuron |
|-------|---------|---------------|
| Input | 7 | - |
| Hidden (MultiplierNOCC) | 28 | 7 ✓ |
| Output | 10 | 28 |

In [None]:
import os
os.environ["TQDM_DISABLE"] = "0"
os.environ["TQDM_MININTERVAL"] = "1"

import sys
from pathlib import Path

notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import glob
import gzip
import urllib.request
import struct

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Prepare Dataset (112×7)

In [None]:
def download_mnist_file(filename, base_url="https://ossci-datasets.s3.amazonaws.com/mnist/"):
    data_dir = Path("./data/mnist")
    data_dir.mkdir(parents=True, exist_ok=True)
    filepath = data_dir / filename
    if not filepath.exists():
        print(f"Downloading {filename}...")
        urllib.request.urlretrieve(base_url + filename, filepath)
    return filepath

def read_mnist_images(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(num, rows, cols)

def read_mnist_labels(filepath):
    with gzip.open(filepath, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        return np.frombuffer(f.read(), dtype=np.uint8)

def prepare_mnist_112x7(output_path="training/datasets/mnist_seq112x7.hdf5"):
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if output_path.exists():
        print(f"Dataset exists: {output_path}")
        with h5py.File(output_path, 'r') as f:
            print(f"  Shape: {f['train']['data'].shape}")
        return output_path
    
    print("Preparing MNIST (112×7 format)...")
    
    train_img = read_mnist_images(download_mnist_file("train-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    train_lbl = read_mnist_labels(download_mnist_file("train-labels-idx1-ubyte.gz")).astype(np.int64)
    test_img = read_mnist_images(download_mnist_file("t10k-images-idx3-ubyte.gz")).astype(np.float32) / 255.0
    test_lbl = read_mnist_labels(download_mnist_file("t10k-labels-idx1-ubyte.gz")).astype(np.int64)
    
    train_img = train_img.reshape(-1, 784).reshape(-1, 112, 7)
    test_img = test_img.reshape(-1, 784).reshape(-1, 112, 7)
    
    np.random.seed(42)
    idx = np.random.permutation(len(train_img))
    n_val = int(len(train_img) * 0.1)
    
    val_img, val_lbl = train_img[idx[:n_val]], train_lbl[idx[:n_val]]
    train_img, train_lbl = train_img[idx[n_val:]], train_lbl[idx[n_val:]]
    
    with h5py.File(output_path, 'w') as f:
        for name, data, labels in [('train', train_img, train_lbl), 
                                    ('val', val_img, val_lbl), 
                                    ('test', test_img, test_lbl)]:
            g = f.create_group(name)
            g.create_dataset('data', data=data)
            g.create_dataset('labels', data=labels)
    
    print("Done!")
    return output_path

data_path = prepare_mnist_112x7()

## 2. Explain MultiplierNOCC Gating

In [None]:
def explain_gating():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: SingleDendrite vs MultiplierNOCC comparison
    ax = axes[0]
    ax.text(0.5, 0.95, 'SingleDendrite vs MultiplierNOCC', ha='center', 
            fontsize=14, fontweight='bold', transform=ax.transAxes)
    
    comparison = """
    SingleDendrite (simple):
    ─────────────────────────
    • One state variable (s)
    • Leaky integration: ds/dt = γ⁺g(φ) - γ⁻s
    • Good for basic temporal processing
    
    MultiplierNOCC (gating):
    ─────────────────────────
    • Dual SQUID states (s1, s2)
    • Secondary input φ_y (learnable gate)
    • Aggregated output m = f(s1, s2, φ_y)
    • Can selectively amplify/suppress
    • Better for long-range dependencies
    """
    ax.text(0.1, 0.8, comparison, fontsize=10, family='monospace',
            transform=ax.transAxes, va='top')
    ax.axis('off')
    
    # Right: Gating diagram
    ax = axes[1]
    
    # Draw gating circuit
    ax.add_patch(plt.Rectangle((0.1, 0.3), 0.3, 0.4, fill=False, ec='blue', lw=2))
    ax.text(0.25, 0.5, 's1, s2\n(dual states)', ha='center', va='center', fontsize=10)
    
    ax.add_patch(plt.Circle((0.6, 0.5), 0.1, fill=False, ec='green', lw=2))
    ax.text(0.6, 0.5, 'm', ha='center', va='center', fontsize=12, fontweight='bold')
    
    # Arrows
    ax.annotate('', xy=(0.5, 0.5), xytext=(0.4, 0.5),
                arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    ax.annotate('', xy=(0.25, 0.3), xytext=(0.25, 0.1),
                arrowprops=dict(arrowstyle='->', color='red', lw=1.5))
    ax.text(0.25, 0.05, 'φ (input)', ha='center', fontsize=10, color='red')
    
    ax.annotate('', xy=(0.6, 0.3), xytext=(0.6, 0.1),
                arrowprops=dict(arrowstyle='->', color='purple', lw=1.5))
    ax.text(0.6, 0.05, 'φ_y (gate)', ha='center', fontsize=10, color='purple')
    
    ax.annotate('', xy=(0.85, 0.5), xytext=(0.7, 0.5),
                arrowprops=dict(arrowstyle='->', color='green', lw=2))
    ax.text(0.9, 0.5, 'output', ha='left', va='center', fontsize=10, color='green')
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('MultiplierNOCC Gating Mechanism', fontsize=12, fontweight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("\nGating Benefits for 112-timestep sequences:")
    print("="*50)
    print("• Selective memory: hold important features, forget noise")
    print("• Learnable φ_y: network learns what to gate")
    print("• Dual states: richer dynamics than single-state neurons")
    print("• Better gradient flow: multiplicative gating helps backprop")

explain_gating()

## 3. Load Model

In [None]:
from soen_toolkit.core.model_yaml import build_model_from_yaml

model_path = Path("training/test_models/model_specs/MNIST_SOENSpec_112x7_gated.yaml")
model = build_model_from_yaml(model_path)

print("="*60)
print("GATED SOEN MODEL (MultiplierNOCC)")
print("="*60)

print("\nLayers:")
for lid, dim in model.layer_nodes.items():
    layer_type = model.layers_config[lid].layer_type if lid < len(model.layers_config) else "Unknown"
    print(f"  Layer {lid}: {dim} neurons ({layer_type})")

print("\nConnections:")
for name, param in model.connections.items():
    print(f"  {name}: {list(param.shape)}")

print("\nLearnable gate parameter (phi_y):")
for name, param in model.named_parameters():
    if 'phi_y' in name:
        print(f"  {name}: shape={list(param.shape)}, requires_grad={param.requires_grad}")

print("\nForward pass test...")
x = torch.randn(2, 112, 7)
with torch.no_grad():
    y, states = model(x)
print(f"  Input: {x.shape}")
print(f"  Output: {y.shape}")
print("  Success!")

## 4. Train

In [None]:
os.environ["SOEN_NO_PROGRESS_BAR"] = "1"

from soen_toolkit.training.trainers.experiment import run_from_config

print("="*60)
print("TRAINING GATED MODEL (MultiplierNOCC)")
print("="*60)
print("Hidden layer uses MultiplierNOCC for selective memory gating")
print("="*60)

run_from_config("training/training_configs/mnist_soen_112x7_gated.yaml", script_dir=Path.cwd())

## 5. Evaluate

In [None]:
def load_checkpoint():
    patterns = ["training/temp/**/checkpoints/**/*.ckpt", "training/temp/**/*.ckpt"]
    ckpts = []
    for p in patterns:
        ckpts.extend(glob.glob(p, recursive=True))
    
    gated_ckpts = [c for c in ckpts if 'gated' in c.lower() or 'nocc' in c.lower()]
    if not gated_ckpts:
        gated_ckpts = ckpts
    
    if not gated_ckpts:
        print("No checkpoint found.")
        return None
    
    latest = max(gated_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading: {latest}")
    
    model = build_model_from_yaml(model_path)
    ckpt = torch.load(latest, map_location='cpu')
    state = ckpt.get('state_dict', ckpt)
    clean = {k[6:] if k.startswith('model.') else k: v for k, v in state.items()}
    model.load_state_dict(clean, strict=False)
    model.eval()
    return model

trained_model = load_checkpoint()

In [None]:
def evaluate(model, data_path):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        test_data = np.array(f['test']['data'])
        test_labels = np.array(f['test']['labels'])
    
    model.eval()
    all_preds = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_data), 128)):
            x = torch.tensor(test_data[i:i+128], dtype=torch.float32)
            y, _ = model(x)
            if y.dim() == 3:
                y = y.mean(dim=1)
            all_preds.append(y.argmax(dim=1).numpy())
    
    preds = np.concatenate(all_preds)
    acc = (preds == test_labels).mean()
    
    print(f"\nTest Accuracy: {acc:.4f} ({acc*100:.2f}%)")
    return acc

if trained_model:
    evaluate(trained_model, data_path)

## 6. Visualize Predictions

In [None]:
def visualize_predictions(model, data_path, n=20):
    if model is None:
        return
    
    with h5py.File(data_path, 'r') as f:
        data = np.array(f['test']['data'])
        labels = np.array(f['test']['labels'])
    
    np.random.seed(42)
    idx = np.random.choice(len(data), n, replace=False)
    samples, true = data[idx], labels[idx]
    images = samples.reshape(n, 784).reshape(n, 28, 28)
    
    model.eval()
    with torch.no_grad():
        x = torch.tensor(samples, dtype=torch.float32)
        y, _ = model(x)
        if y.dim() == 3:
            y = y.mean(dim=1)
        probs = torch.softmax(y, dim=1)
        preds = probs.argmax(dim=1).numpy()
        conf = probs.max(dim=1)[0].numpy()
    
    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.flatten()
    fig.suptitle('Predictions (Gated SOEN - MultiplierNOCC)', fontsize=14, fontweight='bold')
    
    for i in range(n):
        axes[i].imshow(images[i], cmap='gray')
        ok = preds[i] == true[i]
        axes[i].set_title(f"{'✓' if ok else '✗'} {preds[i]} ({conf[i]:.0%})\nTrue: {true[i]}",
                          color='green' if ok else 'red', fontsize=9)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Sample accuracy: {(preds == true).mean():.1%}")

if trained_model:
    visualize_predictions(trained_model, data_path)

## Summary

| Aspect | SingleDendrite | MultiplierNOCC |
|--------|----------------|----------------|
| State variables | 1 (s) | 3 (s1, s2, m) |
| Gating | None | Via φ_y |
| Memory | Leaky integration | Selective memory |
| Parameters | γ⁺, γ⁻ | α, β, β_out, φ_y |

**Gating can help with long sequences (112 timesteps) by:**
1. Selectively retaining important patterns
2. Suppressing noise and irrelevant inputs
3. Learning what to remember via φ_y