In [1]:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from spikingjelly.activation_based import encoding, monitor
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
from spikingjelly.activation_based import lava_exchange

import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
import numpy as np

device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


Define the model

In [2]:
class SCNN(nn.Module):
    def __init__(self, T: int):
        super().__init__()
        self.T = T
        self.encoder = encoding.PoissonEncoder()
        self.conv_and_fc = nn.Sequential(
            layer.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2, bias=False),
            layer.BatchNorm2d(6),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Conv2d(in_channels=6, out_channels=12, kernel_size=5, bias=False),
            layer.BatchNorm2d(12),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(kernel_size=2, stride=2),

            layer.Flatten(),
            layer.Linear(5*5*12, 120, bias=False),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.Linear(120, 84),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            layer.Linear(84, 10),
            neuron.IFNode(surrogate_function=surrogate.ATan())
            )
        functional.set_step_mode(self, step_mode='m')
    def forward(self, x):
        original_shape = x.shape
        x_rep = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) #probs between 0 and 1
        # print(torch.max(x_rep))
        # print(torch.mean(x_rep))
        # print(x_rep.shape)
        x_seq = self.encoder(x_rep) #use probs to generate random 0's and 1's
        # print(x_seq.shape)
        z = self.conv_and_fc(x_seq)
        fr = z.mean(0)
        return fr
    # def encode_img(self, img):
    #     return self.encoder(img)


Set Variables

In [3]:
# tau = 2.0
timesteps = 10
model = SCNN(T=timesteps).to(device=device)
EPOCHS=50 #set to 50 epochs bc of diminishing returns
AMP=True #automatic mixed precision training
lr= 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
out_dir = "./outputs/CSNN"
encoder = encoding.PoissonEncoder()
batch_size=64
num_workers = 10


Download the sets

In [4]:
root = './FMNIST'
train_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

# train_loader = torch.utils.data.DataLoader(
#     dataset=train_set,
#     batch_size=batch_size,
#     shuffle=True
# )

test_set = torchvision.datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# test_loader = torch.utils.data.DataLoader(
#     dataset=test_set,
#     batch_size=batch_size,
#     shuffle=False
# )

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers,
    pin_memory=True
)

Train the CSNN

In [5]:
scaler = None
if AMP:
    scaler = amp.GradScaler()
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')

writer = SummaryWriter(out_dir, purge_step=0)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))
file_dir = './Models/'
if not os.path.exists(file_dir):
    os.makedirs(file_dir)
    print(f'Mkdir {file_dir}.')
full_path = file_dir + '/CSNN.pt'

In [103]:
functional.reset_net(model)
# check = input('running this will remove the old saved model')
best_loss = 10000000.0
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for x, label in train_loader:
        optimizer.zero_grad()
        x = x.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is None:
            out_fr = 0.
            # for t in range(timesteps):
            # encoded_img = encoder(x)
            #     out_fr += model(encoded_img)
            # out_fr = out_fr / timesteps
            out_fr = model(x)
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()
        else:
            with amp.autocast():
                # encoded_img = encoder(x)
                #encoding inside the model instead
                out_fr = model(x)
                # out_fr = out_fr / timesteps
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        # print(out_fr.shape)
        # print(label.shape)
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(model) #need to reset the snn before reuse
    if train_loss < best_loss:
        torch.save(model.state_dict(), f=full_path)
        best_loss = train_loss
        print('new best model saved')
    print('epoch: ' + str(epoch) + '; loss' + str(train_loss))
    

new best model saved
epoch: 0; loss1976.0430250763893
new best model saved
epoch: 1; loss1326.1370067596436
new best model saved
epoch: 2; loss1213.9310052394867
new best model saved
epoch: 3; loss1153.8910070955753
new best model saved
epoch: 4; loss1103.1390072703362
new best model saved
epoch: 5; loss1079.2440074384212
new best model saved
epoch: 6; loss1058.4520105421543
new best model saved
epoch: 7; loss1031.2150076329708
new best model saved
epoch: 8; loss1013.7670096158981
new best model saved
epoch: 9; loss1000.6220080852509
new best model saved
epoch: 10; loss996.7340102791786
new best model saved
epoch: 11; loss978.2880076169968
new best model saved
epoch: 12; loss961.4770079702139
new best model saved
epoch: 13; loss952.4600071907043
new best model saved
epoch: 14; loss944.6530073583126
new best model saved
epoch: 15; loss938.0820101946592
new best model saved
epoch: 16; loss933.2530103325844
new best model saved
epoch: 17; loss915.2670090794563
new best model saved
epoch: 

Load the best model

In [7]:
file_dir = './Models/'
full_path = file_dir + '/CSNN.pt'
checkpoint = torch.load(f=full_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

Test the Accuracy

In [8]:
start_time = time.time()
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
for x, label in test_loader:
    optimizer.zero_grad()
    x = x.to(device)
    label = label.to(device)
    label_onehot = F.one_hot(label, 10).float()
    out_fr = model(x)
    loss = F.mse_loss(out_fr, label_onehot)
    
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (out_fr.argmax(1) == label).float().sum().item()

    functional.reset_net(model) #need to reset the snn before reuse
test_acc = test_acc/test_samples
print('acc: ' + str(test_acc) + '; loss: ' + str(test_loss))

acc: 0.8885; loss: 166.0440024882555


Count the number of Spikes for power

In [9]:
spike_monitor = monitor.OutputMonitor(model, neuron.LIFNode)

start_time = time.time()
model.eval()
test_loss = 0
test_acc = 0
test_samples = 0
for x, label in test_loader:
    optimizer.zero_grad()
    x = x.to(device)
    label = label.to(device)
    label_onehot = F.one_hot(label, 10).float()
    out_fr = model(x)
    loss = F.mse_loss(out_fr, label_onehot)
    
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (out_fr.argmax(1) == label).float().sum().item()

    functional.reset_net(model) #need to reset the snn before reuse
test_acc = test_acc/test_samples
print('acc: ' + str(test_acc) + '; loss: ' + str(test_loss))
print(model)
print(test_samples)
total_spikes = 0
for tens in spike_monitor.records:
    tensnp = tens.cpu().numpy()
    total_spikes += np.sum(tensnp) #outputs are just 0's and 1's. summing up will get the total number of spikes of all neurons
total_steps = timesteps * test_samples
print('total spikes: ' + str(total_spikes))
print('timesteps taken: ' + str(total_steps))
# print(f'spike_seq_monitor.records=\n{len(spike_monitor.records)}')

acc: 0.8881; loss: 167.79400219768286
SCNN(
  (encoder): PoissonEncoder()
  (conv_and_fc): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, step_mode=m)
    (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (4): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, c

Calculate Power

In [107]:
neurons = 26298
Pi = 0.25
Pb = 4/1000
Pn = 0.0234 #watts (J per second)
Ps = 11.3 * (10**-9)

power = Pi + Pb + (neurons * Pn) + (total_spikes * Ps)
print('total power (watts): ' + str(power))

total power (watts): 616.2277274147


In [108]:
print(spike_monitor.records[0].shape)
print(len(spike_monitor.records))
print(f'spike_seq_monitor.monitored_layers={spike_monitor.monitored_layers}')
for p in model.parameters():
    print(p.shape)
    print(len(p))

torch.Size([10, 64, 6, 28, 28])
471
spike_seq_monitor.monitored_layers=['conv_and_fc.2', 'conv_and_fc.6', 'conv_and_fc.10']
torch.Size([6, 1, 5, 5])
6
torch.Size([6])
6
torch.Size([6])
6
torch.Size([12, 6, 5, 5])
12
torch.Size([12])
12
torch.Size([12])
12
torch.Size([120, 300])
120
torch.Size([84, 120])
84
torch.Size([84])
84
torch.Size([10, 84])
10
torch.Size([10])
10
