In [None]:
# Import required libraries
import snntorch as snn
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Dataloader parameters
batch_size = 128
data_path = "./data"

# Define transformations
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

# Load MNIST dataset
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28 * 28
num_hidden = 1000
num_outputs = 10

# Temporal dynamics
num_steps = 25
beta = 0.95

# Define the Network
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        # Initialize hidden states
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Initialize network
net = SNN().to(device)

# Loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

# Helper functions for training and testing
def calculate_accuracy(data, targets):
    spk_rec, _ = net(data.view(batch_size, -1))
    _, predicted = spk_rec.sum(dim=0).max(1)
    acc = (targets == predicted).float().mean()
    return acc.item()

def train_one_epoch(epoch):
    net.train()
    total_loss = 0
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        spk_rec, mem_rec = net(data.view(batch_size, -1))
        loss = torch.zeros(1, device=device)
        for step in range(num_steps):
            loss += loss_fn(mem_rec[step], targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch} Training Loss: {total_loss / len(train_loader):.4f}")

def test_network():
    net.eval()
    total_accuracy = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            total_accuracy += calculate_accuracy(data, targets)
    print(f"Test Set Accuracy: {total_accuracy / len(test_loader) * 100:.2f}%")

# Training loop
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_one_epoch(epoch)
    test_network()

Epoch 1 Training Loss: 10.3127
Test Set Accuracy: 94.21%
Epoch 2 Training Loss: 5.0605
Test Set Accuracy: 95.84%
