In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import itertools

  from snntorch import backprop


In [2]:
# Define hyperparameters
batch_size = 64

# Load FER2013 dataset
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Change brightness and contrast
    # transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Add small shifts
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset
train_dataset = datasets.ImageFolder(
    root='./dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(
    root='./dataset/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,drop_last=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
for sample in iter(train_loader):
  print("Shape of sample object: ", sample[0].shape)
  break
print(len(train_loader.dataset)%64, len(test_loader.dataset)%64)

Shape of sample object:  torch.Size([64, 1, 48, 48])
37 10


In [11]:
gradient = surrogate.fast_sigmoid(slope=25)
beta = 0.65

# Initializing the network
net = nn.Sequential(nn.Conv2d(1, 32, 3),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=gradient,
                              init_hidden=True),
                    nn.Conv2d(32, 64, 3),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=gradient,
                              init_hidden=True),
                    
                    nn.Flatten(),
                    nn.Linear(10*10, 7),
                    snn.Leaky(beta=beta, spike_grad=gradient,
                              init_hidden=True, output=True)
                    ).to(device)

In [12]:
def forward_pass(net, data):
  spk_rec = []
  snn.utils.reset(net)
  for step in range(data.size(0)):
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)
  return torch.stack(spk_rec)

In [13]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.0002, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [14]:
num_epochs = 50
counter = 0

loss_hist = []
acc_hist = []
test_acc_hist = []

# Training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        # Downsampling image from (128 x 128) to (32 x 32)
        # data = nn.functional.interpolate(data, size=(48, 48))
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        # propagating one batch through the network and evaluating loss
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)

        # print metrics every so often
        if counter % 16 == 0:
          print(
              f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")
          print(f"Train Accuracy: {acc * 100:.2f}%\n")

          correct = 0
          total = 0

          for i, (data, targets) in enumerate(iter(test_loader)):
            # data = nn.functional.interpolate(data, size=(48,48))
            data = data.to(device)
            targets = targets.to(device)
            spk_rec = forward_pass(net, data)
            correct += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

          test_acc = (correct/total) * 100
          test_acc_hist.append(test_acc)
          print(f"========== Test Set Accuracy: {test_acc:.2f}% ==========\n")

        counter += 1

Epoch 0, Iteration 0 
Train Loss: 7.73
Train Accuracy: 14.06%


Epoch 0, Iteration 16 
Train Loss: 7.73
Train Accuracy: 14.06%


Epoch 0, Iteration 32 
Train Loss: 7.73
Train Accuracy: 9.38%


Epoch 0, Iteration 48 
Train Loss: 7.72
Train Accuracy: 20.31%


Epoch 0, Iteration 64 
Train Loss: 7.67
Train Accuracy: 17.19%


Epoch 0, Iteration 80 
Train Loss: 6.62
Train Accuracy: 6.25%


Epoch 0, Iteration 96 
Train Loss: 4.44
Train Accuracy: 1.56%


Epoch 0, Iteration 112 
Train Loss: 3.17
Train Accuracy: 18.75%


Epoch 0, Iteration 128 
Train Loss: 2.88
Train Accuracy: 28.12%


Epoch 0, Iteration 144 
Train Loss: 2.86
Train Accuracy: 29.69%


Epoch 0, Iteration 160 
Train Loss: 3.01
Train Accuracy: 26.56%


Epoch 0, Iteration 176 
Train Loss: 2.89
Train Accuracy: 23.44%


Epoch 0, Iteration 192 
Train Loss: 2.84
Train Accuracy: 28.12%


Epoch 0, Iteration 208 
Train Loss: 2.87
Train Accuracy: 28.12%


Epoch 0, Iteration 224 
Train Loss: 2.95
Train Accuracy: 21.88%


Epoch 0, Iteration 24

KeyboardInterrupt: 