[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="400">](https://github.com/jeshraghian/snntorch/)

# snnTorch - Spiking Neural Networks

For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)
The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>

# Introduction


In [1]:

!pip install git+https://github.com/jeshraghian/snntorch.git@master
#!pip install nir
#!pip install nirtorch
!pip install tonic


Collecting git+https://github.com/jeshraghian/snntorch.git@master
  Cloning https://github.com/jeshraghian/snntorch.git (to revision master) to /tmp/pip-req-build-rqytp2o4
  Running command git clone --filter=blob:none --quiet https://github.com/jeshraghian/snntorch.git /tmp/pip-req-build-rqytp2o4
  Resolved https://github.com/jeshraghian/snntorch.git to commit bd56874f961b921a457968eb069b85c195331005
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: snntorch
  Building wheel for snntorch (pyproject.toml) ... [?25l[?25hdone
  Created wheel for snntorch: filename=snntorch-0.9.4-py2.py3-none-any.whl size=126032 sha256=f6a6587ac2024b96a45452ba018603c185d38690efebd8c3b37b2e1da924f2a8
  Stored in directory: /tmp/pip-ephem-wheel-cache-89ejpkoa/wheels/64/d0/0d/a0be9822312baa7950bca3896192dea9a2395a4eba2eff1da4
Successfully built snn

In [3]:
import torch, torch.nn as nn
import snntorch as snn

# DataLoading

In [5]:
data_path='/data/nmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Load NeuromorphicMNIST dataset.

In [36]:
from torch.utils.data import DataLoader
from tonic import datasets, transforms
import numpy as np

num_steps=30
 ## Create datasets
transform = transforms.Compose([
    transforms.ToFrame(sensor_size=(34, 34, 2), n_time_bins=num_steps, include_incomplete=True),
    lambda x: torch.from_numpy(x.astype(np.float32)),  # Convert to torch tensor
])

trainset = datasets.NMNIST('../data', train=True, transform=transform)
testset = datasets.NMNIST('../data', train=False, transform=transform)


batch_size = 32
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=0, drop_last=True, pin_memory=True)

# Define Network



In [6]:
net = nn.Sequential(
    nn.Conv2d(
        2, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False
    ),  # 16, 18, 18
    snn.Leaky(threshold=1, beta=0.4, init_hidden=True),
    nn.Conv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    ),  # 8, 18,18
    snn.Leaky(threshold=1, beta=0.4, init_hidden=True),
    nn.AvgPool2d(kernel_size=(2, 2)),  # 8, 17,17
    nn.Conv2d(
        16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    ),  # 8, 9, 9
    snn.Leaky(threshold=1, beta=0.4, init_hidden=True),
    nn.AvgPool2d(kernel_size=(2, 2)),
    nn.Flatten(),
    nn.Linear(4 * 4 * 8, 256, bias=False),
    snn.Leaky(threshold=1, beta=0.4,init_hidden=True),
    nn.Linear(256, 10, bias=False),
    snn.Leaky(threshold=1, beta=0.4, init_hidden=True),
).to(device)

# Training
Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run.

The correct class has a target firing probability of 100%, and incorrect classes are set to 0%.

In [38]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)

In [42]:
from snntorch import utils

def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      utils.reset(net)
      spk_rec = []
      for step in range(num_steps):
        spk_out = net(data[:, step, :, :, :])
        spk_rec.append(spk_out)

      spk_rec = torch.stack(spk_rec, dim=0)


      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)

      total += spk_rec.size(1)

  return acc/total

In [43]:
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\n")

Test set accuracy: 9.816%



In [None]:
for epoch in range(num_epochs):
    net.train()

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Reset hidden states for snn.Leaky layers
        utils = snn.utils
        utils.reset(net)

        # Forward pass through time
        spk_rec = []
        for step in range(num_steps):
          spk_out = net(images[:, step, :, :, :])
          spk_rec.append(spk_out)

        spk_rec = torch.stack(spk_rec, dim=0)


        # Compute loss
        loss_val = loss_fn(spk_rec, labels)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss_val.item():.4f}")

In [50]:
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\n")

Test set accuracy: 94.121%



In [64]:
from snntorch.export_nir import export_to_nir

sample_data = torch.randn(30, 2, 34, 34)
nir_graph = export_to_nir(net.to('cpu'), sample_data)

TypeError: LIF.__init__() missing 1 required positional argument: 'v_reset'

In [71]:
!pip install git+https://github.com/jeshraghian/snntorch.git@master

Collecting git+https://github.com/jeshraghian/snntorch.git@master
  Cloning https://github.com/jeshraghian/snntorch.git (to revision master) to /tmp/pip-req-build-l_fkt9pt
  Running command git clone --filter=blob:none --quiet https://github.com/jeshraghian/snntorch.git /tmp/pip-req-build-l_fkt9pt
  Resolved https://github.com/jeshraghian/snntorch.git to commit bd56874f961b921a457968eb069b85c195331005
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [10]:
!pip install nir
!pip install nirtorch

Collecting nirtorch
  Downloading nirtorch-2.0.5-py3-none-any.whl.metadata (6.4 kB)
Downloading nirtorch-2.0.5-py3-none-any.whl (27 kB)
Installing collected packages: nirtorch
Successfully installed nirtorch-2.0.5


In [16]:
from snntorch.export_nir import export_to_nir

sample_data = torch.randn(1, 2, 34, 34)
nir_graph = export_to_nir(net.to('cpu'), sample_data, model_name='snntorch.nir')

IndexError: tuple index out of range