# 0. Package Import

In [4]:
import torch, torch.nn as nn
import snntorch as snn

# 1. Hyperparams

In [7]:
batch_size = 128
data_path='/tmp/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# 2. Data Load

In [9]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:08<00:00, 3154993.34it/s]


Extracting /tmp/data/fmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 190334.78it/s]


Extracting /tmp/data/fmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:02<00:00, 1585766.07it/s]


Extracting /tmp/data/fmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 20524978.13it/s]

Extracting /tmp/data/fmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/fmnist/FashionMNIST/raw





# 3. Model

In [24]:
from snntorch import surrogate

# network parameters
num_inputs = 28*28
num_hidden = 128
num_outputs = 10
num_steps = 1

beta = 0.9  # neuron decay rate
grad = surrogate.fast_sigmoid()

In [21]:
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(num_inputs, num_hidden),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                    nn.Linear(num_hidden, num_outputs),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                    ).to(device)

In [25]:
output_pre_class = 50
pop_outputs = num_outputs * output_pre_class

net_pop = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hidden),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                        nn.Linear(num_hidden, pop_outputs),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                        ).to(device)

# 4. Training

In [22]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)



from snntorch import utils

def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      utils.reset(net)
      spk_rec, _ = net(data)

      if population_code:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets, population_code=True, num_classes=10) * spk_rec.size(1)
      else:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets) * spk_rec.size(1)

      total += spk_rec.size(1)

  return acc/total


from snntorch import backprop

num_epochs = 5

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\n")

Epoch: 0
Test set accuracy: 60.661%
Epoch: 1
Test set accuracy: 56.260%
Epoch: 2
Test set accuracy: 65.971%
Epoch: 3
Test set accuracy: 68.829%
Epoch: 4
Test set accuracy: 67.692%


In [26]:
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0, population_code=True, num_classes=10)
optimizer = torch.optim.Adam(net_pop.parameters(), lr=2e-3, betas=(0.9, 0.999))

num_epochs = 5

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net_pop, train_loader, num_steps=num_steps,
                            optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net_pop, num_steps, population_code=True, num_classes=10)*100:.3f}%\n")

                             
                             

Epoch: 0
Test set accuracy: 80.202%
Epoch: 1
Test set accuracy: 81.141%
Epoch: 2
Test set accuracy: 81.774%
Epoch: 3
Test set accuracy: 82.417%
Epoch: 4
Test set accuracy: 82.822%
