# Working with Stimuli

So far we have learned how to load and handle spikes. Now we want to look at the stimuli eliciting the spikes to be able to investigate the stimulus-response relationship of the recordings.

**You will learn to:**
 - Align spike times to stimuli
 - Split spike times into multiple trials
 - Plot peristimulus time histogram (PSTH)
 - Compute and compare ON-OFF index among several cells
 
Let's start again by importing relevant packages.

In [None]:
%config InlineBackend.rc={'figure.figsize': (8, 4), 'font.size': 10}
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path

## 1 - On-Off Stimulus

We will take a look at a stimulus consisting of on-off steps in light intensity. This is roughly how the stimulus looks like.

<p><center><img src="images/onoffsteps.gif" width="175px"></center></p>

The information about the stimulus is stored in a text file. For this stimulus, the times of each step transition is provided in the file.

**Exercise:** Load the stimulus step times from the file `filepath` in to the variable `stimulus`, like we have learned yesterday with the spike trains.

In [None]:
filepath = 'data_on_off_steps/stimulus.txt'

### START CODE HERE ###
stimulus = np.loadtxt(filepath)
### END CODE HERE ###

In [None]:
# Let's check if your code is correct
print(stimulus[:5])

**Expected output:**  
`[10.0405 12.026299999999999 13.019200000000001 15.0051 15.998 ]`

To understand the loaded data, we will visualize them here with time versus light intensity.

For this we want to appoint appropriate changes in light intensity to each stimulus time. These light intensities cycle through  
`0.5, 1, 0.5, 0, 0.5, 1, 0.5, 0, ...`

To obtain the number of cycles, we divide the number of elements in the stimulus list by four:

In [None]:
cycles = np.floor(len(stimulus) / 4)

`cycles` is now a float, but for later use we need it as integer.

In [None]:
cycles = int(cycles)

To create a list with these repeated values we can append these values to a list repeatedly in a loop:

In [None]:
intensities = []
for i in range(cycles):
    intensities.append(0.5)
    intensities.append(1)
    intensities.append(0.5)
    intensities.append(0)

Here is a more concise way, where we add two lists together using the contracted `+=` operator.

In [None]:
intensities = []
for i in range(cycles):
    intensities += [0.5, 1, 0.5, 0]

In [None]:
print(intensities)

In [None]:
print(len(intensities))
print(len(stimulus))

Now we have, on the one hand the light intensity values and, on the other hand, stimulus times, i.e. the transition times of these intensities. With the function `plt.step`, we can now visualized the stimulus.

In [None]:
plt.step(stimulus, intensities, where='post')
plt.xlabel('Time (s)')
plt.ylabel('Light intesity');

The figure above shows the stimulus for the entire recording duration. It's not really easy to see the exact shape of the curve.

**Exercise:** Using `plt.xlim`, show only two cycles of the stimulus. You might have to try different values.

In [None]:
plt.step(stimulus, intensities, where='post')
plt.xlabel('Time (s)')
plt.ylabel('Light intensity')

### START CODE HERE
plt.xlim(10, 22);
### END CODE HERE

## 2 - Stimulus Step Durations

From the figure above we get a rough idea of stimulus. Since we are dealing with a repeating stimulus, we will treat each stimulus cycle as a trial. Each trial consists of four steps:

 1. Mean gray step
 1. On step
 1. Mean gray step
 1. Off step

The number of trials correspond to the number of cycles that we have determined above.

In [None]:
num_trials = cycles

In [None]:
num_trials

We also want to obtain the durations of the stimulus steps. The stimulus contains this information in the differences from one time event to the next, e.g. the difference between the first two elements is the duration of the first mean gray step.

In [None]:
stimulus[1] - stimulus[0]

To get the durations of all stimulus steps we can use `np.diff`, which computes the difference between adjacent items in a list. To get the durations for the first trial, we take the differences of the first five elements:

| Step           | Duration                    |
| ----           | --------                    |
| Mean gray step | `stimulus[1] - stimulus[0]` |
| On step        | `stimulus[2] - stimulus[1]` |
| Mean gray step | `stimulus[3] - stimulus[2]` |
| Off step       | `stimulus[4] - stimulus[3]` |

In [None]:
stimulus[:5]

In [None]:
np.diff(stimulus[:5])

Since instruments might not always be precise in timing for each trial, we want to compute this duration for each trial and then average these durations.

We start off by taking the differences of all adjacent elements.

In [None]:
durations = np.diff(stimulus)

To easily average across trails, we reshape our list into a matrix of four columns by number of trials.

In [None]:
durations.reshape(num_trials, 4)

What went wrong?

In [None]:
len(durations)

In [None]:
len(durations) / 4

Taking the differences of `n` elements, produces a list of `n-1` elements. For the last trial, the last duration is missing. So, we'll exclude this trial from our average, by removing the last three elements from `durations`.

In [None]:
durations = durations[:-3]
durations = durations.reshape(num_trials - 1, 4)

In [None]:
durations.shape

The `reshape` command can also auto-fill the last remaining dimension when using `-1` as argument. For example, the above line is equivalent to
```python
durations = durations.reshape(-1, 4)
```
`reshape` automatically infers that the `-1` corresponds to `num_trails - 1`.

In [None]:
durations = durations.reshape(-1, 4)

In [None]:
durations.shape

Let's have a look: We have the step durations in the columns for each trial in the rows.

In [None]:
durations[:10]  # Show the first 10 rows (i.e. trials)

Now we take the average across the trials, i.e. the rows.

In [None]:
durations_avg = durations.mean(axis=0)

Now we have the average duration of each stimulus step:

In [None]:
durations_avg

Additionally to the durations, we would like to know the average time onsets and offsets of the on and off steps relative to the beginning of the trial.

![](images/onoffstepoffsets.png)

For this we take the cumulative sum of the average durations.

In [None]:
step_timings = durations_avg.cumsum()

Let's have a look:

In [None]:
step_timings

We can now store these on- and offsets into their own variables.

In [None]:
on_start = step_timings[0]
on_end = step_timings[1]
off_start = step_timings[2]
off_end = step_timings[3]

If you prefer to do so, you can also do all of this in one line.

In [None]:
(on_start,
 on_end,
 off_start,
 off_end) = np.diff(stimulus)[:-3].reshape(-1, 4).mean(axis=0).cumsum()

Although the offset of the off-step matches the average length of the trail, here, we round up to the nearest full second. This will come in handy when splitting the trial later into equal-sized bins.

In [None]:
off_end  # The off-step offset is equal to the average trial length

In [None]:
trial_length = np.ceil(off_end)  # Rounding up to the nearest integer

In [None]:
trial_length

## 3 - Alignment

A typical problem when handling data is the alignment of stimulus and recorded data, since it is not easy to start the stimulus and the recording at exactly the same time. Fortunately, we have the timings of both stimulus and spikes to control for this.

In [None]:
# Load the spike timings from file
filepath = 'data_on_off_steps/8_SP_C3002.txt'
spike_times = np.loadtxt(filepath)

As you see here, the first spike occurred before the stimulus started.

In [None]:
spike_times[0], stimulus[0]

Likewise, the last spike recorded happened after the stimulus had ended.

In [None]:
spike_times[-1], stimulus[-1]

The following line gives us a list of truth values for each spike, as to whether it occurred after stimulus onset.

In [None]:
spike_times > stimulus[0]

We can use this list to index `spike_times` to return only the relevant spikes, such that we can simply cut off any spikes occurring before the first stimulus time.

In [None]:
spike_times[spike_times > stimulus[0]]

**Exercise:** Using this indexing technique retrieve a list of all spikes that occurred within the stimulus times and store it back into `spike_times`.

In [None]:
def align(spikes, stimulus):
    
    ### START CODE HERE ###
    spikes_aligned = spikes[spikes > stimulus[0]]  # Remove spikes before
    spikes_aligned = spikes_aligned[spikes_aligned < stimulus[-1]]  # And after
    
    # Faster alternative
    spikes_aligned = spikes[(spikes > stimulus[0]) & (spikes < stimulus[-1])]
    ### END CODE HERE ###

    return spikes_aligned

In [None]:
# Let's check if your code is correct
print(align(spike_times, stimulus))

**Expected output:**  
`[ 10.1834  10.3542  10.3721 ... 413.9438 413.9635 414.1202]`

In [None]:
spike_times = align(spike_times, stimulus)

## 4 - Trials

We want to refer to our spike times by the individual stimulus trials. To do so, we want to appoint a trial number to each spike and create separate lists of spikes for each trial.

From before we know that every fourth list element (starting with the first) indicates the time onset of each trial.

In [None]:
def get_trial_onsets(stimulus):
    trial_onsets = stimulus[::4]

    return trial_onsets

Let's have a look:

In [None]:
trial_onsets = get_trial_onsets(stimulus)

print(trial_onsets)

Each element of this list corresponds to the trial onset, while its index corresponds to the trial number. Any spike occurring between the first two elements can be appointed to the first trial, any spike occurring between the second and third element is appointed to the second trial and so on.

**Exercise:** Using the function `np.digitize`, create a list `trial_idx` that contains the trial index that each spike belongs to.  
In other words: Each element of `trial_idx` should be the trial number of the spike at the same position in `spike_times`. This will produce a list starting with `[0, 0, 0, 0, 0, ..., 1, 1, 1, 1, ...]`. For indexing with it later start with trial 0.

In [None]:
def get_trial_indices(spikes, trial_onsets):

    ### START CODE HERE ###
    trial_idx = np.digitize(spikes, trial_onsets)
    trial_idx -= 1  # Start with trial 0
    ### END CODE HERE ###
    
    return trial_idx

In [None]:
trial_idx = get_trial_indices(spike_times, trial_onsets)

# Let's check if your code is correct
print(trial_idx)

**Expected output:**  
`[ 0  0  0 ... 67 67 67]`

To find all spike times that belong to the same trial we use `trial_idx` to index `spike_times`. The code below, for example, returns all spike times that occurred during the first trial (remember indexing starts at 0). As you can see all elements are between `10.045` and `15.998`, which are the first two elements of `trial_onsets` (see above), and define the duration of the first trial.

In [None]:
spike_times[trial_idx == 0]

For each trial, we create a list like the one above and collect all these lists in a list `trials`.

In [None]:
def get_trials(spikes, trial_idx, trial_onsets):
    num_trials = np.unique(trial_idx).size
    
    trials = []
    for trial in range(num_trials):
        tr = spikes[trial_idx == trial]
        trials.append(tr)
    
    return trials

In [None]:
trials = get_trials(spike_times, trial_idx, trial_onsets)
trials

Let's remind ourselves: The recording was done continuously and the spike times are steadily increasing. However, we want every trial to start at zero.

In [None]:
plt.eventplot(trials, linewidths=0.8)
plt.xlim(0, 51)
plt.ylim(-1, 10)
plt.xlabel('Time (s)')
plt.ylabel('Trials');

**Exercise:** Let's replace the function from above and change the content of the loop above to align all trials at zero.  
*Hint:* Subtract the trial offset from  the spike_times.

In [None]:
def get_trials(spikes, trial_idx, trial_onsets):
    num_trials = np.unique(trial_idx).size
    
    trials = []
    for trial in range(num_trials):
        ### CHANGE CODE HERE ###
        tr = spikes[trial_idx == trial] - trial_onsets[trial]
        trials.append(tr)
        ### END CODE HERE ###
    
    return trials

In [None]:
trials = get_trials(spike_times, trial_idx, trial_onsets)

# Let's check if your code is correct
print(trials[0][:5])

**Expected output:**  
`[0.1429 0.3137 0.3316 0.4045 0.6188]`

`trials` is now a list of lists with trials by spike times. Let's have a look at the first three trials.

In [None]:
trials[:3]

**Exercise**: Now plot the raster for the multiple trials like learned previously.

In [None]:
### START CODE HERE ###
plt.eventplot(trials)
plt.xlabel('Time (s)')
plt.ylabel('Trial');
### END CODE HERE ###

Although we see a nice structure in the spike times, it is useful to visualize the corresponding stimulus changes.

**Exercise**: Add the following line of code to your raster plot:

```python
plt.vlines([on_start, on_end, off_start, off_end], 0, num_trials);
```

In [None]:
### START CODE HERE ###
plt.eventplot(trials)
plt.xlabel('Time (s)')
plt.ylabel('Trial')
plt.vlines([on_start, on_end, off_start, off_end], 0, num_trials);
### END CODE HERE ###

## 5 - Plot PSTH

A peristimulus time histogram (PSTH) is a histogram indicating how often a neuron spikes. It is useful to investigate the firing rate response of a neuron to a stimulus.

First we define bins to collect the spikes in.

In [None]:
# Bin size
dt = 0.01

# Divide trial into bins of length dt to obtain the left and right bin edges
binedges = np.arange(0, trial_length+dt, dt)

# Get the centers of the bin edges
bincenters = binedges[:-1] + dt/2

We need to flatten `trials` from a list of lists into a one dimensional list. The difference between `trials_flattened` and `spike_times` is that the spike times in `trials_flattened` are counted from each trial onset.

In [None]:
trials_flattened = np.concatenate(trials)

In [None]:
trials_flattened.shape

In [None]:
rate, binedges = np.histogram(trials_flattened, binedges)
rate = rate / (num_trials * dt)  # Normalize firing rate to spikes per seconds

**Exercise:** Fill the function below with the code to obtain the firing rate.

In [None]:
def get_firingrate(trials, trial_length, dt):

    ### START CODE HERE ###
    trials_flattened = np.concatenate(trials)
    
    binedges = np.arange(0, trial_length+dt, dt)
    bincenters = binedges[:-1] + dt/2
    rate = np.histogram(trials_flattened, binedges)[0]
    rate = rate / (num_trials * dt)
    ### END CODE HERE ###
    
    return rate, bincenters

In [None]:
rate, bincenters = get_firingrate(trials, trial_length, 0.01)

In [None]:
print(rate[:5])

**Expected output:**
```python
[ 4.41176471 13.23529412  1.47058824  0.          0.        ]
```

Now let's plot the PSTH!

Only the first three lines are really necessary. The code below makes the plot more pretty and adds indicators of the stimulus.

In [None]:
plt.plot(bincenters, rate, 'black')
plt.xlabel('Time (s)')
plt.ylabel('Firing rate (Hz)')

# Optionally make plot pretty (you may inspect this code if you are interested)

# Obtain the axes object from the figure
ax = plt.gca()

# Create rectangles for the stimulus period in the respective color
from matplotlib.patches import Rectangle
bar_h = rate.max()*0.025
ax.add_patch(Rectangle((0, -bar_h*2), width=on_start, height=bar_h, fc='gray'))
ax.add_patch(Rectangle((on_start, -bar_h*2), width=on_end-on_start,
                       height=bar_h, fc='lightgray'))
ax.add_patch(Rectangle((on_end, -bar_h*2), width=off_start-on_end,
                       height=bar_h, fc='gray'))
ax.add_patch(Rectangle((off_start, -bar_h*2), width=off_end-off_start,
                       height=bar_h, fc='black'))

# Remove top and right axes, and 'detach' left and bottom axes
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_bounds(0, rate.max())
ax.spines['bottom'].set_bounds(binedges.min(), binedges.max())

# Add some margins to the 'detached' axes
ax.margins(x=0.025, y=0.05)

## 6 - Calculate On-Off Index

The on-off index allows us to quantify the preference of the cell. The index ranges from `-1` (off cell) to `1` (on cell). It is computed from the on and off periods of the firing rate with

$$
\frac{\sum{r_\text{on}} - \sum{r_\text{off}}}{\sum{r_\text{on}} + \sum{r_\text{off}}} \,,
$$

where $r$ is the firing rate.

Let's first define $r_\text{on}$ (firing rate during on-step) and $r_\text{off}$ (firing rate during off-step). For those, we'll need to collect the indices of the bins corresponding to the periods (i.e. on-step and off-step). We can find the bins for the different steps by the starts and ends.

\begin{align}
\text{bins}_\text{on} &= \text{on}_\text{start} \leq \text{bins} < \text{on}_\text{end}\\
\text{bins}_\text{off} &= \text{off}_\text{start} \leq \text{bins} < \text{off}_\text{end}
\end{align}

In [None]:
indices_on = (on_start <= bincenters) & (bincenters < on_end)
indices_off = (off_start <= bincenters) & (bincenters < off_end)

We can now plug in theses indices into `rate` to obtain $r_\text{on}$ and $r_\text{off}$, i.e. `rate_on` and `rate_off`.

In [None]:
rate_on = rate[indices_on]
rate_off = rate[indices_off]

In [None]:
plt.plot(bincenters[indices_on], rate_on)
plt.plot(bincenters[indices_off], rate_off)
plt.xlim(0, trial_length);

**Exercise:** Given the indices of the on and off steps, implement the equation above to find the on-off index of the cell and write it to the variable `on_off_idx`.

$$
\frac{\sum{r_\text{on}} - \sum{r_\text{off}}}{\sum{r_\text{on}} + \sum{r_\text{off}}}
$$

*Hint:* To sum `rate_on` and `rate_off` use `np.sum()`.  
If you are done quickly, try to prevent a possible division-by-zero. The on-off index should be `np.nan` in that case.

In [None]:
def get_onoffindex(rate, indices_on, indices_off):
    rate_on = rate[indices_on]
    rate_off = rate[indices_off]
    
    ### START CODE HERE ###
    rate_on_sum = rate_on.sum()
    rate_off_sum = rate_off.sum()

    # Calculate on-off index
    if rate_on_sum == rate_off_sum == 0:
        on_off_idx = np.nan
    else:
        on_off_idx = (rate_on_sum-rate_off_sum) / (rate_on_sum+rate_off_sum)
    ### END CODE HERE ###
    
    return on_off_idx

In [None]:
on_off_idx = get_onoffindex(rate, indices_on, indices_off)

# Let's check if your code is correct
print(on_off_idx)

**Expected output:**  
`-0.8161078465260974`

## 7 - Compare Multiple Cells

Let's compare the on-off index between multiple cells.

This means, we'll need to perform all of the analyses above on all cells. Any operations we did regarding the stimulus, however, do not have to be redone, because the stimulus is identical of all cells. But since we have created functions for each step, we'll just have to use these and call them.

For each cell we'll have to...
 1. Load spike times from file
 1. Align the spike times to the stimulus times
 1. Split the spike times into trials
    - Get the trial onsets
    - Get the trial indices
    - Get the trials from the indices
 1. Compute the firing rates from the trials
 1. Compute the on-off index
 
**Note:** We have implemented all of these steps in individual functions, so we don't have to rewrite them, but only call the existing functions.
 
**Exercise:** Scrolling up to the previous pieces of code, call the functions we have written from the loop below.  
*Hint:* Since the stimulus is the same for all the cells you can omit retrieval of the onsets and offset as well as the computation of the bin indices for the on-off index.

In [None]:
# This list will collect the on-off indices of all cells
on_off_indices = []

# Generate a list of the spike files of all cells
filepaths = sorted(Path('data_on_off_steps').glob('8_SP_C*.txt'))

# Iterate over all files
for filepath in filepaths:

    ### START CODE HERE ###

    # 1. Load spike times from file
    spike_times = np.loadtxt(filepath)

    # 2. Align the spike times to the stimulus times
    spike_times = align(spike_times, stimulus)

    # 3. Split the spike times into trials
    trial_idx = get_trial_indices(spike_times, trial_onsets)
    trials = get_trials(spike_times, trial_idx, trial_onsets)

    # 4. Compute the firing rates from the trials
    rate, bincenters = get_firingrate(trials, trial_length, 0.01)

    # 5. Compute the on-off index
    on_off_idx = get_onoffindex(rate, indices_on, indices_off)

    ### END CODE HERE ###


    # Append the on-off index to the list
    on_off_indices.append(on_off_idx)

In [None]:
# Let's check if your code is correct
print(on_off_indices[:5])

**Expected output:**  
```
[0.8707865168539325, 0.7170731707317074, -0.24096385542168686, 0.022480058013053177, 0.7278338945005612]
```

**Exercise:** Visualize the distribution of on-off indices among all cells using a `plt.hist`.

In [None]:
### START CODE HERE ###
plt.hist(on_off_indices, bins=20, range=[-1, 1])
plt.xlabel('On-off index')
plt.ylabel('Number of cells')
plt.legend([f'n = {len(filepaths)}']);  # Show the number of cells
### START CODE HERE ###

We can also find example cells with interesting properties. For example, we can select the cell with the highest on-off index with

In [None]:
high_idx = np.argmax(on_off_indices)
high_idx

By copying the content of the file path loop from the second to last exercise, we can visualize the PSTH of this very cell.

In [None]:
filepath = filepaths[high_idx]

### COPY CODE FROM ABOVE ###

# 1. Load spike times from file
spike_times = np.loadtxt(filepath)

# 2. Align the spike times to the stimulus times
spike_times = align(spike_times, stimulus)

# 3. Split the spike times into trials
trial_idx = get_trial_indices(spike_times, trial_onsets)
trials = get_trials(spike_times, trial_idx, trial_onsets)

# 4. Compute the firing rates from the trials
rate, bincenters = get_firingrate(trials, trial_length, 0.01)

### END CODE HERE ###

We can also save the figure to disk by using `plt.savefig`. By supplying the file extension you can determine the file format, e.g. `png`, `svg`, `pdf`.

In [None]:
### THIS CODE IS COPIED AND MODIFIED FROM ABOVE ###
plt.plot(bincenters, rate, 'black')
plt.xlabel('Time (s)')
plt.ylabel('Firing rate (Hz)')
ax = plt.gca()
bar_h = rate.max()*0.025
ax.add_patch(Rectangle((0, -bar_h*2), width=on_start, height=bar_h, fc='gray'))
ax.add_patch(Rectangle((on_start, -bar_h*2), width=on_end-on_start,
                       height=bar_h, fc='lightgray'))
ax.add_patch(Rectangle((on_end, -bar_h*2), width=off_start-on_end,
                       height=bar_h, fc='gray'))
ax.add_patch(Rectangle((off_start, -bar_h*2), width=off_end-off_start,
                       height=bar_h, fc='black'))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_bounds(0, rate.max())
ax.spines['bottom'].set_bounds(binedges.min(), binedges.max())
ax.margins(x=0.025, y=0.05)
### THIS CODE IS COPIED AND MODIFIED FROM ABOVE ###


# Save the figure to disk
plt.savefig('figure.png');