# 1. Set up environments and datasets

In [1]:
import snntorch as snn
import os
import time
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import numpy as np

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


import matplotlib.pyplot as plt

import models
import train
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
# Prepare MNIST datasets
data_path='/tmp/data/mnist'
batch_size = 256
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)


# 2. Design a SNN architecture

In [3]:
# Define the hyperparameters
num_steps = 2
n_first_hidden = 40
num_binary_layers = 4
n_hidden = 30

seed = np.random.randint(100) # later set a seed to fix the initialization
# seed = 30
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)

# Create a folder to save results
save_path = './mnist_results/'

name = '784-' + str(n_first_hidden)
for i in range(num_binary_layers-2):
    name += '-' +str(n_hidden)
name += '-10'
save_path+= name+'/'
os.makedirs(save_path, exist_ok=True)
file = open(save_path+'results.txt','w')
print('Network architecture: ' +name+ '\nNumber of time steps: T=' +str(num_steps), file=file)
file.close()

In [4]:
net = models.SNN(num_steps=2, n_first_hidden=n_first_hidden, num_binary_layers = num_binary_layers, n_hidden = 20).to(device)
print(net)

SNN(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=40, bias=True)
    (1): Leaky()
    (2): Linear(in_features=40, out_features=20, bias=True)
    (3): Leaky()
    (4): Linear(in_features=20, out_features=20, bias=True)
    (5): Leaky()
    (6): Linear(in_features=20, out_features=10, bias=True)
    (7): Leaky()
  )
)


# 3. Train/ Load networks and save results

In [5]:
num_epochs = 20
start_time = time.time()
train.train_snn(net, train_loader, test_loader, save_path=save_path, num_epochs=num_epochs, output='spike')

file = open(save_path+'results.txt','a')
train.print_and_save(f"Training time is approximately {(time.time() - start_time):.0f} seconds over {num_epochs} epochs." , file)
file.close()

  0%|          | 0/20 [00:00<?, ?it/s]

---------- Training epoch 0 ------------
Iteration 0 --- Train Loss: 2.30 --- Minibatch accuracy: 9.77%

Iteration 50 --- Train Loss: 2.30 --- Minibatch accuracy: 6.25%

Iteration 100 --- Train Loss: 2.26 --- Minibatch accuracy: 20.31%

Iteration 150 --- Train Loss: 1.79 --- Minibatch accuracy: 33.59%

Iteration 200 --- Train Loss: 1.38 --- Minibatch accuracy: 58.98%



  5%|▌         | 1/20 [00:14<04:37, 14.61s/it]

---------- Training epoch 1 ------------
Iteration 0 --- Train Loss: 1.26 --- Minibatch accuracy: 68.75%

Iteration 50 --- Train Loss: 1.12 --- Minibatch accuracy: 80.47%

Iteration 100 --- Train Loss: 1.14 --- Minibatch accuracy: 75.78%

Iteration 150 --- Train Loss: 1.10 --- Minibatch accuracy: 78.52%

Iteration 200 --- Train Loss: 1.08 --- Minibatch accuracy: 77.34%



 10%|█         | 2/20 [00:28<04:19, 14.41s/it]

---------- Training epoch 2 ------------
Iteration 0 --- Train Loss: 1.08 --- Minibatch accuracy: 80.86%

Iteration 50 --- Train Loss: 1.01 --- Minibatch accuracy: 82.81%

Iteration 100 --- Train Loss: 1.03 --- Minibatch accuracy: 77.73%

Iteration 150 --- Train Loss: 1.06 --- Minibatch accuracy: 81.25%

Iteration 200 --- Train Loss: 1.03 --- Minibatch accuracy: 83.59%



 15%|█▌        | 3/20 [00:43<04:06, 14.51s/it]

---------- Training epoch 3 ------------
Iteration 0 --- Train Loss: 0.91 --- Minibatch accuracy: 93.36%

Iteration 50 --- Train Loss: 0.97 --- Minibatch accuracy: 89.84%

Iteration 100 --- Train Loss: 0.93 --- Minibatch accuracy: 92.58%

Iteration 150 --- Train Loss: 1.00 --- Minibatch accuracy: 88.67%

Iteration 200 --- Train Loss: 0.94 --- Minibatch accuracy: 92.58%



 20%|██        | 4/20 [00:57<03:51, 14.45s/it]

---------- Training epoch 4 ------------
Iteration 0 --- Train Loss: 0.97 --- Minibatch accuracy: 91.41%

Iteration 50 --- Train Loss: 0.91 --- Minibatch accuracy: 93.36%

Iteration 100 --- Train Loss: 0.91 --- Minibatch accuracy: 92.19%

Iteration 150 --- Train Loss: 0.90 --- Minibatch accuracy: 93.75%

Iteration 200 --- Train Loss: 0.93 --- Minibatch accuracy: 92.19%



 25%|██▌       | 5/20 [01:12<03:36, 14.42s/it]

---------- Training epoch 5 ------------
Iteration 0 --- Train Loss: 0.94 --- Minibatch accuracy: 91.41%

Iteration 50 --- Train Loss: 0.93 --- Minibatch accuracy: 91.80%

Iteration 100 --- Train Loss: 0.88 --- Minibatch accuracy: 95.70%

Iteration 150 --- Train Loss: 0.92 --- Minibatch accuracy: 92.19%

Iteration 200 --- Train Loss: 0.91 --- Minibatch accuracy: 92.19%



 30%|███       | 6/20 [01:26<03:21, 14.42s/it]

---------- Training epoch 6 ------------
Iteration 0 --- Train Loss: 0.88 --- Minibatch accuracy: 94.14%

Iteration 50 --- Train Loss: 0.91 --- Minibatch accuracy: 93.75%

Iteration 100 --- Train Loss: 0.94 --- Minibatch accuracy: 91.80%

Iteration 150 --- Train Loss: 0.89 --- Minibatch accuracy: 93.75%

Iteration 200 --- Train Loss: 0.93 --- Minibatch accuracy: 92.58%



 35%|███▌      | 7/20 [01:41<03:07, 14.42s/it]

---------- Training epoch 7 ------------
Iteration 0 --- Train Loss: 0.91 --- Minibatch accuracy: 92.19%

Iteration 50 --- Train Loss: 0.85 --- Minibatch accuracy: 96.48%

Iteration 100 --- Train Loss: 0.90 --- Minibatch accuracy: 94.14%

Iteration 150 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 200 --- Train Loss: 0.87 --- Minibatch accuracy: 94.92%



 40%|████      | 8/20 [01:55<02:53, 14.44s/it]

---------- Training epoch 8 ------------
Iteration 0 --- Train Loss: 0.91 --- Minibatch accuracy: 95.31%

Iteration 50 --- Train Loss: 0.89 --- Minibatch accuracy: 93.36%

Iteration 100 --- Train Loss: 0.88 --- Minibatch accuracy: 94.53%

Iteration 150 --- Train Loss: 0.88 --- Minibatch accuracy: 96.09%

Iteration 200 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%



 45%|████▌     | 9/20 [02:09<02:38, 14.43s/it]

---------- Training epoch 9 ------------
Iteration 0 --- Train Loss: 0.86 --- Minibatch accuracy: 96.88%

Iteration 50 --- Train Loss: 0.92 --- Minibatch accuracy: 92.58%

Iteration 100 --- Train Loss: 0.89 --- Minibatch accuracy: 94.14%

Iteration 150 --- Train Loss: 0.91 --- Minibatch accuracy: 93.36%

Iteration 200 --- Train Loss: 0.91 --- Minibatch accuracy: 92.58%



 50%|█████     | 10/20 [02:24<02:23, 14.40s/it]

---------- Training epoch 10 ------------
Iteration 0 --- Train Loss: 0.87 --- Minibatch accuracy: 96.09%

Iteration 50 --- Train Loss: 0.90 --- Minibatch accuracy: 92.97%

Iteration 100 --- Train Loss: 0.85 --- Minibatch accuracy: 96.88%

Iteration 150 --- Train Loss: 0.85 --- Minibatch accuracy: 96.09%

Iteration 200 --- Train Loss: 0.88 --- Minibatch accuracy: 95.31%



 55%|█████▌    | 11/20 [02:39<02:10, 14.51s/it]

---------- Training epoch 11 ------------
Iteration 0 --- Train Loss: 0.90 --- Minibatch accuracy: 94.53%

Iteration 50 --- Train Loss: 0.85 --- Minibatch accuracy: 96.48%

Iteration 100 --- Train Loss: 0.91 --- Minibatch accuracy: 92.19%

Iteration 150 --- Train Loss: 0.88 --- Minibatch accuracy: 94.92%

Iteration 200 --- Train Loss: 0.89 --- Minibatch accuracy: 93.75%



 60%|██████    | 12/20 [02:54<01:58, 14.85s/it]

---------- Training epoch 12 ------------
Iteration 0 --- Train Loss: 0.88 --- Minibatch accuracy: 94.92%

Iteration 50 --- Train Loss: 0.89 --- Minibatch accuracy: 94.14%

Iteration 100 --- Train Loss: 0.90 --- Minibatch accuracy: 94.92%

Iteration 150 --- Train Loss: 0.88 --- Minibatch accuracy: 95.70%

Iteration 200 --- Train Loss: 0.88 --- Minibatch accuracy: 95.31%



 65%|██████▌   | 13/20 [03:13<01:53, 16.16s/it]

---------- Training epoch 13 ------------
Iteration 0 --- Train Loss: 0.91 --- Minibatch accuracy: 92.97%

Iteration 50 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 100 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 150 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 200 --- Train Loss: 0.87 --- Minibatch accuracy: 96.09%



 70%|███████   | 14/20 [03:29<01:35, 16.00s/it]

---------- Training epoch 14 ------------
Iteration 0 --- Train Loss: 0.86 --- Minibatch accuracy: 96.88%

Iteration 50 --- Train Loss: 0.86 --- Minibatch accuracy: 96.09%

Iteration 100 --- Train Loss: 0.90 --- Minibatch accuracy: 94.14%

Iteration 150 --- Train Loss: 0.86 --- Minibatch accuracy: 96.09%

Iteration 200 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%



 75%|███████▌  | 15/20 [03:45<01:19, 15.88s/it]

---------- Training epoch 15 ------------
Iteration 0 --- Train Loss: 0.88 --- Minibatch accuracy: 94.53%

Iteration 50 --- Train Loss: 0.88 --- Minibatch accuracy: 96.09%

Iteration 100 --- Train Loss: 0.88 --- Minibatch accuracy: 94.53%

Iteration 150 --- Train Loss: 0.88 --- Minibatch accuracy: 94.53%

Iteration 200 --- Train Loss: 0.87 --- Minibatch accuracy: 95.70%



 80%|████████  | 16/20 [04:00<01:02, 15.64s/it]

---------- Training epoch 16 ------------
Iteration 0 --- Train Loss: 0.86 --- Minibatch accuracy: 95.70%

Iteration 50 --- Train Loss: 0.87 --- Minibatch accuracy: 96.09%

Iteration 100 --- Train Loss: 0.85 --- Minibatch accuracy: 96.09%

Iteration 150 --- Train Loss: 0.85 --- Minibatch accuracy: 95.70%

Iteration 200 --- Train Loss: 0.85 --- Minibatch accuracy: 96.48%



 85%|████████▌ | 17/20 [04:15<00:46, 15.61s/it]

---------- Training epoch 17 ------------
Iteration 0 --- Train Loss: 0.88 --- Minibatch accuracy: 94.92%

Iteration 50 --- Train Loss: 0.90 --- Minibatch accuracy: 94.14%

Iteration 100 --- Train Loss: 0.88 --- Minibatch accuracy: 95.31%

Iteration 150 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 200 --- Train Loss: 0.88 --- Minibatch accuracy: 95.70%



 90%|█████████ | 18/20 [04:31<00:31, 15.71s/it]

---------- Training epoch 18 ------------
Iteration 0 --- Train Loss: 0.89 --- Minibatch accuracy: 94.92%

Iteration 50 --- Train Loss: 0.87 --- Minibatch accuracy: 95.70%

Iteration 100 --- Train Loss: 0.86 --- Minibatch accuracy: 96.48%

Iteration 150 --- Train Loss: 0.90 --- Minibatch accuracy: 93.75%

Iteration 200 --- Train Loss: 0.90 --- Minibatch accuracy: 94.92%



 95%|█████████▌| 19/20 [04:46<00:15, 15.58s/it]

---------- Training epoch 19 ------------
Iteration 0 --- Train Loss: 0.85 --- Minibatch accuracy: 96.88%

Iteration 50 --- Train Loss: 0.85 --- Minibatch accuracy: 96.88%

Iteration 100 --- Train Loss: 0.86 --- Minibatch accuracy: 96.09%

Iteration 150 --- Train Loss: 0.91 --- Minibatch accuracy: 92.97%

Iteration 200 --- Train Loss: 0.85 --- Minibatch accuracy: 96.48%

####### Statistics over the whole train/test dataset after epoch 19 #######
Train loss: 0.87, train accuracy: 95.65 % 


100%|██████████| 20/20 [05:17<00:00, 15.88s/it]

Test loss: 0.90, test accuracy: 93.85 %
Training time is approximately 318 seconds over 20 epochs.





In [6]:
#train.plot_learning_curve(train_loss_hist, test_loss_hist)

In [7]:
#torch.save(net.state_dict(), save_path+'params.pth')
