In [3]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer, monitor
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
from spikingjelly.activation_based import ann2snn
from spikingjelly.activation_based import encoding


device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Model needs to account for 2 things:

1. ANN proposes Batch Normalization for fast training and convergence. Batch normalization aims to normalize the ANN output to 0 mean, which is contrary to the properties of SNNs. Therefore, the parameters of BN can be absorbed into the previous parameter layers (Linear, Conv2d)

2. According to the transformation theory, the input and output of each layer of ANN need to be limited to the range of [0,1], which requires scaling the parameters (model normalization)

3. There is not a good way to use MaxPooling. AvgPool is recommended instead.

In [2]:
class LeNet_Modified(nn.Module):
    def __init__(self):
        super(LeNet_Modified, self).__init__()
        #5x5 kernal on 28x28 image. Should have 2 padding for "32x32" image
        self.c1 = nn.Conv2d(in_channels=1, kernel_size=5, padding=2, out_channels=6) #results in 28x28 in 6 channels. Should have 1 channel-in bc it is one image at first
        self.bn1 = nn.BatchNorm2d(num_features=6, eps=1e-3)
        self.ap1 = nn.AvgPool2d(kernel_size=2, stride=2)
        #Relu 28x28 -> 28x28
        #avg pool 28x28 -> 14x14 (stride=2)
        self.c2 = nn.Conv2d(in_channels=6, kernel_size=5, out_channels=16) #6 channels to 16 channels. 14x14 -> 10x10 with 5x5kernel
        self.bn2 = nn.BatchNorm2d(num_features=16, eps=1e-3)
        self.ap2 = nn.AvgPool2d(kernel_size=2, stride=2)
        #avg pool 10x10 -> 5x5
        self.fc1 = nn.Linear(25*16, 120) #5x5 images, 16 channels in. 120 out
        self.fc2 = nn.Linear(120, 84) #120 -> 84
        self.fc3 = nn.Linear(84, 10)

        self.network = nn.Sequential(
            self.c1,
            self.bn1,
            nn.ReLU(),
            self.ap1,
            self.c2,
            self.bn2,
            nn.ReLU(),
            self.ap2,
            nn.Flatten(),
            self.fc1,
            nn.ReLU(),
            self.fc2,
            nn.ReLU(),
            self.fc3,
        )

    def forward(self, x):
        x = self.network(x)
        return x

    def name(self):
        return "LeNet"

Define variables and Loss

In [3]:
EPOCHS=100
lr = 0.001
batch_size = 32
model_ann = LeNet_Modified().to(device=device)
optimizer = torch.optim.Adam(model_ann.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Download the FMNIST dataset and then create the dataloaders

In [4]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True
)

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False
)

Train the ANN model

In [5]:
for epoch in range(100):
    # trainning
    total_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = x.cuda(), target.cuda()
        out = model_ann(x)
        loss = criterion(out, target)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_loss = total_loss / len(train_set)
    print(f'==>>> epoch: {epoch}, train loss: {avg_loss:.6f}')

==>>> epoch: 0, train loss: 0.014192
==>>> epoch: 1, train loss: 0.009726
==>>> epoch: 2, train loss: 0.008742
==>>> epoch: 3, train loss: 0.007902
==>>> epoch: 4, train loss: 0.007462
==>>> epoch: 5, train loss: 0.006932
==>>> epoch: 6, train loss: 0.006502
==>>> epoch: 7, train loss: 0.006204
==>>> epoch: 8, train loss: 0.005820
==>>> epoch: 9, train loss: 0.005489
==>>> epoch: 10, train loss: 0.005226
==>>> epoch: 11, train loss: 0.004959
==>>> epoch: 12, train loss: 0.004651
==>>> epoch: 13, train loss: 0.004485
==>>> epoch: 14, train loss: 0.004230
==>>> epoch: 15, train loss: 0.004021
==>>> epoch: 16, train loss: 0.003779
==>>> epoch: 17, train loss: 0.003572
==>>> epoch: 18, train loss: 0.003467
==>>> epoch: 19, train loss: 0.003244
==>>> epoch: 20, train loss: 0.003070
==>>> epoch: 21, train loss: 0.002970
==>>> epoch: 22, train loss: 0.002856
==>>> epoch: 23, train loss: 0.002715
==>>> epoch: 24, train loss: 0.002540
==>>> epoch: 25, train loss: 0.002549
==>>> epoch: 26, train

Save the Model

In [6]:
file_dir = './Models'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
full_path = file_dir + '/LeNET_ann.pt'
torch.save(model_ann.state_dict(), f=full_path)

Test accuracy

In [None]:
model_ann.eval()
total_loss = 0
correct_cnt = 0
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.to(device), target.to(device)
    out = model_ann(x)
    loss = criterion(out, target)
    _, pred_label = torch.max(out, 1)
    correct_cnt += (pred_label == target).sum()
    # smooth average
    total_loss += loss.item()
avg_loss = total_loss / len(test_set)
avg_acc = correct_cnt / len(test_set)
print(f'test loss: {avg_loss:.6f}, test accuracy: {avg_acc:.6f}')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173


: 

Convert Model from ANN to SNN
3 different modes:
* MaxNorm
* 99.9%
* scaling mode (float 0-1)

In [8]:
model_converter = ann2snn.Converter(mode='max', dataloader=train_loader)
model_snn = model_converter(model_ann)

  0%|          | 0/1875 [00:00<?, ?it/s]

100%|██████████| 1875/1875 [00:08<00:00, 208.98it/s]


In [9]:
print(model_snn)

LeNet_Modified(
  (c1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (network): Module(
    (8): Flatten(start_dim=1, end_dim=-1)
  )
  (ap1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (c2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (ap2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (snn tailor): Module(
    (0): Module(
      (0): VoltageScaler(0.129585)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(7.716962)
    )
    (1): Module(
      (0): VoltageScaler(0.203869)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
        (surrogate_function): Sigmoid(alpha=4

Run test on SNN model

In [13]:
model_ann.eval()
T = 100
total_loss = 0
correct_cnt = 0
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.to(device), target.to(device)
    #reset before new eval
    for m in model_snn.modules():
        if hasattr(m, 'reset'):
            m.reset()
    #evaluate for a number of timesteps
    for t in range (T):
        if t ==0:
            out = model_snn(x)
        else:
            out += model_snn(x)
    # loss = criterion(out, target)
    _, pred_label = torch.max(out, 1)
    correct_cnt += (pred_label == target).sum()
    # smooth average
    total_loss += loss.item()
# avg_loss = total_loss / len(test_set)
avg_acc = correct_cnt / len(test_set)
print(f'test accuracy: {avg_acc:.6f}')

test accuracy: 0.905000


Test for spike count.
Go through test set again, but with 1 img at a time.
For everytime step, record membrane potential and number of spikes (if seeing spikes is possible)

In [None]:
spike_monitor = monitor.OutputMonitor()