In [18]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from spikingjelly.activation_based import encoding, monitor
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
import numpy as np

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define the model

In [24]:
class SCNN(nn.Module):
    def __init__(self, T: int):
        super().__init__()
        self.T = T
        self.conv_and_fc = nn.Sequential(
            layer.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2, bias=False),
            layer.BatchNorm2d(6),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Conv2d(in_channels=6, out_channels=12, kernel_size=5, bias=False),
            layer.BatchNorm2d(12),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Flatten(),
            layer.Linear(5*5*12, 10, bias=False),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            # layer.Linear(120, 84),
            # neuron.IFNode(surrogate_function=surrogate.LeakyKReLU()),
            # layer.Linear(84, 10),
            # neuron.IFNode(surrogate_function=surrogate.ATan())
            )
        functional.set_step_mode(self, step_mode='m')
    def forward(self, x):
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        z = self.conv_and_fc(x_seq)
        fr = z.mean(0)
        return fr

Set Variables

In [25]:
# tau = 2.0
timesteps = 40
model = SCNN(T=timesteps).to(device=device)
EPOCHS=100
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/CSNN"
encoder = encoding.PoissonEncoder()
batch_size=64
num_workers = 10


Download the sets

In [26]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

# train_loader = torch.utils.data.DataLoader(
#     dataset=train_set,
#     batch_size=batch_size,
#     shuffle=True
# )

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# test_loader = torch.utils.data.DataLoader(
#     dataset=test_set,
#     batch_size=batch_size,
#     shuffle=False
# )

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

Train the CSNN

In [27]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))
file_dir = './Models/'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
full_path = file_dir + '/CSNN.pt'

In [28]:
functional.reset_net(model)
check = input('running this will remove the old saved model')
best_loss = 10000000.0
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_loader:
        optimizer.zero_grad()
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            # for t in range(timesteps):
            #     encoded_img = encoder(x)
            #     out_fr += model(encoded_img)
            # out_fr = out_fr / timesteps
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():

                out_fr = model(x)
                # out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        # print(out_fr.shape)
        # print(label.shape)
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    if train_loss < best_loss:
        torch.save(model.state_dict(), f=full_path)
        best_loss = train_loss
        print('new best model saved')
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))

new best model saved
epoch: 0; loss1880.071486890316
new best model saved
epoch: 1; loss1328.5742701888084
new best model saved
epoch: 2; loss1220.2118311524391
new best model saved
epoch: 3; loss1158.671469926834
new best model saved
epoch: 4; loss1118.4579518437386
new best model saved
epoch: 5; loss1092.347270399332
new best model saved
epoch: 6; loss1067.9730680286884
new best model saved
epoch: 7; loss1045.5461457967758
new best model saved
epoch: 8; loss1033.0673846006393
new best model saved
epoch: 9; loss1016.8942696750164
new best model saved
epoch: 10; loss1005.9954246878624
new best model saved
epoch: 11; loss993.9650643467903
new best model saved
epoch: 12; loss981.9952650368214
new best model saved
epoch: 13; loss975.672704577446
new best model saved
epoch: 14; loss970.6636627316475
new best model saved
epoch: 15; loss958.3928230702877
new best model saved
epoch: 16; loss953.7237045019865
new best model saved
epoch: 17; loss947.1482636332512
new best model saved
epoch: 18;

Load the best model

In [6]:
file_dir = './Models/'
full_path = file_dir + '/CSNN.pt'
checkpoint = torch.load(f=full_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

Test the Accuracy

In [9]:
start_time = time.time()
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
for x, label in test_loader:
    optimizer.zero_grad()
    x = x.to(device)
    label = label.to(device)
    label_onehot = F.one_hot(label, 10).float()
    out_fr = model(x)
    loss = F.mse_loss(out_fr, label_onehot)
    
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (out_fr.argmax(1) == label).float().sum().item()

    functional.reset_net(model) #need to reset the snn before reuse
test_acc = test_acc/test_samples
print('acc: ' + str(test_acc) + '; loss: ' + str(test_loss))

acc: 0.6539; loss: 576.0729441046715


Count the number of Spikes for power

In [12]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)
start_time = time.time()
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
for x, label in test_loader:
    optimizer.zero_grad()
    x = x.to(device)
    label = label.to(device)
    label_onehot = F.one_hot(label, 10).float()
    out_fr = model(x)
    loss = F.mse_loss(out_fr, label_onehot)
    
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (out_fr.argmax(1) == label).float().sum().item()

    functional.reset_net(model) #need to reset the snn before reuse
test_acc = test_acc/test_samples
print('acc: ' + str(test_acc) + '; loss: ' + str(test_loss))
print(model)
total_spikes = 0
for tens in spike_monitor.records:
    tensnp = tens.cpu().numpy()
    total_spikes += np.sum(tensnp)
print('total spikes: ' + str(total_spikes))
print(f'spike_seq_monitor.records=\n{len(spike_monitor.records)}')

acc: 0.6539; loss: 576.0729424357414
SCNN(
  (conv_and_fc): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), step_mode=m)
    (1): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): LeakyKReLU()
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (3): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1), step_mode=m)
    (4): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): LeakyKReLU()
    )
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (6): Flatten(start_dim=1, end_dim=-1, step_mode=m)
    (7): Linear(in_features=300, out_features=10, bias=True)
    (8): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): LeakyKReLU()
    )
  )
)
total spikes: 