# Spiking Neural Networks (SNNs): from LIF neuron to a tiny MNIST demo

**What you’ll learn:**  
- Why SNNs process information as **spikes over time** and where that’s useful.  
- The **Leaky Integrate-and-Fire (LIF)** neuron: intuition, equation, and a small simulation.  
- How to **encode** numbers/images as **spike trains** (Poisson rate coding).  
- How to **train** a small SNN with **surrogate gradients** on MNIST.  
- How to **visualise** spikes, voltage traces, and class spike counts.


## 1. Why SNNs at all (very short)

- **Temporal computation:** SNNs process *events through time* (closer to how sensors/brains work).
- **Sparsity:** neurons spike only when needed → fewer operations than dense activations.
- **Hardware fit:** event-driven neuromorphic chips can execute spikes very efficiently.

## 2. From perceptron to a spiking neuron

A classical **perceptron** computes:

$
y = f(Wx + b),
$

where $f$ is a continuous activation (e.g. sigmoid, ReLU).

In contrast, a spiking neuron integrates input over time — it has memory of its past state.

When the membrane potential Vm(t) exceeds a threshold Vth, the neuron emits a spike and resets.

### The Leaky Integrate-and-Fire (LIF) neuron

We’ll use the **LIF** model later in training. It behaves like an **RC circuit** and can be described using the following differential equation:

$$
C\,\frac{dV_m(t)}{dt} \;=\; -\frac{V_m(t) - E_L}{R} \;+\; I(t),
\qquad \tau_m = RC
$$

Equivalently,

$$
\tau_m \,\frac{dV_m}{dt} \;=\; -\big(V_m - E_L\big) \;+\; R\,I(t).
$$

**Symbols:**  
- $V_m(t)$: membrane voltage (neuron’s internal state)  
- $E_L$: leak/rest voltage (decay target)  
- $I(t)$: input current (from synaptic spikes)  
- $R$: membrane resistance, \(C\): membrane capacitance  
- $\tau_m = RC$: membrane time constant (leak speed)

**Meaning of terms (at a glance):**  
- $R\,I(t)$: **integration** — input pushes \(V_m\) upward  
- $-\big(V_m - E_L\big)$: **leak** — pulls \(V_m\) back toward rest  
- **Spike rule:** if $V_m \ge V_{\text{th}}$, emit a spike; reset $V_m \to V_{\text{reset}}$.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Params ---
T   = 2000           # total steps
dt  = 1e-3           # 1 ms bin
tau = 0.02           # 20 ms membrane time constant
v_th = 1.0           # threshold
v_reset = 0.0
E_L = 0.0            # rest (leak target)

# --- Input current I(t): several plateaus; some > 1.0 to guarantee spikes ---
I = np.zeros(T)
I[200:350]   += 0.60    # subthreshold: integrates then leaks
I[500:650]   += 1.10    # suprathreshold: periodic spiking
I[800:950]   += 0.90    # near-threshold: maybe a spike after residual + wiggle
I[1200:1600] += 1.20    # longer suprathreshold: clear spike train

# gentle sinusoidal wiggle (always nonnegative) to show smooth integration/leak
t_arr = np.arange(T)
I += 0.05 * (0.5 + 0.5*np.sin(2*np.pi*t_arr/120))   # in [0, 0.05]

# --- Simulate LIF with Euler step ---
V   = np.zeros(T)
spk = np.zeros(T, dtype=bool)

for t in range(T-1):
    dV = (-(V[t] - E_L) + I[t]) * (dt / tau)    # R=1
    V[t+1] = V[t] + dV
    if V[t+1] >= v_th:
        spk[t+1] = True
        V[t+1] = v_reset  # reset after spike

time_ms = t_arr * dt * 1000

# --- Plot ---
fig, ax = plt.subplots(2, 1, figsize=(10,6), sharex=True)

ax[0].plot(time_ms, I, label='Input current  $I(t)$')
ax[0].set_ylabel('I  (a.u.)')
ax[0].legend(loc="upper right")
ax[0].grid(alpha=0.3)

ax[1].plot(time_ms, V, label='Membrane  $V_m(t)$')
ax[1].axhline(v_th, ls='--', c='gray', alpha=0.8, label='threshold')
ax[1].vlines(time_ms[spk], ymin=v_reset, ymax=v_th*1.05, color='crimson', alpha=0.7, label='spikes')
ax[1].set_xlabel('time (ms)')
ax[1].set_ylabel('V')
ax[1].legend(loc="upper right")
ax[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

You should see spikes whenever the input current pushes the membrane potential above threshold.
This simple mechanism is the core of how information is transmitted in an SNN.

## 3. Encoding Real-Valued Data into Spikes

Unlike artificial neural networks (ANNs), which process continuous-valued activations, spiking neural networks (SNNs) require *discrete events* — spikes — as input. Therefore, continuous data (such as pixel intensities or audio amplitudes) must first be transformed into spike trains. This process is known as **neural encoding**.

Several biologically inspired encoding schemes have been proposed, each relying on a different principle of how information is represented in spike activity:

- **Rate coding:** the information is represented by the *average firing rate* — a higher input value produces a higher spike frequency.
- **Temporal coding:** the information is encoded in the *precise timing* of spikes — larger input values result in earlier spikes within a defined time window.

In this tutorial, we will employ **rate coding** using a **Poisson spike generator**.  
A Poisson process models spike events as random occurrences that follow a specific average rate.  
This stochasticity introduces variability similar to biological neurons while maintaining an expected firing rate proportional to the input magnitude.

Formally, let the firing rate be denoted as $ \lambda $ (in spikes per second) and the time step be \( dt \).  
The probability of emitting a spike in a small interval $ dt $ is given by

$$
p = \min(\lambda \, dt, \, 1).
$$

In practice, $ \lambda $ is derived from the input value (for instance, a pixel brightness scaled to a maximum firing rate).  
At each time step, we sample a random number from a uniform distribution in \([0,1)\); if it is less than \( p \), a spike is generated.  
This yields a **Poisson-distributed spike train** whose mean firing rate encodes the input intensity but with biologically realistic variability.

In [None]:
import torch

def poisson_encode(images, num_steps, max_rate_hz=100.0, dt=1e-3, rng=None):
    """
    Encode continuous-valued images into Poisson spike trains.

    Args:
        images (torch.Tensor): [B, *] tensor with values in [0,1]
        num_steps (int): number of simulation time steps
        max_rate_hz (float): maximum firing rate corresponding to pixel=1.0
        dt (float): duration of a single time step (in seconds)
        rng (torch.Generator, optional): random number generator for reproducibility

    Returns:
        spikes (torch.BoolTensor): [T, B, N] tensor of spike trains
    """
    if images.dim() > 2:
        B = images.size(0)
        N = images[0].numel()
        flat = images.view(B, -1)
    else:
        B, N = images.shape
        flat = images

    if rng is None:
        rng = torch.Generator(device=images.device)

    lam = flat * max_rate_hz               # [B, N]
    p = torch.clamp(lam * dt, 0.0, 1.0)    # [B, N]

    # sample T independent Bernoulli trials per pixel
    rand = torch.rand((num_steps, B, N), generator=rng, device=images.device)
    spikes = rand < p                       # [T, B, N], bool
    return spikes



## 4. MNIST as a Temporally Encoded Dataset

The MNIST dataset consists of **static** grayscale images $ \mathbf{x} \in [0,1]^{28\times 28} $.  
Spiking neural networks (SNNs), however, operate on **temporal event streams**. To bridge this mismatch, we convert each pixel $x_i$ into a **Poisson spike train** whose average firing rate is proportional to its intensity.

Let $ f_{\max} $ denote the maximum firing rate (in Hz), and let $ dt $ be the simulation time step (in seconds).  
For pixel $ i $, we define the rate:

$$
\lambda_i = x_i \, f_{\max}, \qquad \lambda_i \ge 0,
$$

and generate spikes by drawing, at each discrete time step, an independent Bernoulli trial with success probability:

$$
p_i = \min(\lambda_i \, dt,\, 1).
$$

Across a window of $T$ steps, the expected spike count for pixel $ i $ is:

$$
\mathbb{E}[N_i] = \lambda_i \, T \, dt = x_i \, f_{\max} \, T \, dt,
$$

so **brighter pixels produce more spikes on average**, preserving the spatial structure statistically over time.

**Why re-draw Poisson realisations each epoch?**  
Because the Poisson generator is stochastic, re-sampling for the same image at each epoch acts as **regularisation / data augmentation**: the label remains fixed while the precise spike times vary. The network is thus encouraged to learn the **rate-coded structure** (the invariant intensity pattern) rather than overfitting to a single realisation of spike times.

> **Assumptions.** This encoding assumes conditionally independent spikes across pixels and time given $ \lambda $ (i.e., a discrete-time Poisson process). While simplified relative to the bilogical setting, it is standard and effective for SNN training on static images.

In [None]:
# --- MNIST loading (values in [0,1]) ---
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.ToTensor()  # keeps pixels in [0,1]
train_ds = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)

subset_size = 10000  # e.g. 10k out of 60k
indices = torch.randperm(len(train_ds))[:subset_size]
train_ds = Subset(train_ds, indices)

test_ds  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False)

len(train_loader), len(test_loader)

In [None]:
# --- Example: encode one minibatch for T time steps (re-draw every call/epoch) ---
num_steps = 100
dt = 1e-3
max_rate_hz = 100.0

batch_imgs, batch_labels = next(iter(train_loader))
batch_imgs = batch_imgs.to(device)   # [B,1,28,28]
spikes = poisson_encode(batch_imgs, num_steps=num_steps, max_rate_hz=max_rate_hz, dt=dt)  # [T,B,784]
spikes.shape, batch_labels.shape

In [None]:
# --- Visualisation: original image, spike raster (time × pixels), and spike-rate map ---
import matplotlib.pyplot as plt

# pick the first sample in the batch
img0 = batch_imgs[0]                  # [1,28,28]
lab0 = batch_labels[0].item()
spk0 = spikes[:, 0]                   # [T, 784]
spk0_img = spk0.view(num_steps, 28, 28)

# spike count per pixel across time (estimate of rate × window)
spk_count_map = spk0_img.sum(dim=0).cpu()

fig = plt.figure(figsize=(11,3.6))

# Original image
ax1 = plt.subplot(1,3,1)
ax1.imshow(img0.squeeze().cpu(), cmap="gray")
ax1.set_title(f"MNIST image (label={lab0})")
ax1.axis("off")

# Spike raster as time × pixel-index image (flattened spatially)
ax2 = plt.subplot(1,3,2)
ax2.imshow(spk0.cpu().T, aspect="auto", interpolation="nearest", cmap="gray_r")
ax2.set_title("Spike raster (time × pixels)")
ax2.set_xlabel("Time step")
ax2.set_ylabel("Pixel index")

# Spike count map (approx. rate over window)
ax3 = plt.subplot(1,3,3)
im = ax3.imshow(spk_count_map, cmap="viridis")
ax3.set_title("Spike counts per pixel\n(∝ intensity × window)")
ax3.axis("off")
plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

## 5. Training a Spiking Neural Network on MNIST

Now that we can represent MNIST digits as spike trains, we can train a simple spiking neural network (SNN) to classify them.  
Our architecture will consist of two fully connected layers with leaky integrate-and-fire (LIF) neurons:

$$
\text{Input spikes (784)} \;\longrightarrow\; \text{Linear} \;\longrightarrow\; \text{LIF} \;\longrightarrow\; \text{Linear} \;\longrightarrow\; \text{LIF (10 outputs)}
$$

Each of the ten output neurons corresponds to one digit class (0–9).  
During the simulation window of $T$ time steps, the network produces spikes at its output layer.  
We estimate class probabilities by **counting spikes per output neuron** over time and predicting the class with the **highest total spike count**:

$$
\hat{y} = \arg\max_{k \in \{0,\ldots,9\}} \sum_{t=1}^{T} s_k(t),
$$

where $s_k(t) \in \{0,1\}$ is the binary spike output of neuron $k$ at time step $t$.

---

### The challenge of non-differentiability

The spike generation function is a **Heaviside step function**:

$$
s(t) = H(V_m(t) - V_{\text{th}}) =
\begin{cases}
1, & V_m(t) \ge V_{\text{th}}, \\
0, & V_m(t) < V_{\text{th}},
\end{cases}
$$

which is **non-differentiable**. Its derivative is the **Dirac delta function**:

$$
\frac{ds}{dV_m} = \delta(V_m - V_{\text{th}}),
$$

and therefore zero almost everywhere. This makes direct gradient-based training impossible — gradients would vanish for nearly all membrane potentials, resulting in so-called *dead neurons*.

---

### Surrogate gradients

To enable backpropagation through spikes, modern SNNs use **surrogate gradient methods**.  
The idea is to replace the true derivative of the step function with a **smooth, differentiable approximation**, such as:

$$
\frac{ds}{dV_m} \approx \frac{1}{(1 + \alpha |V_m - V_{\text{th}}|)^2},
$$

where $\alpha$ controls the steepness of the approximation.  
This modification allows gradient-based learning using standard optimizers (e.g., Adam), while maintaining biologically inspired spiking dynamics during the forward pass.

---

### Practical setup

We will use the `snnTorch` library, which provides efficient implementations of the LIF neuron model and surrogate-gradient learning.  
The training process will proceed as follows:

1. **Encode** each image into Poisson spike trains.  
2. **Propagate** spikes through the network for $T$ time steps.  
3. **Compute** the loss between the predicted and true labels.  
4. **Update** network parameters via backpropagation using surrogate gradients.

This framework keeps the temporal dynamics of spiking neurons while making learning possible with standard deep-learning toolchains.

### 5.1 Model definition (Linear → LIF → Linear → LIF)

We implement a compact feed-forward SNN:

$$
\text{Input }(784) \;\to\; \text{Linear} \;\to\; \text{LIF} \;\to\; \text{Linear} \;\to\; \text{LIF (10 outputs)}.
$$

The forward pass expects spike trains of shape $[T, B, N]$ (time, batch, features).  
At each time step, spikes are propagated through the two LIF layers; we record the output spikes $\;s_k(t)\;$ to compute a **count-based loss** and predictions:

$$
\hat{y} \;=\; \arg\max_{k} \sum_{t=1}^{T} s_k(t).
$$

In [None]:
!pip install snntorch==0.9.4 torchvision

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

import snntorch as snn
from snntorch import surrogate
import snntorch.functional as SF

import matplotlib.pyplot as plt

# Device: CPU by default, GPU if available (code works in both)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

torch.manual_seed(0)
if device.type == "cuda":
    torch.cuda.manual_seed_all(0)

#### SNN module

We implement two `nn.Linear` layers interleaved with two `snn.Leaky` (LIF) neurons.  
We use the **surrogate gradient** provided by snnTorch internally, with decay parameter $\beta \in (0,1)$ controlling the membrane leak.

In [None]:
class TinySNN(nn.Module):
    def __init__(self, in_dim=28*28, hidden=128, out_dim=10, beta_hidden=0.9, beta_out=0.9):
        super().__init__()
        self.in_dim = in_dim
        self.hidden = hidden
        self.out_dim = out_dim

        self.fc1  = nn.Linear(in_dim, hidden, bias=True)
        self.lif1 = snn.Leaky(beta=beta_hidden, spike_grad=surrogate.fast_sigmoid())
        self.fc2  = nn.Linear(hidden, out_dim, bias=True)
        self.lif2 = snn.Leaky(beta=beta_out, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x_seq, return_mem=False):
        """
        x_seq: [T, B, N] spike trains (bool or float)
        returns:
            spk_out: [T, B, out_dim]
            if return_mem=True: also (mem1, mem2) traces
        """
        T, B, N = x_seq.shape
        assert N == self.in_dim, f"Expected input dim {self.in_dim}, got {N}"

        # Initialise LIF hidden states at t=0
        mem1 = self.lif1.init_leaky()   # no args in snnTorch 0.9.4
        mem2 = self.lif2.init_leaky()

        spk_rec = []
        mem1_rec = [] if return_mem else None
        mem2_rec = [] if return_mem else None

        for t in range(T):
            x_t = x_seq[t].float()               # [B, N]
            h1 = self.fc1(x_t)
            spk1, mem1 = self.lif1(h1, mem1)
            h2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(h2, mem2)

            spk_rec.append(spk2)

            if return_mem:
                mem1_rec.append(mem1)
                mem2_rec.append(mem2)

        spk_rec = torch.stack(spk_rec, dim=0)    # [T, B, out_dim]

        if return_mem:
            mem1_rec = torch.stack(mem1_rec, dim=0)  # [T, B, hidden]
            mem2_rec = torch.stack(mem2_rec, dim=0)  # [T, B, out_dim]
            return spk_rec, (mem1_rec, mem2_rec)

        return spk_rec

### 5.2 Data pipeline and training loop

- **Input:** MNIST, normalized to $[0,1]$.
- **Encoding:** Poisson rate encoding for $T$ time steps (re-sampled each iteration).
- **Loss:** Cross-entropy on **spike counts** (sum over time), i.e., logits $=\sum_t s_k(t)$.
- **Prediction:** $\arg\max_k$ of spike counts over the window.

In [None]:
model = TinySNN().to(device)

# Training hyperparameters
NUM_STEPS   = 50      # time steps T
DT          = 1e-3    # 1 ms
MAX_RATE_HZ = 100.0   # max firing rate for pixel=1

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Cross-entropy on spike *counts* over time
# (snnTorch helper: ce_count_loss)
loss_fn = SF.ce_count_loss()

In [None]:
from tqdm.auto import tqdm

def train_one_epoch(epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_seen = 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        imgs   = imgs.to(device)
        labels = labels.to(device)

        # Poisson encode: [T, B, 784]
        spikes = poisson_encode(
            imgs, num_steps=NUM_STEPS,
            max_rate_hz=MAX_RATE_HZ, dt=DT
        ).to(device)

        # Forward: [T, B, 10]
        spk_rec = model(spikes)

        # Loss: snnTorch cross-entropy on spike counts
        loss = loss_fn(spk_rec, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # --- Metrics (correct decoding) ---
        with torch.no_grad():
            counts = spk_rec.sum(dim=0)          # [B,10]
            preds  = counts.argmax(dim=1)        # [B]
            total_correct += (preds == labels).sum().item()
            total_seen    += labels.size(0)
            total_loss    += loss.item() * labels.size(0)

        # free graph
        del spikes, spk_rec, loss

    avg_loss = total_loss / total_seen
    acc = total_correct / total_seen
    print(f"Epoch {epoch}: train loss={avg_loss:.4f} acc={acc:.4f}")

    return avg_loss, acc

@torch.no_grad()
def evaluate():
    model.eval()
    total_correct = 0
    total_seen = 0

    for imgs, labels in test_loader:
        imgs   = imgs.to(device)
        labels = labels.to(device)

        spikes = poisson_encode(
            imgs, num_steps=NUM_STEPS,
            max_rate_hz=MAX_RATE_HZ, dt=DT
        ).to(device)

        spk_rec = model(spikes)

        # correct decoding
        counts = spk_rec.sum(dim=0)      # [B,10]
        preds  = counts.argmax(dim=1)

        total_correct += (preds == labels).sum().item()
        total_seen    += labels.size(0)

        del spikes, spk_rec

    acc = total_correct / total_seen
    print(f"Test acc={acc:.4f}")
    return acc


In [None]:
train_hist = {"loss": [], "acc": [], "val_acc": []}

for ep in range(1, 16):  # bump to e.g. 5–10 on GPU
    tr_loss, tr_acc = train_one_epoch(ep)
    va_acc = evaluate()
    train_hist["loss"].append(tr_loss)
    train_hist["acc"].append(tr_acc)
    train_hist["val_acc"].append(va_acc)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def viz_one_example():
    model.eval()

    imgs, labels = next(iter(test_loader))
    imgs = imgs.to(device)
    labels = labels.to(device)

    # reproducible spikes for this example
    g = torch.Generator(device=imgs.device).manual_seed(0)
    spikes = poisson_encode(
        imgs[:1],
        num_steps=NUM_STEPS,
        max_rate_hz=MAX_RATE_HZ,
        dt=DT,
        rng=g,
    ).to(device)

    # run model, record membranes
    spk_rec, (mem1, mem2) = model(spikes, return_mem=True)  # [T,1,10], [T,1,H], [T,1,10]

    spk = spk_rec[:, 0].cpu().numpy()   # [T,10]
    counts = spk.sum(axis=0)            # [10]
    pred = int(counts.argmax())
    true = int(labels[0].item())

    fig, axes = plt.subplots(2, 1, figsize=(8, 6))
    # make sure axes is always indexable as an array
    axes = np.atleast_1d(axes)

    # --- Raster: time vs class id ---
    t_idx, c_idx = np.where(spk > 0.5)
    axes[0].scatter(t_idx, c_idx, s=8)
    axes[0].invert_yaxis()
    axes[0].set_title(f"Output spikes over time — true={true}, pred={pred}")
    axes[0].set_xlabel("time step")
    axes[0].set_ylabel("class id")

    # --- Spike counts bar plot ---
    axes[1].bar(np.arange(10), counts)
    axes[1].set_xlabel("class")
    axes[1].set_ylabel("spike count")
    axes[1].set_title("Spike counts over window")

    plt.tight_layout()
    plt.show()

viz_one_example()