# **Spike Triggered Average – STA**

We have already looked into how to handle and analyze spike trains recorded under visual stimulation. We will now turn to a very common tool in systems neuroscience, the spike triggered average. 

**(here goes the theory, if not in a separate presentation)**


**You will learn to:**
-  Match the signal from the stimulus generator to the actual presentation of the stimulus.
-  Collect the stimulus portions preceding each spike, given a certain time window.
-  Find the average stimulus that evokes spiking.
-  **(simulation, fitting, shuffle analysis?)**

First, let's import relevant packages:

In [None]:
%matplotlib inline
%config InlineBackend.rc={'figure.figsize': (12, 6), 'font.size': 14 }
from matplotlib import pyplot as plt
import numpy as np
from numpy import load
from pathlib import Path

# **1. Full-field flicker stimulus**

The data you are going to work with comes from extracellular recording of a retinal ganglion cell while stimulated with a full-field flicker, a stimulus consisting of a screen-wide presentation of contrast levels that changed with a given frequency. The contrast value of each presentation is calculated from a number that's been taken "randomly" from a (...).

# **2. Loading the data**

In [None]:
filepath = Path("/home/juand/")

data = load(filepath)


Now, the file just loaded has a number of variables stored, whose names can be retrieved in a list so:

In [None]:
data.files

For us to be able to have access to and manipulate these variables, we have to add them to the local variables of the current workspace:

In [None]:
locals().update(data)

Out of these variables, only three are of interest to us, namely, volts, spikes and ttls. Let us beging with volts. As its name suggests, it is the actual voltage recording, i.e., a list of numbers corresponding to volt values throughout the recording. To visualize the voltage trace, plot this variable.

In [None]:
plt.plot(volts)

You might not be able to make much out of this graph. It is because the spikes, which are of course part of the voltage trace, are so many that you can not tell them appart. To see them better, make a new variable that is only a fraction of "volts" and plot it again.

In [None]:
volts_1=volts[:x]
plt.plot(volts_1)

We won't be working further with the volts variable and will now focus on "spikes" and "ttls" and "stim_rand_nums". "spikes" is a list of numbers representing timestamps –in seconds– for each occurrence of a spike. "ttls" is a list of timestamps of pulses that are generated to signal the presentation of the stimulus, so "ttls" is a list of the time points when the stimulus changed. "stim_rand_nums" is the sequence of "random" numbers that determined the contrast of each stimulus presentation.

Remember that we are going to collect the stimulus section preceding each spike. How long are these sections going to be, i.e, how further back into the past are we going to look? It depends, for different cells integrate information over different stretches of time, so we have to try with different time windows. The first thing we need to do is to find the index of the first spike that occured so long after the stimulus presentation began, that we can already collect the first stimulus portion of the desired size.

In [None]:
###START CODE HERE###
window = 15   
i = 0
while True:
    index = np.searchsorted(ttls_adjusted, spikes_adjusted[i])
    if index <= window:
        i += 1
    else:
        break
###FINISH CODE HERE

Now that we know with which spike to begin with, let's initialize the variable that will store the stimulus portions preceding each spike:

In [None]:
stim_matrix = []

We are going to take all the spikes one by one (i.e., their timestamps) and take the corresponding stimulus portion preceding it, starting with the stimulus value at the instant immediately before the occurrence of the spike and stretching back by the value of the window previously defined:

In [None]:
while i < len(spikes_adjusted):
    stim_vect = stim_rand_nums[np.searchsorted(ttls_adjusted, spikes_adjusted[i]) - window: 
                               np.searchsorted(ttls_adjusted, spikes_adjusted[i])]
    stim_matrix.append(stim_vect)
    i += 1
    
stim_matrix = np.asarray(stim_matrix)

We calculate the STA:

In [None]:
average_stim = np.sum(stim_matrix, axis = 0)
sta = average_stim/len(stim_matrix)
x_ax = np.linspace(-(ttls_adjusted[2] - ttls_adjusted[1])*window, 0, num = len(sta))

And finally we plot it:

In [None]:
plt.plot(x_ax,sta)
plt.xlabel('Time (s)')
plt.ylabel('Contrast')