# Project 1: Probing direction selectivity in the mouse retina

Welcome to the first project of the class. 

**You will learn to:** 
- Construct direction tuning curves from grating data.
- Quantify direction selectivity.
- Perform statistical comparison of paired samples.

Let's first import the packages we are going to use, and set up some plotting parameters.

In [None]:
# import packages
import numpy as np
import matplotlib.pyplot as plt
import spkbasic
from os import listdir
from os.path import isfile, join
from scipy import stats

%matplotlib inline
plt.rcParams['figure.figsize'] = (15.0, 6.0) # set default size of plots

## 1 - Creating tuning curves

The stimulus times, as well as the spikes of 20 cells are in the folder data_drifting_grating. Let's first start by loading the stimulus times.

In [None]:
stimulus = np.loadtxt('data_drifting_grating/stimulus.txt')

Each cycle contains the presentations of gratings drifting in one of 8 directions, for 400 frames, with a period of 100 frames. There is a pulse in the beginning of each period, except for the first one (as it's usually ignored in the analysis). In between the directions, there is a gray screen of 300 frames.

Let's first input declare our values:

In [None]:
duration=400
period=100
Nangles=8
Ncycles=4

Now we can calculate some imporant values that will help sort out the spikes relative to the stimulus times.

In [None]:
angles=np.linspace(0,2*np.pi,num=8,endpoint=False) # the angles of the drifting grating in radians
NanglePulses=int(np.floor(duration/period))-1
NcyclePulses=NanglePulses*Nangles
NstimPulses=NcyclePulses*Ncycles
pulseTimes=stimulus[0:int(NstimPulses)]
pulseTimes=np.reshape(pulseTimes,(Ncycles,Nangles,NanglePulses))
periodDuration=np.mean(np.diff(pulseTimes,n=1,axis=2))

Having set up the stimulus times, now we can count spikes! Let's load the spikes an example cell

In [None]:
def calculateSpikeCounts(pulseTimes,spikeTimes):
    periodDuration=np.mean(np.diff(pulseTimes,n=1,axis=2)) # get period duration
    pulseTimesShape=np.shape(pulseTimes)
    spikeCounts=np.zeros(pulseTimesShape[:2]) # pre-allocate the array
    
    for iCycle in range(0,pulseTimesShape[0]):
        for iDirection in range(0,pulseTimesShape[1]):
            dirSpikes=spikeTimes[spikeTimes>=pulseTimes[iCycle,iDirection,0]]
            dirSpikes=dirSpikes[dirSpikes<pulseTimes[iCycle,iDirection,-1]+periodDuration]
            spikeCounts[iCycle,iDirection]=np.size(dirSpikes)
    return spikeCounts

In [None]:
exampleSpikeTimes=np.loadtxt('data_drifting_grating/5_SP_C3601.txt')
exampleSpikeCounts=calculateSpikeCounts(pulseTimes,exampleSpikeTimes)
print(exampleSpikeCounts)
tuningCurve=np.mean(exampleSpikeCounts,axis=0)
plt.plot(np.rad2deg(angles),tuningCurve)
plt.title('Tuning Curve')
plt.ylabel('Spike count');
plt.xlabel('Direction (deg)');

**Exercise:** Plot the tuning curve in a polar plot by filling in ```plotTuningCurvePolar(angles, responses)```. Consider using ```plt.polar```, and try to make your plot pretty.

*Hint*: Make sure that your final curve is closed!

In [None]:
def plotTuningCurvePolar(angles,responses):
    '''
    Inputs:
            angles: list of strings
            responses: 
    '''
    
    ### START CODE HERE ### (approx. 1-2 lines)
    plt.polar(np.append(angles,angles[0]),np.append(responses,responses[0]),linewidth=2)
    ### END CODE HERE ###

In [None]:
plotTuningCurvePolar(angles,tuningCurve)

We want now to look in all of our data. We will now load the spike trains of all cells in the dataset in a list, and you will have to figure out how to properly get structured spike counts for all cells now.

In [None]:

spikeTrainList=[np.loadtxt(spath) for spath in pathList]

**Exercise:** Let's now calculate the tuning curves for all cells in our dataset. First, we need to sort the spikes of all cells in spike counts. Fill in the function that does that:

*Hint:* You can iterate over cells using the function ```calculateSpikeCounts```, but you can also start over!

In [None]:
def calculateAllSpikeCounts(pulseTimes,allSpikeTimes):
    '''
    Inputs:
            angles: list of strings
            allSpikeTimes: list of spike trains (in np arrays)
    Outputs:
            allSpikeCounts:
    '''
    
    ### START CODE HERE ### (approx. 5-10 lines)
    Ncells=len(allSpikeTimes)
    allSpikeCounts=np.zeros() # pre-allocate the array  
    
    
    
    
    
    ### END CODE HERE ###
    
    return allSpikeCounts

In [None]:
allSpikeCounts=calculateAllSpikeCounts(pulseTimes,spikeTrainsList)
print(allSpikeCounts[2,:,:])

Expected output:

Ok, now use the ```np.mean``` function in the right dimension, to get all the tuning curves!

In [None]:
allTuningCurves=

By running the following cell, you can examine the tuning curve of the second cell. Change the indices to examine the tuning curves for different cells. Can you understand what all of them mean?

In [None]:
plotTuningCurvePolar(angles,allTuningCurves[1,:])

## 2 - Quantification of direction selectivity

### The direction selectivity index (DSI)

The direction selectivity index (DSI) is a common quantification of direction tuning. One of the ways to calculate it is the following:

$$ DSI = \frac{1}{\sum_{k=1}^{N}{r_{k}}} \left|\sum_{k=1}^{N}{r_{k}e^{i\phi_{k}}}\right|$$

**Exercise:** Fill in ```calculateDSI(angles,responses)```. To help you, a complex number is given as 1j, 2j, 3j... You can calculated the maginude of a complex number with ```np.absolute()```.

In [None]:
def calculateDSI(angles,responses):
    '''
    Inputs:
            angles: np array of angles in radians
            responses: 
    '''
    
    ### START CODE HERE ### (approx. 1-2 lines)
    vsum=np.sum(responses*np.exp(1j*angles))/np.sum(responses)
    dsi=np.absolute(vsum)
    ### END CODE HERE ###

    return dsi

In [None]:
print('DSI = ' + str(calculateDSI(angles, tuningCurve)))

Expected output: 0.31383540910149094

Good job! Now let's try to find the DSI values for all the cells we provided for you.

**Exercise:** Calculate the DSIs for all cells provided. Then run the cell below, and examine the histogram. 

*Hint:* Instead of using a for loop, you can use linear algebra to calculate the responses for all cells simultaneously!

In [None]:
def calculateDSIall(angles,multipleResponses):
    '''
    Inputs:
            angles: np array of angles in radians
            multipleResponses: Ncells x Nangles np array
    Outputs:
            dsiAll: Ncells x 1 np array of dsi values for each cell
    '''
    
    ### START CODE HERE ### (approx. 1-2 lines)
    vsums=(multipleResponses@np.exp(1j*angles))/np.sum(responses,axis=1)
    dsiAll=np.absolute(vsums)
    ### END CODE HERE ###

    return dsiAll

In [None]:
plt.hist(calculateDSIall(angles,allTuningCurves))

### Monte Carlo permutation for creating DSI confidence intervals

Examine the tuning curve of the following cell. Although the DSI value is above 0.2, the cell barely responded to the stimulus

**Exercise:** Calculate a permutation distribution of dsis for cell 1. Now, the vectorized version of ```calculateDSI()``` will be definitely useful!

Calculate the p-values of all DSIs observed, and plot them versus the DSI value:

**Exercise (bonus):** It is possible to vectorize the p-value calculation as well. Can you think how to do it?

## 3 - Comparing direction selectivity between different stimuli

Instead of a paired t-test, that assumes normality of the underlying data, we will perform a Wilcoxon signed-rank test.

In [None]:
stats.wilcoxon(dsi1,dsi2)