# Working with Stimuli

So far we have learned how to load and handle spikes. Now we want to look at the stimuli eliciting the spikes to be able to investige the stimulus-response relationship of the recordings.

**You will learn to:**
 - Align spike times to stimuli
 - Split spike times into multiple trials
 - Plot peristimulus time histogram (PSTH)
 - Compute and compare ON-OFF index among several cells
 
Let's start again by importing relevant packages.

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import os

%matplotlib inline

In [None]:
plt.rcParams['figure.figsize'] = (15.0, 6.0) # set default size of plots

## 1 - On-Off Stimulus

We will take a look at a stimulus consisting of on-off steps in light intensity. This is roughly how the stimulus looks like.

<img src="onoffsteps.gif" width="200">

The information about the stimulus is stored in a text file. For this stimulus, the times of each step transition is provided in the file.

**Exercise:** Load the stimulus step times from the file `filepath` in to the variable `stimulus`, like we have learned yesterday with the spike trains.

In [None]:
filepath = os.path.join('data_on_off_steps', 'stimulus.txt')

### START CODE HERE ###
stimulus = np.loadtxt(filepath)
### END CODE HERE ###

To understand the loaded data, we will visualize them here with time versus light intensity.

In [None]:
# Create a list of cylcling intensities: 0.5, 1, 0.5, 0, 0.5, 1, 0.5, 0, ...
intensities = np.tile([0.5, 1, 0.5, 0], len(stimulus)//4)

plt.step(stimulus, intensities, where='post')
plt.xlabel('Time (s)')
plt.ylabel('Light intesity')
plt.xlim([0, 20]);

We now load the spike times.

In [None]:
# Load the spike timings from file
filepath = os.path.join('data_on_off_steps', '3_SP_C1001.txt')
spike_times = np.loadtxt(filepath)

## 3 - Alignment

A typical problem when handling data is the alignment of stimulus and recorded data, since it is not easy to start the stimulus and the recording at exactly the same time. Fortunately, we have the timings of both stimulus and spikes, such that we can simply cut off any spikes occuring beyond the first and last stimulus time.

In [None]:
spike_times = spike_times[spike_times > stimulus[0]]  # Remove spikes before
spike_times = spike_times[spike_times < stimulus[-1]]  # Remove spikes after

## 2 - Trials

Since we are dealing with a repeating stimulus, we will treat each stimulus cycle as a trial. To do so, we want to split the list of spike times into separate lists for each trial.

Each trial in the stimulus consists of four parts:
 1. Mean grey step
 1. On step
 1. Mean grey step
 1. Off step

Since we designed the stimulus, we know the length and start/end of each period (in seconds)

In [None]:
on_onset = 1.5
on_offset = 2.0
off_onset = 3.5
off_offset = 4.0
trial_length = 4.0

`stimulus` can thus be split in individuals trials by groups of four elements. Every fourth list element (starting with the first) indicates the time onset of that trial.

In [None]:
stimulus[::4]

In [None]:
num_trials = len(stimulus[::4])  # Total number of trials

**Exercise:** Using the function `np.digitize` you know from binning spikes, create a list `trial_idx` that contains the trial indices for all spikes, i.e. each element is a trial number that the corresponding spike belongs to. For later indexing start with trial 0.

In [None]:
### START CODE HERE ###
trial_idx = np.digitize(spike_times, stimulus[::4])
trial_idx -= 1  # Start with trial 0
### END CODE HERE ###

To find all spike times that belong the same trial we use the function `np.where`. The code below, for example, returns all spike times that happend during the first trial (remember indexing starts at 0).

In [None]:
spike_times[np.where(trial_idx == 0)]

For each trial, we create a list like the one above and collect all these lists in a list `trials`.

In [None]:
trials = []
for trial in range(num_trials):
    trials.append(spike_times[np.where(trial_idx == trial)])

Let's remind ourselves: The recording was done continuously and the spike times are steadily increasing. However, we want every trial to start at zero.

**Exercise:** Change the content of the loop above to align all trials at zero. **Hint:** Subtract the trial offset from  the spike_times.

In [None]:
trials = []
for trial in range(num_trials):
    ### CHANGE CODE HERE ###
    trials.append(spike_times[np.where(trial_idx == trial)]
                  - stimulus[::4][trial])
    ### END CODE HERE ###

Now plot the raster for the multiple trials like learned previously.

In [None]:
### START CODE HERE ###
plt.eventplot(trials)
plt.xlabel('Time (s)')
plt.ylabel('Trial');
### END CODE HERE ###

## 4 - Plot PSTH

A peristimulus time histogram (PSTH) is a histogram indicating how often a neuron spikes. It is useful to investigate the firing rate response of a neuron to a stimulus.

In [None]:
# Bin size
dt = 0.01

# Divide trial into bins of length dt
bins = np.arange(0, trial_length, dt)

# Appoint bin indices to each spike
indices = np.digitize(np.concatenate(trials), bins)

# Count spikes in each bin
rate = np.bincount(indices, minlength=len(bins))

# Normalize to spikes per seconds
rate = rate / (num_trials*dt)

Now let's plot the PSTH!

Only the first three lines are really necessary. The code below makes the plot more pretty and adds indicators of the stimulus.

In [None]:
plt.plot(bins, rate, 'k')
plt.xlabel('Time (s)')
plt.ylabel('Firing rate (Hz)')


# Make the plot pretty (you may inspect this code if you are interested)

# Create rectangles for the stimulus period in the respective color
from matplotlib.patches import Rectangle
rect = []
rect.append(Rectangle((0, -7), width=on_onset, height=3,
                      facecolor='lightgray'))
rect.append(Rectangle((on_offset, -7), width=off_onset-on_offset, height=3,
                      facecolor='lightgray'))
rect.append(Rectangle((off_onset, -7), width=off_offset-off_onset, height=3,
                      facecolor='black'))

# Add the rectangles to the plot
ax = plt.gca()
for r in rect:
    ax.add_patch(r)

# Remove the top and right axes
ax.spines['bottom'].set_smart_bounds(True)
ax.spines['left'].set_smart_bounds(True)
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.margins(x=0.03, y=0.07)

# Only show integers on x-axis
loc = ax.xaxis.get_major_locator()
loc.set_params(integer=True)

# Limit y-axis length
loc = ax.yaxis.get_major_locator()
loc.set_params(steps=[3])
ax.spines['left'].set_bounds(0, 90)
for i in [0, -1]:
    ax.get_yticklabels()[i].set_visible(False)
for i in [0, -2]:
    ax.get_yticklines()[i].set_visible(False)

## 5 - Calculate On-Off Index

...

## 6 - Compare Multiple Cells

...