# Learning Goals

In Assignment 1, we studied how information is represented by a single spiking neuron. In this assignment, you will learn how to construct networks of spiking neurons for a given cognitive task, how to propagate information through a network, and understanding the intuition behind network design choices. 

Let's first import all the libraries required for this assignment

In [25]:
import math
import numpy as np
import matplotlib.pyplot as plt
import pickle

# Question 1: From single neuron to network of neurons
## 1a.
What computational advantages do networks of neurons offer when compared against information processing by a single neuron? In other words, why do we need networks of neurons? 

## Answer 1a.
Double click to enter your response to Question 1a here

## 1b. 
Describe algorithmically the information flow through a network of spiking leaky-integrate-and-fire (LIF) neurons. Specifically, trace out the steps required to compute network output from a given (continuous-valued) inputs. The algorithm should describe how continuous-valued inputs are fed to the SNN input layer, how the layer activations are computed, and how the output layer activity is decoded. Also, provide a diagrammatic overview of the algorithm to aid your explanation. You are free to assume any network size, and input and output dimensions. 

## Answer 1b.
Double click to enter your response to Question 1b here

# Question 2: Elements of Constructing Feedforward Networks
The fundamental components of a feedforward network are layers of neurons and connections between those layers. In this exercise, you will implement these two fundamental components of a feedforward spiking neural network.
## 2a. 
As the first step towards creating an SNN, we will create a class that defines a layer of LIF neurons. The layer object creates a collection of LIF neurons and applies inputs to it to produce the collective spiking output of the layer. 

In [18]:
class LIFNeurons:
    """ Define Leaky Integrate-and-Fire Neuron Layer """

    def __init__(self, dimension, vdecay, vth):
        """
        Args:
            dimension (int): Number of LIF neurons in the layer
            vdecay (float): voltage decay of LIF neurons
            vth (float): voltage threshold of LIF neurons
        
        This function is complete. You do not need to do anything here.
        """
        self.dimension = dimension
        self.vdecay = vdecay
        self.vth = vth

        # Initialize LIF neuron states
        self.volt = np.zeros(self.dimension)
        self.spike = np.zeros(self.dimension)
    
    def __call__(self, psp_input):
        """
        Args:
            psp_input (ndarray): synaptic inputs 
        Return:
            self.spike: output spikes from the layer
        
        Write the expressions for updating the voltage and generating the spike. 
        """
        self.volt = self.vdecay * self.volt * (1. - self.spike) + psp_input
        self.spike = (self.volt > self.vth).astype(float)
        return self.spike

To verify the correctness of your class implementation, create a layer of neurons using the class definition above, and pass through it random inputs. 

In [5]:
#Create a layer of neurons using the class definition above
layer = LIFNeurons(10, 0.3, 0.75)

#Create random input spikes and print them
in_spikes = np.random.choice([0,1], 10, p=[0.6, 0.4])
print('Inputs: ', in_spikes)

#Propagate the random input spikes through the layer and print the output
out = layer(in_spikes)
print('Output: ', out)

Inputs:  [0 1 0 0 0 1 0 0 1 0]
Output:  [0. 1. 0. 0. 0. 1. 0. 0. 1. 0.]


## 2b.
Now we will create a class the defines connections between the spiking neuron layers. The connection object takes in inputs as the activations of the presynaptic layer and the connection weights between the pre- and post-synaptic layers. The output is the activation of the postsynaptic layer. 

In [19]:
class Connections:
    """ Define connections between spiking neuron layers """

    def __init__(self, weights, pre_dimension, post_dimension):
        """
        Args:
            weights (ndarray): connection weights
            pre_dimension (int): dimension for pre-synaptic neurons
            post_dimension (int): dimension for post-synaptic neurons
        """
        self.weights = weights
        self.pre_dimension = pre_dimension
        self.post_dimension = post_dimension
    
    def __call__(self, spike_input):
        """
        Args:
            spike_input (ndarray): spikes generated by the pre-synaptic neurons
        Return:
            psp: postsynaptic layer activations
        """
        psp = np.matmul(self.weights, spike_input)
        return psp

To verify the correctness of your class implementation, create a connection object and compute the activation of the postsynaptic layer for random presynaptic activation inputs and random connection weights. 

In [9]:
#Define the dimensions of the presynaptic layer in a variable
pre_dim = 10

#Define the dimensions of the postsynaptic layer in a variable
post_dim = 100

#Create random presynaptic inputs
in_spikes = np.random.choice([0,1], 10, p=[0.6, 0.4])

#Create a random connection weight matrix
conn_weights = np.random.rand(100, 10)

#Initialize a connection object using the Connection class definition and pass the variables created above as arguments
conn = Connections(conn_weights, pre_dim, post_dim)

#Compute the postsynaptic activation when the connection object is fed random presynaptic activation inputs
psp = conn(in_spikes)

#Print the shape of the postsynaptic activation
print(psp.shape)

(100,)


# Question 3: Constructing Feedforward SNN
Now that you have implemented the basic elements of an SNN- layer and connection, you are all set to implement a fully functioning 1-layer SNN. 

In [20]:
class SNN:
    """ Define a Spiking Neural Network with One Hidden Layer """

    def __init__(self, input_2_hidden_weight, hidden_2_output_weight, 
                 input_dimension=784, hidden_dimension=256, output_dimension=10,
                 vdecay=0.5, vth=0.5, snn_timestep=20):
        """
        Args:
            input_2_hidden_weight (ndarray): weights for connection between input and hidden layer
            hidden_2_output_weight (ndarray): weights for connection between hidden and output layer
            input_dimension (int): input dimension
            hidden_dimension (int): hidden_dimension
            output_dimension (int): output_dimension
            vdecay (float): voltage decay of LIF neuron
            vth (float): voltage threshold of LIF neuron
            snn_timestep (int): number of timesteps for inference
        """
        self.snn_timestep = snn_timestep
        self.hidden_layer = LIFNeurons(hidden_dimension, vdecay, vth)
        self.output_layer = LIFNeurons(output_dimension, vdecay, vth)
        self.input_2_hidden_connection = Connections(input_2_hidden_weight, input_dimension, hidden_dimension)
        self.hidden_2_output_connection = Connections(hidden_2_output_weight, hidden_dimension, output_dimension)
    
    def __call__(self, spike_encoding):
        """
        Args:
            spike_encoding (ndarray): spike encoding of input
        Return:
            spike outputs of the network
        """
        spike_output = np.zeros(self.output_layer.dimension)
        for tt in range(self.snn_timestep):
            input_2_hidden_psp = self.input_2_hidden_connection(spike_encoding[:, tt])
            hidden_spikes = self.hidden_layer(input_2_hidden_psp)
            hidden_2_output_psp = self.hidden_2_output_connection(hidden_spikes)
            output_spikes = self.output_layer(hidden_2_output_psp)
            spike_output += output_spikes
        return spike_output

To verify the correctness of your class implementation, define the arguments to initialize the SNN. Then initialize the SNN and pass through it random inputs and compute network outputs. 

In [13]:
input_dim = 10
hid_dim = 100
out_dim = 10
vdecay=0.5
vth=0.5
snn_timestep=20
in_2_hid_weight = np.random.rand(100, 10)
hid_2_out_weight = np.random.rand(10, 100)
spike_encoding = np.random.choice([0,1], (input_dim, snn_timestep), p=[0.6, 0.4])
print('Inputs: ', spike_encoding)
snn = SNN(in_2_hid_weight, hid_2_out_weight, input_dimension=input_dim, hidden_dimension=hid_dim, output_dimension=out_dim, vdecay=0.5, vth=0.5, snn_timestep=20)
out = snn(spike_encoding)
print('Outputs: ', out)

Inputs:  [[1 0 1 0 0 1 0 0 0 1 0 1 0 1 0 1 0 0 1 1]
 [0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 0 1 1 0]
 [0 1 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 1 0 1]
 [0 0 1 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 0 1]
 [0 1 1 0 0 0 0 0 0 0 1 0 0 1 0 1 1 1 0 0]
 [0 1 0 1 1 1 1 1 0 0 1 0 1 1 0 1 0 0 0 0]
 [1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0]
 [0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 0 1 0 1 0]
 [1 0 1 0 0 0 1 0 0 1 0 0 1 1 0 1 1 0 0 0]]
Outputs:  [19. 19. 19. 19. 19. 19. 19. 19. 19. 19.]


# Question 4: SNN for Classification of Digits
So far we have learnt how to construct SNNs for random inputs. In this exercise, you will use your implementation of SNNs to classify real-world data, taking the dataset of handwritten digits as an example. The dataset is provided as numpy arrays in the folder data. 

## 4a. 
First, we need to write two helper functions- to read the data from the saved data files, and to convert an image into spikes. The function to read the data is already written for you. You need to complete the function for encoding the data into spikes. 

In [21]:
def read_numpy_mnist_data(save_root, num_sample):
    """
    Read saved numpy MNIST data
    Args:
        save_root (str): path for save data
        num_sample (int): number of samples to read
    Returns:
        image_list: list of MNIST image
        label_list: list of labels
    """
    image_list = np.zeros((num_sample, 28, 28))
    label_list = []
    for ii in range(num_sample):
        image_label = pickle.load(open(save_root + '/' + str(ii) + '.p', 'rb'))
        image_list[ii] = image_label[0]
        label_list.append(image_label[1])

    return image_list, label_list

def img_2_event_img(image, snn_timestep):
    """
    Transform image to event image
    Args:
        image (ndarray): image
        snn_timestep (int): spike timestep
    Returns:
        event_image: event image
    """
    batch_size = image.shape[0]
    image_size = image.shape[2]
    image = image.reshape(batch_size, image_size, image_size, 1)
    random_image = np.random.rand(batch_size, image_size, image_size, snn_timestep)
    event_image = (random_image < image).astype(float)

    return event_image

To verify the correctness of your class implementation, load a sample digit from the saved file and convert it into an event image. Then print the shape of the event image. 

In [15]:
data, labels = read_numpy_mnist_data("../numpy-mnist-snn/data/mnist_test", 5)
print(data.shape)
event_img = img_2_event_img(data, 20)
print(event_img.shape)

(5, 28, 28)
(5, 28, 28, 20)


## 4b. 
Next, we need another helper function to compute the classification accuracy of the network for given samples on MNIST digits. The function takes in as arguments the SNN, directory in which the data is saved, and the number of samples to take from the MNIST dataset. Your task is to use the helper functions created above to load the data, convert to event images, and then compute network prediction and accuracies. 

In [22]:
def test_snn_with_mnist(network, data_save_dir, data_sample_num):
    """
    Test SNN with MNIST test data
    Args:
        network (SNN): defined SNN network
        data_save_dir (str): directory for the test data
        data_sample_num (int): number of test data
    """
    test_image_list, test_label_list = read_numpy_mnist_data(data_save_dir, data_sample_num)
    test_event_image_list = img_2_event_img(test_image_list, network.snn_timestep)
    correct_predition = 0
    for ii in range(data_sample_num):
        snn_output = network(test_event_image_list[ii].reshape(-1, network.snn_timestep))
        pred_label = np.argmax(snn_output)
        if pred_label == test_label_list[ii]:
            correct_predition += 1
    test_accuracy = correct_predition / data_sample_num
    return test_accuracy

# Question 5: Tuning Membrane Properties for Correct Classification 
Great! We have everything that we need to measure the performance of the SNN for classification of MNIST digits. For this, we first need to create the SNN using the class definition we wrote in Q.3. Then we need to call the test function that we wrote in Q.4b. However, note that the SNN needs the connection weights between the layers as inputs. These weights are typically obtained as a result of "training" the network for a given task (such as MNIST classification). However, since training the network isn't a part of this assignment, we provide to you already trained weights. 

## 5a. 
Your task in this exercise is to initialize an SNN with vdecay=1.0 and vth=0.5. Test the SNN on MNIST dataset and obtain the classification accuracy.  

In [23]:
snn_param_dir = '../numpy-mnist-snn/save_models/snn_bptt_mnist_train.p'
snn_param_dict = pickle.load(open(snn_param_dir, 'rb'))
input_2_hidden_weight = snn_param_dict['weight1']
hidden_2_output_weight = snn_param_dict['weight2']
vdecay = 1.0
vth = 0.5
snn = SNN(input_2_hidden_weight, hidden_2_output_weight, vdecay=vdecay, vth=vth)
test_acc = test_snn_with_mnist(snn, '../numpy-mnist-snn/data/mnist_test', 1000)
print("Test Accuracy: ", test_acc)

Test Accuracy:  0.248


What could be a possible reason for the poor accuracy?

## Answer 5a. 
Double click to enter your reponse ot Question 5a. here. 

## 5b. 
Can you tune the membrane properties (vdecay and vth) to obtain higher classification accuracies?

In [24]:
#Write your implementation of Question 5b. here

## 5c.
Based on your response to Questions 5a. and 5b., can you explain how membrane properties affect network activity?

## Answer 5c.
Double click to enter your response to Question 5c. here