In [13]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [14]:
# Define hyperparameters
batch_size = 64
num_epochs = 20

# Load FER2013 dataset
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset
train_dataset = datasets.ImageFolder(
    root='./dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(
    root='./dataset/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
class FeedforwardSNN(nn.Module):
    def __init__(self, num_classes=7):
        super(FeedforwardSNN, self).__init__()
        self.fc1 = nn.Linear(48 * 48, 512)    # First fully connected layer
        # Leaky integrate-and-fire layer for spiking
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(512, 256)        # Second fully connected layer
        self.lif2 = snn.Leaky(beta=0.9)
        self.fc3 = nn.Linear(256, 128)  # Output layer
        self.lif3 = snn.Leaky(beta=0.9)
        self.fc4 = nn.Linear(128, num_classes)  # Output layer
        self.lif4 = snn.Leaky(beta=0.9)

    def forward(self, x):
        # Flatten the input image
        x = x.view(x.size(0), -1)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        # print(x.size())
        # Layer 1 with spiking
        cur1 = self.fc1(x)
        spk1, mem1 = self.lif1(cur1, mem1)
        # print(spk1.size())
        # Layer 2 with spiking
        cur2 = self.fc2(spk1)
        spk2, mem2 = self.lif2(cur2, mem2)
        # print(spk2.size())
        cur3 = self.fc3(spk2)
        spk3, mem3 = self.lif3(cur3, mem3)
        
        out = self.fc4(spk3)
        # print(spk3.size())

        return out
net = FeedforwardSNN().to(device)

In [16]:
train_class_counts = {3: 7215, 4: 4965,
                      5: 4830, 2: 4097, 0: 3995, 6: 3171, 1: 436}
class_weights = torch.tensor([1 / train_class_counts[i]
                             for i in range(7)]).to(device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))
num_epochs = 20
num_steps = 100
# Training loop


def train_snn(num_epochs):
    for epoch in range(num_epochs):
        net.train()
        running_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.long()

            # Forward pass
            optimizer.zero_grad()
            spk_rec = net(images)
            # spk_rec= forward_pass(net, num_steps, data)
            # print(spk_rec.size(),epoch)
            # spk_rec.squeeze(1)
            # labels = labels.view(-1)
            # print(spk_rec.size())
            loss = loss_fn(spk_rec, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(
            f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')


train_snn(num_epochs)

Epoch [1/20], Loss: 1.9062
Epoch [2/20], Loss: 1.8204
Epoch [3/20], Loss: 1.7699
Epoch [4/20], Loss: 1.7326
Epoch [5/20], Loss: 1.7021
Epoch [6/20], Loss: 1.6775
Epoch [7/20], Loss: 1.6519
Epoch [8/20], Loss: 1.6233
Epoch [9/20], Loss: 1.6088
Epoch [10/20], Loss: 1.5789
Epoch [11/20], Loss: 1.5649
Epoch [12/20], Loss: 1.5344
Epoch [13/20], Loss: 1.5146
Epoch [14/20], Loss: 1.5126
Epoch [15/20], Loss: 1.4868
Epoch [16/20], Loss: 1.4797
Epoch [17/20], Loss: 1.4444
Epoch [18/20], Loss: 1.4556
Epoch [19/20], Loss: 1.4392
Epoch [20/20], Loss: 1.4247


In [17]:
val_losses = []
def evaluate():
    correct = 0
    total = 0
    val_loss = 0
    net.eval()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            # loss = loss_fn(outputs, labels)
            # val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')
    
    # val_loss /= len(test_loader)
    # val_losses.append(val_loss)
    # val_accuracy = 100 * correct / total


evaluate()

Test Accuracy: 38.44%


In [None]:
# torch.save(net.state_dict(), 'ffsnn.pth')

In [11]:
from collections import Counter
train_labels = [label for _, label in train_dataset]
test_labels = [label for _, label in test_dataset]

# Count the occurrences of each class label
train_class_counts = Counter(train_labels)
test_class_counts = Counter(test_labels)

# Print the counts
print("Training Class Counts:", train_class_counts)
print("Testing Class Counts:", test_class_counts)

Training Class Counts: Counter({3: 7215, 4: 4965, 5: 4830, 2: 4097, 0: 3995, 6: 3171, 1: 436})
Testing Class Counts: Counter({3: 1774, 5: 1247, 4: 1233, 2: 1024, 0: 958, 6: 831, 1: 111})
