# Project 1: Probing direction selectivity in the mouse retina

Welcome to the first project of the class. 

**You will learn to:** 
- Construct direction tuning curves from grating data.
- Quantify direction selectivity.
- Perform statistical comparison of paired samples.

Let's first import the packages we are going to use, and set up some plotting parameters.

In [None]:
%matplotlib inline
%config InlineBackend.rc={'figure.figsize': (12, 6), 'font.size': 14 }
import matplotlib.pyplot as plt
import numpy as np
from os import listdir
from pathlib import Path
from scipy import stats

## 1 - Creating tuning curves

The stimulus times, as well as the spikes of 20 cells are in the folder `data_drifting_grating`. Let's first start by loading the stimulus times contained in the file `stimulus.txt`.

In [None]:
datapath = 'data_drifting_grating/'
stimulus = np.loadtxt(datapath + 'stimulus.txt')

The stimulus consists of drifting gratings at different angles in eight different directions. For each cycle of the stimulus, the eight different directional gratings are presented after one another (400 frames each) with a gray screen in between (300 frames). In total one cycle has a duration of `8 * (400 + 300) = 5600` frames.

From the stimulus timestamps (i.e. pulses) provided by the stimulus file, we can match the gratings and their speed to the spike recordings. Each grating has a spation period of 100 frames. Each period is marked by a pulse except for the start of the grating as it's usually ignored in the analysis. This is because of the transition from the gray screen. This results in three pules per grating presentation.

The stimulus consists of four cycles.

Let's first define these properties.

In [None]:
duration = 400  # Duration of one grating presentation
period = 100  # Duration of one spatial period
num_angles = 8  # Number of grating direction angles
num_cycles = 4  # Number of stimulus repetitions

Now we can calculate some important values that will help sort out the spikes relative to the stimulus times.

In [None]:
# The angles of the drifting grating in radians
angles = np.linspace(0, 2*np.pi, num=num_angles, endpoint=False)

# Number of pulses per angle, cycle and the entire stimulus
num_periods = int(np.floor(duration / period)) - 1
num_cycle_pulses = num_periods * num_angles
num_stim_pulses = num_cycle_pulses * num_cycles

# Obtain stimulus times by cycles: num_cycles x num_directions x num_periods
pulse_times = stimulus[:num_stim_pulses]
pulse_times = pulse_times.reshape(num_cycles, num_angles, num_periods)

# Calculate the duration of one spatial grating period from the average
period_duration = np.mean(np.diff(pulse_times, axis=2))

To get a better picture of how we extracted the stimulus structure, let's look at the `pulse_times`. This is now a three dimensional array: `cycles x directions x periods`, so `4 x 8 x 3`.

In [None]:
pulse_times

Having set up the stimulus times, now we can count spikes! Let's write a function that loads the spikes of an example cell:

In [None]:
def get_spike_counts(pulses, spikes):
    """
    Calculate the number of spikes per grating direction
    
    Parameters
    ----------
    pulses : (c, a, p) numpy.ndarray
        Stimulus times, where c is the number of cycles, a is the number of
        angles and p is the number of periods per angle
    
    spikes : numpy.ndarray
        One dimensional list containing the spike times of a cell    
    
    Returns
    -------
    spike_counts : (c, d) numpy.ndarray
        Number of spikes for each duration of each stimuls cycle
    """
    # Recover some information
    num_cycles, num_angles, num_periods = pulses.shape
    period_duration = np.mean(np.diff(pulses, axis=2))

    # Pre-allocate array
    spike_counts = np.zeros((num_cycles, num_angles))
    
    # Iterate over all cycles and all angles
    for cyc in range(num_cycles):
        for ang in range(num_angles):
            # Obtain the spikes that fall within the presentation duration
            spk = spikes[spikes >= pulses[cyc, ang, 0]]
            spk = spk[spk < pulses[cyc, ang, -1] + period_duration]
            
            # Count the number of spikes in this duration
            spike_counts[cyc, ang] = spk.size

    return spike_counts

Let's test our function below. Our goal is to obtain a matrix `num_cycles x num_directions`, which we will use for most of our calculations.

In [None]:
example_spike_times = np.loadtxt('data_drifting_grating/5_SP_C3601.txt')
example_spike_counts = get_spike_counts(pulse_times, example_spike_times)
print(example_spike_counts)

The first thing we can do is plot a tuning curve. We plot the spike count vs the direction.

In [None]:
tuning_curve = np.mean(example_spike_counts, axis=0)
plt.plot(np.rad2deg(angles), tuning_curve)
plt.title('Tuning Curve')
plt.ylabel('Spike count');
plt.xlabel('Direction (deg)');

It is also customary to plot tuning curve in a polar plot.

**Exercise:** Plot the tuning curve in a polar plot by filling in `plot_tuning_curve_polar(angles, responses)`. Consider using `plt.polar`, and try to make your plot pretty.

*Hint*: Make sure that your final curve is closed!

In [None]:
def plot_tuning_curve_polar(angles, responses):
    """
    Create a polar figure showing the turning curve of a cell
    
    Parameters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles, where a is the number of angles
        
    responses : (a,) numpy.ndarray
        Average number of spikes for each angle a
    """
    
    ### START CODE HERE ###
    
    # Append the first element to the end of the list
    angles = np.append(angles, angles[0])
    responses = np.append(responses, responses[0])
    
    plt.polar(angles, responses)
    ### END CODE HERE ###

In [None]:
plot_tuning_curve_polar(angles, tuning_curve)

We now want to look at all of our data. We will load the spike trains of all cells in the dataset in a list, and you will have to figure out how to properly get structured spike counts for all cells.

In [None]:
filenames = sorted([i.name for i in Path(datapath).glob('5_SP_C*.txt')])

# List of all spike times
spike_trains = [np.loadtxt(Path(datapath, fpath)) for fpath in filenames]

`spike_trains` is a list of one dimensional numpy arrays. Each numpy array contains the spike times of one cell. Here is we display the number of spikes for each cell in `spike_trains`.

In [None]:
for i in spike_trains:
    print(i.shape)

**Exercise:** Let's now calculate the tuning curves for all cells in our dataset. First, we need to sort the spikes of all cells in spike counts. Fill in the function that does that.

*Hint:* You can iterate over cells using the function `get_spike_counts` from above, but you can also start over!

In [None]:
def get_spike_counts_multi(pulses, spikes):
    """
    Calculate the number of spikes per grating direction for multiple
    cells
    
    Parameters
    ----------
    pulses : (c, a, p) numpy.ndarray
        Stimulus times, where c is the number of cycles, a is the number of
        angles and p is the number of periods per angle
    
    spikes : list of x numpy.ndarrays
        List of spike trains (one dimensinal numpy arrays) for all cells x

    Returns
    -------
    spike_counts : (x, c, d) numpy.ndarray
        Number of spikes for each duration of each stimuls cycle for all
        cells x
    """

    ### START CODE HERE ###

    # Recover some information
    num_cells = len(spikes)
    num_cycles, num_angles, num_periods = pulses.shape
    period_duration = np.mean(np.diff(pulses, axis=2))

    # Pre-allocate array
    spike_counts = np.zeros((num_cells, num_cycles, num_angles))
    
    # Iterate over all cycles and all angles
    for cyc in range(num_cycles):
        for ang in range(num_angles):
            for cell in range(num_cells):
                spk = spikes[cell]

                # Obtain the spikes that fall within the presentation duration
                spk = spk[spk >= pulses[cyc, ang, 0]]
                spk = spk[spk < pulses[cyc, ang, -1] + period_duration]

                # Count the number of spikes in this duration
                spike_counts[cell, cyc, ang] = spk.size

    ### END CODE HERE ###
    
    return spike_counts

In [None]:
all_spike_counts = get_spike_counts_multi(pulse_times, spike_trains)
print(all_spike_counts[10, :, :])

Expected output: 

```python
[[13. 31. 50. 41. 39. 34. 17. 12.]
 [12. 27. 20. 26. 28. 15. 20. 20.]
 [ 9. 17. 29. 30. 30. 22.  8.  5.]
 [ 3. 12. 21. 24. 23. 27. 17.  9.]]
```

Ok, now use the ```np.mean``` function in the right dimension, to get all the tuning curves!

In [None]:
all_tuning_curves = np.mean(all_spike_counts, axis=1)

By running the following cell, you can examine the tuning curve of the second cell. Change the indices to examine the tuning curves for different cells. Can you understand what all of them mean?

In [None]:
plot_tuning_curve_polar(angles, all_tuning_curves[13, :])

## 2 - Quantification of direction selectivity

### The direction selectivity index (DSI)

The direction selectivity index (DSI) is a common quantification of direction tuning. One of the ways to calculate it is the following:

$$
DSI = \frac{1}{\sum_{k=1}^{N}{r_{k}}} \left|\sum_{k=1}^{N}{r_{k}e^{i\phi_{k}}}\right|\,,
$$

where $r$ are the responses (i.e. the average number of spikes per angle), $N$ is the number of responses or angles, and $\phi$ is are the angles.

**Exercise:** Fill in ```get_dsi(angles, responses)```. To help you, a complex number is given as 1j, 2j, 3j... You can calculate the maginude of a complex number with ```np.abs()```.

In [None]:
def get_dsi(angles, responses):
    """
    Calculate the direction selectivity index (DSI) of a cell
    
    Parameters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
        
    responses : (a,) numpy.ndarray
        One dimensional list of average number of spikes per angle a
    
    Returns
    -------
    dsi : float
        DSI for the cell
    """
    
    ### START CODE HERE ###
    vsum = np.sum(responses * np.exp(1j * angles))
    vsum /= np.sum(responses)
    dsi = np.abs(vsum)
    ### END CODE HERE ###

    return dsi

In [None]:
print('DSI = ' + str(get_dsi(angles, tuning_curve)))

Expected output:

```python
0.31383540910149094
```

Good job! Now let's try to find the DSI values for all the cells we provided for you.

**Exercise:** Calculate the DSIs for all cells provided. Then run the cell below, and examine the histogram. 

*Hint:* Instead of using a for loop, you can use linear algebra to calculate the responses for all cells simultaneously!

In [None]:
def get_dsi_multi(angles, responses):
    """
    Calculate the direction selectivity index (DSI) of multiple cells
    
    Parameters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
        
    responses : (x, a) numpy.ndarray
        Array of average number of spikes per angle a for all cells x
    
    Returns
    -------
    dsi : (x,) numpy.ndarray
        DSI for all cells x
    """
    
    ### START CODE HERE ###
    vsums = responses @ np.exp(1j * angles)
    vsums /= np.sum(responses, axis=1)
    dsi = np.abs(vsums)
    ### END CODE HERE ###

    return dsi

In [None]:
plt.hist(get_dsi_multi(angles, all_tuning_curves), num_angles)
plt.xlabel('DSI')
plt.ylabel('Cells');

### Monte Carlo permutation for creating DSI confidence intervals

Examine the tuning curve of the following cell. What about its DSI value?

In [None]:
plot_tuning_curve_polar(angles, all_tuning_curves[11, :])
print('DSI is ' + str(get_dsi(angles, all_tuning_curves[11, :])))

Although the DSI value is above 0.2, the cell barely responded to the stimulus

**Exercise:** Calculate a permutation distribution of DSIs for cell 1. Now, the vectorized version of ```get_dsi``` will be definitely useful!

*Hint:* The function `np.random.permutation` may be useful to suffle the spike counts in conjunction with flattening `spike_counts` (see `np.ravel`) and then reshaping it to its original dimensions (see `reshape`).

In [None]:
spike_counts = all_spike_counts[11, :, :]

In [None]:
def get_pval(angles, spike_counts, num_perturb):
    """
    Calculate the permutation distributation of the dsi along with
    the p-value
    
    Paramters
    ---------
    angles : (a,) numpy.ndarray
        One dimensional list of angles a in radians
    
    spike_counts : (s, a) numpy.ndarray
        Number of spikes s for each angle a
    
    num_perturb : int
        Number of perturbations

    Returns
    -------
    pval : float
        P-value
    
    dsi_rand : float
        Random DSI distro
    """
    
    ### START CODE HERE ###
    
    dsi_true = get_dsi(angles, spike_counts)
    
    # Pre-allocate the array
    dsi_rand = np.zeros(num_perturb)
    
    # Perform num_perturb perturbations
    for p in range(num_perturb):
        # Randomize spike counts
        rand_counts = np.random.permutation(spike_counts.ravel())
        rand_counts = rand_counts.reshape(spike_counts.shape)
        dsi_rand[p] = get_dsi(angles, rand_counts)

    # Compute the p-value
    pval = 1 - np.sum(dsi_true > np.sort(dsi_rand)) / num_perturb
    if pval == 0:
        pval = 1 / num_perturb
    ### END CODE HERE ###

    return pval, dsi_rand

In [None]:
pval, dsi_rand = get_pval(angles, all_spike_counts[11, :, :], 1000)
print('P-value is ' + str(pval))

Expected output: 0.001

We can plot the distribution we generated

In [None]:
plt.hist(dsi_rand, 25)
plt.xlim(0, 0.5)
plt.xlabel('Permuted DSI');

Calculate the p-values of all DSIs observed, and plot them versus the DSI value:

In [None]:
def get_pval_multi(angles, spike_counts, num_perturb):
    """
    Parameters
    ----------
    angles : numpy.ndarray
        One dimensional list of angles in radians
    
    multiTuningCurve : (Ncells, Nangles) numpy.ndarray

    Returns
    -------
    dsiAll : (Ncells, 1) numpy.ndarray
        List of DSI values for each cell
    """
    
    ### START CODE HERE ### (approx. 1-2 lines)
    num_cells = np.size(spike_counts, 0)
    pvalAll = np.zeros(num_cells)
    
    for cell in range(num_cells):
        pv, _ = calculatePval(angles, spike_counts[cell, :, :], num_perturb)
        pvalAll[cell] = pv
    ### END CODE HERE ###

    return pvalAll

In [None]:
allpvals = get_pval_multi(angles, all_spike_counts, 1000)
alldsis = get_dsi_multi(angles, all_tuning_curves)
plt.plot(alldsis, np.log10(allpvals), 'o')
plt.xlabel('DSI')
plt.ylabel('log(pval)');

## 3 - Comparing direction selectivity between different stimuli

The data we gave you contain gratings with different parameters as well.

In [None]:
duration = 180
period = 15
Nangles = 8
Ncycles = 4

Starting from those parameters, try to extract the spike counts and dsi values for the two stimuli. 

*Hint:* Make sure to start counting pulses from the end of the previous stimulus...

In [None]:
NanglePulses2 = int(np.floor(duration / period)) - 1
NcyclePulses2 = NanglePulses2 * Nangles
NstimPulses2 = NcyclePulses2 * Ncycles

# Using NstimPulses as a starting point...
pulseTimes2 = stimulus[NstimPulses:int(NstimPulses + NstimPulses2)]
pulseTimes2 = np.reshape(pulseTimes2, (Ncycles, Nangles, NanglePulses2))
periodDuration2 = np.mean(np.diff(pulseTimes2, n=1, axis=2))

In [None]:
allSpikeCounts2 = calculateAllSpikeCounts(pulseTimes2, allSpikeTrains)
alldsis2 = calculateDSIall(angles, np.mean(allSpikeCounts2, axis=1))
plt.plot(alldsis, alldsis2, 'o');

Does the stimulus change the magnitude of direction selectivity?
Instead of a paired t-test, that assumes normality of the underlying data, we will perform a Wilcoxon signed-rank test.

In [None]:
stats.wilcoxon(alldsis,alldsis2)