# Learning Goals

In Assignment 1, we studied how information is represented by a single spiking neuron. In this assignment, you will learn how to construct networks of spiking neurons for a given cognitive task, how to propagate information through a network, and understanding the intuition behind network design choices. 

Let's first import all the libraries required for this assignment

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
import pickle

# Question 1: From single neuron to network of neurons
## 1a.
What computational advantages do networks of neurons offer when compared against information processing by a single neuron? In other words, why do we need networks of neurons? 

## Answer 1a.
Single neuron models, such as the Hodgkin-Huxley (HH) and Leaky Integrate-And-Fire (LIF), offer computational advantages by allowing us to observe how action potentials (impulses) are generated and propagated in neurons. The Hodgkin-Huxley model is a mathematical representation that accurately captures the biophysical features of real neurons. By incorporating voltage-gated ion channels and considering their dynamics, this model provides a detailed description of how changes in membrane conductance lead to action potential generation. On the other hand, the Leaky Integrate-And-Fire model, while less detailed, is computationally efficient and easier to implement in simulations, particularly for neural networks, due to its simplicity.

While single neuron models have their advantages, networks of neurons also offer distinct computational benefits. Networks of neurons are capable of handling more complex computations compared to single neurons. By connecting multiple neurons into networks, it becomes feasible to perform a wide array of information processing tasks. In fact, networks can encode and process information using patterns of activity distributed across many neurons. This distributed representation enables the network to capture complex relationships and features present in the input data. Additionally, neural networks can be used to implement training algorithms, such as softmax regression (something I *tried* to implement in my Intro to AI class!)

## 1b. 
Describe the algorithm for the information flow through a network of spiking leaky-integrate-and-fire (LIF) neurons. Specifically, trace out the steps required to compute network output from a given (continuous-valued) inputs. The algorithm should describe how continuous-valued inputs are fed to the SNN input layer, how the layer activations are computed, and how the output layer activity is decoded. Also, provide a diagrammatic overview of the algorithm to aid your explanation. You are free to assume any network size, and input and output dimensions. 

## Answer 1b.
We consider a basic feedforward spiking neural network with one input layer, one hidden layer, and one output layer. The steps involved in computing the network output from continuous-valued inputs are as follows:

**Input Layer:**
- Continuous-valued inputs are encoded into spike trains that represent the input signals to the network. This encoding can be achieved using techniques such as rate coding, where the firing rate of neurons in the input layer represents the *magnitude* of the input values, or temporal coding, where spike *timing* encodes the input values.

**Hidden Layers:**
- The encoded spike trains are fed into the input layer of the network.
- The input layer neurons integrate incoming spikes over time according to the leaky-integrate-and-fire dynamics. Each neuron computes its membrane potential based on the incoming spikes and the leakage of charge over time.
- Once the membrane potential of a neuron crosses a certain threshold, it generates an actional potential (or spike) and resets its membrane potential to a resting state.
- The spikes generated by neurons in the input layer propagate to neurons in the hidden layer, where the process of integration and spike generation repeats.
- This propagation then reaches the output layer.

**Output Layer:**
- Neurons in the output layer integrate incoming spikes from the preceding layer and produce output spikes according to their membrane potential dynamics.
- The spike trains generated by neurons in the output layer are decoded to obtain continuous-valued output values. This encoding can be achieved using techniques such as population coding or temporal decoding.

Diagram:
`
Continuous-valued Inputs -> Input Encoding -> Input Layer (LIF Neurons) -> Hidden Layers (LIF Neurons) -> Output Layer (LIF Neurons) -> Output Decoding -> Continuous-valued Output
`

Each layer of the network consists of LIF neurons that integrate incoming spikes and generate output spikes based on their membrane potential dynamics.

# Question 2: Elements of Constructing Feedforward Networks
In this exercise, you will implement the two fundamental components of a feedforward spiking neural network: i) layers of neurons and ii) connections between those layers
## 2a. 
As the first step towards creating an SNN, we will create a class that defines a layer of LIF neurons. The layer object creates a collection of LIF neurons and applies input current to it (also called psp_input for postsynaptic input) to produce the collective spiking output of the layer. 

Below is the class definition for a layer of LIFNeurons. Fill in the components to define the layer. 

In [2]:
class LIFNeurons:
    """ Define Leaky Integrate-and-Fire Neuron Layer """

    def __init__(self, dimension, vdecay, vth):
        """
        Args:
            dimension (int): Number of LIF neurons in the layer.
            vdecay (float): voltage decay of LIF neurons.
            vth (float): voltage threshold of LIF neurons.
        
        This function is complete. You do not need to do anything here.
        """
        self.dimension = dimension
        self.vdecay = vdecay
        self.vth = vth

        # Initialize LIF neuron states.
        self.volt = np.zeros(self.dimension)
        self.spike = np.zeros(self.dimension)
    
    def __call__(self, psp_input):
        """
        Args:
            psp_input (ndarray): Synaptic input current at a single timestep. The shape of this is same as the number of neurons in the layer. 
        Return:
            self.spike: Output spikes from the layer. The shape of this should be the same as the number of neurons in the layer. 
        
        Write the expressions for updating the voltage and generating the spikes for the layer given psp_input at one timestep. 
        """
        # Update the voltage.
        self.volt = self.vdecay * self.volt + psp_input
        
        # Generate the spikes from the voltage.
        self.spike = (self.volt >= self.vth).astype(int)
        
        # Reset the voltage if the neuron spikes.
        self.volt[self.volt >= self.vth] = 0
        
        return self.spike

To verify the correctness of your class implementation, create a layer of neurons using the class definition above, and pass through it random inputs. 

In [3]:
# Create a layer of neurons using the class definition above. You can pick any parameter values for the neurons. 
dimension = 15 # Layer contains 5 LIF neurons.
vdecay = 0.75
vth = 0.75
layer = LIFNeurons(dimension, vdecay, vth)

# Create random input spikes with any probability and print them.
# Numpy random.choice function might be useful here. 
input_random_spikes = np.random.choice(2, dimension)
print("Input spikes:", input_random_spikes)

# Propagate the random input spikes through the layer and print the output.
output_random_spikes = layer(input_random_spikes)
print("Output spikes:", output_random_spikes)

Input spikes: [1 0 1 0 0 1 0 1 1 0 1 1 1 0 0]
Output spikes: [1 0 1 0 0 1 0 1 1 0 1 1 1 0 0]


## 2b.
Now, we will create a class the defines the connection between a presynaptic layer and a postsynaptic layer. To create the connection, we need the activity of the presynaptic layer (also called presynaptic layer activation) and the weight matrix connecting the presynaptic and postsynaptic neurons. The output of the class should be the current for the postsynaptic layer. 

Below is the class definition for Connections. Fill in the components to create the connections. 

In [4]:
class Connections:
    """ Define connections between spiking neuron layers """

    def __init__(self, weights, pre_dimension, post_dimension):
        """
        Args:
            pre_dimension (int): Number of neurons in the presynaptic layer.
            post_dimension (int): Number of neurons in the postsynaptic layer.
            weights (ndarray): Connection weights of shape post_dimension x pre_dimension.

        This function is complete. You do not need to do anything here.

        """
        self.weights = weights
        self.pre_dimension = pre_dimension
        self.post_dimension = post_dimension
    
    def __call__(self, spike_input):
        """
        Args:
            spike_input (ndarray): Spikes generated by the pre-synaptic neurons.
        Return:
            psp: Current for the post-synaptic neurons.
        
        Write the operation for computing psp.
        """
        # Compute psp given spike_input and self.weights.
        psp = np.dot(spike_input, np.transpose(self.weights))
        return psp

To verify the correctness of your class implementation, create a connection object and compute the postsynaptic current for random presynaptic activation inputs and random connection weights. You can pick arbitrary values for class arguments. 

In [5]:
# Define the dimensions of the presynaptic layer in a variable.
presynaptic_dimension = 15

# Define the dimensions of the postsynaptic layer in a variable.
postsynaptic_dimension = 5

# Create random presynaptic inputs with any probability.
# Numpy random choice function might be useful here.
presynaptic_input_spikes = np.random.choice(2, presynaptic_dimension)
print("Input spikes:", presynaptic_input_spikes)

# Create a random connection weight matrix.
# Numpy random rand function might be useful here.
weight_matrix = np.random.rand(postsynaptic_dimension, presynaptic_dimension)
print("Weight matrix:\n", weight_matrix)

# Initialize a connection object using the Connection class definition and pass the variables created above as arguments.
connection = Connections(weight_matrix, presynaptic_dimension, postsynaptic_dimension)

# Compute the current for the postsynaptic layer when the connection object is fed random presynaptic activation inputs.
postsynaptic_current = connection(presynaptic_input_spikes)
print("Postsynaptic curent:", postsynaptic_current)

# Print the shape of the current.
print("Postsynaptic curent shape:", postsynaptic_current.shape)

Input spikes: [1 0 0 0 1 0 0 0 1 1 0 1 0 1 1]
Weight matrix:
 [[0.26391067 0.55366929 0.3168378  0.13303788 0.90770484 0.6312124
  0.29965567 0.22770978 0.83435587 0.64517841 0.13746944 0.83242934
  0.38821789 0.37660779 0.09378695]
 [0.00250425 0.96725963 0.10562083 0.26514974 0.17059024 0.41420219
  0.63261446 0.68923838 0.19734805 0.81045031 0.59879075 0.97231647
  0.08918538 0.03986981 0.78913102]
 [0.58828671 0.21071273 0.72032937 0.58988139 0.1796041  0.72278894
  0.25409031 0.24071129 0.3490788  0.98012182 0.39217624 0.84857443
  0.38422023 0.81411904 0.63597258]
 [0.71034475 0.0302767  0.47419059 0.02722505 0.70278506 0.00731493
  0.53194681 0.05236041 0.83843408 0.34066595 0.04489991 0.78290076
  0.71728313 0.2438638  0.23712024]
 [0.86013299 0.05630427 0.89600343 0.96446268 0.3866694  0.88353065
  0.08458162 0.76968445 0.86356813 0.13837095 0.25429938 0.68780056
  0.83820123 0.87912276 0.03760556]]
Postsynaptic curent: [3.95397386 2.98221016 4.39575748 3.85611463 3.85327036]


# Question 3: Constructing Feedforward SNN
Now that you have implemented the basic elements of an SNN- layer and connection, you are all set to implement a fully functioning SNN. The SNN that you will implement here consists of an input layer, a hidden layer, and an output layer. 

Below is the class definition of an SNN. Your task is to create the layers and connections that form the network using the class definitions in Question 2. Then complete the function to propagate a given input through the network and decode network output. 

In [6]:
class SNN:
    """ Define a Spiking Neural Network with One Hidden Layer """
    def __init__(self, input_2_hidden_weight, hidden_2_output_weight, 
                 input_dimension=784, hidden_dimension=256, output_dimension=10,
                 vdecay=0.5, vth=0.5, snn_timestep=20):
        """
        Args:
            input_2_hidden_weight (ndarray): weights for connection between input and hidden layer. dimension should be hidden_dimension x input_dimension. 
            hidden_2_output_weight (ndarray): weights for connection between hidden and output layer. dimension should be output dimension x hidden dimension. 
            input_dimension (int): number of neurons in the input layer
            hidden_dimension (int): number of neurons in the hidden layer
            output_dimension (int): number of neurons in the output layer
            vdecay (float): voltage decay of LIF neurons
            vth (float): voltage threshold of LIF neurons
            snn_timestep (int): number of timesteps for simulating the network (also called inference timesteps)
        """
        self.snn_timestep = snn_timestep
        
        # Create the hidden layer.
        self.hidden_layer = LIFNeurons(hidden_dimension, vdecay, vth)
        
        # Create the output layer.
        self.output_layer = LIFNeurons(output_dimension, vdecay, vth)
        
        # Create the connection between input and hidden layer.
        self.input_2_hidden_connection = Connections(input_2_hidden_weight, input_dimension, hidden_dimension)
        
        # Create the connection between hidden and output layer.
        self.hidden_2_output_connection = Connections(hidden_2_output_weight, hidden_dimension, output_dimension)
    
    def __call__(self, spike_encoding):
        """
        Args:
            spike_encoding (ndarray): spike encoding of input
        Return:
            output: decoded output from the network
        """
        # Initialize an array to store the decoded network output for all neurons in the output layer.
        spike_output = np.zeros(self.output_layer.dimension)
                                                    
        #Loop through the simulation timesteps and process the input at each timestep tt
        for tt in range(self.snn_timestep):
            # Propagate the input through the input to hidden layer and compute current for hidden layer.
            hidden_layer_current = self.input_2_hidden_connection(spike_encoding[tt])
           
            # Compute hidden layer spikes.
            hidden_layer_spikes = self.hidden_layer(hidden_layer_current)
            
            # Propagate hidden layer inputs to output layer and compute current for output layer.
            output_layer_current = self.hidden_2_output_connection(hidden_layer_spikes)
            
            # Compute output layer spikes.
            output_layer_spikes = self.output_layer(output_layer_current)
            
            # Decode spike outputs by summing them up.
            spike_output = spike_output + output_layer_spikes
            
        return spike_output

To verify the correctness of your class implementation, define the arguments to initialize the SNN. Then initialize the SNN and pass through it random inputs and compute network outputs. You can pick arbitrary values for class arguments. 

In [7]:
# Define the dimensions of the input layer in a variable.
input_dim = 15

# Define the dimensions of the hidden layer in a variable.
hidden_dim = 10

# Define the dimensions of the output layer in a variable.
output_dim = 5

# Define vdecay in a variable.
vdecay = 0.5

# Define vth in a variable.
vth = 0.5

# Define snn_timesteps in a variable.
snn_timestep = 20

# Create random input to hidden layer weights.
# Numpy random rand function might be useful here.
input_2_hidden_weight = np.random.rand(hidden_dim, input_dim)

# Create random hidden to output layer weights.
# Numpy random rand function might be useful here.
hidden_2_output_weight = np.random.rand(output_dim, hidden_dim)

# Create random spike inputs to the network.
# Numpy random choice function might be useful here
spike_inputs = np.random.choice(2, (snn_timestep, input_dim))

# Print the inputs.
print('Input:\n', spike_inputs)

# Create an SNN object using the class definition and variables defined above.
SNN_object = SNN(input_2_hidden_weight, hidden_2_output_weight,
                 input_dim, hidden_dim, output_dim,
                 vdecay, vth, snn_timestep)

# Pass the random spike inputs through the SNN and print the output of the SNN.
output = SNN_object.__call__(spike_inputs)
print('Output:\n', output)

Input:
 [[1 0 1 0 0 1 0 1 0 1 1 0 1 1 0]
 [1 1 1 0 1 1 1 0 1 1 1 0 1 1 1]
 [0 1 1 0 0 0 0 1 0 1 1 0 1 0 1]
 [0 1 1 1 0 0 0 0 1 0 0 1 0 1 0]
 [0 0 0 1 0 1 0 1 1 0 1 1 1 0 1]
 [0 0 1 1 1 0 0 1 0 1 1 1 1 0 0]
 [1 0 1 0 1 0 1 1 0 0 1 1 1 1 0]
 [0 0 1 1 1 0 1 1 0 0 1 1 0 1 0]
 [1 0 0 0 1 0 1 0 1 0 0 1 0 0 1]
 [0 0 1 1 1 1 1 1 0 0 0 0 1 1 1]
 [1 1 1 1 1 0 0 1 0 0 0 0 1 0 1]
 [1 1 0 0 0 1 1 0 0 0 1 1 0 1 0]
 [0 0 0 0 1 1 0 1 0 0 1 0 0 0 0]
 [0 1 1 0 0 0 0 1 1 1 0 0 1 0 0]
 [1 1 1 1 1 0 1 0 1 0 0 1 1 1 1]
 [1 0 0 1 1 0 0 0 1 0 0 0 0 0 0]
 [1 0 1 0 0 1 0 1 0 1 1 0 0 0 1]
 [1 1 0 1 1 0 0 1 0 1 0 0 1 0 0]
 [0 1 1 0 0 1 1 0 0 0 0 1 0 0 1]
 [1 0 1 0 1 1 1 0 1 1 0 0 1 1 1]]
Output:
 [20. 20. 20. 20. 20.]


# Question 4: SNN for Classification of Digits
So far we have learnt how to construct SNNs for random inputs. In this exercise, you will use your implementation of SNNs to classify real-world data, taking the dataset of handwritten digits as an example. The dataset is provided as numpy arrays in the folder "data". Each sample in the MNIST dataset is a 28x28 image of a digit and a label (between 0 and 9) of that image. We will be dealing with batches, which means that we will read a fixed number of samples from the dataset (also called the batch size).

## 4a. 
First, we need to write two helper functions- to read the data from the saved data files, and to convert an image into spikes. The function to read the data is already written for you. You need to complete the function for encoding the data into spikes. 

In [8]:
def read_numpy_mnist_data(save_root, num_sample):
    """
    Read saved numpy MNIST data
    Args:
        save_root (str): path to the folder where the MNIST data is saved
        num_sample (int): number of samples to read
    Returns:
        image_list: list of MNIST image
        label_list: list of corresponding labels
    
    This function is complete. You do not need to do anything here.
    """
    image_list = np.zeros((num_sample, 28, 28))
    label_list = []
    for ii in range(num_sample):
        image_label = pickle.load(open(save_root + '/' + str(ii) + '.p', 'rb'))
        image_list[ii] = image_label[0]
        label_list.append(image_label[1])

    return image_list, label_list

def img_2_event_img(image, snn_timestep):
    """
    Transform image to spikes, also called an event image
    Args:
        image (ndarray): image of shape batch_size x 28 x 28
        snn_timestep (int): spike timestep
    Returns:
        event_image: event image- spike encoding of the image
        
    Complete the expression for converting the image to spikes (event image)
    """
    
    # Reshape the image. Do not touch this code.
    batch_size = image.shape[0]
    image_size = image.shape[2]
    image = image.reshape(batch_size, image_size, image_size, 1)
    
    # Generate a random image of the shape batch_size x image_size x image_size x snn_timestep.
    # Numpy random rand function will be useful here.
    random_image = np.random.rand(batch_size, image_size, image_size, snn_timestep)
    
    # Generate the event image.
    temp = random_image.transpose(0, 3, 1, 2).reshape(batch_size, snn_timestep, -1)
    event_image = (image.reshape(batch_size, -1)[:, np.newaxis, :] > temp).astype(int)
    
    return event_image

To verify the correctness of your class implementation, load a sample digit from the saved file and convert it into an event image. Then print the shape of the event image. 

In [9]:
# Load 1000 samples from the MNIST dataset using the read function defined above.
image_list, label_list = read_numpy_mnist_data("data/mnist_test", 1000)

# Print the shape of the data.
print(image_list.shape)

# Convert the images to event images.
event_image_list = img_2_event_img(image_list, snn_timestep)

# Print the shape of the event image.
print(event_image_list.shape)

(1000, 28, 28)
(1000, 20, 784)


## 4b. 
Next, we need another helper function to compute the classification accuracy of the network. The classification accuracy is defined as the percentage of the samples that the network classifies correctly. To compute the classification accuracy, you need to:

- Propagate each input through the network and obtain the network output.
- Based on the network output, the class of the image is the one for which the output neuron has maximum value. Let's call this predicted class. 
- Compare the predicted class against the true class. 
- Compute accuracy as the percentage of correct predictions. 

Below is the function for computing the test accuracy. The function takes in as arguments the SNN, directory in which the MNIST data is saved, and the number of samples to take from the MNIST dataset. Your task is to use the helper functions created above to load the data, convert into event images, and then compute network prediction and accuracies. 

In [10]:
def test_snn_with_mnist(network, data_save_dir, data_sample_num):
    """
    Test SNN with MNIST test data
    Args:
        network (SNN): defined SNN network
        data_save_dir (str): directory for the test data
        data_sample_num (int): number of test data examples
    """
    # Read image and labels using the read function.
    test_image_list, test_label_list = read_numpy_mnist_data(data_save_dir, data_sample_num)
    
    # Convert the images to event images.
    test_event_image_list = img_2_event_img(test_image_list, snn_timestep)
    
    # Initialize number of correct predictions to 0.
    correct_prediction = 0
    
    # Loop through the test images.
    for ii in range(data_sample_num):
        # Compute network output for each image.
        # You might have to reshape the image using Numpy reshape function so that its appropriate for the SNN.
        network_output = network(test_event_image_list[ii])
        
        # Determine the class of the image from the network output.
        # Numpy argmax function might be useful here.
        classes = np.argmax(network_output)
        
        # Compare the predicted class against true class and update correct_prediction counter.
        if classes == test_label_list[ii]:
            correct_prediction += 1
    
    # Compute test accuracy.
    test_accuracy = correct_prediction / data_sample_num
    print(correct_prediction, '/', data_sample_num)
    
    return test_accuracy

# Question 5: Tuning Membrane Properties for Correct Classification 
Great! We have everything that we need to measure the performance of the SNN for classification of MNIST digits. For this, we first need to create the SNN using the class definition we wrote in Q.3. Then we need to call the test function that we wrote in Q.4b. However, note that the SNN needs the connection weights between the layers as inputs. These weights are typically obtained as a result of "training" the network for a given task (such as MNIST classification). However, since training the network isn't a part of this assignment, we provide to you already trained weights. 

## 5a. 
Your task in this exercise is to initialize an SNN with vdecay=1.0 and vth=0.5. Test the SNN on MNIST dataset and obtain the classification accuracy.  

In [11]:
# Load the weights. Do not touch this code.
snn_param_dir = 'save_models/snn_bptt_mnist_train.p'
snn_param_dict = pickle.load(open(snn_param_dir, 'rb'))
input_2_hidden_weight = snn_param_dict['weight1']
hidden_2_output_weight = snn_param_dict['weight2']

# Define a variable for vdecay.
vdecay = 1.0

# Define a variable for vth.
vth = 0.5

# Create the SNN using the class definition in Q3 and the variables defined above.
input_dimension = 784
hidden_dimension = 256
output_dimension = 10
snn_timestep = 20

SNN_object2 = SNN(input_2_hidden_weight, hidden_2_output_weight,
              input_dimension, hidden_dimension, output_dimension, vdecay, vth, snn_timestep)

# Compute test accuracy for the SNN on 1000 examples from MNIST dataset and print it.
test_accuracy = test_snn_with_mnist(SNN_object2, "data/mnist_test", 1000)
print(test_accuracy)

241 / 1000
0.241


What could be a possible reason for the poor accuracy?

## Answer 5a. 
Referring back to the method that defines an SNN (__init__) in the class SNN, note that the parameter vdecay represents the decay factor for the membrane potential (voltage) of the neurons in the network. It is set to a value of 0.5 which suggests that the voltage decays by half each time step. However, when vdecay = 1.0, the membrane potential of the neurons in the network will not decay over time. This means that the voltage at each timestep will remain the same, thus affecting the behavior of the network during simulation.

## 5b. 
Can you tune the membrane properties (vdecay and vth) to obtain higher classification accuracies?

In [12]:
# Write your implementation of Question 5b here.
vdecay = 0.4
vth = 0.5

input_dimension = 784
hidden_dimension = 256
output_dimension = 10
snn_timestep = 20

SNN_object3 = SNN(input_2_hidden_weight, hidden_2_output_weight,
              input_dimension, hidden_dimension, output_dimension, vdecay, vth, snn_timestep)

test_accuracy2 = test_snn_with_mnist(SNN_object3, "data/mnist_test", 1000)
print(test_accuracy2)

# Student Note: I kept getting dimension errors, and I believe I fixed it.
# But, just in case, if you want to try other vedecay values here, please do [Restart Kernal and Run All Cells] :(

975 / 1000
0.975


## 5c.
Based on your response to Questions 5a and 5b, can you explain how membrane properties affect network activity for classification?

## Answer 5c.
Double click to enter your response to Question 5c here.

Membrane properties in SNNs, can have significant effects on network activity, which in turn can impact the network's ability to perform classification tasks. Specifically, decay dactor (vdecay) controls how quickly the membrane potential of a neuron decays over time. A higher decay factor (>= 1) means that the membrane potential decays more rapidly, influencing the neuron's responsiveness to input spikes. A lower decay factor (< 1, positive) means that the membrane potential decays more slowly, allowing the neuron to integrate input spikes over a longer period, affecting its sensitivity to temporal patterns in the input. Essentially, membrane properties determine spiking behavior, therefore the proper values are necessary for a valid SNN to perform classification.