# Tutorial 2
In this tutorial, we are going to directly train a simple SNN with a single hidden layer using e-prop on the MNIST dataset, converted to a latency spike code.

Clearly, this is far from a state of the art architecture, but it still achieves 97.6% accuracy on MNIST.
## Install PyGeNN wheel from Google Drive
Download wheel file

In [1]:
!gdown 1fllqUtL_1_tyGzjNHSR-9i5fXFI9jqAi 

Downloading...
From: https://drive.google.com/uc?id=1fllqUtL_1_tyGzjNHSR-9i5fXFI9jqAi
To: /content/pygenn-4.8.1-cp310-cp310-linux_x86_64.whl
100% 22.3M/22.3M [00:00<00:00, 82.6MB/s]


and then install PyGeNN from wheel file

In [2]:
!pip install pygenn-4.8.1-cp310-cp310-linux_x86_64.whl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./pygenn-4.8.1-cp310-cp310-linux_x86_64.whl
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Installing collected packages: deprecated, pygenn
Successfully installed deprecated-1.2.13 pygenn-4.8.1


and checkout mlGeNN from git and install

In [7]:
!git clone --branch master https://github.com/genn-team/ml_genn.git
!pip install ml_genn/ml_genn

Cloning into 'ml_genn'...
remote: Enumerating objects: 6759, done.[K
remote: Counting objects:   0% (1/2566)[Kremote: Counting objects:   1% (26/2566)[Kremote: Counting objects:   2% (52/2566)[Kremote: Counting objects:   3% (77/2566)[Kremote: Counting objects:   4% (103/2566)[Kremote: Counting objects:   5% (129/2566)[Kremote: Counting objects:   6% (154/2566)[Kremote: Counting objects:   7% (180/2566)[Kremote: Counting objects:   8% (206/2566)[Kremote: Counting objects:   9% (231/2566)[Kremote: Counting objects:  10% (257/2566)[Kremote: Counting objects:  11% (283/2566)[Kremote: Counting objects:  12% (308/2566)[Kremote: Counting objects:  13% (334/2566)[Kremote: Counting objects:  14% (360/2566)[Kremote: Counting objects:  15% (385/2566)[Kremote: Counting objects:  16% (411/2566)[Kremote: Counting objects:  17% (437/2566)[Kremote: Counting objects:  18% (462/2566)[Kremote: Counting objects:  19% (488/2566)[Kremote: Counting objects:  20% (514

Set environment variable to allow GeNN to find CUDA

In [1]:
%env CUDA_PATH=/usr/local/cuda

env: CUDA_PATH=/usr/local/cuda


## Install MNIST package

In [2]:
!pip install mnist

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Build model
Import standard modules and required mlGeNN classes

In [3]:
import mnist
import numpy as np
import matplotlib.pyplot as plt

from ml_genn import InputLayer, Layer, SequentialNetwork
from ml_genn.callbacks import Checkpoint
from ml_genn.compilers import EPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense,FixedProbability
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
from ml_genn.serialisers import Numpy

from ml_genn.utils.data import (calc_latest_spike_time, log_latency_encode_data)

from ml_genn.compilers.eprop_compiler import default_params


##Parameters

Define some model parameters


In [4]:
NUM_INPUT = 28 * 28
NUM_HIDDEN = 128
NUM_OUTPUT = 10
BATCH_SIZE = 128

## Latency encoding
There are numerous ways to encode images using spikes but here we are going to emit a single spike for each spike at a time calculated as follows from intensity $x$:
\begin{align}
    T(x) = \begin{cases}
        \tau_\text{eff} \log\left(\frac{x}{x-\theta}  \right) & x > \theta\\
        \infty & otherwise\\
    \end{cases}
\end{align}
where $\tau_\text{eff}=20\text{ms}$ and $\theta=51$.

In [5]:
train_spikes = log_latency_encode_data(mnist.train_images(), 20.0, 51)

## Network definition
Because our network is entirely feedforward, we can define it as a ``SequentialNetwork`` where each layer is automatically connected to the previous layer. As we have converted the MNIST dataset to spikes, we will use a ``SpikeInput`` to inject these directly into the network. For our hidden layer we are going use standard Leaky integrate-and-fire neurons as this task does not require more computationally expensive adaptive LIF neurons. Finally, we are going to use a non-spiking output layer and read classifications out of this by summing it's membrane voltage.



In [6]:
# Create sequential model
serialiser = Numpy("latency_mnist_checkpoints")
network = SequentialNetwork(default_params)
with network:
    # Populations
    input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * NUM_INPUT),
                                  NUM_INPUT)
    hidden = Layer(Dense(Normal(sd=1.0 / np.sqrt(NUM_INPUT))), 
                   LeakyIntegrateFire(v_thresh=0.61, tau_mem=20.0,
                                      tau_refrac=5.0),
                   NUM_HIDDEN)
    output = Layer(Dense(Normal(sd=1.0 / np.sqrt(NUM_HIDDEN))),
                   LeakyIntegrate(tau_mem=20.0, readout="sum_var"),
                   NUM_OUTPUT)

## Compilation
In mlGeNN, in order to turn an abstract network description into something that can actually be used for training or inference you use a *compiler* class. Here, we use the ``EPropCompiler`` to train with e-prop and specify batch size and how many timesteps to evaluate each example for as well as choosing our optimiser and loss function. Because this is a classification task, we want to use cross-entropy loss and, because our labels are specified in this way (rather than e.g. one-hot encoded), we use the sparse catgorical variant.

In [7]:
max_example_timesteps = int(np.ceil(calc_latest_spike_time(train_spikes)))
compiler = EPropCompiler(example_timesteps=max_example_timesteps,
                         losses="sparse_categorical_crossentropy",
                         optimiser="adam", batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

## Training
Now we will train the model for 10 epochs using our compiled network. To verify it's performance we take 10% of the training data as a validation split and add an additional callback to checkpoint weights every epoch.



In [None]:
with compiled_net:
    # Evaluate model on numpy dataset
    callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
    compiled_net.train({input: train_spikes},
                       {output: mnist.train_labels()},
                       num_epochs=10, shuffle=True,
                       validation_split=0.1,
                       callbacks=callbacks)

  0%|          | 0/422 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]