In [1]:
import numpy as np
import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate
import sys
sys.path.append('../src')
from dataloader import WISDM_Dataset_parser, WISDM_Dataset
from torch.utils.data import  DataLoader

In [2]:
batch_size = 128
data_path='/tmp/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f'Using device {device}')

Using device cuda


In [4]:

dataset = WISDM_Dataset_parser('../data/watch_subset2_40.npz', norm=None)
train_set = dataset.get_training_set()
test_set = dataset.get_validation_set()

train_dataset = WISDM_Dataset(train_set)
test_dataset = WISDM_Dataset(test_set)

train_loader = DataLoader(dataset=train_dataset, batch_size=int(batch_size), shuffle=True, num_workers=8)
test_loader  = DataLoader(dataset= test_dataset, batch_size=int(batch_size), shuffle=True, num_workers=8)


(6,)
(6,)
num classes train dataset: 7 occurrences of each class:[3189 2987 3083 3262 3046 3071 3082]
num classes eval dataset: 7 occurrences of each class:[1050 1017  982  998 1058 1055 1080]
num classes test dataset: 7 occurrences of each class:[1031  948 1014 1076 1062 1038 1072]


In [5]:
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(num_inputs, num_hidden),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                    nn.Linear(num_hidden, num_outputs),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                    ).to(device)

In [6]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)

In [7]:
from snntorch import utils

def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      utils.reset(net)
      spk_rec, _ = net(data)

      if population_code:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets, population_code=True, num_classes=num_outputs) * spk_rec.size(1)
      else:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets) * spk_rec.size(1)

      total += spk_rec.size(1)

  return acc/total

In [8]:
from snntorch import backprop



# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps, num_classes=num_outputs)*100:.3f}%\n")


  from snntorch import backprop


Epoch: 0
Test set accuracy: 28.905%

Epoch: 1
Test set accuracy: 31.672%

Epoch: 2
Test set accuracy: 31.690%

Epoch: 3
Test set accuracy: 33.906%

Epoch: 4
Test set accuracy: 33.959%

Epoch: 5
Test set accuracy: 35.409%

Epoch: 6
Test set accuracy: 33.682%

Epoch: 7
Test set accuracy: 34.407%

Epoch: 8
Test set accuracy: 30.233%

Epoch: 9
Test set accuracy: 29.083%



In [9]:

net_pop = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hidden),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                        nn.Linear(num_hidden, pop_outputs),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                        ).to(device)

In [10]:
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0, population_code=True, num_classes=num_outputs)
optimizer = torch.optim.Adam(net_pop.parameters(), lr=2e-3, betas=(0.9, 0.999))

In [11]:
# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net_pop, train_loader, num_steps=num_steps,
                            optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net_pop, num_steps, population_code=True, num_classes=num_outputs)*100:.3f}%\n")

Epoch: 0
Test set accuracy: 63.603%

Epoch: 1
Test set accuracy: 66.827%

Epoch: 2
Test set accuracy: 68.241%

Epoch: 3
Test set accuracy: 70.465%

Epoch: 4
Test set accuracy: 70.465%

Epoch: 5
Test set accuracy: 71.191%

Epoch: 6
Test set accuracy: 72.841%

Epoch: 7
Test set accuracy: 73.293%

Epoch: 8
Test set accuracy: 73.142%

Epoch: 9
Test set accuracy: 73.549%

