Make simple snn to classify MNIST

In [1]:
import os
import time
import argparse
import sys
import datetime
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from spikingjelly.activation_based import lava_exchange


from spikingjelly.activation_based import neuron, layer, functional, monitor
from spikingjelly.activation_based import surrogate, encoding

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define Model

In [2]:
class SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()
        self.layer1 = nn.Sequential(
            layer.Linear(28*28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
        )
    def forward(self, x: torch.Tensor):
        # print(x.shape)
        return self.layer1(x)

Create a model with a default tau value

In [3]:
tau = 2.0 #membrane time constant. 
#^ A default value in some of the spikingjelly repo files.
model = SNN(tau=tau).to(device=device)

Download the MNIST dataset and make train and test datasets

In [4]:
data_dir='./data'
train_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [5]:
batch_size=32
num_workers=10 #todo: see if this can be increased with more cores
train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

More settings

In [6]:
EPOCHS=50
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/SLM"
encoder = encoding.PoissonEncoder()
timesteps = 15

Make settings for AMP and outputs

In [7]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

Start training

In [8]:
functional.reset_net(model)
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_data_loader:
        x = x.to(device)
        x = torch.reshape(x, (batch_size, -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            for t in range(timesteps):
                encoded_img = encoder(x)
                out_fr += model(encoded_img)
            out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():
                out_fr = 0.
                for t in range(timesteps):
                    encoded_img = encoder(x)
                    out_fr += model(encoded_img)
                out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

epoch: 0; loss1483.0915752798319
epoch: 1; loss1093.1480142846704
epoch: 2; loss1033.6684593632817
epoch: 3; loss998.5604588165879
epoch: 4; loss981.6506809517741
epoch: 5; loss961.9164581075311
epoch: 6; loss951.537347830832
epoch: 7; loss941.334234520793
epoch: 8; loss933.3186763897538
epoch: 9; loss920.0017856657505
epoch: 10; loss921.0937893763185
epoch: 11; loss922.0360093489289
epoch: 12; loss922.6604538038373
epoch: 13; loss921.3768982067704
epoch: 14; loss921.1560067981482
epoch: 15; loss921.5457861572504
epoch: 16; loss922.295121088624
epoch: 17; loss918.2773398645222
epoch: 18; loss922.0582334920764
epoch: 19; loss920.178232267499
epoch: 20; loss922.7777863517404
epoch: 21; loss920.6004546359181
epoch: 22; loss920.7440111264586
epoch: 23; loss923.024898186326
epoch: 24; loss920.7657874710858
epoch: 25; loss919.8426771387458
epoch: 26; loss920.7262304574251
epoch: 27; loss922.6457889024168
epoch: 28; loss921.815567297861
epoch: 29; loss921.3644532114267
epoch: 30; loss923.8084

Save the SNN Model so that you do not have to retrain

In [9]:
input('are you sure you want to save the moel?')
file_dir = './Models/'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
full_path = file_dir + '/SingleLayer_SNN.pt'
torch.save(model.state_dict(), f=full_path)

Load Model

In [10]:
file_dir = './Models/'
full_path = file_dir + '/SingleLayer_SNN.pt'
checkpoint = torch.load(f=full_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [11]:
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        # print(x.shape)
        x = torch.reshape(x, (x.shape[0], -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
# writer.add_scalar('test_loss', test_loss, epoch)
# writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)

0.9209


Get Spike Counts

In [12]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)

In [13]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
z = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        # print(x.shape)
        x = torch.reshape(x, (x.shape[0], -1))
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
        # if z > 1:
        #     break
        # z += 1
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
# writer.add_scalar('test_loss', test_loss, epoch)
# writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)
print(test_samples)
print(model)
total_spikes = 0
for tens in spike_monitor.records:
    tensnp = tens.cpu().numpy()
    total_spikes += np.sum(tensnp)
print('total spikes: ' + str(total_spikes))
print(f'spike_seq_monitor.records=\n{len(spike_monitor.records)}')
print('total time steps: ' + str(timesteps*test_samples))

0.9208
0.9208
SNN(
  (layer1): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
total spikes: 143701.0
spike_seq_monitor.records=
4695
