# 01c: Hebbian Spiking NN (MNIST)

This notebook explores a simple Hebbian "fire together, wire together" learning approach on MNIST with a two-layer spiking network and competition to avoid all-to-all activation.

- Encoding: Poisson spike trains from pixel intensities.
- Neurons: LIF-like integration with soft reset in hidden and output layers.
- Competition: k-WTA (top-k) at hidden and winner-take-all at output.
- Learning: Oja's rule for input→hidden; supervised Hebbian + anti-Hebbian for hidden→output with teacher-forced spikes.
- Evaluation: Argmax of output spike counts.

This is a didactic example focused on clarity, not performance.

In [None]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Local package
import mlp.data_providers as data_providers

# Ensure MLP_DATA_DIR is set by walking up to find the repo's data/ folder
def ensure_mlp_data_dir():
    if os.environ.get('MLP_DATA_DIR'):
        return Path(os.environ['MLP_DATA_DIR'])
    here = Path.cwd().resolve()
    for p in [here] + list(here.parents):
        candidate = p / 'data'
        if (candidate / 'mnist-train.npz').exists():
            os.environ['MLP_DATA_DIR'] = str(candidate)
            return candidate
    raise RuntimeError('Could not locate data directory with mnist-*.npz')

data_dir = ensure_mlp_data_dir()
print('Using data dir:', data_dir)


In [None]:
# Create training/validation providers
train_dp = data_providers.MNISTDataProvider('train', batch_size=64, max_num_batches=100, shuffle_order=True)
valid_dp = data_providers.MNISTDataProvider('valid', batch_size=64, max_num_batches=50, shuffle_order=False)
print('Batches per epoch (train, valid):', train_dp.len(), valid_dp.len())


In [None]:
# Utilities: scaling and Poisson encoding
def scale01(x: np.ndarray) -> np.ndarray:
    # MNIST may already be in [0,1], but guard just in case
    xmax = np.max(x)
    if xmax > 1.5:
        return (x / 255.0).astype(np.float32)
    return x.astype(np.float32)

def poisson_encode(inputs: np.ndarray, T: int, rate_hz: float = 30.0, dt: float = 0.005, rng: np.random.RandomState | None = None) -> np.ndarray:
    """Encode inputs in [0,1] into Poisson spikes over T steps.
    inputs: (B, D) in [0,1]
    returns: spikes (B, T, D) with 0/1 values
    """
    if rng is None:
        rng = np.random.RandomState(0)
    B, D = inputs.shape
    lam = np.clip(inputs, 0.0, 1.0) * rate_hz * dt  # probability per bin
    spikes = rng.rand(B, T, D) < lam[:, None, :]
    return spikes.astype(np.float32)


In [None]:
# Two-layer Hebbian SNN with competition and decay
def k_wta(v: np.ndarray, k: int) -> np.ndarray:
    if k is None or k <= 0:
        return np.zeros_like(v, dtype=np.float32)
    k = min(k, v.shape[0])
    idx = np.argpartition(v, -k)[-k:]
    spikes = np.zeros_like(v, dtype=np.float32)
    spikes[idx] = (v[idx] > 0).astype(np.float32)
    return spikes

class TwoLayerHebbianSNN:
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int,
                 dt: float = 0.005, tau_h: float = 0.02, tau_o: float = 0.02,
                 vth_h: float = 1.0, vth_o: float = 1.0, seed: int = 0,
                 k_hidden: int = 32, k_output: int = 1):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.dt = dt
        self.alpha_h = float(np.exp(-dt / tau_h))
        self.alpha_o = float(np.exp(-dt / tau_o))
        self.vth_h = vth_h
        self.vth_o = vth_o
        self.k_hidden = k_hidden
        self.k_output = k_output
        rng = np.random.RandomState(seed)
        self.W1 = (0.01 * rng.randn(input_dim, hidden_dim)).astype(np.float32)
        self.W2 = (0.01 * rng.randn(hidden_dim, num_classes)).astype(np.float32)

    def forward_counts(self, spikes: np.ndarray) -> np.ndarray:
        B, T, D = spikes.shape
        assert D == self.input_dim
        counts = np.zeros((B, self.num_classes), dtype=np.float32)
        for b in range(B):
            v_h = np.zeros(self.hidden_dim, dtype=np.float32)
            v_o = np.zeros(self.num_classes, dtype=np.float32)
            for t in range(T):
                x_t = spikes[b, t]
                v_h = self.alpha_h * v_h + x_t @ self.W1
                s_h = k_wta(v_h, self.k_hidden)
                v_h = v_h - self.vth_h * s_h
                v_o = self.alpha_o * v_o + s_h @ self.W2
                s_o = k_wta(v_o, self.k_output)
                v_o = v_o - self.vth_o * s_o
                counts[b] += s_o
        return counts

    def predict(self, spikes: np.ndarray) -> np.ndarray:
        counts = self.forward_counts(spikes)
        return counts.argmax(axis=1)

    def train_step(self, spikes: np.ndarray, y_int: np.ndarray,
                   eta1: float = 0.01, eta2: float = 0.02, eta2_neg: float = 0.01,
                   teacher_rate_hz: float = 60.0, decay1: float = 5e-5, decay2: float = 5e-5,
                   rng: np.random.RandomState | None = None):
        if rng is None:
            rng = np.random.RandomState(0)
        B, T, D = spikes.shape
        C = self.num_classes
        p_teacher = min(1.0, teacher_rate_hz * self.dt)
        for b in range(B):
            v_h = np.zeros(self.hidden_dim, dtype=np.float32)
            v_o = np.zeros(C, dtype=np.float32)
            yb = int(y_int[b])
            for t in range(T):
                x_t = spikes[b, t]  # (D,)
                # Hidden dynamics
                v_h = self.alpha_h * v_h + x_t @ self.W1
                s_h = k_wta(v_h, self.k_hidden)
                v_h = v_h - self.vth_h * s_h
                # Output dynamics
                v_o = self.alpha_o * v_o + s_h @ self.W2
                s_o_nat = k_wta(v_o, self.k_output)
                v_o = v_o - self.vth_o * s_o_nat
                # Teacher spike (supervision)
                teach = 1.0 if rng.rand() < p_teacher else 0.0
                y_vec = np.zeros(C, dtype=np.float32); y_vec[yb] = teach
                # Potentiation uses union of natural + teacher
                post_pos = np.maximum(s_o_nat, y_vec)
                # Depression for non-target natural spikes
                post_neg = s_o_nat.copy(); post_neg[yb] = 0.0
                # Oja's rule for W1 (stability)
                if s_h.any():
                    J = np.where(s_h > 0)[0]
                    x_col = x_t.astype(np.float32)[:, None]
                    sj = s_h[J][None, :]
                    self.W1[:, J] += eta1 * (x_col @ sj - (sj * sj) * self.W1[:, J])
                # Supervised Hebbian for W2 with anti-Hebbian contrast
                self.W2 += eta2 * np.outer(s_h, post_pos).astype(np.float32)
                if eta2_neg > 0:
                    self.W2 -= eta2_neg * np.outer(s_h, post_neg).astype(np.float32)
                # Weight decay
                if decay1 > 0: self.W1 *= (1.0 - decay1)
                if decay2 > 0: self.W2 *= (1.0 - decay2)
            # Column-wise normalization
            def norm_cols(W):
                n = np.linalg.norm(W, axis=0, keepdims=True) + 1e-6
                W /= n
            norm_cols(self.W1); norm_cols(self.W2)
        # Clip
        np.clip(self.W1, -1.0, 1.0, out=self.W1)
        np.clip(self.W2, -1.0, 1.0, out=self.W2)


In [None]:
# Train the two-layer Hebbian SNN on a subset for speed
input_dim = 28*28
hidden_dim = 256
num_classes = 10
snn = TwoLayerHebbianSNN(input_dim, hidden_dim, num_classes, dt=0.005, tau_h=0.02, tau_o=0.02, vth_h=1.0, vth_o=1.0, seed=3, k_hidden=32, k_output=1)

epochs = 20
T = 20
rng = np.random.RandomState(123)

for ep in range(1, epochs+1):
    for xb, yb in train_dp:
        xb = scale01(xb)
        y_int = yb.argmax(axis=1).astype(int)
        spikes = poisson_encode(xb, T=T, rate_hz=30.0, dt=snn.dt, rng=rng)
        snn.train_step(spikes, y_int, eta1=0.01, eta2=0.02, eta2_neg=0.01, teacher_rate_hz=80.0, decay1=5e-5, decay2=5e-5, rng=rng)
    # Validate
    correct_v, total_v = 0, 0
    for xb, yb in valid_dp:
        xb = scale01(xb)
        y_int = yb.argmax(axis=1).astype(int)
        spikes = poisson_encode(xb, T=T, rate_hz=30.0, dt=snn.dt, rng=rng)
        preds = snn.predict(spikes)
        correct_v += (preds == y_int).sum()
        total_v += xb.shape[0]
    acc = correct_v / max(1, total_v)
    print(f'Epoch {ep:02d} | valid acc: {acc:.3f}')


In [None]:
# Visualize class templates by projecting W2 back to input space: W1 @ W2[:, c]
plt.figure(figsize=(10, 3))
classes_to_show = [0,1,2,3,4]
for i, c in enumerate(classes_to_show):
    plt.subplot(1, len(classes_to_show), i+1)
    proto = (snn.W1 @ snn.W2[:, c]).reshape(28,28)
    plt.imshow(proto, cmap='bwr')
    plt.axis('off')
    plt.title(f'class {c}')
plt.suptitle('Class templates (projected to input)')
plt.tight_layout()
plt.show()


In [None]:
# Inspect spikes for a single example
xb, yb = next(iter(valid_dp))
xb = scale01(xb)
label = int(yb[0].argmax())
sp = poisson_encode(xb[:1], T=30, rate_hz=30.0, dt=snn.dt, rng=np.random.RandomState(0))  # (1, T, D)
pred = int(snn.predict(sp)[0])
print('True label:', label, 'Predicted:', pred)

plt.figure(figsize=(6, 3))
subset = 200
sp_idx, t_idx = np.where(sp[0, :, :subset].T > 0)
plt.scatter(t_idx, sp_idx, s=3, c='black')
plt.xlabel('Time step')
plt.ylabel('Input neuron (subset)')
plt.title(f'Input spikes (subset) | true={label} pred={pred}')
plt.tight_layout()
plt.show()
