In [1]:
import torch
import snntorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.85, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [2]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=64)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)
test_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=False)

Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ./tmp/data/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting ./tmp/data/NMNIST/train.zip to ./tmp/data/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./tmp/data/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ./tmp/data/NMNIST/test.zip to ./tmp/data/NMNIST


In [3]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [None]:
def shuffle(dataset):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

In [None]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)
x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
x_test, y_test = x_test.to(device), y_test.to(device)
x_test, y_test = shuffle((x_test, y_test))

In [None]:
#  Initialize Network
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    snn.Leaky(beta=torch.ones(10)*0.5, learn_beta=True, init_hidden=True),
                    nn.MaxPool2d(2),
                    nn.Conv2d(12, 32, 5),
                    snn.Leaky(beta=torch.ones(10)*0.5, learn_beta=True, init_hidden=True),
                    nn.MaxPool2d(2),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=torch.ones(10)*0.5, learn_beta=True, init_hidden=True, output=True)
                    ).to(device)

# this time, we won't return membrane as we don't need it 

def forward_pass(net, data):
  print(data.shape)
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)
  
  return torch.stack(spk_rec)

In [None]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [None]:
num_epochs = 300
loss_hist = []
counter = 0


# Outer training loop
for epoch in range(num_epochs):    
    
    train_batch = shuffle((x_train, y_train))
    train_data, targets = train_batch
    
    
    # Minibatch training loop
    for data, targets in zip(train_data, targets):

        # forward pass
        net.train()
        out_V = forward(net, data)
        # initialize the loss & sum over time
        loss_val = loss(torch.sum(out_V, axis=-2), targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

    # Store loss history for future plotting
    loss_hist.append(loss_val.item())


# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net.eval()
        # Test set forward pass
        out_V = net(test_data)
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=-2), test_targets) )
    
    test_acc = np.mean(batch_acc)


In [None]:
loss_hist

[9.558403968811035,
 6.229371547698975,
 3.955392360687256,
 3.6000406742095947,
 3.3873300552368164,
 3.3082501888275146,
 3.3575539588928223,
 3.0741758346557617,
 2.8566598892211914,
 2.754401206970215,
 2.689972400665283,
 2.703676223754883,
 2.647325277328491,
 2.479779005050659,
 2.46976375579834,
 2.428328275680542,
 2.4558167457580566,
 2.373805046081543,
 2.296424388885498,
 2.297205924987793,
 2.298617362976074,
 2.3695068359375,
 2.320899486541748,
 2.279393196105957,
 2.2895267009735107,
 2.2929110527038574,
 2.1684398651123047,
 2.2434680461883545,
 2.18837833404541,
 2.2598612308502197,
 2.170142650604248,
 2.178297758102417,
 2.178230047225952,
 2.1480274200439453,
 2.097745656967163,
 2.114993095397949,
 2.159008741378784,
 2.1527669429779053,
 2.1277217864990234,
 2.1327974796295166,
 2.1317038536071777,
 2.1158857345581055,
 2.125640869140625,
 2.0982773303985596,
 2.083098888397217,
 2.0460376739501953,
 2.093351364135742,
 2.0999112129211426,
 2.061932325363159,
 2.

In [None]:
test_acc

0.73095703125