# **Spike Triggered Average – STA**

We have already looked into how to handle and analyze spike trains recorded under visual stimulation. We will now turn to a very common tool in systems neuroscience, the spike-triggered average. 


**You will learn to:**

-  Collect the stimulus portions preceding each spike, given a certain time window.
-  Find the average stimulus that evokes spiking.

First, let's import relevant packages:

In [None]:
%matplotlib inline
%config InlineBackend.rc={'figure.figsize': (12, 6), 'font.size': 14 }
from matplotlib import pyplot as plt
import numpy as np
from numpy import load
from pathlib import Path

# **1. Spike-triggered average from full-field flicker responses**

The data you are going to work with comes from extracellular recordings of retinal ganglion cells that were stimulated with a full-field flicker, a stimulus consisting of a screen-wide presentation of contrast levels that change with a given frequency. The contrast value of each presentation is calculated from a number that's been taken "randomly" from a Gaussian distribution of mean zero and standard deviation 0.3, with positive values representing brighter, negative values darker presentations.

<img src="images/fff.gif" width="200">


What are we going to do with the data? We are going to collect all the contrast values that were presented during a time window preceding **every single spike**. Once we have all these stimulus chunks, we will average them and the result will be the spike-triggered average, the average stimulus that elicits a spike.

<img src="images/sta.png" width="600">

# **2. Loading the data**

When we store a variable with numpy, it creates a dictionary containing the variable –as an array– and its key, and saves it in a binary file with the extension .npy. When we do the same for more than one variable, numpy creates again a dictionary –each variable as an array and its corresponding key– and stores it in a binary file with the extension .npz. The data we are about to load is one instance of the latter case.

In [None]:
data = load(filepath)

Now, we can retrieve the names of the stored variables with the following command:

In [None]:
data.files


If we want to see the value of each variable, we retrieve it so:

In [None]:
data["name_of_the_variable"]

If we want to be able to manipulate the variables, we can assign their values to new variables. Keeping the original name is the most logical.

In [None]:
my_new_variable = data["name_of_the_variable"]

As you should expect by now, there's a more straightforward way to update the file variables to our workspace:

In [None]:
locals().update(data)

Let us go over what each of these variables represent, beginning with "volts". As its name suggests, it is the actual voltage recording, i.e., a list of numbers corresponding to the voltage values recorded throughout the experiment. Let's do a sanity check and see if we have spikes to begin with. To visualize the voltage trace, plot this variable.

In [None]:
plt.plot(volts)

We won't be working further with the volts variable and will now focus on the other three: "spikes", "ttls" and "stim_rand_nums". "spikes" is a list of numbers representing timestamps –in seconds– of each occurrence of a spike. "ttls" is a list of timestamps of pulses that are generated to signal the presentation of the stimulus, so "ttls" is a list of the time points when the stimulus changed. "stim_rand_nums" is the sequence of random numbers that determined the contrast of each stimulus presentation, and for simplicity we will take them as a measure of contrast. Let's carry out another sanity check and confirm if, as should be expected already, the size of "ttls" and "stim_rand_nums" is the same. 

In [None]:
###START CODE HERE###

len(ttls) == len(stim_rand_nums) ###Result should be: True
###END CODE HERE###

Remember that we are going to collect the stimulus section preceding each spike. But, given any size of the sections, how much time would they represent? We must first consider that a given sequence of contrast values (random numbers) will correspond to as many stimulus presentations, so the time it took to show those contrast levels equals the number of stimulus presentations (whatever size we choose for the stimulus chunks to collect) multiplied by the time between stimulus presentations. This last value we find by calculating the difference between any two consecutive values. Find the difference of any 3 pairs of consecutive ttls. 

In [None]:
###START CODE HERE###

ttls[n] - ttls[n-1]

###END CODE HERE###

A better way would be to use the average of the differences of all ttls pairs (as you did yesterday).

In [None]:
###START CODE HERE###

avg_diff = np.mean(np.diff(ttls))

###END CODE HERE###

Before continuing, let's take only the relevant spikes, those occurring after the first pulse and before the last.

In [None]:
###START CODE HERE

spikes = [spikes > ttls[0]]
spikes = [spikes < ttls[-1]]

###END CODE HERE

Since different cells integrate information over different stretches of time, we have to try with different time windows, where 0.5-2 seconds is a nice range. Once we have settled on a time window, the first thing we want to  to do is to find the first spike that occured so long after the stimulus presentation began, that we can already collect the first stimulus portion of the desired size.

In [None]:
###START CODE HERE###
"""
Change the spikes vector so 

"""

window = x ## 

spikes = spikes[spikes > ttls[window]]


###END CODE HERE

Now that we have our first useful spike, let's initialize an array to which we will add the stimulus portions preceding each spike:

In [None]:
###START CODE HERE

stim_vector = np.zeros(window)

###END CODE HERE

We are going to take all the spikes one by one (i.e., their timestamps) and take the corresponding stimulus portion preceding it, starting with the stimulus value at the instant immediately before the occurrence of the spike and stretching back by the value of the window previously defined:

In [None]:
###START CODE HERE
"""
It is useful to keep a counter of how many spikes 
the loop went over. Check the in-built function enumerate.

"""
for counter, spike in enumerate(spikes[spike_idx:]):
    ttls_indices = np.where(ttls[ttls < spike])[0][-window:]
    stim_vector += stim_rand_nums[ttls_indices]
    
    
###END CODE HERE

We calculate the STA:

In [None]:
sta = stim_vector/counter

Let's first plot the STA alone and see what we have:

In [None]:
plt.plot(sta)

Finally, plot with labels and an adequate x axis:

In [None]:
###START CODE HERE###

"""
for adequate plotting, create a vector with values within the range of
what is going to be the X-axis and plot the STA against it

"""

past_time_limit = -avg_diff*window
x_ax = np.linspace(past_time_limit, 0, num=len(sta))
plt.plot(x_ax, sta)

###END CODE HERE

The STA is showing us what stimulus the cell "prefers" in the temporal domain. Depending on the data set you have, it might be that the cells prefers darkening and lightening before spiking; lightening and then darkening; or just lightening. 