# Spike Train Generation
## 1. MNIST Conversion to Spike Train
### 1.1. Import packages and setup environment

In [1]:
# Note to self: Clear previous versions of snntorch
#!rm -rf snntorch

In [2]:
#!git clone https://Username:Password@github.com/username/repository.git

Importing `os`, `sys`, and modifying paths are only needed while this notebook is under development.
These can be safely removed once distributed on PyPi.

In [1]:
import torch
import os
import sys

# Note to self: when running locally, I need to change directory to the following path
os.chdir("C:\\Users\\Jason\\Dropbox\\repos\\snntorch")

# When running on Colab, use this line to add it to the search path:
sys.path.insert(0, '/content/snntorch')

import snntorch as snn
from snntorch.spikevision import datamod, spikegen

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Initialize the configuration file which contains information about the dataset.

The first argument is the dimensions of the dataset. `channels` is entered as the second argument.

`split` is used to assign data from the training to the the validation set.
*E.g., for a split of 0.2, the validation set will be made up of 20% of the train set.*

`subset` is used to partition the training and test sets down by the given factor.
*E.g., for a subset of 100, a training set of 60,000 will be reduced to 600.*

`num_classes` is the number of output classes (10 for MNIST).

 `T` is the number of time steps to be simulated.

`data_path` is the target directory for downloading the training set.

In [2]:
config = snn.utils.Configuration([28,28], channels=1, batch_size=100, split=0.1, subset=100, num_classes=10, T=1000,
                           data_path='/data/mnist')

### 1.2 Download Dataset

Note that `mnist_val` is the same as `mnist_train`.
This allows us to retain the `data` and `target` attributes which would otherwise be lost had we used `random_split`.

In [3]:
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize(config.input_size),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(config.data_path, train=True, download=True, transform=transform)
mnist_val = datasets.MNIST(config.data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(config.data_path, train=False, download=True, transform=transform)

Create train/validation split using the value in `config.split`.

`spikevision` is a package within `snntorch` containing useful functions for modifying the data set.
These functions are in the module `datamod`.

In [4]:
mnist_train, mnist_val = datamod.valid_split(mnist_train, mnist_val, config)

Reduce training, validation and test sets to smaller subsets for faster processing.

In [5]:
mnist_train = datamod.data_subset(mnist_train, config)
mnist_val = datamod.data_subset(mnist_val, config)
mnist_test = datamod.data_subset(mnist_test, config)

As a sanity check, let's take a look at the length of each of our datasets:

In [6]:
print(f"The size of mnist_train is {len(mnist_train)}")
print(f"The size of mnist_val is {len(mnist_val)}")
print(f"The size of mnist_test is {len(mnist_test)}")

The size of mnist_train is 540
The size of mnist_val is 60
The size of mnist_test is 100


### 1.3 Create Dataloaders

In [7]:
from torch.utils.data import DataLoader

train_loader = DataLoader(mnist_train, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=config.batch_size, shuffle=True)

### 1.4 Spike Train Generation
The pixels from each image are passed as the mean of a Bernoulli distribution to generate a Poisson spike train.
This is done one minibatch at a time. `gain` is used as a factor to alter the mean, and is clipped between `1` and `0`.

In [8]:
# Create a generator instead of iterator
data = iter(train_loader)
data_it, targets_it = next(data)
data_it = data_it.to(device)
targets_it = targets_it.to(device)

# Spiking Data
spike_data, spike_targets = spikegen.spike_conversion(data_it, targets_it, config, gain=1)

#### 1.4.1 Visualising Data
##### 1.4.1.1 Animations
The `spikeplot` module contains useful functions for visualising spiking data. Let's index into one sample from
`spike_data`, where it is of dimensions [T x B x C x W x H], and must be reduced to [T x W x H].

In [10]:
# Note to Self: matplotlib.use("TkAgg") is only needed when running this notebook in PyCharm:
# import matplotlib as plt
# import matplotlib; matplotlib.use("TkAgg")

from snntorch import spikeplot
from IPython.display import HTML


spike_data_visualizer = spike_data[:,0,0]
data_sample = spikeplot.spike_animator(spike_data_visualizer, x=28, y=28, T=100)
HTML(data_sample.to_html5_video())

# HTML required installation of ffmpeg package on PyCharm.
#plt.show()

# To save as gif, uncomment the following line:
# data_sample.save("spike_plot.gif", writer='imagemagick')

RuntimeError: Requested MovieWriter (ffmpeg) not available

And the associated label can be indexed as follows:

In [10]:
print(f"The target is: {spikegen.from_one_hot(spike_targets[0][0])}")

The target is: 6


As a matter of interest, let's do that again but with 25% of the gain for increased sparsity:

In [14]:
spike_data, spike_targets = spikegen.spike_conversion(data_it, targets_it, config, gain=0.25)
spike_data_visualizer = spike_data[:,0,0]
data_sample = spikeplot.spike_animator(spike_data_visualizer, x=28, y=28, T=100)
HTML(data_sample.to_html5_video())
#print(f"The target is: {spikegen.from_one_hot(spike_targets[0][0])}")
#plt.show()

NameError: name 'spikeplot' is not defined

Now let's average the spikes out over time and reconstruct the input image.

In [None]:
import matplotlib.pyplot as plt

plt.imshow(spike_data_visualizer.mean(axis=0).reshape((28,-1)).cpu())
plt.show()

##### 1.4.1.2 Spike Raster Plots
Let's look at a raster plot of the input spikes by calling `spikeplot.raster`.

In [9]:
spike_data_visualizer = spike_data[:,0,0]
spike_data_visualizer = spike_data_visualizer.reshape((config.T, -1))

# raster plot
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)

# Note: if the following line doesn't work on Colab, update PIL and restart runtime
# !pip install Pillow==5.3.0
spikeplot.raster(data=spike_data_visualizer, ax=ax, s=1, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

<Figure size 1000x500 with 1 Axes>

We can also index into one single neuron:

In [None]:
spikeplot.raster(data=spike_data_visualizer[:,1], ax=ax, s=100, c="black", marker="|")
plt.title("Input Neuron")
plt.xlabel("Time step")
plt.show()


### 2.0 LIF Neuron: 3-Factor Learning Rule
### 2.1 LIF Neuron: Voltage & Current Dependent
### 2.2 LIF Neuron w/Alpha Function -- low priority

### Update plot to handle membrane potentials

There is another function we can call to generate a spike train.
Unsurprisingly, `spikegen.spike_train`.

In [None]:
# Create a random spike train
Sin = torch.FloatTensor(spikegen.spike_train(N_in=10, data_config=config, rate=0.5))

# raster plot
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)

spikeplot.raster(data=Sin, ax=ax, s=0.2, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()