In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, sampler
from skimage.filters import gaussian
from skimage.util import random_noise
import pickle
import os
import time
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
from collections import defaultdict
import noises

In [None]:
class PseudoSpikeRect(torch.autograd.Function):
    """ Rectangular Pseudo-grad function """

    @staticmethod
    def forward(ctx, input, vth, grad_win, grad_amp):
        """
        Args:
            input (Torch Tensor): Input tensor containing voltages of neurons in a layer
            vth (Float): Voltage threshold for spiking 
            grad_win (Float): Window for computing pseudogradient
            grad_amp (Float): Amplification factor for the gradients
        
        Returns:
            output (Torch Tensor): Generated spikes for the input
        
        """
        
        #Saving variables for backward pass. 
        ctx.save_for_backward(input)
        ctx.vth = vth
        ctx.grad_win = grad_win
        ctx.grad_amp = grad_amp
        
        #Compute output from the input.
        output = (input > vth).float()
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Args:
            grad_output (Torch Tensor): Gradient of the output
        
        Returns:
            grad (Torch Tensor): Gradient of the input
        
        """
        
        #Retrieving variables from forward pass.
        input, = ctx.saved_tensors
        vth = ctx.vth
        grad_win = ctx.grad_win
        grad_amp = ctx.grad_amp
        grad_input = grad_output.clone()

        #Compute the gradient of the input using rectangular pseudograd function
        spike_pseudo_grad = torch.abs(ctx.saved_tensors[0] - vth)          
        spike_pseudo_grad = torch.lt(spike_pseudo_grad, grad_win).float()
        #Multiply by gradient amplifier.
        grad = grad_amp * grad_input * spike_pseudo_grad.float()

        return grad, None, None, None

In [None]:
class LinearIFCell(nn.Module):
    """ Leaky Integrate-and-fire neuron layer"""

    def __init__(self, psp_func, pseudo_grad_ops, param):
        """
        Args:
            psp_func (Torch Function): Pre-synaptic function
            pseudo_grad_ops (Torch Function): Pseudo-grad function
            param (tuple): Cell parameters (Voltage Threshold, gradient window, gradient amplitude)
        
        """
        super(LinearIFCell, self).__init__()
        self.psp_func = psp_func
        self.pseudo_grad_ops = pseudo_grad_ops
        self.vdecay, self.vth, self.grad_win, self.grad_amp = param

    def forward(self, input_data, state):
        """
        Forward function
        Args:
            input_data (Tensor): input spike from pre-synaptic neurons
            state (tuple): output spike of last timestep and voltage of last timestep
        Returns:
            output: output spike
            state: updated neuron states
        
        """
        pre_spike, pre_volt = state
        
        #Compute the voltage from the presynaptic inputs.
        volt = self.vdecay*pre_volt*(1-pre_spike) + self.psp_func(input_data)

        #Compute the spike output by using the pseudo_grad_ops function.
        output = self.pseudo_grad_ops(volt, self.vth, self.grad_win, self.grad_amp)
        
        return output, (output, volt)

In [None]:
class SNN(nn.Module):
    """ SNN with two convolutional layers, two pooling layers, and a single fully connected hidden layer """

    def __init__(self, input_dim, output_dim, hidden_dim, conv1_dim, conv2_dim, param_dict):
        """
        Args:
            input_dim (int): input dimension
            output_dim (int): output dimension
            hidden_dim (int): hidden layer dimension
            conv1_dim (int): convolutional layer 1 output dimension
            conv2_dim (int): convolutional layer 2 output dimension
            param_dict (dict): neuron parameter dictionary for each LIF layer (Voltage Threshold, gradient window, gradient amplitude)
        
        """
        super(SNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        pseudo_grad_ops = PseudoSpikeRect.apply
        
        # Create the convolutional and pooling layers.
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, conv1_dim, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(conv1_dim, conv2_dim, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()

        #Create the hidden layer. Assume that the hidden layer neuron parameters are in param_dict['hid_layer']. Set bias=False for nn.Linear.  
        self.hidden_cell = LinearIFCell(nn.Linear(2*2*input_dim, hidden_dim, bias=False), pseudo_grad_ops, param_dict['hid_layer'])
        
        
        #Create the output layer. Output layer params are in param_dict['out_layer']. Set bias=False for nn.Linear.   
        self.output_cell = LinearIFCell(nn.Linear(hidden_dim, output_dim, bias=False), pseudo_grad_ops, param_dict['out_layer'])
        

    def forward(self, spike_data, init_states_dict, batch_size, spike_ts):
        """
        Forward function
        Args:
            spike_data (Tensor): spike data input (batch_size, input_dim, spike_ts)
            init_states_dict (dict): initial states for each layer- 'hid_layer' for hidden layer; 'out_layer' for output layer. 
            batch_size (int): batch size
            spike_ts (int): spike timesteps
        Returns:
            output: number of spikes of output layer
        
        """
        hidden_state, out_state = init_states_dict['hid_layer'], init_states_dict['out_layer']
        output_list = [] #List to store the output at each timestep
        for tt in range(spike_ts):
            #Retrieve the input at time tt
            input_spikes = spike_data[:,:,:,:,tt]

            #Propagate through convolutional and polling layers.
            conv_out = self.layer1(input_spikes)
            conv_out = self.layer2(conv_out)
            conv_out = conv_out.reshape(conv_out.size(0), -1)
            conv_out = self.drop_out(conv_out)

            #Propagate through the hidden layer
            hidden_layer_spikes, hidden_state = self.hidden_cell.forward(conv_out, hidden_state)

            #Propagate through the output layer
            output_layer_spikes, out_state = self.output_cell.forward(hidden_layer_spikes, out_state)

            #Append output spikes to output list
            output_list.append(output_layer_spikes)
        
        #Sum the outputs to compute spike count for each output neuron.

        output = torch.stack(output_list,0)
        output = torch.sum(output,0)

        return output

In [None]:
class WrapSNN(nn.Module):
    """ Wrapper of SNN """

    def __init__(self, input_dim, output_dim, hidden_dim, conv1_dim, conv2_dim, param_dict, device):
        """
        Args:
            input_dim (int): input dimension
            output_dim (int): output dimension
            hidden_dim (int): hidden layer dimension
            conv1_dim (int): convolutional layer 1 output dimension
            conv2_dim (int): convolutional layer 2 output dimension
            param_dict (dict): neuron parameter dictionary
            device (device): device
        """
        super(WrapSNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.conv1_dim = conv1_dim
        self.conv2_dim = conv2_dim
        self.device = device
        self.snn = SNN(input_dim, output_dim, hidden_dim, conv1_dim, conv2_dim, param_dict)

    def forward(self, spike_data):
        """
        Forward function
        Args:
            spike_data (Tensor): spike data input
        Returns:
            output: number of spikes of output layer
        """
        batch_size = spike_data.shape[0]
        spike_ts = spike_data.shape[-1]
        init_states_dict = {}
        # Hidden layer
        hidden_volt = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        hidden_spike = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        init_states_dict['hid_layer'] = (hidden_spike, hidden_volt)
        # Output layer
        out_volt = torch.zeros(batch_size, self.output_dim, device=self.device)
        out_spike = torch.zeros(batch_size, self.output_dim, device=self.device)
        init_states_dict['out_layer'] = (out_spike, out_volt)
        # SNN
        output = self.snn(spike_data, init_states_dict, batch_size, spike_ts)
        return output

In [None]:
def img_2_event_img(image, device, spike_ts):
    """
    Transform image to event image
    Args:
        image (Tensor): image
        device (device): device (can be either CPU or GPU)
        spike_ts (int): spike timestep
    Returns:
        event_image: event image
    """
    batch_size = image.shape[0]
    channel_size = image.shape[1]
    image_size = image.shape[2]
    image = image.view(batch_size, channel_size, image_size, image_size, 1)
    image.to(device)

    #Create a random image of shape batch_size x channel_size x image_size x image_size x spike_ts.
    random_image = torch.rand((batch_size, channel_size, image_size, image_size, spike_ts), device=device)

    #Generate event image using image and random image
    event_image = (image > random_image).float()

    return event_image

In [None]:
def get_sample_dataset(dataset, desired_labels, sample_number):
    """
    Get 'sample_number' samples for each label in desired_labels
    Args:
        dataset (Tensor): full dataset
        desired_lables (dict): dictionary where keys indicate the desired labels
        sample_number (int): the number of samples to extract for each label
    Returns:
        sample_dataset: subset of dataset containig only the samples with the desired labels
    """
    indices_for_each_label = dict()
    sample_indices = []
    for (i, x) in enumerate(dataset):
      if indices_for_each_label.get(x[1], 0) < sample_number:
        indices_for_each_label[x[1]] = indices_for_each_label.get(x[1], 0) + 1
        sample_indices.append(i)
    sample_dataset = torch.utils.data.Subset(dataset, sample_indices)
    return sample_dataset

In [None]:
def stbp_snn_training(network, desired_labels, spike_ts, device, batch_size=128, test_batch_size=256, epoch=100, sample_number=200, noise=None):
    """
    STBP SNN training
    Args:
        network (SNN): STBP learning SNN
        spike_ts (int): spike timestep
        device (device): device
        batch_size (int): batch size for training
        test_batch_size (int): batch size for testing
        epoch (int): number of epochs
    Returns:
        train_loss_list: list of training loss for each epoch
        test_accuracy_list: list of test accuracy for each epoch
    """
    
    #Creating folder where EMNIST data is saved. Load the EMNIST dataset.
    try:
        os.mkdir("./data")
        print("Directory data Created")
    except FileExistsError:
        print("Directory data already exists")
    data_path = './data/'

    if not noise:
      # No noise is being applied to the training dataset.
      train_dataset = torchvision.datasets.EMNIST(root=data_path, split="balanced", train=True, download=True,
                                                  transform=torchvision.transforms.Compose([
                                                    lambda img: torchvision.transforms.functional.rotate(img, -90),
                                                    lambda img: torchvision.transforms.functional.hflip(img),
                                                    torchvision.transforms.ToTensor(),
                                                  ]))
    else:
      # Noise is being applied to the training dataset.
      train_dataset = torchvision.datasets.EMNIST(root=data_path, split="balanced", train=True, download=True,
                                                  transform=torchvision.transforms.Compose([
                                                    lambda img: torchvision.transforms.functional.rotate(img, -90),
                                                    lambda img: torchvision.transforms.functional.hflip(img),
                                                    torchvision.transforms.ToTensor(),
                                                    noise
                                                  ]))
    
    test_dataset = torchvision.datasets.EMNIST(root=data_path, split="balanced", train=False, download=True,
                                                transform=torchvision.transforms.Compose([
                                                  lambda img: torchvision.transforms.functional.rotate(img, -90),
                                                  lambda img: torchvision.transforms.functional.hflip(img),
                                                  torchvision.transforms.ToTensor(),
                                                ]))
    indices = list()
    for (i, x) in enumerate(train_dataset):
      if x[1] in desired_labels:
        indices.append(i)

    train_dataset = torch.utils.data.Subset(train_dataset, indices)

    indices = list()
    for (i, x) in enumerate(test_dataset):
      if x[1] in desired_labels:
        indices.append(i)

    test_dataset = torch.utils.data.Subset(test_dataset, indices)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=False, num_workers=4)
    train_sample_count = dict()
    for data in train_dataloader:
      images, labels = data
      for l in labels:
        train_sample_count[l.item()] = train_sample_count.get(l.item(), 0) + 1

    sample_dataset = get_sample_dataset(train_dataset, desired_labels, sample_number)
    train_dataloader = DataLoader(sample_dataset, batch_size=batch_size,
                                  shuffle=True, num_workers=4)
    train_sample_count = dict()
    for data in train_dataloader:
      images, labels = data
      for l in labels:
        train_sample_count[l.item()] = train_sample_count.get(l.item(), 0) + 1
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size,
                                 shuffle=False, num_workers=4)


    #Initialize MSE loss.
    criterion = nn.MSELoss()

    #Initialize Adam Optimizer.
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-3, betas=[0.9, 0.999])

    # List for saving loss and accuracy
    train_loss_list, test_accuracy_list = [], []
    test_num = len(test_dataset)

    # Dictionaries for recording test accuracy for each letter
    test_num_by_letter = dict()
    test_accuracy_by_letter = dict()

    for data in test_dataloader:
      images, labels = data
      for l in labels:
        subset_label = desired_labels[l.item()]
        test_num_by_letter[subset_label] = test_num_by_letter.get(subset_label, 0) + 1
    for k,v in desired_labels.items():
        test_accuracy_by_letter[v] = []

    # Start training
    
    #Put the network on the device
    network.to(device)
    
    #Loop for the epochs
    # Initialize max_test_accuracy to print the confusion matrix.
    max_test_accuracy = float('-inf')
    for ee in range(epoch):
        #Keep track of running loss
        running_loss = 0.0
        running_batch_num = 0
        train_start = time.time()
        
        #Iterate over the training data in train dataloader
        for data in train_dataloader:
            #Retrieve the image and label from data
            images, labels = data
            labels = torch.tensor([desired_labels[x.item()] for x in labels])
            #Reshape labels for MSE.
            labels = torch.tensor([[i == x.item() for i in range(len(desired_labels.keys()))] for x in labels]).float()
            #Put the image and labels on the device
            images = images.to(device)
            labels = labels.to(device)
            
            #Convert images to event images
            event_images = img_2_event_img(images, device, spike_ts)
            
            optimizer.zero_grad()
            
            #Compute the network output for the event images
            outputs = network.forward(event_images)
            
            #Compute the firing rates of the each ouput neuron for MSE
            outputs = outputs/spike_ts

            #Compute the loss using the criterion defined previously. Store in a variable called loss
            loss = criterion(outputs, labels)

            #Backpropagate the loss through the network.
            loss.backward()
            
            #Update the network weights by taking an optimizer 'step'. 
            optimizer.step()
            
            #Updating tracking variables. Nothing to do here
            running_loss += loss.item()
            running_batch_num += 1

        train_end = time.time()
        train_loss_list.append(running_loss / running_batch_num)
        print("Epoch %d Training Loss %.4f" % (ee, train_loss_list[-1]), end=" ")
        
        #Counters to keep track of the number of correct predictions
        test_correct_num = 0
        test_correct_by_letter = defaultdict(int)
        confusion_matrix = [[0 for i in range(5)] for i in range(5)]

        test_start = time.time()
        with torch.no_grad():
            for data in test_dataloader:
                
                #Retrieve the image and label from test data
                images, labels = data
                labels = torch.tensor([desired_labels[x.item()] for x in labels])
                
                #Put the image and labels on the device
                images = images.to(device)
                labels = labels.to(device)
                
                #Convert the image into event images
                event_images = img_2_event_img(images, device, spike_ts)
                
                #Compute the network predictions and store in a variable called outputs
                outputs = network.forward(event_images)
                
                #Get the class label as the largest activation. This is complete.
                _, predicted = torch.max(outputs, 1)
                
                #Compare the network predictions against the true labels and update the counter for correct predictions.
                test_correct_num += torch.sum(torch.eq(predicted, labels).float())
                for (i,p) in enumerate(predicted):
                  confusion_matrix[p.item()][labels[i].item()] += 1
                  if p.item() == labels[i].item():
                    test_correct_by_letter[p.item()] += 1

        #Updating tracking variables.
        test_end = time.time()
        test_accuracy_list.append(test_correct_num/test_num)
        for k,v in desired_labels.items():
            test_accuracy_by_letter[v].append(test_correct_by_letter[v]/test_num_by_letter[v])

        print("Test Accuracy %.4f Training Time: %.1f Test Time: %.1f" % (
            test_accuracy_list[-1], train_end - train_start, test_end - test_start))
        for k, v in desired_labels.items():
          print("Test Accuracy for %.f : %.4f" %(k, test_accuracy_by_letter[v][-1]))
        if test_accuracy_list[-1] > max_test_accuracy and ee>=30:
          # Update max test accuracy
          max_test_accuracy = test_accuracy_list[-1]
          print (f"The confusion matrix:\n {confusion_matrix}")

    #Return the loss and accuracies. 
    print("End Training")
    network.to('cpu')
    return train_loss_list, test_accuracy_list, test_accuracy_by_letter

In [None]:
 # Define the device on which training will be performed\
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Define the input dimensions in a variable
input_dim = 784

#Define the output dimensions in a variable
output_dim = 5

#Define the hidden dimension in a variable
hidden_dim = 400

#Define the convolutional layer output dimensions
conv1_dim = 32
conv2_dim = 64

#Create a dictionary of the neuron parameters for the hidden and output layer. The keys should be 'hid_layer' and 'out_layer'.
#The values of the dictionary is a list of the neuron parameters for each layer where the list elements are [vdecay, vth, grad_win, grad_amp]
param_dict = {'hid_layer': [0.5, 0.5, 0.5, 1.0], 'out_layer': [0.5, 0.5, 0.5, 1.0]}

#Create the SNN using the class definition in 3b and the arguments defined above
network = WrapSNN(input_dim, output_dim, hidden_dim, conv1_dim, conv2_dim, param_dict, device)

#Define snn timesteps
spike_ts = 30

#Batch size for training
batch_size = 256

#Batch size for testing
test_batch_size = 128

#Epochs
epoch=40

#Sample Number
sample_number = 800
#The keys of desired_labels indicate the targeted labels. The values indicate their 'new' labels in the subset.
# a, b, d, g, and q
desired_labels = {36: 0, 37:1, 38:2, 41: 3, 44: 4}


# Initialize a noise variable using the classes in noise.py (i.e noise = noises.AddGaussianBlur(5, 0.2)) and pass the variable to the last argument of sbtp_snn_training to train with noise.
train_loss_list, test_accuracy_list, test_accuracy_by_letter = stbp_snn_training(network, desired_labels, spike_ts, device, batch_size, test_batch_size, epoch, sample_number)

In [None]:
#Script for plotting accuracy results.

epochs = range(0,epoch)
plt.plot(epochs, test_accuracy_by_letter[0], '-o', label='Label a')
plt.plot(epochs, test_accuracy_by_letter[1], '-o', label='Label b')
plt.plot(epochs, test_accuracy_by_letter[2], '-o', label='Label d')
plt.plot(epochs, test_accuracy_by_letter[3], '-o', label='Label g')
plt.plot(epochs, test_accuracy_by_letter[4], '-o', label='Label q')
plt.plot(epochs, [x.item() for x in test_accuracy_list], '-o', label='Overall Accuracy')
plt.xticks(np.arange(0, epoch, 1.0))
plt.rcParams["figure.figsize"] = (10, 10)
plt.yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.title('Testing Accuracy At Each Epoch')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()