This Notebook provides a minimal example for using LFP to train a simple MLP-Spiking Neural Network (SNN) on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

In [1]:
try:
    import snntorch as snn
    from snntorch import utils as snnutils
except ImportError:
    print(
        "The SNN functionality of this package requires extra dependencies ",
        "which can be installed via pip install lfprop[snn] (or lfprop[full] for all dependencies).",
    )
    raise ImportError("snntorch required; reinstall lfprop with option `snn` (pip install lfprop[snn])")

### Imports

In [2]:
import os

import numpy as np
import torch
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

import snntorch.functional as SF

from experiment_utils.data.datasets import get_dataset
from lfprop.model import activations
from experiment_utils.data.dataloaders import get_dataloader
from experiment_utils.data.transforms import get_transforms

from lfprop.rewards import reward_functions as rewards  # Reward Functions

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [None]:
savepath = "./minimal-example-data"
os.makedirs(savepath, exist_ok=True)

dataset_name = "cifar10"
data_path = f"/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f51/Datasets/{dataset_name}"
training_mode = "lfp"  # options: "lfp", "surr", "both"

batch_size = 32  # 128
n_channels = 1 if dataset_name == "mnist" else 3 #784
n_outputs = 10
n_steps = 15
lr = 0.02
momentum = 0.9
epochs = 20
model_name = "lifresnetlike" #"lifresnetlike" #deepersnn

#lif_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": True}
lif_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": False, "spike_grad": "step"}
#lif_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": False, "spike_grad": "atan"}
#surrogate_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": False, "spike_grad": "step"}
surrogate_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": False, "spike_grad": "atan"}



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [4]:
training_data, _, _, _ = get_dataset(
    dataset_name=dataset_name,
    root_path=data_path,
    transform=get_transforms(dataset_name, "train"),
    mode="train"
)
validation_data, _, _, _ = get_dataset(
    dataset_name=dataset_name,
    root_path=data_path,
    transform=get_transforms(dataset_name, "test"),
    mode="test"
)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)

OSError: None is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

### Load Model

In [None]:
from experiment_utils.model.models import get_model

# Create the first model and initialize weights
model = get_model(model_name=model_name, n_channels=n_channels, n_outputs=n_outputs, device=device, **lif_kwargs)
model.reset()
model.to(device)
model.eval()

# Copy the weights to use for the surrogate model
model_surr = get_model(model_name=model_name, n_channels=n_channels, n_outputs=n_outputs, device=device, **surrogate_kwargs)
model_surr.load_state_dict(model.state_dict())
model_surr.reset()
model_surr.to(device)
model_surr.eval()

def name_modules(module, name):
    """
    Recursive function to name modules for debugging 
    """
    for cname, child in module.named_children():
        child.tmpname = cname if name == "root" else f"{name}.{cname}"
        name_modules(child, child.tmpname)

name_modules(model, "root")
name_modules(model_surr, "root")

### Set Up LFP

In [None]:
# Initialize the SNN-Propagator
from lfprop.propagation.propagator_snn import LFPSNNEpsilonComposite

propagation_composite = LFPSNNEpsilonComposite(epsilon=1e-6)

# Initialize the Reward Function.
reward_func = rewards.SnnCorrectClassRewardSpikesRateCoded(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

### Set Up Gradient Descent

In [None]:
optimizer_surr = torch.optim.SGD(model_surr.parameters(), lr=lr, momentum=momentum)
loss_fn_surr = SF.loss.ce_rate_loss()

### Set Up Simple Evaluation using torcheval

In [None]:
def eval_model(loader, n_steps: int = 15):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()
    model.reset()

    # Iterate over Data Loader
    for index, (inputs, labels) in tqdm(enumerate(loader), desc="Evaluating", total=len(loader)):
        inputs = inputs.to(device)
        labels = (labels).to(device)

        with torch.no_grad():
            # Get model predictions
            u_rec, spk_rec = [], []
            for step in tqdm(range(n_steps), disable=True):  # [ ] move this into the fwd method of the model?
                y = model(inputs)
                spk_out, u_out = y
                u_rec.append(u_out)
                spk_rec.append(spk_out)

            spikes = torch.stack(spk_rec, dim=0)
            membrane_potential = torch.stack(u_rec, dim=0)

            # Get rewards
            reward = reward_func(spikes=spikes, potentials=membrane_potential, labels=labels)
            outputs = reward_func.get_predictions(spikes=spikes, potentials=membrane_potential)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}
    model.reset()
    # Return evaluation
    return return_dict

def eval_model_surr(loader, n_steps=15):
    eval_metrics = {
        "loss": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=n_outputs, k=1, device=device),
    }
    model_surr.eval()
    model_surr.reset()
    for index, (inputs, labels) in tqdm(enumerate(loader), desc="Evaluating (Grad)", total=len(loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            spk_rec = []
            for step in range(n_steps):
                spk_out, _ = model_surr(inputs)
                spk_rec.append(spk_out)
            spikes = torch.stack(spk_rec, dim=0)
            loss = loss_fn_surr(spk_out=spikes, targets=labels)
            outputs = spikes.sum(0).argmax(-1)
        eval_metrics["loss"].update(loss)
        eval_metrics["accuracy"].update(outputs, labels)
    return {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}


### Training Loop

In [None]:
import time

def grad_step(inputs, labels, n_steps=15):
    model_surr.train()
    model_surr.reset()
    optimizer_surr.zero_grad()

    # Forward pass through time
    spk_rec = []
    t0 = time.time()
    for step in range(n_steps):
        spk_out, _ = model_surr(inputs)
        spk_rec.append(spk_out)
    spikes = torch.stack(spk_rec, dim=0)
    t1 = time.time()
    forward_time = t1 - t0

    # Compute loss and backward
    t2 = time.time()
    loss = loss_fn_surr(spk_out=spikes, targets=labels)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model_surr.parameters(), 3.0, 2.0)
    optimizer_surr.step()
    t3 = time.time()
    backward_time = t3 - t2

    model_surr.reset()
    model_surr.eval()
    return loss.item(), forward_time, backward_time

def lfp_step(inputs, labels, n_steps: int = 15, print_model: bool = False):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    model.train()
    model.reset()

    optimizer.zero_grad()
    with propagation_composite.context(model) as modified:
        inputs = inputs.detach().requires_grad_(True)
        
        if print_model:
            print(modified)
        
        # Forward pass
        u_rec, spk_rec = [], []
        t0 = time.time()
        for step in range(n_steps):
            outputs = modified(inputs)
            spk_out, u_out = outputs
            u_rec.append(u_out)
            spk_rec.append(spk_out)
        spikes = torch.stack(spk_rec, dim=0)
        membrane_potential = torch.stack(u_rec, dim=0)
        t1 = time.time()
        forward_time = t1 - t0

        # Reward
        reward = torch.from_numpy(reward_func(spikes=spikes, labels=labels).detach().cpu().numpy()).to(device)
        reward /= n_steps

        # Backward pass
        t2 = time.time()
        torch.autograd.grad((spikes,), (inputs,), grad_outputs=(reward,), retain_graph=False)
        t3 = time.time()
        backward_time = t3 - t2

    for name, param in model.named_parameters():
        if not hasattr(param, 'feedback'):
            print(f"Parameter {name} does not have feedback attribute.")
        param.grad = -param.feedback

    torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

    t4 = time.time()
    optimizer.step()
    t5 = time.time()

    backward_time = t3 - t2 + t5-t4
    model.reset()
    model.eval()
    return forward_time, backward_time

# Initialize dictionary to store timing info per epoch
timing_stats = {}
# Initialize dictionary to store evaluation stats per epoch
eval_stats_per_epoch = {}

# Training Loop
for epoch in range(epochs):
    lfp_fwd_total, lfp_bwd_total = 0.0, 0.0
    surr_fwd_total, surr_bwd_total = 0.0, 0.0
    num_batches = 0

    with tqdm(total=len(training_loader)) as pbar:
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # LFP step
            lfp_fwd, lfp_bwd = lfp_step(inputs, labels, n_steps=n_steps, print_model=index==0 and epoch==0)
            lfp_fwd_total += lfp_fwd
            lfp_bwd_total += lfp_bwd

            # Surrogate gradient step
            surr_loss, surr_fwd, surr_bwd = grad_step(inputs, labels, n_steps=n_steps)
            surr_fwd_total += surr_fwd
            surr_bwd_total += surr_bwd

            num_batches += 1
            pbar.update(1)

    # Store timing info for this epoch
    timing_stats[epoch] = {
        "lfp_forward": lfp_fwd_total,
        "lfp_backward": lfp_bwd_total,
        "lfp_total": lfp_fwd_total + lfp_bwd_total,
        "surr_forward": surr_fwd_total,
        "surr_backward": surr_bwd_total,
        "surr_total": surr_fwd_total + surr_bwd_total,
        "num_batches": num_batches
    }

    # Evaluate both models
    eval_stats_train_lfp = eval_model(training_loader, n_steps=n_steps)
    eval_stats_val_lfp = eval_model(validation_loader, n_steps=n_steps)
    eval_stats_train_surr = eval_model_surr(training_loader, n_steps=n_steps)
    eval_stats_val_surr = eval_model_surr(validation_loader, n_steps=n_steps)

    # Store evaluation stats for this epoch
    eval_stats_per_epoch[epoch] = {
        "eval_train_lfp": eval_stats_train_lfp,
        "eval_val_lfp": eval_stats_val_lfp,
        "eval_train_surr": eval_stats_train_surr,
        "eval_val_surr": eval_stats_val_surr
    }

    # Also store in timing_stats for compatibility
    timing_stats[epoch]["eval_train_lfp"] = eval_stats_train_lfp
    timing_stats[epoch]["eval_val_lfp"] = eval_stats_val_lfp
    timing_stats[epoch]["eval_train_surr"] = eval_stats_train_surr
    timing_stats[epoch]["eval_val_surr"] = eval_stats_val_surr

    # Print aggregated timing info for this epoch
    print(
        f"Epoch {epoch+1}/{epochs} Timing:\n"
        f"  LFP:   forward {lfp_fwd_total:.4f}s, backward {lfp_bwd_total:.4f}s\n"
        f"  Grad:   forward {surr_fwd_total:.4f}s, backward {surr_bwd_total:.4f}s"
    )
    print(
        f"Epoch {epoch+1}/{epochs}:\n"
        f"  LFP:   (Train Reward) {eval_stats_train_lfp['reward']:.2f}; (Train Acc) {eval_stats_train_lfp['accuracy']:.2f}; "
        f"(Val Reward) {eval_stats_val_lfp['reward']:.2f}; (Val Acc) {eval_stats_val_lfp['accuracy']:.2f}\n"
        f"  Grad:   (Train Loss) {eval_stats_train_surr['loss']:.2f}; (Train Acc) {eval_stats_train_surr['accuracy']:.2f}; "
        f"(Val Loss) {eval_stats_val_surr['loss']:.2f}; (Val Acc) {eval_stats_val_surr['accuracy']:.2f}"
    )


  0%|          | 0/391 [00:00<?, ?it/s]

ResNet(
  (block1): LFPEpsilonSNN(
    (module): SpikingLayer(
      (parameterized_layer): NoisyWrapper(
        (module): Conv2d(3, 30, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      )
      (spike_mechanism): Leaky(
        (spike_grad): Step()
      )
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): LFPEpsilonSNN(
    (module): SpikingLayer(
      (parameterized_layer): NoisyWrapper(
        (module): Conv2d(30, 150, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (spike_mechanism): Leaky(
        (spike_grad): Step()
      )
    )
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): LFPEpsilonSNN(
    (module): SpikingLayer(
      (parameterized_layer): NoisyWrapper(
        (module): Conv2d(150, 250, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (spike_mechanism): Leaky(
        (spike_grad): 

  warn(
 29%|██▊       | 112/391 [01:06<02:45,  1.69it/s]


KeyboardInterrupt: 

### Plot Results

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import matplotlib.cm as cm
import numpy as np
import copy

# Set font properties and plot style (copied from beans-vit-training)
font_path = plt.matplotlib.get_data_path() + "/fonts/ttf/cmr10.ttf"
cmfont = font_manager.FontProperties(fname=font_path)
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = cmfont.get_name()
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["font.size"] = 15
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["axes.formatter.use_mathtext"] = True
plt.rcParams['axes.linewidth'] = 1.5

# Prepare data
epochs_range = list(eval_stats_per_epoch.keys())
acc_lfp = [eval_stats_per_epoch[e]["eval_val_lfp"]["accuracy"] for e in epochs_range]
acc_surr = [eval_stats_per_epoch[e]["eval_val_surr"]["accuracy"] for e in epochs_range]
runtime_lfp = [timing_stats[e]["lfp_total"] for e in epochs_range]
runtime_surr = [timing_stats[e]["surr_total"] for e in epochs_range]

# Colormap and labels
colors = np.linspace(0, 1, 2)
palette = cm.get_cmap("Set1")(colors)
pastel = 0.3
palette = (1 - pastel) * palette + pastel * np.ones((2, 4))

LABELS = {
    "lfp": r"LFP-$\varepsilon$",
    "surr": r"Surrogate Grad",
}

# Plot accuracy over epochs
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
xaxis = np.arange(len(epochs_range))

ax.plot(xaxis, acc_lfp, color=palette[0], label=LABELS["lfp"], linewidth=3.5, alpha=1)
ax.plot(xaxis, acc_surr, color=palette[1], label=LABELS["surr"], linewidth=3.5, alpha=1)

linelocs = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
ax.hlines(
    linelocs,
    xmin=-1,
    xmax=xaxis[-1],
    color=(0.5, 0.5, 0.5, 1),
    linewidth=1.5,
    zorder=0,
)

ax.set_ylabel("Validation Accuracy [%]")
ax.set_xlabel("Epoch")
ax.set_ylim([0.0, 1.01])
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim([0.0, xaxis[-1]])
ax.set_yticklabels([0, 20, 40, 60, 80, 100])
ax.tick_params(length=6, width=2)
ax.legend()
plt.tight_layout()
plt.show()

# Plot runtime vs. accuracy (logscale x-axis)
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.plot(np.cumsum(runtime_lfp), acc_lfp, color=palette[0], label=LABELS["lfp"], marker="o", linewidth=3.5)
ax.plot(np.cumsum(runtime_surr), acc_surr, color=palette[1], label=LABELS["surr"], marker="o", linewidth=3.5)
ax.set_xscale("log")
ax.set_xlabel("Cumulative Runtime (s, log scale)")
ax.set_ylabel("Validation Accuracy [%]")
ax.set_ylim([0.0, 1.01])
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels([0, 20, 40, 60, 80, 100])
ax.tick_params(length=6, width=2)
ax.legend()
plt.tight_layout()
plt.show()