In [1]:
import torch, torch.nn as nn
import snntorch as snn

In [2]:
batch_size = 128
data_path='/tmp/data/mnist'
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")

In [3]:
device

device(type='cpu')

In [4]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [5]:
from snntorch import surrogate

beta = 0.9  # neuron decay rate 
spike_grad = surrogate.fast_sigmoid() # fast sigmoid surrogate gradient

#  Initialize Convolutional SNN
net = nn.Sequential(nn.Conv2d(1, 8, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(8, 16, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [6]:
net

Sequential(
  (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
  (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): Leaky()
  (3): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Leaky()
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=256, out_features=10, bias=True)
  (8): Leaky()
)

In [7]:
from snntorch import utils 

def forward_pass(net, data, num_steps):  
  spk_rec = [] # record spikes over time
  utils.reset(net)  # reset/initialize hidden states for all LIF neurons in net

  for step in range(num_steps): # loop over time
      spk_out, mem_out = net(data) # one time step of the forward-pass
      spk_rec.append(spk_out) # record spikes
  
  return torch.stack(spk_rec)

In [8]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [9]:
num_epochs = 1 # run for 1 epoch - each data sample is seen only once
num_steps = 25  # run for 25 time steps 

loss_hist = [] # record loss over iterations 
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train() 
        spk_rec = forward_pass(net, data, num_steps) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets) 
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")
        
        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 2.42
Accuracy: 4.69%

Epoch 0, Iteration 25 
Train Loss: 0.72
Accuracy: 43.75%

Epoch 0, Iteration 50 
Train Loss: 0.44
Accuracy: 74.22%

Epoch 0, Iteration 75 
Train Loss: 0.40
Accuracy: 87.50%

Epoch 0, Iteration 100 
Train Loss: 0.33
Accuracy: 85.16%

Epoch 0, Iteration 125 
Train Loss: 0.28
Accuracy: 89.84%

Epoch 0, Iteration 150 
Train Loss: 0.23
Accuracy: 96.88%

Epoch 0, Iteration 175 
Train Loss: 0.23
Accuracy: 91.41%

Epoch 0, Iteration 200 
Train Loss: 0.23
Accuracy: 93.75%

Epoch 0, Iteration 225 
Train Loss: 0.21
Accuracy: 92.97%

Epoch 0, Iteration 250 
Train Loss: 0.20
Accuracy: 93.75%

Epoch 0, Iteration 275 
Train Loss: 0.20
Accuracy: 96.88%

Epoch 0, Iteration 300 
Train Loss: 0.20
Accuracy: 94.53%

Epoch 0, Iteration 325 
Train Loss: 0.21
Accuracy: 93.75%

Epoch 0, Iteration 350 
Train Loss: 0.20
Accuracy: 95.31%

Epoch 0, Iteration 375 
Train Loss: 0.14
Accuracy: 100.00%

Epoch 0, Iteration 400 
Train Loss: 0.17
Accuracy: 95.31%

Ep