<a href="https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_5_FCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This file is based on the tutorial based on:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>

# Dependencies

In [1]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

import time
import csv
import os

# Load the dataset

In [2]:
# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [4]:
print(mnist_train)
print(mnist_test)

Dataset MNIST
    Number of datapoints: 60000
    Root location: /tmp/data/mnist
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=warn)
               Grayscale(num_output_channels=1)
               ToTensor()
               Normalize(mean=(0,), std=(1,))
           )
Dataset MNIST
    Number of datapoints: 10000
    Root location: /tmp/data/mnist
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=warn)
               Grayscale(num_output_channels=1)
               ToTensor()
               Normalize(mean=(0,), std=(1,))
           )


In [5]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Define the Network(s)

## SNN_Baseline

In [6]:
# accepts a tensor of batchx784
# static coding used - the same value passed every time
class SNN_Baseline(nn.Module):
    def __init__(self, num_steps=25, beta=0.95, num_hidden=1000):
        super().__init__()
        self.name = "SNN_Baseline"
        
        # Temporal Dynamics
        self.num_steps = num_steps 
        self.beta = beta
        
        # Network Architecture
        num_inputs = 28*28
        num_outputs = 10

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=self.beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=self.beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# SNN_Large

In [45]:
# accepts a tensor of batchx784
# static coding used - the same value passed every time
class SNN_Large(nn.Module):
    def __init__(self, num_steps=25, beta=0.95, num_hidden=1000):
        super().__init__()
        self.name = "SNN_Large"
        
        # Temporal Dynamics
        self.num_steps = num_steps 
        self.beta = beta
        
        # Network Architecture
        num_inputs = 28*28
        num_outputs = 10

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=self.beta)
        self.fc2 = nn.Linear(num_hidden, num_hidden)
        self.lif2 = snn.Leaky(beta=self.beta)
        self.fc3 = nn.Linear(num_hidden, num_outputs)
        self.lif3 = snn.Leaky(beta=self.beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        # Record the final layer
        spkx_rec = []
        memx_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            spkx_rec.append(spk3)
            memx_rec.append(mem3)

        return torch.stack(spkx_rec, dim=0), torch.stack(memx_rec, dim=0)

# SNN_Clamp

In [52]:
# accepts a tensor of batchx784
# static coding used - the same value passed every time
class SNN_Clamp(nn.Module):
    def __init__(self, num_steps=25, beta=0.95, num_hidden=1000):
        super().__init__()
        self.name = "SNN_Clamp"
        
        # Temporal Dynamics
        self.num_steps = num_steps 
        self.beta = beta
        
        # Network Architecture
        num_inputs = 28*28
        num_outputs = 10

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=self.beta)
        self.fc2 = nn.Linear(num_hidden, 500)
        self.lif2 = snn.Leaky(beta=self.beta)
        self.fc3 = nn.Linear(500, num_outputs)
        self.lif3 = snn.Leaky(beta=self.beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        # Record the final layer
        spkx_rec = []
        memx_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            spkx_rec.append(spk3)
            memx_rec.append(mem3)

        return torch.stack(spkx_rec, dim=0), torch.stack(memx_rec, dim=0)

# Evaluation Functions

In [8]:
def plot_loss(loss_hist):
    fig = plt.figure(facecolor="w", figsize=(10, 5))
    plt.plot(loss_hist)
    plt.title("Loss Curves")
    plt.legend(["Train Loss", "Test Loss"])
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.show()

def evaluate(net):
    total = 0
    correct = 0
    
    # drop_last switched to False to keep all samples
    testLoader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)
    
    with torch.no_grad():
      net.eval()
      for data, targets in testLoader:
        data = data.to(device)
        targets = targets.to(device)
        
        # forward pass
        test_spk, _ = net(data.view(data.size(0), -1))
    
        # calculate total accuracy
        _, predicted = test_spk.sum(dim=0).max(1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
    
    print(f"Total correctly classified test set images: {correct}/{total}")
    print(f"Test Set Accuracy: {100 * correct / total:.2f}%")
    
    return correct / total

# The Training Function

In [9]:
def train_net(net, num_epochs):
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
    
    loss_hist = []
    counter = 0
    
    # Outer training loop
    for epoch in range(num_epochs):
        iter_counter = 0
        train_batch = iter(train_loader)
    
        # Minibatch training loop
        for data, targets in train_batch:
            data = data.to(device)
            targets = targets.to(device)
    
            # forward pass
            net.train()
            spk_rec, mem_rec = net(data.view(batch_size, -1))
    
            # initialize the loss & sum over time
            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(net.num_steps):
                loss_val += loss(mem_rec[step], targets)
    
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
    
            # Store loss history for future plotting
            loss_hist.append(loss_val.item())

            '''
            # Test set
            with torch.no_grad():
                net.eval()
                test_data, test_targets = next(iter(test_loader))
                test_data = test_data.to(device)
                test_targets = test_targets.to(device)
    
                # Test set forward pass
                test_spk, test_mem = net(test_data.view(batch_size, -1))
    
                # Test set loss
                test_loss = torch.zeros((1), dtype=dtype, device=device)
                for step in range(net.num_steps):
                    test_loss += loss(test_mem[step], test_targets)
                test_loss_hist.append(test_loss.item())
    
                # Print train/test loss/accuracy
                if counter % 10000 == 0:
                    train_printer(
                        data, targets, epoch,
                        counter, iter_counter,
                        loss_hist, test_loss_hist,
                        test_data, test_targets)
            '''
            counter += 1
            iter_counter +=1

            print("\rProgress %5d\t%8d / %8d" % (counter, counter*batch_size, len(mnist_train)*num_epochs), end='', flush=True)
    return loss_hist
            

# Training Loop

In [75]:
# List of tuples of format (model, num_steps, beta, num_hidden, num_epochs)
hyperList = [
    (SNN_Baseline, 5, 0.95, 3000, 2),
    (SNN_Baseline, 10, 0.95, 3000, 2),
    (SNN_Baseline, 10, 0.95, 3000, 30),
    (SNN_Baseline, 10, 0.95, 3000, 90),
    (SNN_Baseline, 15, 0.95, 3000, 2),
    (SNN_Baseline, 20, 0.95, 3000, 2),
    (SNN_Baseline, 25, 0.95, 3000, 30),
    (SNN_Baseline, 30, 0.95, 3000, 2),
    (SNN_Baseline, 40, 0.95, 3000, 2),
    (SNN_Large, 40, 0.95, 3000, 2),
    (SNN_Clamp, 40, 0.95, 5000, 30),
    (SNN_Baseline, 40, 0.95, 5000, 30),
]

In [76]:
# cache
csv_file = 'hyper_result_cache.csv'
results = {}
if os.path.exists(csv_file):
    with open(csv_file, mode='r', newline='') as f:
        reader = csv.reader(f)
        for row in reader:
            key_str, result = row
            results[key_str] = result

In [None]:
for hyper in hyperList:

    shyper = str(hyper)
    if (shyper in results) and True:
        print("Cached hyper", hyper)
        print(f"Test Set Accuracy: {100 * float(results[shyper]):.2f}%")
        continue
        
    print("Training for", hyper)
    startTime = time.time()
    
    net = hyper[0](hyper[1], hyper[2], hyper[3])
    loss_hist = train_net(net = net.to(device),
                          num_epochs = hyper[4])

    print("\nFinished, time elapsed %.2f s" % (time.time()-startTime))
    plot_loss(loss_hist)
    accu = evaluate(net)
    print("-----\n")

    with open(csv_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        results[shyper] = accu
        writer.writerow([shyper, accu])

Cached hyper (<class '__main__.SNN_Baseline'>, 5, 0.95, 3000, 2)
Test Set Accuracy: 93.62%
Cached hyper (<class '__main__.SNN_Baseline'>, 10, 0.95, 3000, 2)
Test Set Accuracy: 95.02%
Cached hyper (<class '__main__.SNN_Baseline'>, 10, 0.95, 3000, 30)
Test Set Accuracy: 95.22%
Training for (<class '__main__.SNN_Baseline'>, 10, 0.95, 3000, 90)
Progress 24296	 3109888 /  5400000

In [62]:
results

{"(<class '__main__.SNN_Baseline'>, 25, 0.95, 10000, 1)": '0.9464',
 "(<class '__main__.SNN_Baseline'>, 25, 0.95, 3000, 1)": '0.938',
 "(<class '__main__.SNN_Large'>, 25, 0.95, 3000, 1)": '0.9033',
 "(<class '__main__.SNN_Baseline'>, 25, 0.95, 3000, 2)": '0.9656',
 "(<class '__main__.SNN_Baseline'>, 25, 0.95, 10000, 2)": '0.9576',
 "(<class '__main__.SNN_Large'>, 25, 0.95, 3000, 2)": '0.8959',
 "(<class '__main__.SNN_Baseline'>, 5, 0.95, 3000, 2)": '0.9362',
 "(<class '__main__.SNN_Baseline'>, 10, 0.95, 3000, 2)": '0.9502',
 "(<class '__main__.SNN_Baseline'>, 15, 0.95, 3000, 2)": '0.959',
 "(<class '__main__.SNN_Baseline'>, 20, 0.95, 3000, 2)": '0.9523',
 "(<class '__main__.SNN_Baseline'>, 30, 0.95, 3000, 2)": '0.9583',
 "(<class '__main__.SNN_Baseline'>, 40, 0.95, 1000, 1)": '0.9215',
 "(<class '__main__.SNN_Baseline'>, 40, 0.95, 1000, 2)": '0.9478',
 "(<class '__main__.SNN_Baseline'>, 40, 0.95, 3000, 2)": '0.9423',
 "(<class '__main__.SNN_Large'>, 40, 0.95, 3000, 2)": '0.9081',
 "(<c