This Notebook provides a minimal example for using LFP to train a simple MLP-Spiking Neural Network (SNN) on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

In [1]:
try:
    import snntorch as snn
    from snntorch import utils as snnutils
except ImportError:
    print(
        "The SNN functionality of this package requires extra dependencies ",
        "which can be installed via pip install lfprop[snn] (or lfprop[full] for all dependencies).",
    )
    raise ImportError("snntorch required; reinstall lfprop with option `snn` (pip install lfprop[snn])")

### Imports

In [2]:
import os

import numpy as np
import torch
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from lfprop.rewards import reward_functions as rewards  # Reward Functions

### Parameters

In [3]:
savepath = "./minimal-example-data"
os.makedirs(savepath, exist_ok=True)

batch_size = 128  # 128
n_channels = 1 #784
n_outputs = 10
n_steps = 15
lr = 0.02
momentum = 0.9
epochs = 20
model_name = "lifcnn"
lif_kwargs = {"beta": 0.9, "reset_mechanism": "subtract"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [4]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
training_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=True,
)

validation_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=False,
)

# [DEBUG] overfit to a small dataset
# training_data = torch.utils.data.Subset(training_data, list(range(0, len(training_data) // 2)))
# validation_data = torch.utils.data.Subset(validation_data, list(range(0, 10)) * 100)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)

### Load Model

In [5]:
from lfprop.model.spiking_networks import get_model

model = get_model(model_name=model_name, n_channels=n_channels, n_outputs=n_outputs, device=device, **lif_kwargs)
model.reset()
model.to(device)
model.eval()

def name_modules(module, name):
    """
    Recursive function to name modules for debugging 
    """
    
    for cname, child in module.named_children():
        child.tmpname = cname if name == "root" else f"{name}.{cname}"
        name_modules(child, child.tmpname)

name_modules(model, "root")

### Set Up LFP

In [6]:
# Initialize the SNN-Propagator
from lfprop.propagation.propagator_snn import LFPSNNEpsilonComposite

propagation_composite = LFPSNNEpsilonComposite(epsilon=1e-6)

# Initialize the Reward Function.
reward_func = rewards.SnnCorrectClassRewardSpikesRateCoded(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  from .autonotebook import tqdm as notebook_tqdm


### Set Up Simple Evaluation using torcheval

In [7]:
def eval_model(loader, n_steps: int = 15):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()
    model.reset()

    # Iterate over Data Loader
    for index, (inputs, labels) in tqdm(enumerate(loader), desc="Evaluating", total=len(loader)):
        inputs = inputs.to(device)
        labels = (labels).to(device)

        with torch.no_grad():
            # Get model predictions
            u_rec, spk_rec = [], []
            for step in tqdm(range(n_steps), disable=True):  # [ ] move this into the fwd method of the model?
                y = model(inputs)
                spk_out, u_out = y
                u_rec.append(u_out)
                spk_rec.append(spk_out)

            spikes = torch.stack(spk_rec, dim=0)
            membrane_potential = torch.stack(u_rec, dim=0)

            # Get rewards
            reward = reward_func(spikes=spikes, potentials=membrane_potential, labels=labels)
            outputs = reward_func.get_predictions(spikes=spikes, potentials=membrane_potential)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}
    model.reset()
    # Return evaluation
    return return_dict

### Training Loop

In [8]:

def lfp_step(inputs, labels, n_steps: int = 15):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    # Set Model to training mode
    model.train()
    model.reset()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()
        
        with propagation_composite.context(model) as modified:
    
            inputs = inputs.detach().requires_grad_(True)
        
            # Model forward pass
            u_rec, spk_rec = [], []
            for step in tqdm(range(n_steps), disable=True):
                outputs = modified(inputs)
                spk_out, u_out = outputs
                u_rec.append(u_out)
                spk_rec.append(spk_out)
            spikes = torch.stack(spk_rec, dim=0)
            membrane_potential = torch.stack(u_rec, dim=0)

            # Reward
            reward = torch.from_numpy(reward_func(spikes=spikes, potentials=membrane_potential, labels=labels).detach().cpu().numpy()).to(device)
            reward /= n_steps #Total reward should not increase with additional steps

            #print(reward.sum())

            # Modified Backward Pass
            #for step in range(n_steps):
            #    torch.autograd.grad((spk_rec[n_steps-(step+1)],), (inputs,), grad_outputs=(reward[n_steps-(step+1)],), retain_graph=False)
            torch.autograd.grad((spikes,), (inputs,), grad_outputs=(reward,), retain_graph=False)
            #snn_propagator.propagate(iteration_feedback=reward[-(step + 1)], iteration_idx=step)

    for name, param in model.named_parameters():
        param.grad = -param.feedback
        #print(name, param.feedback.abs().sum(), param.data.abs().sum())

    # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
    # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
    torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

    # update parameters
    optimizer.step()

    # Set Model back to eval model
    model.reset()  # necessary to free the internal state of the model
    model.eval()


# Training Loop
for epoch in range(epochs):
    with tqdm(total=len(training_loader)) as pbar:
        # Iterate over Data Loader
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = (labels).to(device)

            # Perform Update Step
            lfp_step(inputs, labels, n_steps=n_steps)

            # Update Progress Bar
            pbar.update(1)
            # if index >= 30:
            # break

    # Evaluate and print performance after every epoch
    eval_stats_train = eval_model(training_loader, n_steps=n_steps)
    eval_stats_val = eval_model(validation_loader, n_steps=n_steps)
    print(
        "Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}".format(
            epoch + 1,
            epochs,
            float(np.mean(eval_stats_train["reward"])),
            float(eval_stats_train["accuracy"]),
            float(np.mean(eval_stats_val["reward"])),
            float(eval_stats_val["accuracy"]),
        )
    )

# training takes approx. 5 min

  warn(
100%|██████████| 469/469 [01:03<00:00,  7.40it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.92it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.85it/s]


Epoch 1/20: (Train Reward) 0.01; (Train Accuracy) 0.80; (Val Reward) 0.01; (Val Accuracy) 0.81


100%|██████████| 469/469 [00:58<00:00,  8.06it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.69it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.90it/s]


Epoch 2/20: (Train Reward) 0.02; (Train Accuracy) 0.76; (Val Reward) 0.02; (Val Accuracy) 0.77


100%|██████████| 469/469 [00:58<00:00,  8.08it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 19.95it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 19.87it/s]


Epoch 3/20: (Train Reward) 0.01; (Train Accuracy) 0.66; (Val Reward) 0.01; (Val Accuracy) 0.67


100%|██████████| 469/469 [00:58<00:00,  7.95it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.10it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.33it/s]


Epoch 4/20: (Train Reward) 0.04; (Train Accuracy) 0.66; (Val Reward) 0.03; (Val Accuracy) 0.67


100%|██████████| 469/469 [00:55<00:00,  8.39it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.26it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.35it/s]


Epoch 5/20: (Train Reward) 0.04; (Train Accuracy) 0.64; (Val Reward) 0.04; (Val Accuracy) 0.64


100%|██████████| 469/469 [00:55<00:00,  8.47it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.24it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.00it/s]


Epoch 6/20: (Train Reward) 0.04; (Train Accuracy) 0.64; (Val Reward) 0.04; (Val Accuracy) 0.64


100%|██████████| 469/469 [00:58<00:00,  8.02it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.53it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.38it/s]


Epoch 7/20: (Train Reward) 0.04; (Train Accuracy) 0.59; (Val Reward) 0.03; (Val Accuracy) 0.59


100%|██████████| 469/469 [00:56<00:00,  8.36it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.37it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.28it/s]


Epoch 8/20: (Train Reward) 0.04; (Train Accuracy) 0.57; (Val Reward) 0.04; (Val Accuracy) 0.57


100%|██████████| 469/469 [00:55<00:00,  8.41it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.81it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.46it/s]


Epoch 9/20: (Train Reward) 0.03; (Train Accuracy) 0.58; (Val Reward) 0.03; (Val Accuracy) 0.59


100%|██████████| 469/469 [00:57<00:00,  8.14it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.76it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.97it/s]


Epoch 10/20: (Train Reward) 0.04; (Train Accuracy) 0.57; (Val Reward) 0.04; (Val Accuracy) 0.57


100%|██████████| 469/469 [00:55<00:00,  8.48it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.68it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.89it/s]


Epoch 11/20: (Train Reward) 0.03; (Train Accuracy) 0.55; (Val Reward) 0.03; (Val Accuracy) 0.56


100%|██████████| 469/469 [00:55<00:00,  8.39it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.19it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.36it/s]


Epoch 12/20: (Train Reward) 0.04; (Train Accuracy) 0.54; (Val Reward) 0.04; (Val Accuracy) 0.55


100%|██████████| 469/469 [00:56<00:00,  8.29it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.30it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.52it/s]


Epoch 13/20: (Train Reward) 0.04; (Train Accuracy) 0.52; (Val Reward) 0.04; (Val Accuracy) 0.52


100%|██████████| 469/469 [00:57<00:00,  8.10it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.68it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.90it/s]


Epoch 14/20: (Train Reward) 0.05; (Train Accuracy) 0.53; (Val Reward) 0.05; (Val Accuracy) 0.53


100%|██████████| 469/469 [00:55<00:00,  8.44it/s]
Evaluating: 100%|██████████| 469/469 [00:22<00:00, 20.79it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.81it/s]


Epoch 15/20: (Train Reward) 0.04; (Train Accuracy) 0.51; (Val Reward) 0.04; (Val Accuracy) 0.51


100%|██████████| 469/469 [00:57<00:00,  8.18it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.11it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.41it/s]


Epoch 16/20: (Train Reward) 0.04; (Train Accuracy) 0.50; (Val Reward) 0.04; (Val Accuracy) 0.50


100%|██████████| 469/469 [00:56<00:00,  8.32it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.29it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.95it/s]


Epoch 17/20: (Train Reward) 0.04; (Train Accuracy) 0.50; (Val Reward) 0.03; (Val Accuracy) 0.50


100%|██████████| 469/469 [00:57<00:00,  8.14it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.25it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.26it/s]


Epoch 18/20: (Train Reward) 0.03; (Train Accuracy) 0.48; (Val Reward) 0.03; (Val Accuracy) 0.49


100%|██████████| 469/469 [00:58<00:00,  7.99it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.11it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.16it/s]


Epoch 19/20: (Train Reward) 0.05; (Train Accuracy) 0.49; (Val Reward) 0.05; (Val Accuracy) 0.49


100%|██████████| 469/469 [00:56<00:00,  8.36it/s]
Evaluating: 100%|██████████| 469/469 [00:23<00:00, 20.23it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.86it/s]

Epoch 20/20: (Train Reward) 0.05; (Train Accuracy) 0.48; (Val Reward) 0.04; (Val Accuracy) 0.48



