Make simple snn to classify MNIST

In [1]:
import os
import time
import argparse
import sys
import datetime
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, layer, functional, monitor
from spikingjelly.activation_based import surrogate, encoding

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define Model

In [2]:
class SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()
        self.layer1 = nn.Sequential(
            layer.Linear(28*28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
        )
    def forward(self, x: torch.Tensor):
        # print(x.shape)
        return self.layer1(x)

Create a model with a default tau value

In [3]:
tau = 2.0 #membrane time constant. 
#^ A default value in some of the spikingjelly repo files.
model = SNN(tau=tau).to(device=device)

Download the MNIST dataset and make train and test datasets

In [4]:
data_dir='./data'
train_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [5]:
batch_size=32
num_workers=10 #todo: see if this can be increased with more cores
train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

More settings

In [6]:
EPOCHS=100
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/SLM"
encoder = encoding.PoissonEncoder()
timesteps = 100

Make settings for AMP and outputs

In [7]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

Start training

In [42]:
functional.reset_net(model)
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_data_loader:
        x = x.to(device)
        x = torch.reshape(x, (batch_size, -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            for t in range(timesteps):
                encoded_img = encoder(x)
                out_fr += model(encoded_img)
            out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():
                out_fr = 0.
                for t in range(timesteps):
                    encoded_img = encoder(x)
                    out_fr += model(encoded_img)
                out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

epoch: 0; loss1465.4321936815977
epoch: 1; loss1074.6598600372672
epoch: 2; loss1013.7691004052758
epoch: 3; loss978.1228725686669
epoch: 4; loss955.739439509809
epoch: 5; loss936.859009757638
epoch: 6; loss925.2431800663471
epoch: 7; loss916.6635410562158
epoch: 8; loss893.2880175858736
epoch: 9; loss893.4390191398561
epoch: 10; loss893.4556886255741
epoch: 11; loss892.8046701513231
epoch: 12; loss892.7847896181047
epoch: 13; loss892.8264885768294
epoch: 14; loss892.7614993713796
epoch: 15; loss893.962477080524
epoch: 16; loss893.4855384528637
epoch: 17; loss893.6630686409771
epoch: 18; loss893.4488283991814
epoch: 19; loss892.6274483725429
epoch: 20; loss893.0832287408412
epoch: 21; loss893.6761670857668
epoch: 22; loss893.6992178708315
epoch: 23; loss892.9940860755742
epoch: 24; loss892.7976972423494
epoch: 25; loss892.8244476206601
epoch: 26; loss893.2834375761449
epoch: 27; loss893.6475564651191
epoch: 28; loss893.0162987858057
epoch: 29; loss893.0060585215688
epoch: 30; loss893.2

Save the SNN Model so that you do not have to retrain

In [43]:
file_dir = './Models/'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
input('are you sure you want to save the model?')
full_path = file_dir + '/SingleLayer_SNN.pt'
torch.save(model.state_dict(), f=full_path)

Load Model

In [8]:
file_dir = './Models/'
full_path = file_dir + '/SingleLayer_SNN.pt'
checkpoint = torch.load(f=full_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [9]:
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        # print(x.shape)
        x = torch.reshape(x, (x.shape[0], -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
# writer.add_scalar('test_loss', test_loss, epoch)
# writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)

0.9236


Get Spike Counts

In [14]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)

In [27]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
z = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        # print(x.shape)
        x = torch.reshape(x, (x.shape[0], -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
        # if z > 1:
        #     break
        # z += 1
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
# writer.add_scalar('test_loss', test_loss, epoch)
# writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)
print(test_acc)
print(model)
total_spikes = 0
for tens in spike_monitor.records:
    tensnp = tens.cpu().numpy()
    total_spikes += np.sum(tensnp)
print('total spikes: ' + str(total_spikes))
print(f'spike_seq_monitor.records=\n{len(spike_monitor.records)}')

0.9236
0.9236
SNN(
  (layer1): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
total spikes: 986453.0
spike_seq_monitor.records=
31300
