# Learning Goals
The goal of this assignment is to use the knowledge gained in the course to develop an end-to-end SNN for MNIST digits classification trained using state-of-the-art gradient-descent algorithm. Along the way, you will also learn the basics of developing any machine learning application- how to handle data using data loaders; defining and optimizing loss; evaluating an algorithm using validation set; speeding up training using GPUs. You will also learn the basics of programming using PyTorch which is by far the most widely used library for machine learning research- being used in applications such as autonomous driving, robotic control, cancer research, and much more!

Let's import all the libraries required for this assignment. 

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, sampler
import torch.optim as optim
import os
import time
import pickle
import numpy as np

# Question 1: Limitation of backprop for SNN

## 1a. 
Sketch out the algorithm for training an SNN using backpropagation. Your algorithm should describe when and how the weights are updated. What is the main limitation of using backpropagation for training an SNN? Describe a solution to resolve the limitation. 

## Answer 1a. 
Double click to enter your response to Question 1a. here.

## 1b. 
In this exercise, we will implement the solution for overcoming backprop limitation for SNN. First the preliminaries: In PyTorch, arrays are called tensors. Gradients are computed automatically using the automatic differentiation package (autograd). The main elements of an autograd function are the forward and backward functions. The forward function simply performs the forward pass, i.e. computing output tensors from the input tensors. The backward function receives the gradient of the output tensors w.r.t. some scalar value, and computes the gradient of the input tensors w.r.t. the same scalar value. 

Below, we define the autograd function class for pseudo-gradient using the rectangular function. Most of the implementation is already written for you. Your task is to fill two key components- i) in the forward function, write the implementation for generating spike outputs from the inputs; ii) in the backward function, write the implementation for computing the gradient of the spike using rectangular psuedo-grad function. 

In [None]:
class PseudoSpikeRect(torch.autograd.Function):
    """ Rectangular Pseudo-grad function """

    @staticmethod
    def forward(ctx, input, vth, grad_win, grad_amp):
        """
        Args:
            input (Torch Tensor): Input tensor containing voltages of neurons in a layer
            vth (Float): Voltage threshold for spiking 
            grad_win (Float): Window for computing pseudogradient
            grad_amp (Float): Amplification factor for the gradients
        
        Returns:
            output (Torch Tensor): Generated spikes for the input
        
        Write the operation for computing the output spikes from the input. The operation should be vectorized, i.e. no loops. 
        """
        
        #Saving variables for backward pass. Nothing to do here
        ctx.save_for_backward(input)
        ctx.vth = vth
        ctx.grad_win = grad_win
        ctx.grad_amp = grad_amp
        
        #Compute output from the input. No loops. Hint: Use Pytorch "greater than" function. 
        output = 
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Args:
            grad_output (Torch Tensor): Gradient of the output
        
        Returns:
            grad (Torch Tensor): Gradient of the input
        
        Write the operation for computing the output spikes from the input. The operation should be vectorized, i.e. no loops. 
        """
        
        #Retrieving variables from forward pass. Nothing to do here. 
        input, = ctx.saved_tensors
        vth = ctx.vth
        grad_win = ctx.grad_win
        grad_amp = ctx.grad_amp
        grad_input = grad_output.clone()
        
        #compute the gradient of the input using rectangular pseudograd function
        spike_pseudo_grad = 
        
        #Multiplying by gradient amplifier. Nothing to do here
        grad = grad_amp * grad_input * spike_pseudo_grad.float()
        return grad, None, None, None

# Question 2: Creating a Layer

In this exercise, we will create the class definition for a layer of LIF neurons (similar to the implementation in Assignment 2).

Below is the class definition of an LIF neuron layer. Your task is to write the operation for integrating the presynaptic spikes into voltage, and then transforming to spikes. 

In [None]:
class LinearIFCell(nn.Module):
    """ Leaky Integrate-and-fire neuron layer"""

    def __init__(self, psp_func, pseudo_grad_ops, param):
        """
        Args:
            psp_func (Torch Function): Pre-synaptic function
            pseudo_grad_ops (Torch Function): Pseudo-grad function
            param (tuple): Cell parameters (Voltage Threshold, gradient window, gradient amplitude)
        
        This function is complete. You do not need to do anything here. 
        """
        super(LinearIFCell, self).__init__()
        self.psp_func = psp_func
        self.pseudo_grad_ops = pseudo_grad_ops
        self.vdecay, self.vth, self.grad_win, self.grad_amp = param

    def forward(self, input_data, state):
        """
        Forward function
        Args:
            input_data (Tensor): input spike from pre-synaptic neurons
            state (tuple): output spike of last timestep and voltage of last timestep
        Returns:
            output: output spike
            state: updated neuron states
        
        Write the operation for integrating the presynaptic spikes into voltage.
        """
        pre_spike, pre_volt = state
        
        #Compute the voltage from the presynaptic inputs. This should be a vectorized operation. No loops. 
        volt = 
        
        #Compute the spike output by using the pseudo_grad_ops function. This should be a vectorized operation. No loops.
        output = 
        
        return output, (output, volt)

# Question 3: Creating a Network 

## 3a.
We will now create a one-layer SNN using the class definitions above. Preliminaries: In Assignment 2, the psp was computed using numpy matrix multiplication of weights and inputs. In PyTorch, nn.Linear() achieves the same. You can find the documentation here: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html. We will use this class to serve as our psp function required for creating a layer according to the implementation in Q2. 

Below is the class definition of a network. Your task is to fill in the required components in the init and forward functions. 

In [None]:
class SingleHiddenLayerSNN(nn.Module):
    """ SNN with single hidden layer """

    def __init__(self, input_dim, output_dim, hidden_dim, param_dict):
        """
        Args:
            input_dim (int): input dimension
            output_dim (int): output dimension
            hidden_dim (int): hidden layer dimension
            param_dict (dict): neuron parameter dictionary for each layer (Voltage Threshold, gradient window, gradient amplitude)
        
        Create hidden and output layers using implementation of the layer in Q2. and using nn.Linear as the psp function. 
        """
        super(SingleHiddenLayerSNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        pseudo_grad_ops = PseudoSpikeRect.apply
        
        #Create the hidden layer. Assume that the hidden layer neuron parameters are in param_dict['hid_layer']. Set bias=False for nn.Linear.  
        self.hidden_cell = 
        
        
        #Create the output layer. Output layer params are in param_dict['out_layer']. Set bias=False for nn.Linear.   
        self.output_cell = 
        

    def forward(self, spike_data, init_states_dict, batch_size, spike_ts):
        """
        Forward function
        Args:
            spike_data (Tensor): spike data input (batch_size, input_dim, spike_ts)
            init_states_dict (dict): initial states for each layer- 'hid_layer' for hidden layer; 'out_layer' for output layer. 
            batch_size (int): batch size
            spike_ts (int): spike timesteps
        Returns:
            output: number of spikes of output layer
        
        Write the operations for propagating the input through the network and computing the spike outputs. 
        """
        hidden_state, out_state = init_states_dict['hid_layer'], init_states_dict['out_layer']
        spike_data_flatten = spike_data.view(batch_size, self.input_dim, spike_ts)
        output_list = [] #List to store the output at each timestep
        for tt in range(spike_ts):
            #Retrieve the input at time tt
            
            
            #Propagate through the hidden layer
            
            
            #Propagate through the output layer
            
            
            #Append output spikes to output list
            
        
        #Sum the outputs to compute spike count for each output neuron. Torch.stack and Torch.sum might be useful here. No loops
        output = 
        
        return output

## 3b. 
Next, we need a Wrapper class that: i) Initializes the parameters required for creating the SNN class object; ii) creates the SNN class objects using the initial parameters; iii) Computes the SNN output and returns it. The class is already written for you. You do not need to do anything here. Just understand the implementation so that you can use it later. 

In [None]:
class WrapSNN(nn.Module):
    """ Wrapper of SNN """

    def __init__(self, input_dim, output_dim, hidden_dim, param_dict, device):
        """
        Args:
            input_dim (int): input dimension
            output_dim (int): output dimension
            hidden_dim (int): hidden layer dimension
            param_dict (dict): neuron parameter dictionary
            device (device): device
        """
        super(WrapSNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.device = device
        self.snn = SingleHiddenLayerSNN(input_dim, output_dim, hidden_dim, param_dict)

    def forward(self, spike_data):
        """
        Forward function
        Args:
            spike_data (Tensor): spike data input
        Returns:
            output: number of spikes of output layer
        """
        batch_size = spike_data.shape[0]
        spike_ts = spike_data.shape[-1]
        init_states_dict = {}
        # Hidden layer
        hidden_volt = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        hidden_spike = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        init_states_dict['hid_layer'] = (hidden_spike, hidden_volt)
        # Output layer
        out_volt = torch.zeros(batch_size, self.output_dim, device=self.device)
        out_spike = torch.zeros(batch_size, self.output_dim, device=self.device)
        init_states_dict['out_layer'] = (out_spike, out_volt)
        # SNN
        output = self.snn(spike_data, init_states_dict, batch_size, spike_ts)
        return output

# Question 4: Encoding MNIST into spikes
Following is the function that converts an MNIST image into spikes. You have already implemented it using Numpy in Assignment 2. The implementation remains the same- except that we will now use Torch tensors instead of numpy arrays. Fill in the components to convert a batch of torch tensors into spikes. Since the goal is to learn writing optimized code using PyTorch, you are supposed to do this without any loops. Use vector operations instead. 

In [None]:
def img_2_event_img(image, device, spike_ts):
    """
    Transform image to event image
    Args:
        image (Tensor): image
        device (device): device (can be either CPU or GPU)
        spike_ts (int): spike timestep
    Returns:
        event_image: event image
    """
    batch_size = image.shape[0]
    channel_size = image.shape[1]
    image_size = image.shape[2]
    image = image.view(batch_size, channel_size, image_size, image_size, 1)
    
    #Create a random image of shape batch_size x channel_size x image_size x image_size x spike_ts. Torch rand function might be useful here. 
    #Remember to put the random image on the device specified in the function argument. 
    random_image = 
    
    #Generate event image using image and random image
    event_image = 

    return event_image

# Question 5: Training the SNN

In this exercise, we will write the function to train an SNN using spatiotemporal backprop (stbp). A typical training loop works as follows:

1. Split the dataset into train and test. The network is trained on all the batches in the train dataset, and then validated on the test dataset. This gives us an idea of how well the network generalizes to unseen data. 

2. A criterion is defined to compute the loss. For classification tasks, this is generally Cross Entropy (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)

3. The network is initialized with random weights. 

4. Typically, the training is done on mini-batches of train data. This means that at every training instance, the network receives batches of training data where each batch contains small number of samples. The loss and the gradients required for updating the weights are averaged over these batches. 

5. Within every training iteration, first we retrieve the batches of data. We compute the network prediction. Then we compute the loss by comparing the network prediction against the true labels. Then the gradient of the loss with respect to all the network weights is computed. This gradient is then used to update the weights in the network.  

Thankfully, PyTorch provides APIs to automate most of the above steps. Below is the function training an SNN that implements the algorithm presented above using PyTorch. Your task is to fill the components. Refer to the comments for hints on what PyTorch functions to use.   

In [None]:
def stbp_snn_training(network, spike_ts, device, batch_size=128, test_batch_size=256, epoch=100):
    """
    STBP SNN training
    Args:
        network (SNN): STBP learning SNN
        spike_ts (int): spike timestep
        device (device): device
        batch_size (int): batch size for training
        test_batch_size (int): batch size for testing
        epoch (int): number of epochs
    Returns:
        train_loss_list: list of training loss for each epoch
        test_accuracy_list: list of test accuracy for each epoch
    """
    
    #Creating folder where MNIST data is saved. Loading the MNIST dataset. This code is complete. Do not touch it. 
    try:
        os.mkdir("./data")
        print("Directory data Created")
    except FileExistsError:
        print("Directory data already exists")
    
    data_path = './data/'
    train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=True,
                                               transform=transforms.ToTensor())
    test_dataset = torchvision.datasets.MNIST(root=data_path, train=False, download=True,
                                              transform=transforms.ToTensor())

    # Train and test dataloader
    
    #Given the train and test datasets, we need to create dataloaders to load the datasets in the right format. 
    #You can read about PyTorch Dataset and Dataloaders here: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
    #For now, this part is complete and you do not need to do anything here. 
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size,
                                 shuffle=False, num_workers=4)

    
    # Next we need to define a criteria for computing the loss. 
    # Refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html on how to define the cross entropy loss. 
    # Note that you just need to define the loss here (and not compute it) 
    
    criterion = 
    
    #Define an optimizer that will perform the weight updates. You can find more about the optimizers in PyTorch here: https://pytorch.org/docs/stable/optim.html
    #Taking the help of the documentation above, create an optimizer for stochastic gradient descent (SGD). 
    optimizer = 
    

    # List for saving loss and accuracy
    train_loss_list, test_accuracy_list = [], []
    test_num = len(test_dataset)

    # Start training
    
    #Put the network on the device (typically GPU but can also be cpu)
    network.to(device)
    
    #Loop for the epochs
    for ee in range(epoch):
        #Keep track of running loss
        running_loss = 0.0
        running_batch_num = 0
        train_start = time.time()
        
        #Iterate over the training data in train dataloader
        for data in train_dataloader:
            
            #Retrieve the image and label from data

            
            #Put the image and labels on the device

            
            #Convert images to event images

            
            #Before we backprop, we need to set the gradients for each tensor to zero. This is done using the zero_grad function in Pytorch

            
            #Compute the network output for the event images

            
            #Compute the loss using the criterion defined previously. Store in a variable called loss

            
            #Backpropagate the loss through the network. Use Pytorch backward() function: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html

            
            #Update the network weights by taking an optimizer 'step'. 
            #You can learn how to do that here: https://pytorch.org/docs/stable/optim.html#taking-an-optimization-step

            
            #Updating tracking variables. Nothing to do here
            running_loss += loss.item()
            running_batch_num += 1
        train_end = time.time()
        train_loss_list.append(running_loss / running_batch_num)
        print("Epoch %d Training Loss %.4f" % (ee, train_loss_list[-1]), end=" ")
        
        
        #This ends one training iteration. After every training iteration, we can evaluate how well the network does on data that it has not seen before. 
        #This step is called testing and is done on test dataset. 
        
        #Counter to keep track of the number of correct predictions
        test_correct_num = 0
        test_start = time.time()
        with torch.no_grad():
            for data in test_dataloader:
                
                #Retrieve the image and label from test data

                
                #Put the image and labels on the device

                
                #Convert the image into event images

                
                #Compute the network predictions and store in a variable called outputs

                
                #Get the class label as the largest activation. This is complete. Nothing to do here.
                _, predicted = torch.max(outputs, 1)
                
                #Compare the network predictions against the true labels and update the counter for correct predictions. No loops. 

        
        #Updating tracking variables. Nothing to do here
        test_end = time.time()
        test_accuracy_list.append(test_correct_num / test_num)
        print("Test Accuracy %.4f Training Time: %.1f Test Time: %.1f" % (
            test_accuracy_list[-1], train_end - train_start, test_end - test_start))
    
    #Return the loss and accuracies. Nothing to do here. 
    print("End Training")
    network.to('cpu')
    return train_loss_list, test_accuracy_list

Now we have everything ready to train and test our SNN using backprop. All that is left to do is initialize the network with the right parameters and call the training function on it. 


In [None]:
# Define the device on which training will be performed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Define the input dimensions in a variable


#Define the output dimensions in a variable


#Define the hidden layer dimension in a variable


#Create a dictionary of the neuron parameters for the hidden and output layer. The keys should be 'hid_layer' and 'out_layer'.
#The values of the dictionary should be a list of the neuron parameters for each layer where the list elements are [vdecay, vth, grad_win, grad_amp]


#Define snn timesteps in a variable


#Create the SNN using the class definition in 3b and the arguments defined above


#Define the following training parameters
#Batch size for training


#Batch size for testing


#Epochs


#Train the snn using the above arguments and the definition in Q5. 


# Question 6: Concluding Remarks 
If you have been able to implement all the parts in Q.1.-Q.5., you can now feel proud of yourself for having learned the basics of a deep learning pipeline that powers all the modern AI applications. Of course, the exact implementation varies in the input and network design for more complex applications, but the pipeline remains the same. In fact, we went beyond the conventional deep learning, and implemented a spiking deep learning algorithm which is not a skill that many people possess. 

Now that you know the basic implementation, there are lots of directions which you can pursue if you are interested in this research area- how do the hyper-parameters such as the pseudo-gradient window affect the network training process? How much training data do you need to achieve good performance? How does the amount of training data vary with the complexity of the task? To what limits can you push the spike encoding, i.e. what is the minimum timesteps you need to achieve good performance? 

You do not need to answer these questions in this assignment but it is something to think about if you are interested. 

**As the concluding question of the course, can you describe what are the three key lessons you learned from the course and why is it important to learn them?**

## Answer 6
Double click to enter your response to Question 6 here. 