In [1]:
!nvidia-smi

Tue Feb 27 17:38:33 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               On  | 00000000:08:00.0 Off |                  Off |
| 30%   33C    P8              32W / 300W |      2MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [1]:
import torch

import norse.torch as norse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#torch.cuda.set_per_process_memory_fraction(0.8, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple



In [2]:
torch.__version__

'2.2.1+cu121'

In [3]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        #tensor = np.packbits(tensor, axis=0) pytorch does not have an unpack feature.
        return tensor.astype(np.uint8)

In [4]:
sample_T = 256
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T
batch_size = 256

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("./data", train=True, transform=transform)

In [5]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//8,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)


In [6]:

def shuffle(dataset, batch_size):
    x, y = dataset

    full_batches = y.shape[0] // batch_size

    indices = torch.randperm(y.shape[0])[:full_batches*batch_size]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return obs, labels
    

In [7]:
# Define Network
class Net(torch.nn.Module):
    def __init__(self, net_width, disable):
        super().__init__()
        num_hidden = net_width

        p1 = norse.LIFBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
            v_th=torch.tensor([1.0], device="cuda"),
            v_reset=torch.tensor([0.0], device="cuda"),
            alpha=torch.tensor([100.0], device="cuda"),
        )

        p2 = norse.LIFBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
            v_th=torch.tensor([1.0], device="cuda"),
            v_reset=torch.tensor([0.0], device="cuda"),
            alpha=torch.tensor([100.0], device="cuda"),
        )

        p3 = norse.LIBoxParameters(
            tau_mem_inv=torch.tensor([100.0], device="cuda"),
            v_leak=torch.tensor([0.0], device="cuda"),
        )

        self.model = torch.compile(norse.SequentialState(
            torch.nn.Linear(128, num_hidden, bias=False),
            norse.LIFBoxCell(p1),
            torch.nn.Linear(num_hidden, num_hidden, bias=False),
            norse.LIFBoxCell(p2),
            torch.nn.Linear(num_hidden, 20, bias=False),
            norse.LIBoxCell(p3)
        ).to(device), disable=disable)

            
    def forward(self, x):

        x = x.float() # [batch, time, channel]
        
        x = x.permute(1,0,2) # [time, batch, channel]
        # Initialize hidden states at t=0

        # time, batch, classes
        V = []
        
        state = None
        for t in range(x.shape[0]):
            out, state = self.model(x[t], state)
            V.append(out)
        return torch.stack(V, axis=0)
        
# Load the network onto CUDA if available
#precompiled_net = Net().to(device)
#net = torch.compile(precompiled_net, fullgraph=True)

In [8]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [9]:
def benchmark(net, loss, num_epochs, batch_size):
    
    acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
    
    loss_hist = []    
    # Outer training loop
    start = time()
    for epoch in range(num_epochs):
        print(epoch)
        train_batch = shuffle((x_train, y_train), batch_size)
        train_data, targets = train_batch
        
        
        # Minibatch training loop
        for data, targets in zip(train_data, targets):
    
            # forward pass
            net.train()
            out_V = net(data)
            # initialize the loss & sum over time
            loss_val = loss(torch.sum(out_V, axis=0), targets.to(torch.int64))
            #print(acc(out_V, targets))
    
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
    
        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
    
    end = time() - start
    #print(end)
    
    return end


In [10]:
x_train.dtype

torch.uint8

In [11]:
from time import time

def run_bench(trials, num_epochs, net_width, batch_size, disable=True):
    
    loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
    net = Net(net_width, disable)

    loss = torch.compile(loss, fullgraph=True, mode="reduce-overhead")

    print("starting warmup")
    comp_start = time()
    benchmark(net, loss, 1, batch_size)
    
    print("Warmup compilation finished:", time() - comp_start)
    
    times = []
    for t in range(trials):
        times.append(benchmark(net, loss, num_epochs, batch_size))
        print(t, ":", times[t])
    
    print("Mean:", np.mean(times), "Std. Dev.:", np.std(times))
    

In [12]:
run_bench(1, 10, 128, 256, True)

starting warmup
0
Warmup compilation finished: 8.765508890151978
0
1
2
3
4
5
6
7
8
9
0 : 19.065460443496704
Mean: 19.065460443496704 Std. Dev.: 0.0


In [13]:
run_bench(1, 10, 128, 256, False)

starting warmup
0
Warmup compilation finished: 17.967191696166992
0
1
2
3
4
5
6
7
8
9
0 : 30.830129146575928
Mean: 30.830129146575928 Std. Dev.: 0.0


In [None]:


# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net.eval()
        # Test set forward pass
        out_V = net(test_data)
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=-2), test_targets) )
    
    test_acc = np.mean(batch_acc)


### Using NIR to load a network from Spyx in Norse:

The following code is boiler plate until snnTorch merges support for importing from NIR.

In [None]:
G = nir.read("./spyx_shd.nir")

In [None]:
net2 = from_nir(G).to(device)

In [None]:

def forward_pass(network, data):
  v_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  transposed_data = torch.permute(data, (1,0,2))

  for step in transposed_data:  # data.size(0) = number of time steps
      spk_out, v = network(step)
      v_rec.append(v)
  
  return torch.stack(v_rec)

In [None]:
# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net2.zero_grad()
        # Test set forward pass
        out_V = forward_pass(net2, test_data.to(torch.float32))
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=0), test_targets) )
    
    test_acc = np.mean(batch_acc)

test_acc

As we can see, it gets about the same accuracy as it did in Spyx.

In [None]:
batch_acc