# STDP Network Training with Complete Energy Tracking

This notebook provides:
1. **Seamless resume** - Training can be paused and resumed without losing any data
2. **Complete spike tracking** - Both training and inference spikes are properly accumulated
3. **Energy-ready checkpoints** - All data needed for energy analysis is saved

## Workflow
1. Run Cells 1-4 (Setup, Encoder, Network, Save/Load functions)
2. Run Cell 5 to either start fresh or resume from checkpoint
3. Training automatically saves checkpoint with all spike data
4. Evaluation (receptive fields, accuracy, confusion matrix) runs automatically
5. To continue training, just run Cell 5 again

In [None]:
# =============================================================================
# Cell 1: Install Brian2 and Imports
# =============================================================================

!pip install -q git+https://github.com/brian-team/brian2.git

import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
import brian2 as b2
from brian2 import *
import time
import pickle
import os
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Brian2 config
b2.prefs.codegen.target = 'cython'
b2.prefs.codegen.cpp.extra_compile_args_gcc = ['-O3', '-ffast-math']
b2.prefs.core.default_float_dtype = np.float32

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

print(f"Brian2 version: {b2.__version__}")
print(f"PyTorch version: {torch.__version__}")
print("Setup complete!")

In [None]:
# =============================================================================
# Cell 2: MNIST Spike Encoder
# =============================================================================

class MNISTSpikeEncoder:
    """
    MNIST spike encoding following Diehl & Cook (2015).
    Pixel intensities converted to Poisson firing rates.
    """
    def __init__(self, time_window=350, rest_window=150, max_firing_rate=63.75):
        self.time_window = time_window
        self.rest_window = rest_window
        self.max_rate = max_firing_rate

        transform = transforms.ToTensor()
        self.train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
        self.test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

        print(f"Loaded {len(self.train_data)} training, {len(self.test_data)} test samples")
        print(f"Encoding: {max_firing_rate} Hz max rate, {time_window}ms presentation")

# Create encoder
encoder = MNISTSpikeEncoder()
print("Encoder ready!")

In [None]:
# =============================================================================
# Cell 3: STDP Network Class
# =============================================================================

class BiologicalSTDPNetwork:
    """
    STDP Network matching Diehl & Cook (2015).
    Includes spike monitors for input, excitatory, and inhibitory layers.
    """

    def __init__(self, n_input=784, n_excitatory=400, n_inhibitory=400):
        self.n_input = n_input
        self.n_exc = n_excitatory
        self.n_inh = n_inhibitory

        # Parameters from Diehl & Cook (2015)
        self.params = {
            'v_rest_e': -65 * mV,
            'v_rest_i': -60 * mV,
            'v_reset_e': -65 * mV,
            'v_reset_i': -45 * mV,
            'v_thresh_e': -52 * mV,
            'v_thresh_i': -40 * mV,
            'tau_mem_e': 100 * ms,
            'tau_mem_i': 10 * ms,
            'refrac_e': 5 * ms,
            'refrac_i': 2 * ms,
            'E_exc': 0 * mV,
            'E_inh': -100 * mV,
            'tau_ge': 1 * ms,
            'tau_gi': 2 * ms,
            'tau_pre': 20 * ms,
            'tau_post': 20 * ms,
            'eta_pre': 0.0001,
            'eta_post': 0.01,
            'mu': 0.4,
            'x_tar': 0.4,
            'wmax': 1.0,
            'theta_plus': 0.05 * mV,
            'tau_theta': 1e7 * ms,
            'w_ei': 10.4,
            'w_ie': 17.0,
        }

        self.weight_norm_target = 78.4
        self.final_weights = None
        self.training_time = 0
        self.net = None

    def build_network(self):
        """Construct all neurons and synapses."""
        # Input layer
        self.input_group = PoissonGroup(
            self.n_input,
            rates=np.zeros(self.n_input) * Hz
        )

        # Excitatory neurons
        eqs_exc = '''
        dv/dt = (v_rest_e - v + ge*(E_exc - v) + gi*(E_inh - v)) / tau_mem_e : volt (unless refractory)
        dge/dt = -ge / tau_ge : 1
        dgi/dt = -gi / tau_gi : 1
        dtheta/dt = -theta / tau_theta : volt
        '''

        self.exc_neurons = NeuronGroup(
            self.n_exc,
            eqs_exc,
            threshold='v > v_thresh_e + theta',
            reset='v = v_reset_e; theta += theta_plus',
            refractory='refrac_e',
            method='euler',
            namespace=self.params
        )
        self.exc_neurons.v = self.params['v_rest_e']
        self.exc_neurons.theta = 0 * mV

        # Inhibitory neurons
        eqs_inh = '''
        dv/dt = (v_rest_i - v + ge*(E_exc - v)) / tau_mem_i : volt (unless refractory)
        dge/dt = -ge / tau_ge : 1
        '''

        self.inh_neurons = NeuronGroup(
            self.n_inh,
            eqs_inh,
            threshold='v > v_thresh_i',
            reset='v = v_reset_i',
            refractory='refrac_i',
            method='euler',
            namespace=self.params
        )
        self.inh_neurons.v = self.params['v_rest_i']

        # STDP synapses: Input -> Excitatory
        stdp_model = '''
        w : 1
        dApre/dt = -Apre / tau_pre : 1 (event-driven)
        dApost/dt = -Apost / tau_post : 1 (event-driven)
        '''

        on_pre = '''
        ge_post += w
        Apre += 1.0
        w = clip(w - eta_pre * Apost * (w ** mu), 0, wmax)
        '''

        on_post = '''
        Apost += 1.0
        w = clip(w + eta_post * (Apre - x_tar) * (clip(wmax - w, 0, wmax) ** mu), 0, wmax)
        '''

        self.syn_input_exc = Synapses(
            self.input_group,
            self.exc_neurons,
            model=stdp_model,
            on_pre=on_pre,
            on_post=on_post,
            namespace=self.params
        )
        self.syn_input_exc.connect()
        self.syn_input_exc.w = 'rand() * wmax * 0.3'

        # Exc -> Inh (one-to-one)
        self.syn_exc_inh = Synapses(
            self.exc_neurons,
            self.inh_neurons,
            on_pre='ge_post += w_ei',
            namespace=self.params
        )
        self.syn_exc_inh.connect('i == j')

        # Inh -> Exc (lateral inhibition, excluding self)
        self.syn_inh_exc = Synapses(
            self.inh_neurons,
            self.exc_neurons,
            on_pre='gi_post += w_ie',
            namespace=self.params
        )
        self.syn_inh_exc.connect('i != j')

        # SPIKE MONITORS - Critical for energy analysis!
        self.spike_mon_in = SpikeMonitor(self.input_group)
        self.spike_mon_exc = SpikeMonitor(self.exc_neurons)
        self.spike_mon_inh = SpikeMonitor(self.inh_neurons)

        # Assemble network
        self.net = Network(
            self.input_group,
            self.exc_neurons,
            self.inh_neurons,
            self.syn_input_exc,
            self.syn_exc_inh,
            self.syn_inh_exc,
            self.spike_mon_in,
            self.spike_mon_exc,
            self.spike_mon_inh
        )

        self.normalize_weights()
        print(f"Network built: {self.n_input} input → {self.n_exc} exc ↔ {self.n_inh} inh")
        print(f"Spike monitors: input, exc, inh (all active)")

    def normalize_weights(self):
        """Divisive weight normalization."""
        w = np.zeros((self.n_input, self.n_exc))
        w[self.syn_input_exc.i[:], self.syn_input_exc.j[:]] = self.syn_input_exc.w[:]
        col_sums = w.sum(axis=0)
        col_sums[col_sums == 0] = 1
        w = w * (self.weight_norm_target / col_sums)
        w = np.clip(w, 0, self.params['wmax'])
        self.syn_input_exc.w[:] = w[self.syn_input_exc.i[:], self.syn_input_exc.j[:]]

    def present_sample(self, rates_hz, min_spikes=5, max_attempts=5):
        """Present a sample with adaptive rate boosting."""
        rate_boost = 0.0

        for attempt in range(max_attempts):
            n_before = self.spike_mon_exc.num_spikes
            self.input_group.rates = (rates_hz + rate_boost * Hz)
            self.net.run(350 * ms)
            n_after = self.spike_mon_exc.num_spikes
            n_spikes = n_after - n_before

            # Rest period
            self.input_group.rates = np.zeros(self.n_input) * Hz
            self.net.run(150 * ms)

            if n_spikes >= min_spikes:
                return n_spikes, attempt + 1
            rate_boost += 32.0

        return n_spikes, max_attempts

    def train(self, encoder, n_samples, normalize_interval=20, print_interval=100):
        """Train the network."""
        print(f"\n{'='*60}")
        print(f"Training STDP Network")
        print(f"{'='*60}")
        print(f"Samples: {n_samples}")

        if self.net is None:
            self.build_network()

        start_time = time.time()
        total_spikes = 0

        for i in range(n_samples):
            img, label = encoder.train_data[i % len(encoder.train_data)]
            rates = img.numpy().flatten() * encoder.max_rate * Hz
            n_spikes, attempts = self.present_sample(rates)
            total_spikes += n_spikes

            if (i + 1) % normalize_interval == 0:
                self.normalize_weights()

            if (i + 1) % print_interval == 0 or (i + 1) <= 3:
                elapsed = time.time() - start_time
                avg_time = elapsed / (i + 1)
                remaining = avg_time * (n_samples - i - 1)
                print(f"[{i+1:6d}/{n_samples}] "
                      f"spikes: {total_spikes/(i+1):.1f}, "
                      f"θ: {np.mean(self.exc_neurons.theta/mV):.2f}mV, "
                      f"ETA: {remaining/60:.1f}min")

        self.training_time = time.time() - start_time
        self._save_final_state()
        print(f"\nTraining complete! Time: {self.training_time/60:.1f} min")

    def _save_final_state(self):
        """Cache final state for analysis."""
        w = np.zeros((self.n_input, self.n_exc))
        w[self.syn_input_exc.i[:], self.syn_input_exc.j[:]] = self.syn_input_exc.w[:]
        self.final_weights = {
            'weights': w,
            'theta': np.array(self.exc_neurons.theta / mV),
        }

print("BiologicalSTDPNetwork class defined!")

In [None]:
# =============================================================================
# Cell 4: Save/Load Functions with Spike Accumulation
# =============================================================================

def save_complete_checkpoint(
    network,
    encoder,
    filepath,
    total_training_samples,
    previous_checkpoint=None,
    inference_accuracy=None,
    neuron_labels=None,
    n_inference_samples=500,
    notes=None,
):
    """
    Save complete checkpoint with:
    - Accumulated training spikes (across sessions)
    - Inference spikes (measured on test data)
    - All data needed for energy analysis
    """
    print(f"\n{'='*60}")
    print("Saving Complete Checkpoint")
    print(f"{'='*60}")

    # =========================================================================
    # Part 1: Accumulate TRAINING spikes
    # =========================================================================
    current_train_in = int(network.spike_mon_in.num_spikes)
    current_train_exc = int(network.spike_mon_exc.num_spikes)
    current_train_inh = int(network.spike_mon_inh.num_spikes)

    if previous_checkpoint is not None:
        prev_train = previous_checkpoint.get("training_spikes", {})
        prev_in = int(prev_train.get("total_input", 0) or 0)
        prev_exc = int(prev_train.get("total_exc", 0) or 0)
        prev_inh = int(prev_train.get("total_inh", 0) or 0)
        prev_time = float(previous_checkpoint.get("training_time_seconds", 0) or 0)
        print(f"  Previous training spikes: in={prev_in}, exc={prev_exc}, inh={prev_inh}")
    else:
        prev_in = prev_exc = prev_inh = 0
        prev_time = 0

    acc_train_in = current_train_in + prev_in
    acc_train_exc = current_train_exc + prev_exc
    acc_train_inh = current_train_inh + prev_inh
    acc_time = network.training_time + prev_time

    print(f"  Current session spikes: in={current_train_in}, exc={current_train_exc}, inh={current_train_inh}")
    print(f"  ACCUMULATED training:   in={acc_train_in}, exc={acc_train_exc}, inh={acc_train_inh}")

    # =========================================================================
    # Part 2: Measure INFERENCE spikes (on test data)
    # =========================================================================
    print(f"\n  Measuring inference spikes on {n_inference_samples} test samples...")

    # Record current counts (after training)
    pre_inf_in = int(network.spike_mon_in.num_spikes)
    pre_inf_exc = int(network.spike_mon_exc.num_spikes)
    pre_inf_inh = int(network.spike_mon_inh.num_spikes)

    inf_start = time.time()
    for i in range(n_inference_samples):
        img, label = encoder.test_data[i % len(encoder.test_data)]
        rates = img.numpy().flatten() * encoder.max_rate * Hz
        network.input_group.rates = rates
        network.net.run(350 * ms)
        network.input_group.rates = np.zeros(network.n_input) * Hz
        network.net.run(150 * ms)

        if (i + 1) % 100 == 0:
            print(f"    [{i+1}/{n_inference_samples}]")

    inf_time = time.time() - inf_start

    # Compute inference-only spikes
    inf_in = int(network.spike_mon_in.num_spikes) - pre_inf_in
    inf_exc = int(network.spike_mon_exc.num_spikes) - pre_inf_exc
    inf_inh = int(network.spike_mon_inh.num_spikes) - pre_inf_inh

    print(f"  Inference spikes: in={inf_in}, exc={inf_exc}, inh={inf_inh}")
    print(f"  Spikes/inference: in={inf_in/n_inference_samples:.1f}, "
          f"exc={inf_exc/n_inference_samples:.1f}, inh={inf_inh/n_inference_samples:.1f}")

    # =========================================================================
    # Part 3: Extract weights and theta
    # =========================================================================
    w_dense = np.zeros((network.n_input, network.n_exc), dtype=np.float32)
    w_dense[network.syn_input_exc.i[:], network.syn_input_exc.j[:]] = \
        np.asarray(network.syn_input_exc.w[:], dtype=np.float32)

    theta_arr = np.array(network.exc_neurons.theta / mV, dtype=np.float32)

    # =========================================================================
    # Part 4: Build and save checkpoint
    # =========================================================================
    checkpoint = {
        "schema_version": "complete_v1",
        "saved_at_unix": time.time(),
        "notes": notes,

        # Architecture
        "n_input": int(network.n_input),
        "n_exc": int(network.n_exc),
        "n_inh": int(network.n_inh),

        # Learned state
        "weights": w_dense,
        "theta": theta_arr,

        # Training metadata
        "training_samples": int(total_training_samples),
        "training_time_seconds": acc_time,

        # TRAINING SPIKES (accumulated across all sessions)
        "training_spikes": {
            "total_input": acc_train_in,
            "total_exc": acc_train_exc,
            "total_inh": acc_train_inh,
            "input_per_sample": acc_train_in / total_training_samples,
            "exc_per_sample": acc_train_exc / total_training_samples,
            "inh_per_sample": acc_train_inh / total_training_samples,
        },

        # INFERENCE SPIKES (from test data measurement)
        "inference_spikes": {
            "n_samples": n_inference_samples,
            "total_input": inf_in,
            "total_exc": inf_exc,
            "total_inh": inf_inh,
            "input_per_inference": inf_in / n_inference_samples,
            "exc_per_inference": inf_exc / n_inference_samples,
            "inh_per_inference": inf_inh / n_inference_samples,
            "measurement_time_seconds": inf_time,
        },

        # For backward compatibility with energy_comparison_blocks.ipynb
        "total_input_spikes": acc_train_in,
        "total_exc_spikes": acc_train_exc,
        "total_inh_spikes": acc_train_inh,

        # Connectivity
        "connectivity": {
            "n_syn_input_exc": len(network.syn_input_exc),
            "n_syn_exc_inh": len(network.syn_exc_inh),
            "n_syn_inh_exc": len(network.syn_inh_exc),
            "fanout_input_to_exc": int(network.n_exc),
            "fanout_exc_to_inh": 1,
            "fanout_inh_to_exc": max(network.n_exc - 1, 0),
        },

        # Evaluation results
        "inference_accuracy": inference_accuracy,
        "neuron_labels": neuron_labels,
    }

    with open(filepath, "wb") as f:
        pickle.dump(checkpoint, f)

    print(f"\n{'='*60}")
    print(f"✓ Checkpoint saved: {filepath}")
    print(f"{'='*60}")
    print(f"  Training samples:     {total_training_samples}")
    print(f"  Training time:        {acc_time/60:.1f} min")
    print(f"  Train spikes/sample:  in={acc_train_in/total_training_samples:.1f}, "
          f"exc={acc_train_exc/total_training_samples:.1f}")
    print(f"  Inf spikes/inference: in={inf_in/n_inference_samples:.1f}, "
          f"exc={inf_exc/n_inference_samples:.1f}")
    if inference_accuracy is not None:
        print(f"  Accuracy:             {inference_accuracy:.2f}%")

    return checkpoint


def load_checkpoint(filepath, verbose=True):
    """
    Load checkpoint and restore network.
    Returns (network, checkpoint) for seamless resume.
    """
    with open(filepath, "rb") as f:
        checkpoint = pickle.load(f)

    if verbose:
        print(f"\n{'='*60}")
        print(f"Loading checkpoint: {filepath}")
        print(f"{'='*60}")
        print(f"  Schema: {checkpoint.get('schema_version', 'unknown')}")
        print(f"  Training samples: {checkpoint.get('training_samples', 'unknown')}")

    # Create fresh network with correct Brian2 units
    n_input = int(checkpoint.get("n_input", 784))
    n_exc = int(checkpoint.get("n_exc", 400))
    n_inh = int(checkpoint.get("n_inh", 400))

    network = BiologicalSTDPNetwork(
        n_input=n_input,
        n_excitatory=n_exc,
        n_inhibitory=n_inh
    )
    network.build_network()

    # Restore weights
    w = checkpoint.get("weights")
    if w is not None:
        network.syn_input_exc.w[:] = w[network.syn_input_exc.i[:], network.syn_input_exc.j[:]]
        if verbose:
            print(f"  ✓ Weights restored: {w.shape}")

    # Restore theta
    theta = checkpoint.get("theta")
    if theta is not None:
        network.exc_neurons.theta = np.asarray(theta) * mV
        if verbose:
            print(f"  ✓ Theta restored: mean={np.mean(theta):.2f} mV")

    # Restore training time
    network.training_time = float(checkpoint.get("training_time_seconds", 0))

    # Cache for analysis functions
    network.final_weights = {
        'weights': w,
        'theta': theta,
    }

    if verbose:
        train_spikes = checkpoint.get("training_spikes", {})
        inf_spikes = checkpoint.get("inference_spikes", {})
        print(f"  Training spikes (accumulated): "
              f"in={train_spikes.get('total_input', 'N/A')}, "
              f"exc={train_spikes.get('total_exc', 'N/A')}")
        print(f"  Inference spikes/inf: "
              f"in={inf_spikes.get('input_per_inference', 'N/A'):.1f}, "
              f"exc={inf_spikes.get('exc_per_inference', 'N/A'):.1f}")
        print(f"  Ready to resume training!")

    return network, checkpoint


print("Save/Load functions defined!")

In [None]:
# =============================================================================
# Cell 5: Evaluation Functions
# =============================================================================

def assign_neuron_labels(network, encoder, n_samples=10000):
    """Assign digit labels to neurons based on response."""
    print(f"\nAssigning neuron labels ({n_samples} samples)...")

    weights = network.final_weights['weights']
    neuron_responses = np.zeros((network.n_exc, 10))

    for i in range(n_samples):
        img, label = encoder.train_data[i]
        flat_img = img.numpy().flatten()
        response = flat_img @ weights

        # Winner-take-all: top-k neurons vote
        k = 10
        top_k = np.argsort(response)[-k:]
        for idx in top_k:
            neuron_responses[idx, label] += response[idx]

        if (i + 1) % 2000 == 0:
            print(f"  [{i+1}/{n_samples}]")

    neuron_labels = np.argmax(neuron_responses, axis=1)
    max_responses = np.max(neuron_responses, axis=1)
    sum_responses = np.sum(neuron_responses, axis=1)
    sum_responses[sum_responses == 0] = 1
    confidence = max_responses / sum_responses

    # Report distribution
    print(f"\n  Label distribution:")
    for digit in range(10):
        count = np.sum(neuron_labels == digit)
        print(f"    Digit {digit}: {count} neurons ({100*count/len(neuron_labels):.1f}%)")

    return neuron_labels, confidence


def test_accuracy(network, encoder, neuron_labels, n_samples=10000):
    """Test accuracy using neuron labels."""
    print(f"\nTesting accuracy ({n_samples} samples)...")

    weights = network.final_weights['weights']
    predictions = []
    true_labels = []

    for i in range(n_samples):
        img, label = encoder.test_data[i % len(encoder.test_data)]
        flat_img = img.numpy().flatten()
        response = flat_img @ weights

        # Voting
        votes = np.zeros(10)
        k = 10
        top_k = np.argsort(response)[-k:]
        for idx in top_k:
            votes[neuron_labels[idx]] += response[idx]

        predictions.append(np.argmax(votes))
        true_labels.append(label)

    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    accuracy = 100.0 * np.mean(predictions == true_labels)

    print(f"\n  Accuracy: {accuracy:.2f}%")
    return accuracy, predictions, true_labels


def plot_receptive_fields(network, n_show=100):
    """Plot learned receptive fields."""
    weights = network.final_weights['weights']
    n_neurons = min(n_show, weights.shape[1])
    grid_size = int(np.ceil(np.sqrt(n_neurons)))

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))

    for i in range(grid_size * grid_size):
        ax = axes[i // grid_size, i % grid_size]
        if i < n_neurons:
            rf = weights[:, i].reshape(28, 28)
            ax.imshow(rf, cmap='hot', interpolation='nearest')
        ax.axis('off')

    plt.suptitle('Learned Receptive Fields', fontsize=14)
    plt.tight_layout()
    plt.show()


def plot_confusion_matrix(true_labels, predictions, accuracy):
    """Plot confusion matrix."""
    cm = confusion_matrix(true_labels, predictions)

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=range(10), yticklabels=range(10))
    plt.xlabel('Predicted', fontsize=12)
    plt.ylabel('True', fontsize=12)
    plt.title(f'Confusion Matrix (Accuracy: {accuracy:.1f}%)', fontsize=14)
    plt.tight_layout()
    plt.show()

    print("\nPer-class accuracy:")
    for digit in range(10):
        mask = true_labels == digit
        class_acc = 100.0 * np.mean(predictions[mask] == digit)
        print(f"  Digit {digit}: {class_acc:.1f}%")


print("Evaluation functions defined!")

In [None]:
# =============================================================================
# Cell 6: MAIN TRAINING/RESUME CELL
# =============================================================================
# 
# CONFIGURE THESE PARAMETERS:
#

CHECKPOINT_PATH = "stdp_checkpoint.pkl"  # Checkpoint file path
SAMPLES_THIS_SESSION = 1000              # How many samples to train this session
N_INFERENCE_SAMPLES = 500                # Samples for inference spike measurement

#
# =============================================================================

# Start fresh scope
b2.start_scope()

# Check if resuming or starting fresh
if os.path.exists(CHECKPOINT_PATH):
    print("Found existing checkpoint - RESUMING")
    stdp_net, prev_ckpt = load_checkpoint(CHECKPOINT_PATH)
    previous_samples = prev_ckpt.get("training_samples", 0)
else:
    print("No checkpoint found - STARTING FRESH")
    stdp_net = BiologicalSTDPNetwork()
    stdp_net.build_network()
    prev_ckpt = None
    previous_samples = 0

# Train
stdp_net.train(
    encoder=encoder,
    n_samples=SAMPLES_THIS_SESSION,
    normalize_interval=20,
    print_interval=100
)

total_samples = previous_samples + SAMPLES_THIS_SESSION
print(f"\nTotal training samples: {total_samples}")

# Evaluate
print("\n" + "="*60)
print("EVALUATION")
print("="*60)

plot_receptive_fields(stdp_net, n_show=100)

neuron_labels, confidence = assign_neuron_labels(stdp_net, encoder, n_samples=10000)
accuracy, predictions, true_labels = test_accuracy(stdp_net, encoder, neuron_labels, n_samples=10000)

plot_confusion_matrix(true_labels, predictions, accuracy)

# Save complete checkpoint
save_complete_checkpoint(
    network=stdp_net,
    encoder=encoder,
    filepath=CHECKPOINT_PATH,
    total_training_samples=total_samples,
    previous_checkpoint=prev_ckpt,
    inference_accuracy=accuracy,
    neuron_labels=neuron_labels,
    n_inference_samples=N_INFERENCE_SAMPLES,
    notes=f"Session: {SAMPLES_THIS_SESSION} samples, Total: {total_samples}"
)

print("\n" + "="*60)
print("SESSION COMPLETE")
print("="*60)
print(f"To continue training, just run this cell again!")
print(f"Checkpoint: {CHECKPOINT_PATH}")

In [None]:
# =============================================================================
# Cell 7: View Checkpoint Contents (Optional)
# =============================================================================

def inspect_checkpoint(filepath):
    """Display all checkpoint contents for verification."""
    with open(filepath, "rb") as f:
        ckpt = pickle.load(f)

    print(f"\n{'='*70}")
    print(f"CHECKPOINT CONTENTS: {filepath}")
    print(f"{'='*70}")

    print(f"\n[METADATA]")
    print(f"  Schema:           {ckpt.get('schema_version', 'unknown')}")
    print(f"  Training samples: {ckpt.get('training_samples', 'N/A')}")
    print(f"  Training time:    {ckpt.get('training_time_seconds', 0)/60:.1f} min")
    print(f"  Accuracy:         {ckpt.get('inference_accuracy', 'N/A')}%")

    print(f"\n[ARCHITECTURE]")
    print(f"  Input:  {ckpt.get('n_input', 'N/A')}")
    print(f"  Exc:    {ckpt.get('n_exc', 'N/A')}")
    print(f"  Inh:    {ckpt.get('n_inh', 'N/A')}")

    print(f"\n[TRAINING SPIKES - for Training Energy Analysis]")
    ts = ckpt.get("training_spikes", {})
    print(f"  Total input:     {ts.get('total_input', 'N/A'):,}")
    print(f"  Total exc:       {ts.get('total_exc', 'N/A'):,}")
    print(f"  Total inh:       {ts.get('total_inh', 'N/A'):,}")
    print(f"  Input/sample:    {ts.get('input_per_sample', 'N/A'):.1f}")
    print(f"  Exc/sample:      {ts.get('exc_per_sample', 'N/A'):.1f}")

    print(f"\n[INFERENCE SPIKES - for Inference Energy Analysis]")
    inf = ckpt.get("inference_spikes", {})
    print(f"  Measured on:     {inf.get('n_samples', 'N/A')} test samples")
    print(f"  Input/inference: {inf.get('input_per_inference', 'N/A'):.1f}")
    print(f"  Exc/inference:   {inf.get('exc_per_inference', 'N/A'):.1f}")
    print(f"  Inh/inference:   {inf.get('inh_per_inference', 'N/A'):.1f}")

    print(f"\n[CONNECTIVITY]")
    conn = ckpt.get("connectivity", {})
    print(f"  Input→Exc synapses: {conn.get('n_syn_input_exc', 'N/A'):,}")
    print(f"  Fanout input→exc:   {conn.get('fanout_input_to_exc', 'N/A')}")

    print(f"\n[LEARNED STATE]")
    w = ckpt.get("weights")
    theta = ckpt.get("theta")
    if w is not None:
        print(f"  Weights shape:   {w.shape}")
        print(f"  Weights range:   [{w.min():.4f}, {w.max():.4f}]")
    if theta is not None:
        print(f"  Theta mean:      {np.mean(theta):.2f} mV")
        print(f"  Theta range:     [{theta.min():.2f}, {theta.max():.2f}] mV")

    labels = ckpt.get("neuron_labels")
    if labels is not None:
        print(f"\n[NEURON LABELS]")
        for digit in range(10):
            count = np.sum(labels == digit)
            print(f"  Digit {digit}: {count} neurons")

    return ckpt


# Inspect the current checkpoint
if os.path.exists(CHECKPOINT_PATH):
    _ = inspect_checkpoint(CHECKPOINT_PATH)
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")

---
## Energy Analysis Compatibility

The checkpoint saved by this notebook contains everything needed for `energy_comparison_blocks.ipynb`:

### For Training Energy (Block 3)
- `training_spikes.total_input` - Total input spikes during training
- `training_spikes.total_exc` - Total excitatory spikes during training
- `training_spikes.total_inh` - Total inhibitory spikes during training
- `training_samples` - Number of training samples
- `training_time_seconds` - Wall-clock training time

### For Inference Energy (Block 4)
- `inference_spikes.input_per_inference` - Avg input spikes per test sample
- `inference_spikes.exc_per_inference` - Avg exc spikes per test sample
- `inference_spikes.inh_per_inference` - Avg inh spikes per test sample
- `inference_spikes.n_samples` - Number of test samples measured

### For Network Topology
- `connectivity.n_syn_input_exc` - Number of Input→Exc synapses
- `connectivity.fanout_input_to_exc` - Fanout per input neuron
- All other connectivity metadata

### Backward Compatibility
- `total_input_spikes`, `total_exc_spikes`, `total_inh_spikes` - Legacy fields