This Notebook provides a minimal example for using LFP to train a simple MLP-Spiking Neural Network (SNN) on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

In [1]:
try:
    import snntorch as snn
    from snntorch import utils as snnutils
except ImportError:
    print(
        "The SNN functionality of this package requires extra dependencies ",
        "which can be installed via pip install lfprop[snn] (or lfprop[full] for all dependencies).",
    )
    raise ImportError("snntorch required; reinstall lfprop with option `snn` (pip install lfprop[snn])")

### Imports

In [2]:
import os

import numpy as np
import torch
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from lfprop.rewards import reward_functions as rewards  # Reward Functions
from lfprop.model.models import get_model

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [3]:
savepath = "./minimal-example-data"
os.makedirs(savepath, exist_ok=True)

batch_size = 128  # 128
n_channels = 784
n_outputs = 10
n_steps = 15
lr = 0.02
momentum = 0.9
epochs = 3
model_name = "smalllifmlp"
lif_kwargs = {"beta": 0.9, "reset_mechanism": "subtract", "surrogate_disable": False, "spike_grad": "step"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [4]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
training_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=True,
)

validation_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=False,
)

# [DEBUG] overfit to a small dataset
# training_data = torch.utils.data.Subset(training_data, list(range(0, len(training_data) // 2)))
# validation_data = torch.utils.data.Subset(validation_data, list(range(0, 10)) * 100)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)

### Load Model

In [5]:
model = get_model(model_name=model_name, n_channels=n_channels, n_outputs=n_outputs, device=device, **lif_kwargs)
model.reset()
model.to(device)
model.eval()

SmallLifMLP(
  (classifier): Sequential(
    (0): SpikingLayer(
      (parameterized_layer): NoisyWrapper(
        (module): Linear(in_features=784, out_features=1000, bias=True)
      )
      (spike_mechanism): Leaky(
        (spike_grad): Step()
      )
    )
    (1): SpikingLayer(
      (parameterized_layer): NoisyWrapper(
        (module): Linear(in_features=1000, out_features=10, bias=True)
      )
      (spike_mechanism): Leaky(
        (spike_grad): Step()
      )
    )
  )
)

### Set Up LFP

In [6]:
# Initialize the SNN-Propagator
from lfprop.propagation.propagator_snn import LFPSNNEpsilonComposite

snn_propagator = LFPSNNEpsilonComposite(epsilon=1e-6)

# Initialize the Reward Function.
reward_func = rewards.SnnCorrectClassRewardSpikesRateCoded(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

### Set Up Simple Evaluation using torcheval

In [8]:
def eval_model(loader, n_steps: int = 15):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()
    model.reset()

    # Iterate over Data Loader
    for index, (inputs, labels) in tqdm(enumerate(loader), desc="Evaluating", total=len(loader)):
        inputs = inputs.to(device)
        labels = (labels).to(device)

        with torch.no_grad():
            # Get model predictions
            u_rec, spk_rec = [], []
            for step in tqdm(range(n_steps), disable=True):  # [ ] move this into the fwd method of the model?
                y = model(inputs)
                spk_out, u_out = y
                u_rec.append(u_out)
                spk_rec.append(spk_out)

            spikes = torch.stack(spk_rec, dim=0)

            # Get rewards
            reward = reward_func(spikes=spikes, labels=labels)
            outputs = reward_func.get_predictions(spikes=spikes)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}
    model.reset()
    # Return evaluation
    return return_dict


### Training Loop

In [9]:
def lfp_step(inputs, labels, n_steps: int = 15, print_model: bool = False):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    model.train()
    model.reset()

    optimizer.zero_grad()
    with snn_propagator.context(model) as modified:
        inputs = inputs.detach().requires_grad_(True)
        
        if print_model:
            print(modified)
        
        # Forward pass
        u_rec, spk_rec = [], []
        for step in range(n_steps):
            outputs = modified(inputs)
            spk_out, u_out = outputs
            u_rec.append(u_out)
            spk_rec.append(spk_out)
        spikes = torch.stack(spk_rec, dim=0)
        membrane_potential = torch.stack(u_rec, dim=0)

        # Reward
        reward = torch.from_numpy(reward_func(spikes=spikes, labels=labels).detach().cpu().numpy()).to(device)
        reward /= n_steps

        # Backward pass
        torch.autograd.grad((spikes,), (inputs,), grad_outputs=(reward,), retain_graph=False)

    for name, param in model.named_parameters():
        if not hasattr(param, 'feedback'):
            print(f"Parameter {name} does not have feedback attribute.")
        param.grad = -param.feedback

    torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

    optimizer.step()

    model.reset()
    model.eval()

# Training Loop
for epoch in range(epochs):
    with tqdm(total=len(training_loader)) as pbar:
        # Iterate over Data Loader
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Perform Update Step
            lfp_step(inputs, labels, n_steps=n_steps)

            # Update Progress Bar
            pbar.update(1)
            # if index >= 30:
            # break

    # Evaluate and print performance after every epoch
    eval_stats_train = eval_model(training_loader, n_steps=n_steps)
    eval_stats_val = eval_model(validation_loader, n_steps=n_steps)
    print(
        "Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}".format(
            epoch + 1,
            epochs,
            float(np.mean(eval_stats_train["reward"])),
            float(eval_stats_train["accuracy"]),
            float(np.mean(eval_stats_val["reward"])),
            float(eval_stats_val["accuracy"]),
        )
    )

# training takes approx. 5 min

  warn(
100%|██████████| 469/469 [00:47<00:00,  9.94it/s]
Evaluating: 100%|██████████| 469/469 [00:21<00:00, 22.13it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 22.24it/s]


Epoch 1/3: (Train Reward) -0.00; (Train Accuracy) 0.88; (Val Reward) -0.00; (Val Accuracy) 0.89


100%|██████████| 469/469 [00:47<00:00,  9.88it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 21.16it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 21.35it/s]


Epoch 2/3: (Train Reward) 0.01; (Train Accuracy) 0.92; (Val Reward) 0.01; (Val Accuracy) 0.91


100%|██████████| 469/469 [00:47<00:00,  9.87it/s]
Evaluating: 100%|██████████| 469/469 [00:21<00:00, 21.60it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 21.98it/s]

Epoch 3/3: (Train Reward) 0.01; (Train Accuracy) 0.89; (Val Reward) 0.01; (Val Accuracy) 0.88



