In [None]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools
import tensorflow as tf
import os

In [None]:
# Leaky neuron model, overriding the backward pass with a custom function
class LeakySigmoidSurrogate(nn.Module):
  def __init__(self, beta, threshold=1.0, k=25):

      # Leaky_Surrogate is defined in the previous tutorial and not used here
      super(Leaky_Surrogate, self).__init__() 

      # initialize decay rate beta and threshold
      self.beta = beta
      self.threshold = threshold
      self.surrogate_func = self.FastSigmoid.apply
  
  # the forward function is called each time we call Leaky
  def forward(self, input_, mem):
    spk = self.surrogate_func((mem-self.threshold))  # call the Heaviside function
    reset = (spk - self.threshold).detach()
    mem = self.beta * mem + input_ - reset
    return spk, mem

  # Forward pass: Heaviside function
  # Backward pass: Override Dirac Delta with gradient of fast sigmoid
  @staticmethod
  class FastSigmoid(torch.autograd.Function):  
    @staticmethod
    def forward(ctx, mem, k=25):
        ctx.save_for_backward(mem) # store the membrane potential for use in the backward pass
        ctx.k = k
        out = (mem > 0).float() # Heaviside on the forward pass: Eq(1)
        return out

    @staticmethod
    def backward(ctx, grad_output): 
        (mem,) = ctx.saved_tensors  # retrieve membrane potential
        grad_input = grad_output.clone()
        grad = grad_input / (ctx.k * torch.abs(mem) + 1.0) ** 2  # gradient of fast sigmoid on backward pass: Eq(4)
        return grad, None

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import numpy as np
import pandas as pd
import os
import shutil
batch_size = 1

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def split_train(data_folder, train_fraction,transform):
    # print(train_data_folder)
    data_train = ImageFolder(data_folder, transform)
    train, test = random_split(data_train, [int(train_fraction*len(data_train)),len(data_train)-int(train_fraction*len(data_train))], generator=torch.Generator().manual_seed(42))
    return train, test

def prepare_data_classes_from_train(src_data_folder, dest_data_folder, train_labels_csv_file, move = False):
    data = pd.read_csv(train_labels_csv_file)
    labels_dic = {}
    for id, pos_label in zip(data['id'], data['pos_label']):
        labels_dic[str(id).split(".")[0]] = int(pos_label)
    # print(data.groupby('pos_label').count())
    n_classes = data.groupby('pos_label').count().count()[0]

    if os.path.exists(f'{dest_data_folder}'):
        shutil.rmtree(f'{dest_data_folder}')
    os.mkdir(f'{dest_data_folder}')


    for cls in range(n_classes):
        if os.path.exists(f'{dest_data_folder}/{cls}'):
            shutil.rmtree(f'{dest_data_folder}/{cls}')
        os.mkdir(f'{dest_data_folder}/{cls}')

    for file in os.listdir(src_data_folder):
        # print(labels_dic[file.split('.')[0]])
        # print(f'{dest_data_folder}/{str(labels_dic[file.split(".")[0]])}/{file}')

        try:            
            # print(f'{src_data_folder}/{file} ==> {dest_data_folder}/{str(labels_dic[file.split(".")[0]])}/{file}')
            if move:
                shutil.move(f'{src_data_folder}/{file}', f'{dest_data_folder}/{str(labels_dic[file.split(".")[0]])}/{file}')
            else:
                shutil.copy(f'{src_data_folder}/{file}', f'{dest_data_folder}/{str(labels_dic[file.split(".")[0]])}/{file}')
        except Exception:
            print(f'Error in copying ... file: {file}')

    return n_classes

def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()
    
    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)
  
  return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

data_folder = 'snn_original1'
dest_data_folder = 'snn_original_classes'

prepare_data_classes_from_train(data_folder, dest_data_folder, 'labels.csv', move=False)
X_train, X_test = split_train(dest_data_folder, 0.8, transform)
train_dataloader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(X_test, batch_size=batch_size, shuffle=True)

In [None]:
X_train, X_test = split_train(dest_data_folder, 0.8, transform)
train_dataloader = DataLoader(X_train, shuffle=True)
test_dataloader = DataLoader(X_test, shuffle=True)

In [None]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [7]:
#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, threshold=0.6, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, threshold=0.4, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*53*53, 10),
                    snn.Leaky(beta=beta, threshold=0.2, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [None]:
data, targets = next(iter(train_dataloader))
data = data.to(device)
targets = targets.to(device)

for step in range(num_steps):
    spk_out, mem_out = net(data)

In [None]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)
  
  return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
spk_rec, mem_rec = forward_pass(net, num_steps, data)

In [None]:
loss_fn = SF.ce_rate_loss()

In [None]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()
    
    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999))
num_epochs = 20
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in iter(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)
        

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

    with torch.no_grad():
        net.eval()
        
        test_acc = batch_accuracy(test_dataloader, net, num_steps)
        print(f'iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n')
        test_acc_hist.append(test_acc.item())

In [None]:
fig = plt.figure(facecolor="w")
plt.plot(train_acc_hist)
plt.plot(test_acc_hist)
plt.legend(["Train Accuracy", "Test Accuracy"])
plt.title("Accuracy Curves")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.savefig('accuracy_curves.svg')
plt.show()