## CRI MNIST Demonstration with Spikingjelly

## Training SNN with sikingjelly

In [1]:
# imports
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
import shutil
from quant_layer import *
from cri_converter import BN_Folder, Quantize_Network, CRI_Converter
from torchsummary import summary

### Import MNIST datasets

In [2]:
# dataloader arguments
batch_size = 1
data_path='~/justinData/mnist'
out_dir = 'runs/transformers'
epochs = 25
start_epoch = 0
lr = 0.1
momentum = 0.9
T = 4
channels = 8
max_test_acc = -1

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Prepare the dataset
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transforms.ToTensor())

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

### Define the network

In [4]:
class CSNN(nn.Module):
    def __init__(self, T: int, channels: int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.AvgPool2d(2, 2),  # 14 * 14

        layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.AvgPool2d(2, 2),  # 7 * 7

        layer.Flatten(),
        layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),

        layer.Linear(channels * 4 * 4, 10, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        )
        
        functional.set_step_mode(self, step_mode='m')
        
    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        # self.output_shape(x)
        x_seq = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr
    
    def spiking_encoder(self):
        return self.conv_fc[0:3]
    
    def output_shape(self,x):
        for layer in self.conv_fc:
            x = layer(x)
            print(x.shape)
            

In [5]:
net = CSNN(T = T, channels = channels, use_cupy=False)

In [6]:
print(net)

CSNN(
  (conv_fc): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0, step_mode=m)
    (4): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): AvgPool2d(kernel_size=2, stride=2, padding=0, step_mode=m)
    (8): Flatten(start_dim=1, end_dim=-1, step_mode=m)
    (9): Linear(in_features=392, out_features=128, b

In [7]:
# net.to(device)

In [8]:
# optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

In [9]:
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

In [10]:
writer = SummaryWriter(out_dir)

### Training

In [11]:
# for epoch in range(start_epoch, epochs):
#     start_time = time.time()
#     net.train()
#     train_loss = 0
#     train_acc = 0
#     train_samples = 0
#     for img, label in train_loader:
#         optimizer.zero_grad()
#         img = img.to(device)
#         label = label.to(device)
#         label_onehot = F.one_hot(label, 10).float()
#         out_fr = net(img)
#         loss = F.mse_loss(out_fr, label_onehot)
#         loss.backward()
#         optimizer.step()
        
#         train_samples += label.numel()
#         train_loss += loss.item() * label.numel()
#         train_acc += (out_fr.argmax(1) == label).float().sum().item()
        
#         functional.reset_net(net)
        
#     train_time = time.time()
#     train_speed = train_samples / (train_time - start_time)
#     train_loss /= train_samples
#     train_acc /= train_samples

#     writer.add_scalar('train_loss', train_loss, epoch)
#     writer.add_scalar('train_acc', train_acc, epoch)
#     lr_scheduler.step()
    
#     net.eval()
#     test_loss = 0
#     test_acc = 0
#     test_samples = 0

#     with torch.no_grad():
#         for img, label in test_loader:
#             img = img.to(device)
#             label = label.to(device)
#             label_onehot = F.one_hot(label, 10).float()
#             out_fr = net(img)
#             loss = F.mse_loss(out_fr, label_onehot)

#             test_samples += label.numel()
#             test_loss += loss.item() * label.numel()
#             test_acc += (out_fr.argmax(1) == label).float().sum().item()
#             functional.reset_net(net)
#         test_time = time.time()
#         test_speed = test_samples / (test_time - train_time)
#         test_loss /= test_samples
#         test_acc /= test_samples
#         writer.add_scalar('test_loss', test_loss, epoch)
#         writer.add_scalar('test_acc', test_acc, epoch)
    
#     save_max = False
#     if test_acc > max_test_acc:
#         max_test_acc = test_acc
#         save_max = True
            
#     checkpoint = {
#         'net': net.state_dict(),
#         'optimizer': optimizer.state_dict(),
#         'lr_scheduler': lr_scheduler.state_dict(),
#         'epoch': epoch,
#         'max_test_acc': max_test_acc
#     }

#     if save_max:
#         torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

#     torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))
    
#     print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
#     print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
#     print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')


    

### Save Checkpoint for Model

In [12]:
def save_checkpoint(state, is_quan, fdir, num_layer):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_quan:
        shutil.copyfile(filepath, os.path.join(fdir, f'model_spikingjelly_quan_{num_layer}.pth.tar'))
    else:
        shutil.copyfile(filepath, os.path.join(fdir, f'model_spikingjelly_{num_layer}.pth.tar'))

In [13]:
if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'
if not os.path.exists(fdir):
    os.makedirs(fdir)

In [14]:
# save_checkpoint({'state_dict': net.state_dict(),}, 0, fdir, len(net.state_dict()))

In [15]:
# n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
# print(f"number of params: {n_parameters}")

### Validate

In [16]:
def validate(net, test_loader, device):
    start_time = time.time()
    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    
    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            label_onehot = F.one_hot(label, 10).float()
            out_fr = net(img)
            print(out_fr.shape)
            return 
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - start_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss)
        writer.add_scalar('test_acc', test_acc)
    
    print(f'test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
    print(f'test speed ={test_speed: .4f} images/s')


In [17]:
validate(net, test_loader, device)

torch.Size([1, 10])


In [18]:
n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)

In [19]:
net.conv_fc[0].weight.shape

torch.Size([8, 1, 3, 3])

In [20]:
for p in net.parameters():
    if p.requires_grad:
        print(p.numel())

72
8
8
576
8
8
50176
1280


### Load Saved Model

In [21]:
best_model_path = '/home/keli/code/CRI_Mapping/result/model_spikingjelly_14_s.pth.tar'
checkpoint = torch.load(best_model_path, map_location=device)
net_1 = CSNN(T = T, channels = channels, use_cupy=False)
net_1.load_state_dict(checkpoint['state_dict'])
net_1.eval()
net_1.to(device)


CSNN(
  (conv_fc): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0, step_mode=m)
    (4): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): AvgPool2d(kernel_size=2, stride=2, padding=0, step_mode=m)
    (8): Flatten(start_dim=1, end_dim=-1, step_mode=m)
    (9): Linear(in_features=392, out_features=128, b

In [22]:
validate(net_1, test_loader, device)

torch.Size([1, 10])


### Quantization

In [23]:
bn = BN_Folder()  #Fold the BN layer 
net_bn = bn.fold(net_1.eval())

In [24]:
# summary(net_bn)

In [25]:
validate(net_bn, test_loader, device)

torch.Size([1, 10])


In [26]:
quan_fun = Quantize_Network(dynamic_alpha = False) # weight_quantization
net_quan = quan_fun.quantize(net_bn)

Quantized:  conv_fc
Quantized:  0
Quantized:  1
Quantized:  2
Quantized:  3
Quantized:  4
Quantized:  5
Quantized:  6
Quantized:  7
Quantized:  8
Quantized:  9
Quantized:  10
Quantized:  11
Quantized:  12
Quantization time: 0.0020759105682373047
Quantization time: 0.0036220550537109375


In [27]:
# summary(net_quan)

In [28]:
validate(net_quan, test_loader, device)

torch.Size([1, 10])


In [29]:
cri_convert = CRI_Converter(4, 0, 11, np.array((1, 28, 28)),'spikingjelly') # num_steps, input_layer, output_layer, input_size
cri_convert.layer_converter(net_quan)

Number of layers in net:  1
Number of layers in net:  13
Constructing Axons from Conv2d Layer
Input layer shape(infeature, outfeature): [ 1 28 28] [ 8 28 28]


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.59it/s]


Constructing 8 bias axons for input layer.
Numer of neurons: 0, number of axons: 792
Converting Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=m) takes 0.0696554183959961
Constructing hidden avgpool layer
Hidden layer shape(infeature, outfeature): (8, 28, 28) [ 8 14 14]
Neuron_offset: 6272
Last output: 7839


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 708.72it/s]


Numer of neurons: 6272, number of axons: 792
Constructing Neurons from Conv2d Layer
Hidden layer shape(infeature, outfeature): (8, 14, 14) [ 8 14 14]
Neuron_offset: 7840
Last output: ['9394' '9395' '9396' '9397' '9398' '9399' '9400' '9401' '9402' '9403'
 '9404' '9405' '9406' '9407']


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 77.28it/s]


Constructing 8 bias axons for input layer.
Numer of neurons: 7840, number of axons: 800
Converting Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=m) takes 0.10711503028869629
Constructing hidden avgpool layer
Hidden layer shape(infeature, outfeature): (8, 14, 14) [8 7 7]
Neuron_offset: 9408
Last output: 9799


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2834.23it/s]

Numer of neurons: 9408, number of axons: 800
Constructing Neurons from Linear Layer
Hidden layer shape(infeature, outfeature):  392 128
Neuron_offset: 9800
Last output: 9927
curr_neuron_offset, next_neuron_offset: (9408, 9800)





Numer of neurons: 9800, number of axons: 800
Constructing Neurons from Linear Layer
Hidden layer shape(infeature, outfeature):  128 10
Neuron_offset: 9928
Last output: 9937
curr_neuron_offset, next_neuron_offset: (9800, 9928)
Instantiate output neurons
Numer of neurons: 9938, number of axons: 800


In [30]:
cri_convert.output_neurons

['9928',
 '9929',
 '9930',
 '9931',
 '9932',
 '9933',
 '9934',
 '9935',
 '9936',
 '9937']

In [31]:
# cri_convert.axon_dict

In [32]:
len(cri_convert.neuron_dict)

9938

In [33]:
print(len(cri_convert.neuron_dict['6272']))

8
