In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal

We want to be able to take a noisy LFP or intracellular Vm trace, band-pass filter it in the theta band (4 - 10 Hz),<br>
and analyze the phase and amplitude of theta at each time point.<br>
<br>
Eventually, we want to take the intracellular Vm of a hippocampal place cell, and assign theta phase to spikes.<br>
We can also do this with the peaks of the theta-filtered intracellular Vm.<br>
If the phase of these spikes or Vm peaks changes with time or spatial position - this is phase precession!

In [None]:
duration = 10000. # ms
dt = 1  # ms
t = np.arange(0., duration, dt)

freq = 7 # Hz

For a given input signal (e.g. an oscillating CA3 input or local PV inhibitory interneuron), we want to be able to <br>
independently control:<br>
i) the peak amplitude (e.g. firing rate)<br>
ii) the "depth" of the oscillation, or (peak - trough) / peak<br>
iii) the phase of the oscillation peak relative to some reference oscillation

In [None]:
amplitude = 10.
theta_depth = 0.8
phase_offset = 0

trace = amplitude * (theta_depth * (np.cos(2. * np.pi * freq * t / 1000. + phase_offset) + 1.) / 2. + (1 - theta_depth))

In [None]:
plt.figure()
plt.plot(t, trace)
ylim = plt.ylim()
plt.ylim(0., ylim[1])
plt.ylabel('Firing rate (Hz)')
plt.xlabel('Time (ms)')

Here is the syntax for constructing a bandpass filter, using a sliding window of 2 second duration:

In [None]:
window_dur = 2. # sec
window_len = int(window_dur * 1000 / dt)
theta_filter = scipy.signal.firwin(window_len, [4., 10.], nyq=1000./2./dt, pass_zero=False)

We will attempt to filter the signal in three ways. First we will try the naive way, with no signal padding:

In [None]:
un_padded_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], trace, padtype=None)

Band-pass filtered traces are returned with zero mean, so we need to subtract the mean from our<br> 
original trace to superimpose them.

In [None]:
mean_subtracted_trace = trace - np.mean(trace)

plt.figure()
plt.plot(t, mean_subtracted_trace)
plt.plot(t, un_padded_filtered_trace)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')

This produces a bad edge artifact, as the window slides past the end of the trace.<br>
<br>
Now let's try three alternative methods to "pad" the start and ends of the trace.<br>
The 'odd' method is a 180 degree "rotation" of the trace around start (and end) point. This is the scipy default.<br>
The 'even' method is to "mirror" the trace on both ends (reflect about the y-axis).<br>
The 'gust' method attemps to find edge padding that makes the forward and backwards filters return the same result.

In [None]:
pad_len = int(window_len)
even_padded_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], trace, padtype='even', padlen=pad_len)

plt.figure()
plt.plot(t, mean_subtracted_trace, '--', label='original', zorder=1)
plt.plot(t, even_padded_filtered_trace, label='even filtered', zorder=0)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

In [None]:
odd_padded_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], trace, padtype='odd', padlen=pad_len)
gust_method_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], trace, method='gust')

In [None]:
plt.figure()
plt.plot(t, mean_subtracted_trace, '--', label='original', zorder=1)
plt.plot(t, odd_padded_filtered_trace, label='odd filtered', zorder=0)
plt.plot(t, even_padded_filtered_trace, label='even filtered', zorder=0)
plt.plot(t, gust_method_filtered_trace, label='gust filtered', zorder=0)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

We can also analyze the residual error between the filtered traces and the original trace.

In [None]:
plt.figure()
plt.plot(t, np.subtract(odd_padded_filtered_trace, mean_subtracted_trace), label='odd filtered')
plt.plot(t, np.subtract(even_padded_filtered_trace, mean_subtracted_trace), label='even filtered')
plt.plot(t, np.subtract(gust_method_filtered_trace, mean_subtracted_trace), label='gust filtered')
plt.ylabel('Residual error')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

In [None]:
print('odd filtered:', np.sum(np.abs(np.subtract(odd_padded_filtered_trace, mean_subtracted_trace))))
print('even filtered:', np.sum(np.abs(np.subtract(even_padded_filtered_trace, mean_subtracted_trace))))
print('gust filtered:', np.sum(np.abs(np.subtract(gust_method_filtered_trace, mean_subtracted_trace))))

It appears that mirror-padding ('even' padtype) produces the smallest amplitude edge artifact.

Now we want to take an oscillating band-pass filtered signal as a reference (like an LFP), and assign an<br>
oscillation phase (or angle) to each time point. We can use the Hilbert transformation for this.<br>
The input to the Hilbert transformation should be a mean subtracted signal. Our filter output<br>
should already meet this criterion:

In [None]:
reference = even_padded_filtered_trace
signal = scipy.signal.hilbert(reference)
envelope = np.abs(signal)
phase = np.angle(signal)

plt.figure()
plt.plot(t, signal, label='signal')
plt.plot(t, envelope, label='envelope')        
plt.xlabel('Time (ms)')
plt.ylabel('Amplitude')

plt.figure()
plt.plot(t, phase)
plt.ylabel('Angle (rads)')
plt.xlabel('Time (ms)')

This phase ranges from -pi to +pi. We could also convert it to be between 0 and 360 degrees:

In [None]:
phase_degrees = 360. / 2. / np.pi * (phase + np.pi)

plt.figure()
plt.plot(t, phase_degrees)
plt.ylabel('Angle (degrees)')
plt.xlabel('Time (ms)')

Now let's consider the above phase vs. time to be our reference (like an LFP).<br>
Now let's consider a recording of intracellular Vm - we want to take the peaks of the oscillation and determine<br>
their phase with respect to the reference (LFP) oscillation.

In [None]:
def get_osc(t, freq, depth, phase_offset):
    return depth * (np.cos(2. * np.pi * freq * t / 1000. - phase_offset) + 1.) / 2. + (1 - depth)

In [None]:
theta_Vm = get_osc(t, 7., 1., np.pi / 2)

plt.figure()
plt.plot(t, theta_Vm)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')

We can use the scipy method find_peaks to find the locations of relative local peaks in our trace.

In [None]:
peak_indexes = scipy.signal.find_peaks(theta_Vm)

In [None]:
type(peak_indexes), len(peak_indexes[0])

Each element in this returned tuple corresponds to 1 axis. The first (and only) axis here contained 70 peaks.

In [None]:
plt.figure()
plt.plot(t, theta_Vm)
plt.plot(t[peak_indexes[0]], theta_Vm[peak_indexes[0]], '.')
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')

Now we can look up the reference theta phase of these intracellular Vm peaks:

In [None]:
plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)

This shows that our method can correctly report the offset phase of a signal that is a constant oscillation<br>
at the same frequency as the reference.<br>
<br>
What about a signal that is oscillating slightly faster than the reference?

In [None]:
theta_Vm = get_osc(t, 7.2, 1., 0.)
peak_indexes = scipy.signal.find_peaks(theta_Vm)

plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)

That looks like phase precession! Except this is for a non-place cell that has a constant mean rate.

What about a place cell with intracellular Vm that has both a slow spatial ramp depolarization and a theta oscillation?

In [None]:
velocity = 20.  # cm / s
position = t / 1000. * velocity

def get_gaussian_rate(x, peak_loc, width, wrap=False):
    if wrap:
        dx = x[1] - x[0]
        length = dx * len(x)
        extended_x = np.concatenate([x - length, x, x + length])
    else:
        extended_x = x
    sigma = width / 3. / np.sqrt(2.)
    extended_rate = np.exp(-((extended_x - peak_loc) / sigma) ** 2)
    if wrap:
        rate = np.maximum(extended_rate[:len(x)], extended_rate[len(x):-len(x)])
        rate = np.maximum(rate, extended_rate[-len(x):])
    else:
        rate = extended_rate
    return rate

In [None]:
place_field = get_gaussian_rate(position, 100., 60.)

plt.figure()
plt.plot(position, place_field)
plt.ylabel('Amplitude')
plt.xlabel('Position (cm)')

Now we can multiply this spatial rate by a temporal theta modulation factor:

In [None]:
place_cell_theta = get_osc(t, 7., 0.7, 0.)

place_cell_rate = place_cell_theta * place_field

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], place_cell_rate, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, place_cell_rate, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

It appears that every one of the peaks of the theta-filtered Vm is aligned to the peaks of the reference.

In [None]:
plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

There is very little phase precession, mostly due to a slight warping of the oscillation by the spatial ramp.<br>
<br>
What if the Vm was oscillating slightly faster than the reference?

In [None]:
place_cell_theta = get_osc(t, 7.2, 0.7, 0.)

place_cell_rate = place_cell_theta * place_field

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], place_cell_rate, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, place_cell_rate, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

In [None]:
plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
pop_size = 200
track_length = duration / 1000. * velocity
field_width = 60.  # cm

peak_locs = np.linspace(-track_length/2., track_length + track_length / 2., pop_size)

input_spatial_rates = np.empty((pop_size, len(position)))
for i, peak_loc in enumerate(peak_locs):
    input_spatial_rates[i,:] = get_gaussian_rate(position, peak_loc, field_width)

In [None]:
plt.figure()
plt.imshow(input_spatial_rates, aspect='auto', extent=(0., track_length, pop_size, 0))
plt.xlabel('Position (cm)')
plt.ylabel('Input cell ID')

In [None]:
L = field_width / velocity
intra_freq = freq + 1. / L
c = 1. / (L * intra_freq)
tau_offsets = c * peak_locs / velocity

phase_offsets = 2. * np.pi * intra_freq * tau_offsets

# phase_offsets = peak_locs / velocity * 2. * np.pi * freq / (field_width / velocity) / intra_freq

plt.figure()
plt.plot(peak_locs / velocity, phase_offsets)

In [None]:
print(T)
print(intra_freq)
print(c)

In [None]:
input_rates = np.empty((pop_size, len(position)))
for i, peak_loc in enumerate(peak_locs):
    this_intra_theta = get_osc(t, intra_freq, 1., phase_offsets[i])
    input_rates[i,:] = input_spatial_rates[i,:] * this_intra_theta

In [None]:
place_cell_rate = input_rates[100]

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], place_cell_rate, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, place_cell_rate, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

In [None]:
plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
plt.figure()
for i in range(0, pop_size, 5):
    place_cell_rate = input_rates[i]
    theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], place_cell_rate, padtype='even', padlen=pad_len)
    peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)
    plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
pop_rate_sum = np.sum(input_rates, axis=0)
plt.figure()
for i in range(0, pop_size, 5):
    place_cell_rate = input_rates[i]
    plt.plot(t, place_cell_rate)
plt.plot(t, pop_rate_sum / np.max(pop_rate_sum))

In [None]:
normalized_pop_rate_sum = pop_rate_sum - np.min(pop_rate_sum)
normalized_pop_rate_sum /= np.max(normalized_pop_rate_sum)

plt.figure()
plt.plot(t, normalized_reference)
plt.plot(t, normalized_pop_rate_sum, '--')

In [None]:
theta_filtered_pop_rate_sum = scipy.signal.filtfilt(theta_filter, [1.], pop_rate_sum, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(pop_rate_sum, height=0.001)

plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
# building a place field in a CA1 cell by weighting the population of inputs from CA3
target_peak_loc = 100  # cm
weight = get_gaussian_rate(peak_locs, target_peak_loc, field_width)

plt.figure()
plt.scatter(peak_locs, weight)

intra_Vm = weight.dot(input_rates)
plt.figure()
plt.plot(t, intra_Vm)

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], intra_Vm, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

normalized_intra_Vm = intra_Vm / np.max(intra_Vm)

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, normalized_intra_Vm, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
# building a place field in a CA1 cell by weighting the population of inputs from CA3
target_peak_loc = 100  # cm
weight = get_gaussian_rate(peak_locs, target_peak_loc, field_width) + 1

plt.figure()
plt.scatter(peak_locs, weight)

intra_Vm = weight.dot(input_rates)

no_inh_intra_Vm = intra_Vm
plt.figure()
plt.plot(t, intra_Vm)

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], intra_Vm, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

normalized_intra_Vm = intra_Vm / np.max(intra_Vm)

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, normalized_intra_Vm, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

no_inh_peak_indexes = peak_indexes

plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
population_phases = np.empty((pop_size, ))

plt.figure()
for i in range(0, pop_size, 5):
    place_cell_rate = input_rates[i]
    theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], place_cell_rate, padtype='even', padlen=pad_len)
    peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)
    plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)

In [None]:
phase_offset = 0.
uniform_inh = get_osc(t, freq, 0.5, phase_offset)

plt.figure()
plt.plot(t, uniform_inh)

In [None]:
CA3_weights = weight
CA3_input_rates = input_rates
inh_weight = -10.

phase_offset = 0. # np.pi / 4
uniform_inh = get_osc(t, freq, 0.65, phase_offset)

intra_Vm = CA3_weights.dot(CA3_input_rates) + inh_weight * uniform_inh
plt.figure()
plt.plot(t, intra_Vm)
plt.plot(t, no_inh_intra_Vm)

theta_filtered_trace = scipy.signal.filtfilt(theta_filter, [1.], intra_Vm, padtype='even', padlen=pad_len)

peak_indexes = scipy.signal.find_peaks(theta_filtered_trace, height=0.001)

normalized_reference = 0.5 * reference / np.max(reference) + 0.5

normalized_intra_Vm = intra_Vm / np.max(intra_Vm)

plt.figure()
plt.plot(t, normalized_reference, label='reference')
plt.plot(t, normalized_intra_Vm, label='Vm')
plt.plot(t, theta_filtered_trace, label='theta Vm')
ylim = plt.ylim()
plt.vlines(t[peak_indexes[0]], ymin=ylim[0], ymax=ylim[1], color='grey', linestyle='dashed', alpha=0.5)
plt.ylim(ylim)
plt.ylabel('Amplitude')
plt.xlabel('Time (ms)')
plt.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

plt.figure()
plt.plot(t[peak_indexes[0]], phase_degrees[peak_indexes[0]], '.', label='With inh')
plt.plot(t[no_inh_peak_indexes[0]], phase_degrees[no_inh_peak_indexes[0]], '.', label='No inh')
plt.ylabel('Theta phase (degrees)')
plt.xlabel('Time (ms)')
plt.ylim(0., 360.)
plt.xlim(0., duration)
plt.legend(loc='best', frameon=False)

1. In-field_Vm - Out-of-field_Vm (w/ inhibition) = ramp = 6 - 8 mV w/ Inh
2. Block inh -> Both In-field_Vm and Out-of-field_Vm -> increase ~ 3 mV
3. In-field_theta_amp_Vm > Out-of-field_theta_amp_Vm (w / inhibition)
4. Block inh -> Increase both In-field_theta_amp_Vm and Out-of-field_theta_amp_Vm
5. ~180-360 phase precession w/ Inhibition
6. Block inh -> decrease phase precession


TODO:
 - Make synapses conductance-based 
 - Compare uniform inh to decreasing inh-in-field to increasing inh-in-field
 - Add NMDARs