In [None]:
# have you installed snn torch?
# %pip install snntorch

In [None]:
import snntorch as snn


In [None]:
DATADIR = "/tmp/data"

## Download Dataset using `spikedata` (deprecated)

In [None]:
from snntorch.spikevision import spikedata 
# note that a default transform is already applied
train_ds = spikedata.NMNIST(f"{DATADIR}/nmnist", train=True, num_steps=300, dt=1000) # dt is the # of microseconds integrated
test_ds = spikedata.NMNIST(f"{DATADIR}/nmnist", train=False, num_steps=300, dt=1000)

In [None]:
train_ds 

## Download Dataset using `tonic`

In [None]:
import tonic
train_ds = tonic.datasets.NMNIST(save_to=DATADIR, train=True)
test_ds = tonic.datasets.NMNIST(save_to=DATADIR, train=False)

In [None]:
import tonic
import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                      transforms.ToFrame(sensor_size=sensor_size,
                                                         time_window=1000),
                                     ])

train_ds = tonic.datasets.NMNIST(save_to=DATADIR, transform=frame_transform, train=True)
test_ds = tonic.datasets.NMNIST(save_to=DATADIR, transform=frame_transform, train=False)

In [None]:
train_ds

## Create DataLoader

In [None]:
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, shuffle=True, batch_size=64)
test_dl = DataLoader(test_ds, shuffle=False, batch_size=64)

In [None]:
print('the number of items in the dataset is', len(train_dl.dataset))

## Play with Data

In [None]:
# get a feel for the data
i_item = 20000 # random index into a sample
data, label = train_dl.dataset[i_item]
import torch
data = torch.Tensor(data)

print('The data sample has size', data.shape)
print(f"in case you're blind AF, the target is: {label}")

## Visualize

In [None]:
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML, display
import numpy as np

# flatten on-spikes and off-spikes into one channel
# a = (train_dl.dataset[n][0][:, 0] + train_dl.dataset[n][0][:, 1])
data = (data>=1).float() # some spikes are equal to 2...
a = (data[:, 0, :, :] - data[:, 1, :, :])
# a = np.swapaxes(a, 0, -1)
#  Plot
fig, ax = plt.subplots()
anim = splt.animator(a, fig, ax, interval=30, cmap='seismic')
HTML(anim.to_html5_video())
# anim.save('nmnist_animation.mp4', writer = 'ffmpeg', fps=50)  