In [1]:
!pip install spikingjelly cupy-cuda12x

[0m

In [2]:
import torch
import spikingjelly

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.9, device=0)

import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



## Dataloading


In [3]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=64),
                                      lambda x : np.array(x, dtype=np.uint8)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)

In [4]:
batch_size = 32

def shuffle(dataset, batch_size):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

In [5]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//2,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

### SpikingJelly Setup and Training

In [26]:
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torch.nn as nn


class CSNN(nn.Module):
    def __init__(self, channel_multiplier: int=1, use_cupy=False):
        super().__init__()
        self.cupy = use_cupy
        mult = channel_multiplier

        self.conv_fc = nn.Sequential(
            layer.Conv2d(2, 12*mult, kernel_size=5, bias=False),
            neuron.ParametricLIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(2, 2),  
    
            layer.Conv2d(12*mult, 32*mult, kernel_size=5, bias=False),
            neuron.ParametricLIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(2, 2), 
    
            layer.Flatten(),
            layer.Linear(32*mult * 5 * 5, 10, bias=False),
            neuron.ParametricLIFNode(surrogate_function=surrogate.ATan()),
        )

        functional.set_step_mode(self, step_mode='m')
        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, T, C, H, W]
        x = torch.swapaxes(x.to(torch.float32), 0, 1)
        # [N, T, C, H, W] -> [T, N, C, H, W]
        if self.cupy:
            return self.conv_fc(x)
        return torch.stack([self.conv_fc(t) for t in x])
        

In [29]:
net = CSNN(2, True).to(device)


In [30]:
x_train.shape
# [# batches, batch_size, time_bins, channels, x, y]

torch.Size([30000, 64, 2, 34, 34])

In [31]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [32]:
def benchmark(net, loss, num_epochs, batch_size):
    
    acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

    
    loss_hist = []    
    # Outer training loop
    start = time()
    for epoch in range(num_epochs):    
        print(epoch)
        train_batch = shuffle((x_train, y_train), batch_size)
        train_data, targets = train_batch
        
        
        # Minibatch training loop
        for data, targets in zip(train_data, targets):
    
            # forward pass
            net.train()
            out_V = net(data)
            # initialize the loss & sum over time
            loss_val = loss(torch.sum(out_V, axis=0), targets.to(torch.int64))
    
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            functional.reset_net(net)
    
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
    
    end = time() - start
    
    return end


In [33]:
from time import time

def run_bench(trials, num_epochs, batch_size, mult, comp=False):
    
    loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
    net = CSNN(mult, comp).to(device)

    print("starting warmup")
    comp_start = time()
    benchmark(net, loss, 1, batch_size)
    
    print("Warmup compilation finished:", time() - comp_start)
    
    times = []
    for t in range(trials):
        times.append(benchmark(net, loss, num_epochs, batch_size))
        print(t, ":", times[t])
    
    print("Mean:", np.mean(times), "Std. Dev.:", np.std(times))
    

In [15]:
 run_bench(5, 20, 32, 1, True)

starting warmup
0
Warmup compilation finished: 6.32789158821106
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 128.57680249214172
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 129.72782492637634
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 129.63487672805786
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 129.54661536216736
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 129.5935082435608
Mean: 129.41592555046083 Std. Dev.: 0.42378385661630535


In [23]:
 run_bench(5, 20, 64, 1, True)

starting warmup
0
Warmup compilation finished: 6.210282564163208
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 126.09527707099915
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 126.6083436012268
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 126.5502233505249
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 126.54639339447021
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 126.54906249046326
Mean: 126.46985998153687 Std. Dev.: 0.18872135098070214


In [24]:
 run_bench(5, 20, 128, 1, True)

starting warmup
0
Warmup compilation finished: 6.253170967102051
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 125.16312313079834
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 125.25466871261597
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 125.26775765419006
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 125.27563834190369
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 125.2936954498291
Mean: 125.25097665786743 Std. Dev.: 0.04570901261347594


In [34]:
 run_bench(5, 20, 32, 2, True)

starting warmup
0
Warmup compilation finished: 11.568138360977173
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 233.6879382133484
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 234.07036876678467
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 234.03429627418518
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 233.99937176704407
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 233.9830777645111
Mean: 233.9550105571747 Std. Dev.: 0.13687736361568467


In [35]:
 run_bench(5, 20, 64, 2, True)

starting warmup
0
Warmup compilation finished: 11.489962577819824
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 229.7857584953308
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 229.79479908943176
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 229.68084836006165
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 229.6432912349701
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 229.59819674491882
Mean: 229.70057878494262 Std. Dev.: 0.07782838043685052


In [36]:
 run_bench(5, 20, 128, 2, True)

starting warmup
0
Warmup compilation finished: 11.382784128189087
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 227.60896158218384
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 227.61265659332275
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 227.71440172195435
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
3 : 227.63836693763733
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4 : 227.62950778007507
Mean: 227.64077892303467 Std. Dev.: 0.03835721956863768


In [14]:
net.cuda

<bound method Module.cuda of CSNN(
  (conv_fc): Sequential(
    (0): Conv2d(2, 12, kernel_size=(5, 5), stride=(1, 1), bias=False, step_mode=m)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (3): Conv2d(12, 32, kernel_size=(5, 5), stride=(1, 1), bias=False, step_mode=m)
    (4): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (6): Flatten(start_dim=1, end_dim=-1, step_mode=m)
    (7): Linear(in_features=800, out_features=10, bias=False)
    (8): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy, tau=2.0
      (