<a href="https://colab.research.google.com/github/ebatty/EncodingDecodingNotes/blob/main/Notes/03_SpikeTriggeredAverages.ipynb" target="_blank"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"/></a>

# 03 - Spike Triggered Averages

Learning objectives of lecture/notes: After lecture, students should be able to:

- Interpret information from spike-triggered average and know when you can do so

- Be able to compute STA using both math and code



 Imports


In [None]:
# @markdown Imports
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
!pip install -q gdown

# Section 1: Why do we need to move beyond tuning curves?


In the last lecture, we saw the use of tuning curves as a way to understand the relationship between a stimulus and the responses of a neuron. However, this approch is pretty limited. Tuning curves are very simple: you can't use them to capture responses to complex or multiple features. You can only use them to look at average firing rate to discrete options of a stimulus that you choose (like different reaches).  Additionally, they don't consider timing information present in neural responses. Neurons tend to respond to aspects of the stimuli in a recent window in the past, essentially integrating over the recent past. They also respond in a time-varying manner to different stimuli. So, we need to move beyond tuning curves to more sophisticated methods of observing the relationship between stimulus and neural response. 

# Section 2: One-dimensional spike triggered average (STA)



## Section 2.1: Intro to STA

We will first work with spike times from a single neuron, recording while a 1-dimensional stimulus was presented. We will work with data collected by Rob de Ruyter van Steveninck from a fly H1 neuron responding to an approximate white-noise visual motion stimulus (from Dayan & Abbott 2001). The data was collected for 20 minutes at a sampling rate of 500 Hz.



 Execute this cell to download & visualize data


In [None]:
# @markdown Execute this cell to download & visualize data
!gdown --id  1V2QAdFLQnAb-gdFQ9-Py1AydI7tlTifm

d = sio.loadmat('c1p8.mat')

spikes = d['rho'].reshape((-1, ))
stim = d['stim'].reshape((-1,))
dt = 2  # length of each bin in milliseconds

fig, axes = plt.subplots(2, 1,figsize = (10, 5), sharex = True)

axes[0].plot(stim[:500], 'k')
axes[0].set(title = 'Stimulus')

axes[1].eventplot(np.where(spikes[:500])[0], color = 'k')
axes[1].set(title = 'Spikes', xlabel = 'Time bins')

As you can see below, for each time bin (2 ms chunk of time), we have the value of the stimulus and whether a spike occurred or not. The stimulus is denoted by just a single number per time bin.

So how do we start to understand what about the stimulus (if any) is causing the neuron to spike? We could simply ask **what on average was the stimulus before a spike**. For example, if a neuron responds to high values of the stimulus, you would expect the stimulus to have on average high values before a spike. This is the idea behind a spike-triggered average!

To compute the spike-triggered average, we want to gather the stimulus chunks before each spike, and take the average of them. Let's make that mathematical. Given spike times $t_1$, $t_2$, ..., $t_n$ for $n$ spikes and stimulus $s$:

$$STA(\tau) = \frac{1}{n}\sum_{i=1}^n s(t_i - \tau)$$

This tells us that the spike triggered average at some $\tau$ before a spike is equal to the average of the stimulus at that $\tau$ before every spike.

**STA in terms of time bins:**

We often work within the context of time bins, as we saw in our last lecture, so let's reframe the STA in this context, instead of thinking about continuous time. Let's assume we have the spike counts at each time bin as $y_1$, $y_2$, ..., $y_T$ and the stimulus at each time bin as $s_1$, $s_2$, ..., $s_T$ for $T$ time bins. We choose the number of time bins before a spike we want to include in our spike-triggered average - let's say in this case we choose 5 time bins. Then, our equation becomes:

$$STA = \frac{1}{n}\sum_{i = 0}^T y_i*s_{i - 5:i} $$

where $s_{i - 5:i}$ is a vector of the stimulus over the past 5 time bins before i. Note that for all the bins where there are zero spikes, the term $y_i*s_{i - 5:i} $ will be 0 so we could speed up our computation by just focusing on bins where there are spikes.

We show the computed STA for our data above in the cell below.

 Execute to visualize STA


In [None]:
# @markdown Execute to visualize STA
# Figure out how many time bins we want the STA to be
STA_length = 150

# Figure out when spikes occur
spike_times = np.where(spikes)[0]

# Ignore spike times less than STA length (since there won't be enough preceding stimulus)
spike_times = spike_times[spike_times > STA_length]

# Initialize STA
STA = np.zeros((STA_length,))

# Loop over spikes
for i_sp in spike_times:
    
    # Add preceding stimulus
    STA += stim[i_sp - STA_length + 1: i_sp + 1]

# Divide by number of spikes
STA = STA / len(spike_times)

time_before_spike = (-np.arange(STA_length)*dt)[::-1]

fig, ax = plt.subplots(1, 1)

ax.plot(time_before_spike, STA, 'k', lw = 2)

ax.set(title = 'STA', 
      xlabel = 'Time before spike (ms)',
      ylabel = 'Stimulus intensity');

## Section 2.2: Computing STAs by hand

Let's compute a spike-triggered average (STA) with a numerical example. Let's say we have the following stimulus values, where each entry in the vector is the stimulus at a time bin (so the first entry is $s_1$, the second is $s_2$, and so on).

$s$ =    [0, 1, 1, 0, 2, -1, 0, 3, 1, 0, -4]

We have the following spike counts for the corresponding time bins:

$y$ = [0, 0, 0, 1, 0,  1,  0, 1, 0, 0, 0]




```{admonition} **Stop and think!** What is the spike-triggered average? Let's use 3 time bins, including the time bin that the spike occurs.
:class: tip, dropdown
Our first spike is at time step 4. The previous stimulus is [1, 1, 0]. Remember we're including the time step of the spike, so we are not using [0, 1, 1]. 

The next spike is at time step 6. The previous stimulus is [0, 2, -1].

The next spike is at time step 8. The previous stimulus is [-1, 0, 3]

So our spike-triggered stimuli are: [1, 1, 0], [0, 2, -1], and [-1, 0, 3]. 

Our STA is the average of those so: [0, 1, 2].

## Section 2.3: Computing STAs in code

Let's dive into some code that computes a spike-triggered average on our data. `spikes` is an array that gives the sequence of spiking events at the sampled times (every 2 ms). When an element of `spikes` is one, this indicates the presence of a spike at the corresponding time, whereas a zero value indicates no spike. 

The variable `stim` gives the sequence of stimulus values at the sampled times. 

We will compute a spike-triggered average that is 300 ms long (the 300 ms before each spike).

See the cell below for code to compute a spike triggered average. We dive into more details of the steps in the text below.

In [None]:
# Step 1) Figure out how many time bins we want the STA to be
STA_length = int(300 / 2)

# Step 2) Figure out when spikes occur
spike_times = np.where(spikes)[0]

# Step 3) Ignore spike times less than STA length
spike_times = spike_times[spike_times > STA_length]

# Step 4) Initialize STA
STA = np.zeros((STA_length,))

# Step 5) Loop over spikes & add preceding stimulus
for i_sp in spike_times:
    
    # Add preceding stimulus
    STA += stim[i_sp - STA_length + 1: i_sp + 1]

# Step 6) Divide by number of spikes
STA = STA / len(spike_times)

# Visualize
fig, ax = plt.subplots(1, 1)

time_before_spike = (-np.arange(STA_length)*dt)[::-1]
ax.plot(time_before_spike, STA, 'k', lw = 2)

ax.set(title = 'STA', 
      xlabel = 'Time before spike (ms)',
      ylabel = 'Stimulus intensity');

Step 1) We first figure out how many time bins we want the STA to be - if we want it to be 300 ms and each bin is 2 ms/bin, we we want 300 ms / 2 ms/bin = 150 bins long. 

Step 2) We then find the bins where spikes occur. We can use `np.where` to get the indices of the entries in `spikes` that are non-zero. Our maximum spike count in `spikes` is 1 so we don't have to worry about time bins having more than one spike (although we sometimes would!).

Step 3) We will just ignore spike times that occur really early in our data, as there is not a long enough preceding stimulus to use for our averaging. 

Step 4) We want to initialize our STA so we can add stimulus chunks to it iteratively. We create a vector of zeros the correct length.

Step 5) We loop over spike times. For each spike time, we take the preceding 150 time bins of the stimulus using `stim[i_sp - STA_length -1: i_sp + 1]` and add it to the STA. The -1 and +1 is so we include the time bin of the spike.

Step 6) We have summed over the stimuli before spikes but we want to average, so we need to divide by the number of spikes.

# Section 3: STAs for higher-dimensional stimuli

So far, we have only looked at computing STAs for a 1d stimulus over time. In other words, a stimulus that has only one data point per time bin. We can compute an STA for a stimulus for any number of dimensions. For example, let's say the stimulus is a movie. Instead of looking at the average number that triggered a spike at certain delays before a spike (as in the case of the 1d stimulus), we want to look at the average image that triggered a spike at various time bins before the spike. So we want the average image 1 time step back, the average image 2 time steps back, and so on. This idea is conveyed in the image below. We'd end up with a spike triggered average of shape (number of time bins of STA x number of pixels x number of pixels).

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/2/2c/Illustration_diagram_for_the_Spike-triggered_average.pdf/page1-1335px-Illustration_diagram_for_the_Spike-triggered_average.pdf.jpg" alt="STA from movie">

*Spike triggered average with a movie stimulus from https://upload.wikimedia.org/wikipedia/commons/2/2c/Illustration_diagram_for_the_Spike-triggered_average.pdf*

We may have stimuli that aren't images but have more than 1 dimension. We'd treat these similar to the movie, where we average over each feature (or dimension) separately. So if our stimulus is represented by 3 numbers per time bin, we would get an STA of shape (number of time bins x 3).

# Section 4: Interpreting STAs

## Section 4.1: STAs ~ receptive fields
The spike-triggered average tells us what about the stimulus leads the neuron to respond, allowing us to interpret something about neural processing. In other words, the STA provides an estimate of a neuron's linear receptive field. For example, look atLet's take another look at our STA from our example data.

 Execute to visualize STA


In [None]:
# @markdown Execute to visualize STA
# Figure out how many time bins we want the STA to be
STA_length = 150

# Figure out when spikes occur
spike_times = np.where(spikes)[0]

# Ignore spike times less than STA length (since there won't be enough preceding stimulus)
spike_times = spike_times[spike_times > STA_length]

# Initialize STA
STA = np.zeros((STA_length,))

# Loop over spikes
for i_sp in spike_times:
    
    # Add preceding stimulus
    STA += stim[i_sp - STA_length: i_sp]

# Divide by number of spikes
STA = STA / len(spike_times)

time_before_spike = (-np.arange(STA_length)*dt)[::-1]

fig, ax = plt.subplots(1, 1)

ax.plot(time_before_spike, STA, 'k', lw = 2)

ax.set(title = 'STA', 
      xlabel = 'Time before spike (ms)',
      ylabel = 'Stimulus intensity');

```{admonition} **Stop and think!** What can you interpret about this fly H1 neuron response to the visual motion stimulus from the STA?
:class: tip, dropdown
High values of the stimulus tend to prompt a neural spike around 50 milliseconds later.

 Execute to visualize another STA


In [None]:
# @markdown Execute to visualize another STA
time_before_spike = (-np.arange(STA_length)*dt)[::-1]

fig, ax = plt.subplots(1, 1)

ax.plot(time_before_spike, np.zeros((len(time_before_spike),)), 'k', lw = 2)

ax.set(title = 'STA', 
      xlabel = 'Time before spike (ms)',
      ylabel = 'Stimulus intensity');

```{admonition} **Stop and think!** If you saw a completely flat STA, as above, what would you guess about the neuron?
:class: tip, dropdown
You could hypothesis that the neural response doesn't depend on the stimulus/isn't affected by it. You might see this if you are recording from an auditory neuron and trying to relate to a visual stimulus for example! This may not be true, the neural response may depend on the stimulus in a way that's more complex than an STA can capture.

```{admonition} **Stop and think!** Why do we use the prior stimulus before a spike to compute the STA? Why not also use the stimulus after a spike?
:class: tip, dropdown
A neuron could not respond based on a future stimulus that it hasn't seen yet. If we were looking at a behavioral variable, it could be useful to look at the spike-triggered average in the past and future as neurons could be firing in advance of some behavior. However, behavior doesn't usually fulfil the input requirements for useful spike-triggered averages (detailed in next section).

## Section 4.2: Stimulus requirements

The spike-triggered average is only meaningful to compute and interpret for certain types of stimuli. Specifically, for uncorrelated stimuli, where the values of the stimulus **independent and identially distributed** at each time bin. In other words, the value of the stimulus at a certain time should not depend on the values of the stimulus before or after and should be drawn at random from a probability distribution (the same distribution should be used for all time bins). **White noise stimuli** fulfill these requirements.

So, why can't we interpret spike-triggered averages with correlated stimuli? Imagine we had a repeating stimulus, as below. We presented this to the fly H1 neuron we've been working with and record the spikes. Please note, the data below is faked.



 Execute cell to visualize correlated stimulus and simulated responses


In [None]:
# @markdown Execute cell to visualize correlated stimulus and simulated responses

# Create repeating stimulus
stim_segment = -STA[::-1] + STA
corr_stim = np.tile(stim_segment, int(stim.shape[0]/stim_segment.shape[0]))

# Create model for fake spikes
STA_stim = np.zeros((stim.shape[0],))
for i_t in range(150, stim.shape[0]):
    STA_stim[i_t] = np.dot(stim[i_t - STA_length + 1: i_t + 1], STA)

bins = np.linspace(np.min(STA_stim), np.max(STA_stim), 25)

all_stim_bins, _ = np.histogram(STA_stim, bins)
spike_stim_bins, _ = np.histogram(STA_stim[spike_times], bins)
nonlin = spike_stim_bins/all_stim_bins

# Create fake spikes
STA_stim = np.zeros((stim.shape[0],))
sim_spikes = np.zeros((stim.shape[0],))
all_nonlin = np.zeros((stim.shape[0],))

for i_t in range(150, stim.shape[0]):
    STA_stim[i_t] = np.dot(corr_stim[i_t - STA_length + 1: i_t + 1], STA)
    if STA_stim[i_t] < bins[0]:
        this_nonlin = nonlin[0]
    elif STA_stim[i_t] > bins[-1]:
      this_nonlin = nonlin[-1]
    else:
      nonlin_bin, _ = np.histogram(STA_stim[i_t], bins)
      this_nonlin = nonlin[np.where(nonlin_bin)[0]]

    all_nonlin[i_t] = this_nonlin
    sim_spikes[i_t] = np.random.poisson(this_nonlin)

# Visualize
fig, axes = plt.subplots(2, 1,figsize = (10, 5), sharex = True)

axes[0].plot(corr_stim[:1500], 'k')
axes[0].set(title = 'Stimulus')

axes[1].eventplot(np.where(sim_spikes[:1500])[0], color = 'k')
axes[1].set(title = 'Spikes', xlabel = 'Time bins');

The neuron always spikes around the same point in the pattern. This means that when we average over the stimuli before the spikes, we get unwanted structure in our STA. See the STA computed in the cell below. As we've seen, the neuron responds to high stimulus values about 50 ms before the spike. Because of the repeating nature of the stimulus, it looks from this STA like the neuron responds to negative values of the stimulus around .. before the spike, which is incorrect! In essence, the correlated nature of the stimulus has meant we haven't explored enough stimulus patterns to get a good sense of the neural dependency on the stimulus. This is an especially egregious case of a correlated stimulus - we have repeating chunks - but any amount of correlation can interfere with our ability to estimate the linear receptive field with a linear filter well!

 Execute to visualize STA computed from correlated stimulus


In [None]:
# @markdown Execute to visualize STA computed from correlated stimulus

# Step 1) Figure out how many time bins we want the STA to be
STA_length = int(300 / 2)

# Step 2) Figure out when spikes occur
spike_times = np.where(sim_spikes)[0]

# Step 3) Ignore spike times less than STA length
spike_times = spike_times[spike_times > STA_length]

# Step 4) Initialize STA
sim_STA = np.zeros((STA_length,))

# Step 5) Loop over spikes & add preceding stimulus
for i_sp in spike_times:
    
    # Add preceding stimulus
    sim_STA += corr_stim[i_sp - STA_length + 1: i_sp + 1]

# Step 6) Divide by number of spikes
sim_STA = sim_STA / len(spike_times)


fig, ax = plt.subplots(1, 1, figsize = (10, 7))
ax.plot(time_before_spike, sim_STA, 'k', lw = 2)

ax.set(title = 'STA', 
      xlabel = 'Time before spike (ms)',
      ylabel = 'Stimulus intensity');

# Optional Reading

Dayan & Abbott cover spike triggered averages in Chapter 1: http://www.gatsby.ucl.ac.uk/~lmate/biblio/dayanabbott.pdf

More advanced than we go, but a very thorough review of spike-triggered neural characterization: https://jov.arvojournals.org/article.aspx?articleid=2192881