Make simple snn to classify MNIST

In [1]:
import os
import time
import argparse
import sys
import datetime
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, layer, functional, monitor
from spikingjelly.activation_based import surrogate, encoding

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define Model

In [2]:
class SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()
        self.layer1 = nn.Sequential(
            layer.Linear(28*28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
        )
    def forward(self, x: torch.Tensor):
        return self.layer1(x)

Create a model with a default tau value

In [3]:
tau = 2.0 #membrane time constant. 
#^ A default value in some of the spikingjelly repo files.
model = SNN(tau=tau).to(device=device)

Download the MNIST dataset and make train and test datasets

In [4]:
data_dir='./data'
train_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [5]:
batch_size=32
num_workers=4 #todo: see if this can be increased with more cores
train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

More settings

In [6]:
EPOCHS=100
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/SLM"
encoder = encoding.PoissonEncoder()
timesteps = 100

Make settings for AMP and outputs

In [7]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

Start training

In [8]:
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_data_loader:
        optimizer.zero_grad()
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            for t in range(timesteps):
                encoded_img = encoder(x)
                out_fr += model(encoded_img)
            out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():
                out_fr = 0.
                for t in range(timesteps):
                    encoded_img = encoder(x)
                    out_fr += model(encoded_img)
                out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

epoch: 0; loss1473.2269598692656
epoch: 1; loss1075.422732412815
epoch: 2; loss1011.1290106400847
epoch: 3; loss976.7880894318223
epoch: 4; loss954.6542589962482
epoch: 5; loss939.7806394323707
epoch: 6; loss925.6065278202295
epoch: 7; loss915.5676494017243
epoch: 8; loss906.2968903928995
epoch: 9; loss900.0954683721066
epoch: 10; loss892.9173497408628
epoch: 11; loss887.5273577496409
epoch: 12; loss881.0368397347629
epoch: 13; loss877.2217765562236
epoch: 14; loss872.8955892696977
epoch: 15; loss869.4887975566089
epoch: 16; loss866.014687821269
epoch: 17; loss863.2271680906415
epoch: 18; loss861.9286185540259
epoch: 19; loss858.624628201127
epoch: 20; loss854.8929985538125
epoch: 21; loss854.0605385005474
epoch: 22; loss851.3097783252597
epoch: 23; loss850.2689578793943
epoch: 24; loss846.8791446890682
epoch: 25; loss844.2525178343058
epoch: 26; loss843.1649576984346
epoch: 27; loss843.2444987073541
epoch: 28; loss840.7092669159174
epoch: 29; loss839.7281689345837
epoch: 30; loss838.1

In [10]:
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
writer.add_scalar('test_loss', test_loss, epoch)
writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)

0.9283


Get Spike Counts

In [None]:
#test the accuracy
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
    for x, label in test_data_loader:
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(timesteps):
            encoded_img = encoder(x)
            out_fr += model(encoded_img)
        out_fr = out_fr / timesteps
        loss = F.mse_loss(out_fr, label_onehot)

        test_samples += label.numel()
        test_loss += loss.item() * label.numel()
        test_acc += (out_fr.argmax(1) == label).float().sum().item()
        functional.reset_net(model)
test_time = time.time()
# test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
writer.add_scalar('test_loss', test_loss, epoch)
writer.add_scalar('test_acc', test_acc, epoch)
print(test_acc)