In [2]:
from tonic import transforms
import tonic
from torch.utils.data import DataLoader
import torch
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


def get_nmnist_dataset(timestep_size, sensor_width):
    h,w = sensor_width, sensor_width
    transform = transforms.Compose([
        transforms.Downsample(spatial_factor=sensor_width / 34.0),
        transforms.ToFrame(sensor_size=(h,w,2), time_window=timestep_size),
        torch.from_numpy,
    ])
    train_dataset = tonic.datasets.NMNIST(save_to="./n_mnist/data_small", transform=transform, train=True, first_saccade_only=True)
    test_dataset = tonic.datasets.NMNIST(save_to="./n_mnist/data_small", transform=transform, train=False, first_saccade_only=True)
    return train_dataset, test_dataset

train_dataset, test_dataset = get_nmnist_dataset(100000, 12)
trainloader = DataLoader(train_dataset, batch_size=128, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
testloader = DataLoader(test_dataset, batch_size=1024, collate_fn=tonic.collation.PadTensors(batch_first=False))

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
print(train_dataset[0][0].shape)

torch.Size([1, 2, 12, 12])


800


In [3]:
# Some of this comes from: https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_7.html

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 1.0 # No decay

#  Initialize Network
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(12*12*2, 64),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Linear(64, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True),
                    ).to(device)

def forward_pass(net, data):
    spk_rec = []
    mem_rec = []
    utils.reset(net)

    for step in range(data.size(0)):
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)

    return torch.stack(spk_rec), torch.stack(mem_rec)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()

num_epochs = 10
train_loss_hist = []
train_acc_hist = []
test_loss_hist = []
test_acc_hist = []

# training loop
for epoch in range(num_epochs):
    # Train the network
    net.train()
    for data, targets in tqdm(iter(trainloader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, mem_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        train_loss_hist.append(loss_val.item())
        acc = SF.accuracy_rate(spk_rec, targets)
        train_acc_hist.append(acc)

    # Test the network
    net.eval()
    test_losses_epoch = []
    test_accs_epoch = []
    batch_sizes = []
    for data, targets in tqdm(iter(testloader)):
        data = data.to(device)
        targets = targets.to(device)

        spk_rec, mem_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Store loss history for future plotting
        test_losses_epoch.append(loss_val.item())
        acc = SF.accuracy_rate(spk_rec, targets)
        test_accs_epoch.append(acc)
        batch_sizes.append(data.size(0))
    test_loss_hist.append(np.average(test_losses_epoch, weights=batch_sizes))
    test_acc_hist.append(np.average(test_accs_epoch, weights=batch_sizes))

    # Print training and testing statistics
    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {np.mean(train_loss_hist):.4f} | Train Acc: {np.mean(train_acc_hist):.4f} | Test Loss: {test_loss_hist[-1]:.4f} | Test Acc: {test_acc_hist[-1]:.4f}")

100%|██████████| 469/469 [00:34<00:00, 13.74it/s]
100%|██████████| 10/10 [00:05<00:00,  1.97it/s]


Epoch 1/10 | Train Loss: 1.6128 | Train Acc: 0.7738 | Test Loss: 1.5348 | Test Acc: 0.8664


100%|██████████| 469/469 [00:32<00:00, 14.32it/s]
100%|██████████| 10/10 [00:04<00:00,  2.10it/s]


Epoch 2/10 | Train Loss: 1.5722 | Train Acc: 0.8244 | Test Loss: 1.5268 | Test Acc: 0.8907


100%|██████████| 469/469 [00:34<00:00, 13.69it/s]
100%|██████████| 10/10 [00:07<00:00,  1.40it/s]


Epoch 3/10 | Train Loss: 1.5567 | Train Acc: 0.8454 | Test Loss: 1.5217 | Test Acc: 0.8874


100%|██████████| 469/469 [00:39<00:00, 12.01it/s]
100%|██████████| 10/10 [00:04<00:00,  2.44it/s]


Epoch 4/10 | Train Loss: 1.5476 | Train Acc: 0.8577 | Test Loss: 1.5215 | Test Acc: 0.8972


100%|██████████| 469/469 [00:57<00:00,  8.19it/s]
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]


Epoch 5/10 | Train Loss: 1.5418 | Train Acc: 0.8660 | Test Loss: 1.5197 | Test Acc: 0.9029


100%|██████████| 469/469 [00:45<00:00, 10.25it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch 6/10 | Train Loss: 1.5376 | Train Acc: 0.8720 | Test Loss: 1.5228 | Test Acc: 0.8963


100%|██████████| 469/469 [00:44<00:00, 10.65it/s]
100%|██████████| 10/10 [00:08<00:00,  1.13it/s]


Epoch 7/10 | Train Loss: 1.5348 | Train Acc: 0.8760 | Test Loss: 1.5187 | Test Acc: 0.9088


100%|██████████| 469/469 [00:46<00:00, 10.04it/s]
100%|██████████| 10/10 [00:05<00:00,  1.97it/s]


Epoch 8/10 | Train Loss: 1.5323 | Train Acc: 0.8795 | Test Loss: 1.5182 | Test Acc: 0.8874


100%|██████████| 469/469 [00:30<00:00, 15.50it/s]
100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Epoch 9/10 | Train Loss: 1.5304 | Train Acc: 0.8824 | Test Loss: 1.5163 | Test Acc: 0.9061


100%|██████████| 469/469 [00:30<00:00, 15.35it/s]
100%|██████████| 10/10 [00:04<00:00,  2.34it/s]

Epoch 10/10 | Train Loss: 1.5288 | Train Acc: 0.8849 | Test Loss: 1.5183 | Test Acc: 0.9044



