# Project 1: Probing direction selectivity in the mouse retina

Welcome to the first project of the class. 

**You will learn to:** 
- Construct direction tuning curves from grating data.
- Quantify direction selectivity.
- Perform statistical comparison of paired samples.

Let's first import the packages we are going to use, and set up some plotting parameters.

In [None]:
%matplotlib inline
%config InlineBackend.rc={'figure.figsize': (12, 6), 'font.size': 14 }
import matplotlib.pyplot as plt
import numpy as np
from os import listdir
from pathlib import Path
from scipy import stats

## 1 - Creating tuning curves

The stimulus times, as well as the spikes of 19 cells are in the folder `data_drifting_grating`. Let's first start by loading the stimulus times contained in the file `stimulus.txt`.

In [None]:
datapath = 'data_drifting_grating/'
stimulus = np.loadtxt(datapath + 'stimulus.txt')

The stimulus consists of gratings drifting in eight different directions (or angles). For each cycle of the stimulus, the eight different directional gratings are presented after one another with a gray screen in between.

<img src="driftinggrating.gif" width="200px">

We can match the gratings and their presentation times to the spike recordings from the stimulus timestamps (i.e. pulses) provided by the stimulus file.  Each grating is shown for four temporal periods. Each period is marked by a pulse except for the first period (as it's usually ignored in the analysis). This is because of the transition from the gray screen. This results in three pulses per grating presentation.

The stimulus consists of four cycles, as the sequence of eight directions is repeated four times.

Let's first define these properties.

In [None]:
num_angles = 8  # Number of grating direction angles
num_periods = 4 - 1  # Number of grating periods (-1 because first period does not give a pulse)
num_cycles = 4  # Number of stimulus repetitions

Now we can calculate some important values that will help sort out the spikes relative to the stimulus times.

In [None]:
# The angles of the drifting grating in radians
angles = np.linspace(0, 2*np.pi, num=num_angles, endpoint=False)

# Number of cycle and the entire stimulus
num_cycle_pulses = num_periods * num_angles
num_stim_pulses = num_cycle_pulses * num_cycles

# Obtain stimulus times by cycles: num_cycles x num_angles x num_periods
pulse_times = stimulus[:num_stim_pulses]
pulse_times = pulse_times.reshape(num_cycles, num_angles, num_periods)

# Calculate the duration of one spatial grating period from the average
period_duration = np.mean(np.diff(pulse_times, axis=2))

To get a better picture of how we extracted the stimulus structure, let's look at the `pulse_times`. This is now a three dimensional array: `cycles x directions x periods`, so `4 x 8 x 3`.

In [None]:
pulse_times

Having set up the stimulus times, we can now count spikes! Below you can find a function that loads the spikes of a single cell:

In [None]:
def get_spike_counts(pulses, spikes):
    """
    Calculate the number of spikes per grating direction
    
    Parameters
    ----------
    pulses : (nc, nd, np) numpy.ndarray
        Stimulus times, where nc is the number of cycles, nd is the number of
        directions and np is the number of periods per angle
    
    spikes : numpy.ndarray
        One dimensional list containing the spike times of a cell    
    
    Returns
    -------
    spike_counts : (nc, nd) numpy.ndarray
        Number of spikes for each direction of each stimuls cycle
    """
    # Recover some information
    num_cycles, num_angles, num_periods = pulses.shape
    period_duration = np.mean(np.diff(pulses, axis=2))

    # Pre-allocate array
    spike_counts = np.zeros((num_cycles, num_angles))
    
    # Iterate over all cycles and all angles
    for cyc in range(num_cycles):
        for ang in range(num_angles):
            # Obtain the spikes that fall within the presentation duration
            spk = spikes[spikes >= pulses[cyc, ang, 0]]
            spk = spk[spk < pulses[cyc, ang, -1] + period_duration]
            
            # Count the number of spikes in this duration
            spike_counts[cyc, ang] = spk.size

    return spike_counts

Let's look at what our function did below. It should return a 2-dimensional array of shape `(num_cycles, num_angles)`, which we will use for most of our calculations.

In [None]:
example_spike_times = np.loadtxt('data_drifting_grating/5_SP_C3601.txt')
example_spike_counts = get_spike_counts(pulse_times, example_spike_times)
print(example_spike_counts)

The first thing we can do with such an array is plot a tuning curve. A tuning curve shows neuronal responses as a function of a continuous stimulus attribute, such as orientation, wavelength, or frequency. 

**Exercise:** Calculate a tuning curve by averaging spike counts over the different cycles.

In [None]:
### START CODE HERE ###
example_tuning_curve = None
example_tuning_curve = example_spike_counts.mean(axis=0)
### END CODE HERE ###

In [None]:
print(example_tuning_curve)

Expected output: 
```python 
[ 74.75  90.25 177.5  166.   158.    74.5   51.75  52.5 ]
```

Good job! Now we can plot the mean spike count vs the drifting grating direction. Since our ```angles``` variable is in radians, we transform it to degrees just before plotting.

In [None]:
plt.plot(np.rad2deg(angles), example_tuning_curve)
plt.title('Tuning Curve')
plt.ylabel('Spike count');
plt.xlabel('Direction (deg)');

Can you understand this curve? Which directions did our neuron prefer?

It is also customary to plot a tuning curve in a polar fashion, because it helps us visually relate the tuning curve to the underlying stimulus.

**Exercise:** Plot the tuning curve in a polar plot by filling in `plot_tuning_curve_polar(angles, responses)`. Consider using `plt.polar`, and try to make your plot pretty.

*Hint*: Make sure that your final curve is closed! The function `np.append` may be usefull for that.

In [None]:
def plot_tuning_curve_polar(angles, responses):
    """
    Create a polar figure showing the turning curve of a cell
    
    Parameters
    ---------
    angles : (na,) numpy.ndarray
        One dimensional list of angles, where na is the number of angles
        
    responses : (na,) numpy.ndarray
        Average number of spikes for each angle
    """
    
    ### START CODE HERE ###
    
    # Append the first element to the end of the list
    angles_to_plot = np.append(angles, angles[0])
    responses_to_plot = np.append(responses, responses[0])
    
    # Plot
    plt.polar(angles_to_plot, responses_to_plot)
    
    ### END CODE HERE ###

In [None]:
plot_tuning_curve_polar(angles, example_tuning_curve)

We would also like to look at the tuning curves of other neurons. We will first load the spike trains of all the neurons in our dataset in a list, and you will have to figure out how to properly get structured spike counts for all cells.

In [None]:
# Finding all data file names
filenames = sorted([fpath for fpath in Path(datapath).glob('5_SP_C*.txt')])

# List of all spike times
spike_trains = [np.loadtxt(fpath) for fpath in filenames]

`spike_trains` is a list of one dimensional numpy arrays. Each numpy array contains the spike times of one cell. Here we display the number of spikes for each cell in `spike_trains`.

In [None]:
for i in spike_trains:
    print(i.shape)

We now want to calculate the tuning curves for all cells in our dataset.

**Exercise:**  First, we will need to sort the spikes of all cells in spike counts. Fill in the function ```get_spike_counts_multi(pulses, spikes)``` that does that.

*Hint:* You can iterate over cells using the function `get_spike_counts` from above, but you can also start over!

In [None]:
def get_spike_counts_multi(pulses, spikes):
    """
    Calculate the number of spikes per grating direction for multiple
    cells
    
    Parameters
    ----------
    pulses : (nc, nd, np) numpy.ndarray
        Stimulus times, where nc is the number of cycles, nd is the number of
        directions and np is the number of periods per direction
    
    spikes : list of ncell numpy.ndarrays
        List of spike trains (one dimensinal numpy arrays) for all cells

    Returns
    -------
    spike_counts : (ncell, nc, nd) numpy.ndarray
        Number of spikes for each direction of each stimuls cycle for all
        cells
    """

    ### START CODE HERE ###

    # Recover some information
    num_cells = len(spikes)
    num_cycles, num_angles, num_periods = pulses.shape
    period_duration = np.mean(np.diff(pulses, axis=2))

    # Pre-allocate array
    spike_counts = np.zeros((num_cells, num_cycles, num_angles))
    
    # Iterate over all cycles and all angles
    for cyc in range(num_cycles):
        for ang in range(num_angles):
            for cell in range(num_cells):
                spk = spikes[cell]

                # Obtain the spikes that fall within the presentation duration
                spk = spk[spk >= pulses[cyc, ang, 0]]
                spk = spk[spk < pulses[cyc, ang, -1] + period_duration]

                # Count the number of spikes in this duration
                spike_counts[cell, cyc, ang] = spk.size

    ### END CODE HERE ###
    
    return spike_counts

In [None]:
all_spike_counts = get_spike_counts_multi(pulse_times, spike_trains)
print(all_spike_counts[10, :, :])

Expected output: 

```python
[[13. 31. 50. 41. 39. 34. 17. 12.]
 [12. 27. 20. 26. 28. 15. 20. 20.]
 [ 9. 17. 29. 30. 30. 22.  8.  5.]
 [ 3. 12. 21. 24. 23. 27. 17.  9.]]
```

**Exercise:** Now use the ```np.mean``` function in the right dimension, to get all the tuning curves!

In [None]:
### START CODE HERE ###
all_tuning_curves = np.mean(all_spike_counts, axis=1)
### END CODE HERE ###

In [None]:
print(all_tuning_curves[3, :])

Expected output: 

```python
[50.5  28.25 22.   20.75 22.75 35.25 54.   54.75]
```

By running the following cell, you can examine the tuning curve of the 14th cell. Change the index in ```all_tuning_curves``` to examine the tuning curves for different cells. Can you understand what all of them mean? Can you spot the direction selective cells?

In [None]:
plot_tuning_curve_polar(angles, all_tuning_curves[13, :])

## 2 - Quantification of direction selectivity

### The direction selectivity index (DSI)

The direction selectivity index (DSI) is a common quantification of direction tuning. One of the ways to calculate it is the following:

$$
DSI = \frac{1}{\sum_{k=1}^{N}{r_{k}}} \left|\sum_{k=1}^{N}{r_{k}e^{i\phi_{k}}}\right|\,,
$$

where $r$ are the responses (i.e. the average number of spikes per angle), $N$ is the number of responses or angles, and $\phi$ is are the angles.

**Exercise:** Fill in ```get_dsi(angles, responses)```. To help you, a complex number is given as 1j, 2j, 3j... You can calculate the maginude of a complex number with ```np.abs()```.

In [None]:
def get_dsi(angles, responses):
    """
    Calculate the direction selectivity index (DSI) of a cell
    
    Parameters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
        
    responses : (a,) numpy.ndarray
        One dimensional list of average number of spikes per angle a
    
    Returns
    -------
    dsi : float
        DSI for the cell
    """
    
    ### START CODE HERE ###
    vsum = np.sum(responses * np.exp(1j * angles))
    vsum /= np.sum(responses)
    dsi = np.abs(vsum)
    ### END CODE HERE ###

    return dsi

In [None]:
print('DSI =', get_dsi(angles, example_tuning_curve))

Expected output:

```python
0.31383540910149094
```

Good job! Now let's try to find the DSI values for all the cells we provided for you.

**Exercise:** Calculate the DSIs for all cells provided. Then run the cell below, and examine the histogram. 

*Hint:* Instead of using a ```for``` loop, you can use linear algebra to calculate the responses for all cells simultaneously!

In [None]:
def get_dsi_multi(angles, responses):
    """
    Calculate the direction selectivity index (DSI) of multiple cells
    
    Parameters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
        
    responses : (x, a) numpy.ndarray
        Array of average number of spikes per angle a for all cells x
    
    Returns
    -------
    dsi : (x,) numpy.ndarray
        DSI for all cells x
    """
    
    ### START CODE HERE ###
    vsums = responses @ np.exp(1j * angles)
    vsums /= np.sum(responses, axis=1)
    dsi = np.abs(vsums)
    ### END CODE HERE ###

    return dsi

In [None]:
plt.hist(get_dsi_multi(angles, all_tuning_curves), num_angles)
plt.xlabel('DSI')
plt.ylabel('# cells');

A DSI value of >0.2 is usually a good indicator of whether a cell has increased selectivity of a stimulus direction. Do the DSI values for each cell match with the tuning curves you examined above?

### Monte Carlo permutation for calculating DSI p-values

Examine the tuning curve of the following cell. What does its DSI value suggest?

In [None]:
plot_tuning_curve_polar(angles, all_tuning_curves[11, :])
print('DSI is', get_dsi(angles, all_tuning_curves[11, :]))

Although the DSI value is above 0.2, the cell barely responded to the stimulus. For this reason, just looking at DSI values can be misleading. 

In many cases such as these, one can estimate p-values using resampling. In our case, we can compute the distribution of DSI values from shuffled the spike counts and calculate a p-value (one-sided) of the DSI. **The p-value is the probability that we find a permuted DSI larger than the observed one.** If the probability is high, it means that our observed DSI could have arised by chance...

To get all possible permutations of values in our data, we would need (4 * 8)! = 32! permutations, which is a huge number. For that reason, we use Monte Carlo sampling, where we calculate the DSI values for a relative small number of permutations.

**Exercise (hard):** Calculate a permutation distribution of `num_perturb` DSIs, returned as `dsi_rand` and the p-value returned as `pval`. Your function ```get_dsi``` will definitely be useful!

*Hint:* The function `np.random.permutation` may be useful to shuffle the spike counts in conjunction with flattening `spike_counts` (see `np.ravel`) and then reshaping it afterwards to its original dimensions (see `np.reshape`).

*Hint:* Think of what the p-value should be if there is no permuted DSI larger than the observed one (given the number of permutations we performed).


In [None]:
def get_pval(angles, spike_counts, num_permutations):
    """
    Calculate the permutation distributation of the dsi along with
    the p-value
    
    Paramters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
    
    spike_counts : (s, a) numpy.ndarray
        Number of spikes s for each angle a
    
    num_permutations : int
        Number of permutations

    Returns
    -------
    pval : float
        P-value
    
    dsi_rand : (num_perturb,) numpy.ndarray
        Random DSI distribution
    """
    dsi_true = get_dsi(angles, spike_counts.mean(axis=0))

    ### START CODE HERE ###
    
    # Pre-allocate the array
    dsi_rand = np.zeros(num_permutations)
    
    # Perform num_perturb perturbations
    for p in range(num_permutations):
        # Randomize spike counts
        rand_counts = np.random.permutation(spike_counts.ravel())
        rand_counts = rand_counts.reshape(spike_counts.shape)
        dsi_rand[p] = get_dsi(angles, rand_counts.mean(axis=0))

    # Compute the p-value
    pval = np.sum(dsi_true < dsi_rand) / num_permutations
    
    
    if pval == 0:
        pval = 1 / num_permutations
        
    ### END CODE HERE ###

    return pval, dsi_rand

In [None]:
pval, dsi_rand = get_pval(angles, all_spike_counts[2, :, :], 1000)
print('P-value is', pval)

Expected output:
```python
0.001
```

We can plot the distribution we generated.

In [None]:
plt.hist(dsi_rand, 25)
plt.xlim(0, 0.5)
plt.xlabel('Permuted DSI')
plt.ylabel('# permutations');

**Exercise:** Calculate the p-values of all DSIs observed. Then we will plot them versus the DSI value.

In [None]:
def get_pval_multi(angles, spike_counts, num_perturb):
    """
    Parameters
    ----------
    angles : numpy.ndarray
        One dimensional list of angles in radians
    
    multiTuningCurve : (Ncells, Nangles) numpy.ndarray

    Returns
    -------
    dsiAll : (Ncells, 1) numpy.ndarray
        List of DSI values for each cell
    """
    num_cells = spike_counts.shape[0]
    pvals = np.zeros(num_cells)

    ### START CODE HERE ###  
    for cell in range(num_cells):
        pv = get_pval(angles, spike_counts[cell, :, :], num_perturb)[0]
        pvals[cell] = pv
    ### END CODE HERE ###

    return pvals

In [None]:
all_pvals = get_pval_multi(angles, all_spike_counts, 1000)
all_dsi = get_dsi_multi(angles, all_tuning_curves)
plt.plot(all_dsi, np.log10(all_pvals), 'o')
plt.xlabel('DSI')
plt.ylabel('log(pval)');

Can you understand the plot above? Where do the direction selective cells lie?

## 3 - Comparing direction selectivity between different stimuli

The data we gave you also contain a set of drifting gratings with different parameters. In particular, the other set of gratings was presented to the retina after the first set was over. The pulses for the second set are also included in the ```stimulus``` variable, and the following parameters will help you read them.

In [None]:
num_angles_2 = 8
num_periods_2 = 12 - 1
num_cycles_2 = 4

**Exercise:** Starting only from those parameters, try to extract the spike counts and DSI values for the new stimulus.

First, start off by retrieving the pulses and times.  
*Hint:* Make sure to start counting pulses from the end of the previous stimulus.

In [None]:
### START CODE HERE ###
num_cycle_pulses_2 = num_periods_2 * num_angles_2
num_stim_pulses_2 = num_cycle_pulses_2 * num_cycles_2

# Using num_stim_pulses as a starting point
pulse_times_2 = stimulus[num_stim_pulses:int(num_stim_pulses
                                             + num_stim_pulses_2)]
pulse_times_2 = pulse_times_2.reshape(num_cycles_2, num_angles_2,
                                      num_periods_2)

period_duration = np.mean(np.diff(pulse_times_2, axis=2))
### END CODE HERE ###

Now, obtain the spike counts and the DSI values, `all_spike_counts_2` and `all_dsi_2`.

In [None]:
### START CODE HERE ###  
all_spike_counts_2 = get_spike_counts_multi(pulse_times_2, spike_trains)
angles_2 = np.linspace(0, 2*np.pi, num=num_angles_2, endpoint=False)
all_dsi_2 = get_dsi_multi(angles_2, np.mean(all_spike_counts_2, axis=1))
### END CODE HERE ###

We can now plot the two DSI values against each other for all cells.

In [None]:
# Draw diagonal
plt.plot([0, np.max([all_dsi, all_dsi_2])],
         [0, np.max([all_dsi, all_dsi_2])], 'gray')

plt.plot(all_dsi, all_dsi_2, 'o')
plt.xlabel('$DSI_1$')  # LaTeX label
plt.ylabel('$DSI_2$')
plt.gca().set_aspect('equal')

Does the stimulus choice change the magnitude of direction selectivity?

**Exercise:** Instead of a paired t-test, that assumes normality of the underlying data, we will perform a Wilcoxon signed-rank test. ```stats.wilcoxon``` should help

In [None]:
### START CODE HERE ###
stats.wilcoxon(all_dsi, all_dsi_2)
### END CODE HERE ###

How do you interpret the result of the test? Which of the two stimuli would you prefer to probe direction selectivity?