# Surrogate Gradient SNN Training with Complete Energy Tracking

This notebook trains a Surrogate Gradient SNN on MNIST and saves **all metrics needed for energy analysis**:

## What Gets Saved

| Category | Metrics | Used For |
|----------|---------|----------|
| **Training Spikes** | total_input, total_hidden, total_output | Training energy analysis |
| **Inference Spikes** | input/hidden/output per inference | Inference energy analysis |
| **Connectivity** | fanout, synapse counts | Synaptic event calculation |
| **Performance** | accuracy, training time | Efficiency comparisons |

## Workflow
1. Run all cells in order
2. Training automatically tracks all spikes
3. After training, inference spikes are measured on test data
4. Complete checkpoint saved with all energy-relevant data

In [None]:
# =============================================================================
# Cell 1: Imports and Setup
# =============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
from datetime import datetime
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("\nSetup complete!")

In [None]:
# =============================================================================
# Cell 2: LIF Neuron with Spike Counting
# =============================================================================

class SimpleLIFNeuron(nn.Module):
    """
    Leaky Integrate-and-Fire neuron with surrogate gradient.
    Includes spike counting for energy analysis.
    """
    def __init__(self, input_size, output_size, threshold=1.0, decay=0.9):
        super().__init__()
        self.synapses = nn.Linear(input_size, output_size)
        self.decay = decay
        self.threshold = threshold
        self.membrane_v = None

        # Spike counting for energy analysis
        self.total_spikes = 0
        self.counting_enabled = True

    def forward(self, input_spikes):
        batch_size = input_spikes.size(0)
        if self.membrane_v is None or self.membrane_v.size(0) != batch_size:
            self.membrane_v = torch.zeros(
                batch_size, self.synapses.out_features,
                device=input_spikes.device
            )

        # Integrate
        synaptic_current = self.synapses(input_spikes)
        self.membrane_v = self.decay * self.membrane_v + synaptic_current

        # Spike with surrogate gradient (straight-through estimator)
        spikes = (self.membrane_v >= self.threshold).float()
        spikes = spikes + (self.membrane_v - self.membrane_v.detach()) * 0.3

        # Count spikes for energy tracking
        if self.counting_enabled:
            self.total_spikes += spikes.detach().sum().item()

        # Reset neurons that spiked
        self.membrane_v = self.membrane_v * (1 - spikes.detach())

        return spikes

    def reset_state(self):
        """Reset membrane potential (call between samples)."""
        self.membrane_v = None

    def reset_spike_count(self):
        """Reset spike counter."""
        self.total_spikes = 0


class SimpleSpikingNetwork(nn.Module):
    """
    Two-layer SNN: Input -> Hidden (LIF) -> Output (LIF)
    Tracks spikes at all layers for energy analysis.
    """
    def __init__(self, n_input=784, n_hidden=400, n_output=10,
                 threshold=1.0, decay=0.9):
        super().__init__()
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_output = n_output

        self.layer1 = SimpleLIFNeuron(n_input, n_hidden, threshold, decay)
        self.layer2 = SimpleLIFNeuron(n_hidden, n_output, threshold, decay)

        # Input spike tracking
        self.total_input_spikes = 0
        self.counting_enabled = True

    def forward(self, spike_sequence):
        """
        Forward pass through time.

        Args:
            spike_sequence: (batch, time_steps, n_input) tensor of input spikes

        Returns:
            total_output_spikes: (batch, n_output) accumulated output spikes
        """
        batch_size, time_steps, _ = spike_sequence.shape

        # Reset neuron states
        self.layer1.reset_state()
        self.layer2.reset_state()

        total_output_spikes = torch.zeros(
            batch_size, self.n_output,
            device=spike_sequence.device
        )

        for t in range(time_steps):
            current_input = spike_sequence[:, t, :]

            # Count input spikes
            if self.counting_enabled:
                self.total_input_spikes += current_input.sum().item()

            # Forward through layers
            hidden_spikes = self.layer1(current_input)
            output_spikes = self.layer2(hidden_spikes)
            total_output_spikes += output_spikes

        return total_output_spikes

    def reset_all_spike_counts(self):
        """Reset all spike counters."""
        self.total_input_spikes = 0
        self.layer1.reset_spike_count()
        self.layer2.reset_spike_count()

    def set_counting(self, enabled):
        """Enable/disable spike counting."""
        self.counting_enabled = enabled
        self.layer1.counting_enabled = enabled
        self.layer2.counting_enabled = enabled

    def get_spike_counts(self):
        """Get current spike counts."""
        return {
            'input': self.total_input_spikes,
            'hidden': self.layer1.total_spikes,
            'output': self.layer2.total_spikes,
        }

    def get_architecture_info(self):
        """Get architecture details for energy analysis."""
        return {
            'n_input': self.n_input,
            'n_hidden': self.n_hidden,
            'n_output': self.n_output,
            'n_syn_input_hidden': self.n_input * self.n_hidden,
            'n_syn_hidden_output': self.n_hidden * self.n_output,
            'n_synapses_total': self.n_input * self.n_hidden + self.n_hidden * self.n_output,
            'fanout_input_to_hidden': self.n_hidden,
            'fanout_hidden_to_output': self.n_output,
            'threshold': self.layer1.threshold,
            'decay': self.layer1.decay,
        }


print("Network classes defined!")
print(f"  SimpleLIFNeuron: LIF with surrogate gradient + spike counting")
print(f"  SimpleSpikingNetwork: 2-layer SNN with full spike tracking")

In [None]:
# =============================================================================
# Cell 3: Data Loading and Encoding
# =============================================================================

def poisson_encoder(image, time_window, max_rate=80):
    """
    Convert image to Poisson spike train.

    Args:
        image: (C, H, W) or flattened image tensor
        time_window: Number of time steps
        max_rate: Maximum firing rate in Hz

    Returns:
        spikes: (time_window, n_pixels) binary spike tensor
    """
    if torch.is_tensor(image):
        flat_image = image.cpu().reshape(-1).numpy()
    else:
        flat_image = np.array(image).reshape(-1)

    # Convert pixel intensity to firing probability per timestep
    # P(spike) = (intensity * max_rate) / 1000 per ms
    firing_rates = flat_image * max_rate
    spike_probs = firing_rates / 1000.0

    # Generate spikes
    spikes = np.random.rand(time_window, len(firing_rates)) < spike_probs
    return torch.from_numpy(spikes.astype(np.float32))


def prepare_spike_batch(images, time_window, max_rate=80):
    """
    Convert batch of images to spike trains.

    Args:
        images: (batch, C, H, W) image batch
        time_window: Number of time steps
        max_rate: Maximum firing rate in Hz

    Returns:
        spike_batch: (batch, time_window, n_pixels) spike tensor
    """
    batch = []
    for img in images:
        spikes = poisson_encoder(img.squeeze(), time_window, max_rate)
        batch.append(spikes)
    return torch.stack(batch)


# Load MNIST
transform = transforms.ToTensor()
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

print(f"MNIST loaded: {len(train_data)} train, {len(test_data)} test")
print(f"Image shape: {train_data[0][0].shape}")

In [None]:
# =============================================================================
# Cell 4: Hyperparameters
# =============================================================================

# Architecture
N_INPUT = 784
N_HIDDEN = 400
N_OUTPUT = 10

# Neuron parameters
THRESHOLD = 1.0
DECAY = 0.9

# Training
BATCH_SIZE = 64
MAX_EPOCHS = 20
LEARNING_RATE = 0.003
WEIGHT_DECAY = 1e-5
PATIENCE = 5

# Encoding
TIME_WINDOW = 50  # ms (timesteps)
MAX_RATE = 80     # Hz

# Inference measurement
N_INFERENCE_SAMPLES = 4096

# Checkpoint path
CHECKPOINT_PATH = "surrogate_snn_checkpoint.pth"

print("="*60)
print("HYPERPARAMETERS")
print("="*60)
print(f"\nArchitecture: {N_INPUT} → {N_HIDDEN} → {N_OUTPUT}")
print(f"Synapses: {N_INPUT * N_HIDDEN + N_HIDDEN * N_OUTPUT:,}")
print(f"\nNeuron: threshold={THRESHOLD}, decay={DECAY}")
print(f"Encoding: {TIME_WINDOW}ms window, {MAX_RATE}Hz max rate")
print(f"\nTraining: {MAX_EPOCHS} epochs, batch={BATCH_SIZE}, LR={LEARNING_RATE}")
print(f"Early stopping patience: {PATIENCE}")

In [None]:
# =============================================================================
# Cell 5: Training
# =============================================================================

# Create data loaders
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# Create network
snn = SimpleSpikingNetwork(
    n_input=N_INPUT,
    n_hidden=N_HIDDEN,
    n_output=N_OUTPUT,
    threshold=THRESHOLD,
    decay=DECAY
).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(snn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2
)

# Training tracking
best_val_acc = 0
best_model_state = None
epochs_without_improvement = 0
training_history = {
    'losses': [],
    'train_acc': [],
    'val_acc': [],
    'epoch_times': [],
}

# Reset spike counters before training
snn.reset_all_spike_counts()
snn.set_counting(True)

print("\n" + "="*60)
print("TRAINING (with spike tracking)")
print("="*60 + "\n")

training_start = time.time()
total_training_samples = 0

for epoch in range(MAX_EPOCHS):
    epoch_start = time.time()
    snn.train()

    epoch_loss = 0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(train_loader):
        labels = labels.to(device)
        spike_trains = prepare_spike_batch(images, TIME_WINDOW, MAX_RATE).to(device)

        optimizer.zero_grad()
        outputs = snn(spike_trains)  # Spikes counted here!
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(snn.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_training_samples += labels.size(0)

        if batch_idx % 200 == 0:
            print(f"  Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}: "
                  f"Loss={loss.item():.4f}, Acc={100*correct/total:.1f}%")

    # Validation (don't count these spikes as training)
    snn.eval()
    snn.set_counting(False)

    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            labels = labels.to(device)
            spike_trains = prepare_spike_batch(images, TIME_WINDOW, MAX_RATE).to(device)
            outputs = snn(spike_trains)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    snn.set_counting(True)  # Re-enable for next epoch

    train_acc = 100 * correct / total
    val_acc = 100 * val_correct / val_total
    epoch_time = time.time() - epoch_start

    training_history['losses'].append(epoch_loss / len(train_loader))
    training_history['train_acc'].append(train_acc)
    training_history['val_acc'].append(val_acc)
    training_history['epoch_times'].append(epoch_time)

    print(f"\n*** Epoch {epoch+1}: Train={train_acc:.1f}%, Val={val_acc:.1f}%, "
          f"Time={epoch_time:.1f}s ***\n")

    scheduler.step(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = {k: v.cpu().clone() for k, v in snn.state_dict().items()}
        epochs_without_improvement = 0
        print(f"  ★ New best: {best_val_acc:.2f}%")
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= PATIENCE:
        print(f"\nEarly stopping after {epoch+1} epochs")
        break

    if val_acc >= 98.0:
        print(f"\nReached 98%+, stopping")
        break

training_time = time.time() - training_start
epochs_trained = len(training_history['losses'])

# Get training spike counts
training_spikes = snn.get_spike_counts()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"  Total samples processed: {total_training_samples:,}")
print(f"  Total time: {training_time:.1f}s ({training_time/60:.1f} min)")
print(f"  Epochs: {epochs_trained}")
print(f"  Best validation accuracy: {best_val_acc:.2f}%")
print(f"\n  TRAINING SPIKES (accumulated during forward passes):")
print(f"    Input:  {training_spikes['input']:,.0f} "
      f"({training_spikes['input']/total_training_samples:.1f}/sample)")
print(f"    Hidden: {training_spikes['hidden']:,.0f} "
      f"({training_spikes['hidden']/total_training_samples:.1f}/sample)")
print(f"    Output: {training_spikes['output']:,.0f} "
      f"({training_spikes['output']/total_training_samples:.1f}/sample)")

In [None]:
# =============================================================================
# Cell 6: Load Best Model and Measure Inference Spikes
# =============================================================================

# Load best model
snn.load_state_dict(best_model_state)
snn.eval()

# Reset counters for inference measurement
snn.reset_all_spike_counts()
snn.set_counting(True)

print("\n" + "="*60)
print(f"MEASURING INFERENCE SPIKES ({N_INFERENCE_SAMPLES} test samples)")
print("="*60 + "\n")

inference_start = time.time()
inference_samples_seen = 0

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(test_loader):
        if inference_samples_seen >= N_INFERENCE_SAMPLES:
            break

        # Clip batch if needed
        remaining = N_INFERENCE_SAMPLES - inference_samples_seen
        if images.size(0) > remaining:
            images = images[:remaining]

        spike_trains = prepare_spike_batch(images, TIME_WINDOW, MAX_RATE).to(device)
        outputs = snn(spike_trains)  # Spikes counted
        inference_samples_seen += images.size(0)

        if (batch_idx + 1) % 20 == 0:
            print(f"  [{inference_samples_seen}/{N_INFERENCE_SAMPLES}]")

inference_time = time.time() - inference_start
inference_spikes = snn.get_spike_counts()

print(f"\n  Measurement complete in {inference_time:.1f}s")
print(f"\n  INFERENCE SPIKES (per inference):")
print(f"    Input:  {inference_spikes['input']/inference_samples_seen:.2f}")
print(f"    Hidden: {inference_spikes['hidden']/inference_samples_seen:.2f}")
print(f"    Output: {inference_spikes['output']/inference_samples_seen:.2f}")

In [None]:
# =============================================================================
# Cell 7: Final Evaluation and Visualization
# =============================================================================

# Disable counting for final evaluation
snn.set_counting(False)
snn.eval()

print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60 + "\n")

all_predictions = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        labels_dev = labels.to(device)
        spike_trains = prepare_spike_batch(images, TIME_WINDOW, MAX_RATE).to(device)
        outputs = snn(spike_trains)
        _, predicted = torch.max(outputs, 1)

        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)

test_accuracy = 100 * np.mean(all_predictions == all_labels)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Per-class accuracy
print(f"\nPer-class accuracy:")
for digit in range(10):
    mask = all_labels == digit
    class_acc = 100 * np.mean(all_predictions[mask] == digit)
    print(f"  Digit {digit}: {class_acc:.1f}%")

# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=range(10), yticklabels=range(10))
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(f'Confusion Matrix (Accuracy: {test_accuracy:.2f}%)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# =============================================================================
# Cell 8: Training Curves
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(training_history['losses'], 'b-', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(training_history['train_acc'], 'b-', label='Train', linewidth=2)
axes[1].plot(training_history['val_acc'], 'r-', label='Validation', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal train accuracy: {training_history['train_acc'][-1]:.2f}%")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Final test accuracy: {test_accuracy:.2f}%")

In [None]:
# =============================================================================
# Cell 9: Save Complete Checkpoint
# =============================================================================

print("\n" + "="*60)
print("SAVING COMPLETE CHECKPOINT")
print("="*60 + "\n")

# Get architecture info
arch_info = snn.get_architecture_info()

# Compute derived metrics for energy analysis
n_synapses = arch_info['n_synapses_total']
dense_ops_per_inference = n_synapses * TIME_WINDOW

# Event-driven synaptic events per inference
# Input->Hidden: input_spikes * fanout_to_hidden
# Hidden->Output: hidden_spikes * fanout_to_output
event_syn_per_inference = (
    (inference_spikes['input'] / inference_samples_seen) * arch_info['fanout_input_to_hidden'] +
    (inference_spikes['hidden'] / inference_samples_seen) * arch_info['fanout_hidden_to_output']
)

# Build checkpoint
checkpoint = {
    # Metadata
    "schema_version": "surrogate_complete_v1",
    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "model_class": "SimpleSpikingNetwork",

    # Model weights
    "model_state_dict": best_model_state,

    # Architecture
    "architecture": arch_info,
    "n_input": N_INPUT,
    "n_hidden": N_HIDDEN,
    "n_output": N_OUTPUT,

    # Encoding protocol
    "encoding": {
        "type": "poisson",
        "time_window": TIME_WINDOW,
        "max_rate_hz": MAX_RATE,
        "spike_prob_formula": "p = (pixel * max_rate) / 1000 per timestep",
    },

    # Training configuration
    "training_config": {
        "batch_size": BATCH_SIZE,
        "max_epochs": MAX_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "optimizer": "Adam",
        "patience": PATIENCE,
        "grad_clip_max_norm": 1.0,
    },

    # Training results
    "training_results": {
        "epochs_trained": epochs_trained,
        "total_samples": total_training_samples,
        "training_time_seconds": training_time,
        "best_val_accuracy": best_val_acc,
        "final_train_accuracy": training_history['train_acc'][-1],
    },
    "training_history": training_history,

    # Test performance
    "test_accuracy": test_accuracy,

    # =================================================================
    # ENERGY ANALYSIS DATA
    # =================================================================

    # Training spikes (for training energy analysis)
    "training_spikes": {
        "total_input": training_spikes['input'],
        "total_hidden": training_spikes['hidden'],
        "total_output": training_spikes['output'],
        "input_per_sample": training_spikes['input'] / total_training_samples,
        "hidden_per_sample": training_spikes['hidden'] / total_training_samples,
        "output_per_sample": training_spikes['output'] / total_training_samples,
        "total_samples": total_training_samples,
    },

    # Inference spikes (for inference energy analysis)
    "inference_spikes": {
        "n_samples": inference_samples_seen,
        "total_input": inference_spikes['input'],
        "total_hidden": inference_spikes['hidden'],
        "total_output": inference_spikes['output'],
        "input_per_inference": inference_spikes['input'] / inference_samples_seen,
        "hidden_per_inference": inference_spikes['hidden'] / inference_samples_seen,
        "output_per_inference": inference_spikes['output'] / inference_samples_seen,
        "measurement_time_seconds": inference_time,
    },

    # Connectivity (for synaptic event calculations)
    "connectivity": {
        "n_syn_input_hidden": arch_info['n_syn_input_hidden'],
        "n_syn_hidden_output": arch_info['n_syn_hidden_output'],
        "n_synapses_total": n_synapses,
        "fanout_input_to_hidden": arch_info['fanout_input_to_hidden'],
        "fanout_hidden_to_output": arch_info['fanout_hidden_to_output'],
    },

    # Pre-computed energy proxies (for convenience)
    "energy_proxies": {
        "dense_ops_per_inference": dense_ops_per_inference,
        "event_syn_per_inference": event_syn_per_inference,
        "sparsity_ratio": event_syn_per_inference / dense_ops_per_inference,
        "formula_event_syn": "S_in * fanout_in_h + S_hidden * fanout_h_out",
        "formula_dense": "N_synapses * TIME_WINDOW",
    },

    # For optimizer resume (optional)
    "optimizer_state_dict": optimizer.state_dict(),
}

# Save
torch.save(checkpoint, CHECKPOINT_PATH)

print(f"Checkpoint saved: {CHECKPOINT_PATH}")
print(f"\nSchema: {checkpoint['schema_version']}")
print(f"\n" + "-"*60)
print("CHECKPOINT CONTENTS SUMMARY")
print("-"*60)

print(f"\n[ARCHITECTURE]")
print(f"  {N_INPUT} → {N_HIDDEN} → {N_OUTPUT}")
print(f"  Synapses: {n_synapses:,}")
print(f"  Time window: {TIME_WINDOW} steps")

print(f"\n[TRAINING]")
print(f"  Samples: {total_training_samples:,}")
print(f"  Epochs: {epochs_trained}")
print(f"  Time: {training_time:.1f}s")

print(f"\n[PERFORMANCE]")
print(f"  Best val accuracy: {best_val_acc:.2f}%")
print(f"  Test accuracy: {test_accuracy:.2f}%")

print(f"\n[TRAINING SPIKES - for Training Energy]")
print(f"  Total input:  {training_spikes['input']:,.0f}")
print(f"  Total hidden: {training_spikes['hidden']:,.0f}")
print(f"  Total output: {training_spikes['output']:,.0f}")
print(f"  Input/sample:  {training_spikes['input']/total_training_samples:.1f}")
print(f"  Hidden/sample: {training_spikes['hidden']/total_training_samples:.1f}")

print(f"\n[INFERENCE SPIKES - for Inference Energy]")
print(f"  Measured on: {inference_samples_seen} samples")
print(f"  Input/inference:  {inference_spikes['input']/inference_samples_seen:.2f}")
print(f"  Hidden/inference: {inference_spikes['hidden']/inference_samples_seen:.2f}")
print(f"  Output/inference: {inference_spikes['output']/inference_samples_seen:.2f}")

print(f"\n[ENERGY PROXIES]")
print(f"  Dense ops/inference:      {dense_ops_per_inference:,}")
print(f"  Event syn/inference:      {event_syn_per_inference:,.2f}")
print(f"  Sparsity ratio:           {event_syn_per_inference/dense_ops_per_inference:.4f}")

print(f"\n" + "="*60)
print("CHECKPOINT COMPLETE - READY FOR ENERGY ANALYSIS")
print("="*60)

In [None]:
# =============================================================================
# Cell 10: Verify Checkpoint (Optional)
# =============================================================================

def inspect_surrogate_checkpoint(filepath):
    """Load and display all checkpoint contents."""
    ckpt = torch.load(filepath, map_location='cpu', weights_only=False)

    print(f"\n{'='*70}")
    print(f"CHECKPOINT: {filepath}")
    print(f"{'='*70}")

    print(f"\n[METADATA]")
    print(f"  Schema: {ckpt.get('schema_version', 'unknown')}")
    print(f"  Timestamp: {ckpt.get('timestamp', 'unknown')}")

    print(f"\n[ARCHITECTURE]")
    print(f"  Input:  {ckpt.get('n_input', 'N/A')}")
    print(f"  Hidden: {ckpt.get('n_hidden', 'N/A')}")
    print(f"  Output: {ckpt.get('n_output', 'N/A')}")

    enc = ckpt.get('encoding', {})
    print(f"\n[ENCODING]")
    print(f"  Type: {enc.get('type', 'N/A')}")
    print(f"  Time window: {enc.get('time_window', 'N/A')}")
    print(f"  Max rate: {enc.get('max_rate_hz', 'N/A')} Hz")

    tr = ckpt.get('training_results', {})
    print(f"\n[TRAINING RESULTS]")
    print(f"  Epochs: {tr.get('epochs_trained', 'N/A')}")
    print(f"  Samples: {tr.get('total_samples', 'N/A'):,}")
    print(f"  Time: {tr.get('training_time_seconds', 0):.1f}s")
    print(f"  Best val acc: {tr.get('best_val_accuracy', 'N/A'):.2f}%")

    print(f"\n[TEST ACCURACY]")
    print(f"  {ckpt.get('test_accuracy', 'N/A'):.2f}%")

    ts = ckpt.get('training_spikes', {})
    print(f"\n[TRAINING SPIKES]")
    print(f"  Total input:   {ts.get('total_input', 0):,.0f}")
    print(f"  Total hidden:  {ts.get('total_hidden', 0):,.0f}")
    print(f"  Total output:  {ts.get('total_output', 0):,.0f}")
    print(f"  Input/sample:  {ts.get('input_per_sample', 0):.1f}")
    print(f"  Hidden/sample: {ts.get('hidden_per_sample', 0):.1f}")

    inf = ckpt.get('inference_spikes', {})
    print(f"\n[INFERENCE SPIKES]")
    print(f"  N samples:          {inf.get('n_samples', 'N/A')}")
    print(f"  Input/inference:    {inf.get('input_per_inference', 0):.2f}")
    print(f"  Hidden/inference:   {inf.get('hidden_per_inference', 0):.2f}")
    print(f"  Output/inference:   {inf.get('output_per_inference', 0):.2f}")

    conn = ckpt.get('connectivity', {})
    print(f"\n[CONNECTIVITY]")
    print(f"  Syn input→hidden:  {conn.get('n_syn_input_hidden', 'N/A'):,}")
    print(f"  Syn hidden→output: {conn.get('n_syn_hidden_output', 'N/A'):,}")
    print(f"  Total synapses:    {conn.get('n_synapses_total', 'N/A'):,}")
    print(f"  Fanout in→hidden:  {conn.get('fanout_input_to_hidden', 'N/A')}")
    print(f"  Fanout hidden→out: {conn.get('fanout_hidden_to_output', 'N/A')}")

    ep = ckpt.get('energy_proxies', {})
    print(f"\n[ENERGY PROXIES]")
    print(f"  Dense ops/inference:  {ep.get('dense_ops_per_inference', 0):,}")
    print(f"  Event syn/inference:  {ep.get('event_syn_per_inference', 0):,.2f}")
    print(f"  Sparsity ratio:       {ep.get('sparsity_ratio', 0):.6f}")

    return ckpt


# Verify the saved checkpoint
_ = inspect_surrogate_checkpoint(CHECKPOINT_PATH)

---

## Checkpoint Schema Reference

The saved checkpoint (`surrogate_snn_checkpoint.pth`) contains:

### For Training Energy Analysis
```python
ckpt['training_spikes']['total_input']      # Total input spikes during training
ckpt['training_spikes']['total_hidden']     # Total hidden layer spikes
ckpt['training_spikes']['total_output']     # Total output layer spikes
ckpt['training_spikes']['input_per_sample'] # Average per training sample
ckpt['training_spikes']['total_samples']    # Number of training samples
ckpt['training_results']['training_time_seconds']
```

### For Inference Energy Analysis
```python
ckpt['inference_spikes']['input_per_inference']   # Avg input spikes per test image
ckpt['inference_spikes']['hidden_per_inference']  # Avg hidden spikes per test image
ckpt['inference_spikes']['output_per_inference']  # Avg output spikes per test image
ckpt['inference_spikes']['n_samples']             # Number measured
```

### For Synaptic Event Calculations
```python
ckpt['connectivity']['n_syn_input_hidden']       # 784 * 400 = 313,600
ckpt['connectivity']['n_syn_hidden_output']      # 400 * 10 = 4,000
ckpt['connectivity']['fanout_input_to_hidden']   # 400
ckpt['connectivity']['fanout_hidden_to_output']  # 10
```

### Pre-computed Proxies
```python
ckpt['energy_proxies']['dense_ops_per_inference']   # N_syn * T
ckpt['energy_proxies']['event_syn_per_inference']   # S_in*400 + S_hid*10
ckpt['energy_proxies']['sparsity_ratio']            # event/dense
```

### Performance
```python
ckpt['test_accuracy']                         # Final test accuracy
ckpt['training_results']['best_val_accuracy'] # Best validation accuracy
```