This repository contains the code underlying the paper Smooth Exact Gradient Descent Learning in Spiking Neural Networks. It shows how to perform gradient descent learning that is both exact and smooth (or at least continuous) in spiking neural networks. The scheme relies on neuron models whose spikes can only appear or vanish at the trial end, such as quadratic integrate-and-fire neurons. These neuron models further allow adding or removing spikes in a principled, gradient-based manner.
We use event-based spiking neural network simulations, iterating over input and network spikes. The code is written in Python using JAX and makes use of its automatic differentiation and JIT-compilation features.
-
Create a virtual environment and install Python 3.10 as well as JAX, preferably with GPU support.
-
Clone or download this repository.
-
Install the package and the dependencies necessary to use it or to run most of the experiments via
pip install -e .
-
For the MNIST experiments, you need to additionally install PyTorch (cpu-only is sufficient) and TorchVision to load the dataset.
The neuron models including methods to simulate networks of them and to compute pseudospike times are located in spikegd
. Abstract base classes, i.e. templates for neuron models, are given in models.py
. Specific implementations of a few neuron models are provided in the other files in the same directory.
The experiments
folder contains scripts and notebooks to generate the main results of the paper. They can also serve as a starting point to implement new experiments.
If you use this code in your research, please cite our arXiv paper:
@misc{klos2023smooth,
title={Smooth Exact Gradient Descent Learning in Spiking Neural Networks},
author={Christian Klos and Raoul-Martin Memmesheimer},
year={2023},
eprint={2309.14523},
archivePrefix={arXiv},
primaryClass={q-bio.NC}
}