In [1]:
import torch
from torch import nn

import norse.torch as snn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.85, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



2024-02-25 09:41:07.503370: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-25 09:41:07.530486: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-25 09:41:07.648088: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=64),
                                      lambda x : np.array(x, dtype=np.uint8)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)

In [3]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//20,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=True))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [4]:

def shuffle(dataset, batch_size):
    x, y = dataset

    full_batches = y.shape[0] // batch_size

    indices = torch.randperm(y.shape[0])[:full_batches*batch_size]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)
    

In [26]:
#  Initialize Network


class Net(torch.nn.Module):
    def __init__(self, channel_multiplier):
        super().__init__()

        mult = channel_multiplier

        p1 = norse.LIFBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
            v_th=torch.tensor([1.0], device="cuda"),
            v_reset=torch.tensor([0.0], device="cuda"),
            alpha=torch.tensor([100.0], device="cuda"),
        )

        p2 = norse.LIFBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
            v_th=torch.tensor([1.0], device="cuda"),
            v_reset=torch.tensor([0.0], device="cuda"),
            alpha=torch.tensor([100.0], device="cuda"),
        )

        p3 = norse.LIBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
        )

        self.model = snn.SequentialState(
            nn.Conv2d(2, 12*mult, 5, bias=False),
            snn.LIFBoxCell(p1),
            nn.MaxPool2d(2),
            nn.Conv2d(12*mult, 32*mult, 5, bias=False),
            snn.LIFBoxCell(p2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            snn.LIBoxCell(p3),
        )




    def forward(self, data):
        x = data.permute(1,0,2,3,4).to(torch.float32) #make time leading axis
        
        # time, batch, classes
        V = torch.zeros((x.shape[0], x.shape[1], 10), device=x.device)

        state = None
        for t in range(x.shape[0]):
            out, state = self.model(x, state)
            V[t] = out
        return V
    
#net = torch.compile(_net, fullgraph=True, mode="reduce-overhead")

In [27]:
net = torch.compile(Net(1))

In [28]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [29]:
def benchmark(net, loss, num_epochs, batch_size):
    
    acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

    
    loss_hist = []    
    # Outer training loop
    start = time()
    for epoch in range(num_epochs):    
        print(epoch)
        train_batch = shuffle((x_train, y_train), batch_size)
        train_data, targets = train_batch
        
        
        # Minibatch training loop
        for data, targets in zip(train_data, targets):
    
            # forward pass
            net.train()
            out_V = net(data)
            print(out_V)
            # initialize the loss & sum over time
            loss_val = loss(torch.sum(out_V, axis=0), targets.to(torch.int64))
    
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
    
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
    
    end = time() - start
    #print(end)
    
    return end


In [30]:
from time import time

def run_bench(trials, num_epochs, batch_size, mult, comp=False):
    
    loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
    net = Net(mult).to(device)

    if comp:
        net = torch.compile(net, fullgraph=True, mode="reduce-overhead")
        loss = torch.compile(loss, fullgraph=True, mode="reduce-overhead")

    print("starting warmup")
    comp_start = time()
    benchmark(net, loss, 1, batch_size)
    
    print("Warmup compilation finished:", time() - comp_start)
    
    times = []
    for t in range(trials):
        times.append(benchmark(net, loss, num_epochs, batch_size))
        print(t, ":", times[t])
    
    print("Mean:", np.mean(times), "Std. Dev.:", np.std(times))
    

In [31]:
run_bench(3, 20, 32, 1, True)

starting warmup
0


TorchRuntimeError: Failed running call_function <built-in method full of type object at 0x7f35ce2ce840>(*((32, 12, 30, 30), FakeTensor(..., device='cuda:0', size=(1,))), **{'device': device(type='cuda', index=0), 'dtype': torch.float32}):
full() received an invalid combination of arguments - got (tuple, FakeTensor, dtype=torch.dtype, device=torch.device), but expected one of:
 * (tuple of ints size, Number fill_value, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, Number fill_value, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


from user code:
   File "/tmp/ipykernel_6642/3964507059.py", line 43, in forward
    return self.model(x[0])
  File "/home/legion/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/legion/.local/lib/python3.10/site-packages/norse/torch/module/sequential.py", line 108, in forward
    input_tensor, s = module(input_tensor, state[index])
  File "/home/legion/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/legion/.local/lib/python3.10/site-packages/norse/torch/module/snn.py", line 84, in forward
    state = state if state is not None else self.state_fallback(input_tensor)
  File "/home/legion/.local/lib/python3.10/site-packages/norse/torch/module/lif_box.py", line 46, in initial_state
    v=torch.full(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
