In [39]:
!pip install -e ./snntorch

Obtaining file:///root/snntorch
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Installing collected packages: snntorch
  Attempting uninstall: snntorch
    Found existing installation: snntorch 0.7.0
    Uninstalling snntorch-0.7.0:
      Successfully uninstalled snntorch-0.7.0
  Running setup.py develop for snntorch
Successfully installed snntorch
[0m

In [40]:
import torch
from torch import nn
import snntorch as snn
from snntorch import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.85, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [3]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=64),
                                      lambda x : np.array(x, dtype=np.uint8)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)
test_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=False)

In [4]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//2,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=True))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [92]:

def shuffle(dataset, batch_size):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)
    

In [93]:
#test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
#                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
#        
#x_test, y_test = next(test_dl)
#x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
#x_test, y_test = x_test.to(device), y_test.to(device)
#x_test, y_test = shuffle((x_test, y_test))

In [94]:
#  Initialize Network

class Net(torch.nn.Module):
    def __init__(self, channel_multiplier):
        super().__init__()

        mult = channel_multiplier
        self._net = nn.Sequential(
                    nn.Conv2d(2, 12*mult, 5, bias=False),
                    snn.Leaky(beta=torch.ones(30*mult)*0.5, learn_beta=True, init_hidden=True),
                    nn.MaxPool2d(2),
            
                    nn.Conv2d(12*mult, 32*mult, 5, bias=False),
                    snn.Leaky(beta=torch.ones(11*mult)*0.5, learn_beta=True, init_hidden=True),
                    nn.MaxPool2d(2),
            
                    nn.Flatten(),
                    nn.Linear(32*mult*5*5, 10, bias=False),
                    snn.Leaky(beta=torch.ones(10)*0.5, learn_beta=True, init_hidden=True, output=True)
                    ).to(device)

    def forward(self, data):
      data = data.permute(1,0,2,3,4).to(torch.float32)
      spk_rec = []
      utils.reset(self._net)  # resets hidden states for all LIF neurons in net
    
      for step in range(data.size(0)):  # data.size(0) = number of time steps
          spk_out, mem_out = self._net(data[step])
          spk_rec.append(spk_out)
      
      return torch.stack(spk_rec)
    
#net = torch.compile(_net, fullgraph=True, mode="reduce-overhead")

In [95]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [99]:
def benchmark(net, loss, num_epochs, batch_size):
    
    acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

    
    loss_hist = []    
    # Outer training loop
    start = time()
    for epoch in range(num_epochs):    
        print(epoch)
        train_batch = shuffle((x_train, y_train), batch_size)
        train_data, targets = train_batch
        
        
        # Minibatch training loop
        for data, targets in zip(train_data, targets):
    
            # forward pass
            net.train()
            out_V = net(data)
            # initialize the loss & sum over time
            loss_val = loss(torch.sum(out_V, axis=0), targets.to(torch.int64))
    
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
    
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
    
    end = time() - start
    #print(end)
    
    return end


In [97]:
from time import time

def run_bench(trials, num_epochs, batch_size, mult, comp=False):
    
    loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
    net = Net(mult).to(device)

    if comp:
        net = torch.compile(net, fullgraph=True, mode="reduce-overhead")
        loss = torch.compile(loss, fullgraph=True, mode="reduce-overhead")

    print("starting warmup")
    comp_start = time()
    benchmark(net, loss, 1, batch_size)
    
    print("Warmup compilation finished:", time() - comp_start)
    
    times = []
    for t in range(trials):
        times.append(benchmark(net, loss, num_epochs, batch_size))
        print(t, ":", times[t])
    
    print("Mean:", np.mean(times), "Std. Dev.:", np.std(times))
    

In [98]:
run_bench(3, 20, 31)

starting warmup
0
Warmup compilation finished: 43.07494878768921
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0 : 860.8501217365265
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1 : 860.0979428291321
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2 : 861.8277435302734
Mean: 860.9252693653107 Std. Dev.: 0.7081845291547473
