# Setup

In [None]:
%matplotlib inline
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from bgcellmodels.common import units
units.set_units_module('quantities')
import quantities as pq

## Plotting Options

In [None]:
# Width of the page for calibrating fig_size.
# Approx. 16 for matplotlib backend %inline  8 for %notebook
page_width = 14
ax_height = 3

# Style of figures (default colors etc.): see https://matplotlib.org/gallery/style_sheets/style_sheets_reference.html
plt.style.use('default')

## Load Data

In [None]:
# All populations:
outputs = "/run/media/luye/Windows7_OS/Users/lkoelman/simdata-win/2018.06.27_job-777694.sonic-head_DA-depleted_CTX-poisson-f150_STN-lateral-f02_GPE-lateral-p05"

# Single population:
# outputs = [
# "/home/luye/storage/2018.06.21_job-testmpi6_DA-control_CTX-beta/CTX_2018.06.21_pop-100_dur-50.0_job-testmpi6.mat",
# ]

if isinstance(outputs, str):
    filenames = os.listdir(outputs)
    pop_files = [os.path.join(outputs, f) for f in filenames if f.endswith('.mat')]
else:
    pop_files = outputs

pops_segments = {}
read_segment_id = 0

# Read binary files using Neo IO module
for pop_file in pop_files:
    reader = neo.io.get_io(pop_file)
    blocks = reader.read()
    assert len(blocks) == 1, "More than one Neo Block in file."
    pop_label = blocks[0].name

    if len(blocks[0].segments)-1 < read_segment_id:
        raise ValueError("Segment index greater than number of Neo segments"
                         " in file {}".format(pop_file))
    if pop_label in pops_segments:
        raise ValueError("Duplicate population labels in files")
        
    pops_segments[pop_label] = blocks[0].segments[read_segment_id]

# Save all PSDs for comparison in figures
all_psd = {}
all_signals = {}

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# - Each segment has attributes 'analogsignals' and 'spiketrains'
# - Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

# List all recorded signals
for pop_label, segment in pops_segments.items():
    print("\n{} has following signals:".format(pop_label))
    for signal in segment.analogsignals:
        print("\t- '{:10}\t[{}] - description: {}".format(signal.name, signal.units, signal.description))
    print("\t- {} spiketrains".format(len(segment.spiketrains)))

# Spike Trains

<span style='color:red;font-weight:bold'>WARNING</SPAN>: In rastergram plots, note the number of spike trains plotted (see y-axis). If it is too high you get overlapping marker bars (marker height is larger than row height allocated to one spiketrain). This leads to misleading plots as spiketrains are overlapping which looks like an artificially elevated firing rate.

In [None]:
num_pops = len(pops_segments)
pops_per_fig = 2

# Plot spikes
fig_spikes, axes_spikes = None, None # plt.subplots(num_pops, 1, figsize=(10,14))

i_pop = 0
pop_spike_colors = 'rgcbm'
for pop_label, segment in pops_segments.items():
    # Don't plot all rastergrams in same figure
    if i_pop % pops_per_fig == 0:
        fig_spikes, axes_spikes = plt.subplots(pops_per_fig, 1, figsize=(page_width,2*ax_height), sharex=True)
        fig_spikes.suptitle('Spikes for each population')
    ax = axes_spikes[i_pop % pops_per_fig]
    
    # Which cells to plot (all cells causes overlapping markers and misleading plot)
    # cell_ids = range(len(segment.spiketrains))
    cell_ids = range(20)
    
    for i_train in cell_ids:
        spiketrain = segment.spiketrains[i_train]
        y = spiketrain.annotations.get('source_id', i_train)
        y_vec = np.ones_like(spiketrain) * y
        ax.plot(spiketrain, y_vec, marker='|', linestyle='', snap=True, color=pop_spike_colors[i_pop % 5])
        ax.set_ylabel('{} cell #'.format(pop_label))

    i_pop += 1

plt.show(block=False)

# Raw Signals

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
stn_vm_signal = signal

In [None]:
def plot_vm_signals(signal, cell_indices, interval):
    
    rec_dt = signal.sampling_period.magnitude
    irange = [0, signal.shape[0]-1] if interval is None else [int(t/rec_dt) for t in interval]
    times = signal.times[irange[0]:irange[1]]

    # Plot a bunch of STN Vm signals
    
    fig, axes = plt.subplots(len(cell_indices), 1, 
                             figsize=(0.75*page_width,2*ax_height),
                             sharex=True, sharey=True)
    fig.suptitle("{} membrane voltage".format(pop_label))

    for i_ax, i_cell in enumerate(cell_indices):
        ax = axes[i_ax]
        if 'source_ids' in signal.annotations:
            label = "id {}".format(signal.annotations['source_ids'][i_cell])
        else:
            label = "cell {}".format(i_cell)

        sig = signal[irange[0]:irange[1], i_cell]

        ax.plot(times, sig, label=label)
        ax.grid(True)
        ax.set_ylim((-80, 25))
        # ax.legend()

        if i_ax == len(cell_indices)-1:
            #ax.set_ylabel("voltage ({})".format(signal.units))
            ax.set_xlabel('time ({})'.format(times.units))

    fig.text(0.06, 0.5, "voltage ({})".format(signal.units), va='center', rotation='vertical')

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = [12.75e3, 14e3] # [2000.0, 6000.0]
cell_indices = range(5) # range(num_signals)
plot_vm_signals(stn_vm_signal, cell_indices, interval)

## STN LFP

In [None]:
# Load each individual cell's LFP contribution
pop_label = 'STN'
segment = pops_segments[pop_label]
lfp_sigs = next((sig for sig in segment.analogsignals if sig.name == 'lfp'))
lfp_summed = lfp_sigs.sum(axis=1)

# Turn it into AnalogSignal object
lfp_signal = stn_lfp_signal = neo.AnalogSignal(lfp_summed, units=lfp_sigs.units, 
                                               sampling_rate=stn_vm_signal.sampling_rate,
                                               t_start=stn_vm_signal.times[0])

In [None]:
interval = [12.75e3, 14e3] # None
rec_dt = lfp_sigs.sampling_period.magnitude
irange = [0, lfp_sigs.shape[0]] if interval is None else [int(t/rec_dt) for t in interval]
lfp_times = lfp_signal.times[irange[0]:irange[1]]

fig, ax = plt.subplots(figsize=(page_width, ax_height))
ax.plot(lfp_times, lfp_signal[irange[0]:irange[1]],
        label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 5], label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 8], 'r.', ms=1, mew=1, label='{} LFP'.format(pop_label))

# ax.set_ylim((-0.5, 4.0))
ax.set_xlim((lfp_times[0].magnitude, lfp_times[-1].magnitude))
ax.set_ylabel('LFP magnitude ({})'.format(lfp_sigs.units))
ax.set_xlabel('time ({})'.format(lfp_sigs.times.units))
ax.set_title('LFP for {} population'.format(pop_label))
ax.grid(True)
ax.legend()

## GPe Vm

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
gpe_vm_signal = signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = [12.75e3, 14e3] # [2000.0, 6000.0]
cell_indices = range(5) # range(num_signals)
plot_vm_signals(gpe_vm_signal, cell_indices, interval)

# Power Spectrum

## STN LFP

In [None]:
# Computes PSD of all 100 LFP signals at the same time
freqs, psd = elephant.spectral.welch_psd(lfp_signal, freq_res=0.5)
psd = psd.ravel() # we only have one axis so make 1-dimensional
all_psd['STN_LFP'] = (freqs, psd)

In [None]:
# Plot the PSD
fig, ax = plt.subplots(figsize=(0.5*page_width, ax_height))
ax.plot(freqs, psd, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for LFP of STN cells')
ax.legend()

### Beta Power Evolution

In [None]:
# Plot spectrogram using STFT
dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(stn_lfp_signal.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))

plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, abs(Sxx[:, int(1000.0/t_res):]).max())
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - evolution of STN LFP over time ($nV^2/Hz$)')

In [None]:
dur = stn_vm_signal.t_stop.magnitude
# print("Shape of spectrogram is ", Sxx.shape, " and frequecies is ", freqs.shape)

### Low Beta
bin_indices, = np.where((freqs >= 13) & (freqs <= 21))
beta_spectrogram = Sxx[bin_indices, :]
betapower = np.sum(beta_spectrogram, axis=0) / len(bin_indices)

# Find time to (1-exp(-1)) * max power
p_saturation = (1.0-np.exp(-1.0)) * betapower.max()
i_saturated, = np.where(betapower >= p_saturation)
t_saturated = t[i_saturated[0]]

### High Beta
bins_beta_high, = np.where((freqs >= 22) & (freqs <= 30))
beta_high_spectrogram = Sxx[bins_beta_high, :]
beta_high_power = np.sum(beta_high_spectrogram, axis=0) / len(bins_beta_high)

# Find time to (1-exp(-1)) * max power
p_sat_high = (1.0-np.exp(-1.0)) * beta_high_power.max()
i_sat_high, = np.where(beta_high_power >= p_sat_high)
t_sat_high = t[i_sat_high[0]]

fig, ax = plt.subplots(figsize=(0.75*page_width, 1.5*ax_height))
# Low Beta
ax.plot(t, betapower, label=r'$\beta$ ({} - {} Hz)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))
ax.hlines(p_saturation, 0, dur, 'orange', label='$(1-e^{-1}) * P_{max}$')
ax.plot(t_saturated, p_saturation, '+', color='red', markersize=5)

# High Beta
ax.plot(t, beta_high_power, 'r-', label=r'$\beta$ ({} - {} Hz)'.format(freqs[bins_beta_high[0]], freqs[bins_beta_high[-1]]))
ax.hlines(p_sat_high, 0, dur, 'orange')
ax.plot(t_sat_high, p_sat_high, '+', color='red', markersize=5)

# ax.vlines(t_saturated, 0, p_saturation, 'orange')
ax.set_ylabel('Power ($nV^2/Hz$)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))
ax.set_xlabel('time (ms)')
# ax.set_ylim((0, Sxx.max()))
ax.set_xlim((0, dur))
ax.grid(True)
ax.legend()
ax.set_title('Mean Power in range {} - {} Hz ($nV^2/Hz$)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))

print("Time to Beta-mid saturation is {} ms".format(t_saturated))
print("Max of mean Beta-mid power is {}".format(betapower.max()))

print("Time to Beta-high saturation is {} ms".format(t_sat_high))
print("Max of mean Beta-high power is {}".format(beta_high_power.max()))

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.plot(all_psd['STN_LFP'][0], all_psd['STN_LFP'][1] * 1e3, 'r', label='STN LFP x 10')
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average STN membrane voltages')
ax.legend()

## CTX spikes

### Convolve spike times with AP

In [None]:
# Load Pyramidal cell action potential from saved recording
pyramidal_trace = np.loadtxt('pyramidal.dat')
vm = pyramidal_trace[:,1]
tvec = pyramidal_trace[:,0]
dt = np.round(tvec[1] - tvec[0], 4)
ap_t_interval = [4.9, 20.0]
ap_i_interval = [int(t/dt) for t in ap_t_interval]
subsample = 2
ap_range = range(ap_i_interval[0], ap_i_interval[1]+1, subsample)
ap_kernel = vm[ap_range]
ap_baseline = -65.0
ap_kernel -= ap_baseline # center on 0 for convolution

plt.figure(figsize=(6,2))
plt.plot(tvec[ap_range], ap_kernel)
plt.suptitle('AP kernel for cortical neurons')
plt.grid(True)

In [None]:
pop_label = 'CTX'
segment = pops_segments[pop_label]
spiketrains = segment.spiketrains

In [None]:
# Construct AnalogSignal of N channels from N spiketrains
dur = np.round(spiketrains[0].t_stop.magnitude, 4)
dt = 0.05
signal_matrix = np.empty((int(dur/dt)+1, len(spiketrains)))

for i, st in enumerate(spiketrains):
    
    time = np.arange(0, dur + dt, dt)
    spiketimes = st.times
    spike_pulses = np.zeros_like(time)
    spike_pulses[[int(t/dt) for t in spiketimes]] = 1.0
    
    # Convole pulses at spike times with AP kernel
    spike_signal = np.convolve(spike_pulses, ap_kernel, mode='same') + ap_baseline
    signal_matrix[:, i] = spike_signal

ctx_vm_signal = neo.AnalogSignal(signal_matrix, units='mV', sampling_rate=1/dt/pq.ms,
                                t_start=spiketrains[0].t_start, t_stop=spiketrains[0].t_stop)
ctx_vm_mean = ctx_vm_signal.sum(axis=1) / ctx_vm_signal.shape[1]

### Average PSD

In [None]:
# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(ctx_vm_signal, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]
all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Mean Welch PSD for CTX membrane voltages')
ax.legend()

## GPe Vm

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
vm_sig = gpe_vm_signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

# Find peak frequency
i_peak = next((i for i,p in enumerate(psd_avg) if p == psd_avg.max()))
f_peak = freqs[i_peak]
print("PSD peak power occurs at f = {}".format(f_peak))

all_psd[pop_label] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(all_psd['STN_Vm'][0], all_psd['STN_Vm'][1], 'r', label='Avg STN Vm PSD')
ax.plot(freqs, psd_avg, label='Avg {} Vm PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average GPE membrane voltages')
ax.legend()

### Beta Power Evolution

In [None]:
# Plot spectrogram using STFT

# Mean GPe membrane voltage
gpe_vm_mean = gpe_vm_signal.sum(axis=1) / psd.shape[1]

dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(gpe_vm_mean.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))

plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
plt.clim(0, abs(Sxx).max())
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - Power of mean GPe Vm over time ($nV^2/Hz$)')

In [None]:
dur = gpe_vm_signal.t_stop.magnitude
# print("Shape of spectrogram is ", Sxx.shape, " and frequecies is ", freqs.shape)

### Low Beta
bin_indices, = np.where((freqs >= 13) & (freqs <= 21))
beta_spectrogram = Sxx[bin_indices, :]
betapower = np.sum(beta_spectrogram, axis=0) / len(bin_indices)

# Find time to (1-exp(-1)) * max power
p_saturation = (1.0-np.exp(-1.0)) * betapower.max()
i_saturated, = np.where(betapower >= p_saturation)
t_saturated = t[i_saturated[0]]

### High Beta
bins_beta_high, = np.where((freqs >= 22) & (freqs <= 30))
beta_high_spectrogram = Sxx[bins_beta_high, :]
beta_high_power = np.sum(beta_high_spectrogram, axis=0) / len(bins_beta_high)

# Find time to (1-exp(-1)) * max power
p_sat_high = (1.0-np.exp(-1.0)) * beta_high_power.max()
i_sat_high, = np.where(beta_high_power >= p_sat_high)
t_sat_high = t[i_sat_high[0]]

fig, ax = plt.subplots(figsize=(0.75*page_width, 1.5*ax_height))
# Low Beta
ax.plot(t, betapower, label=r'$\beta$ ({} - {} Hz)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))
ax.hlines(p_saturation, 0, dur, 'orange', label='$(1-e^{-1}) * P_{max}$')
ax.plot(t_saturated, p_saturation, '+', color='red', markersize=5)

# High Beta
ax.plot(t, beta_high_power, 'r-', label=r'$\beta$ ({} - {} Hz)'.format(freqs[bins_beta_high[0]], freqs[bins_beta_high[-1]]))
ax.hlines(p_sat_high, 0, dur, 'orange')
ax.plot(t_sat_high, p_sat_high, '+', color='red', markersize=5)

# ax.vlines(t_saturated, 0, p_saturation, 'orange')
ax.set_ylabel('Power ($nV^2/Hz$)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))
ax.set_xlabel('time (ms)')
# ax.set_ylim((0, Sxx.max()))
ax.set_xlim((0, dur))
ax.grid(True)
ax.legend()
ax.set_title('Mean Power in range {} - {} Hz ($nV^2/Hz$)'.format(freqs[bin_indices[0]], freqs[bin_indices[-1]]))

print("Time to Beta-mid saturation is {} ms".format(t_saturated))
print("Max of mean Beta-mid power is {}".format(betapower.max()))

print("Time to Beta-high saturation is {} ms".format(t_sat_high))
print("Max of mean Beta-high power is {}".format(beta_high_power.max()))

# Phase Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## Beta Phase

### Calculate Phase

In [None]:
# Calculate analytical signal -> phase
# See neuroscience & signal analysis papers for phase calculations, useful metrics, visualizations

signal = gpe_vm_signal

# Take 10 cells within shorter time interval
rec_dt = signal.sampling_period.magnitude
interval= [5e3, 22e3]
irange = [int(t/rec_dt) for t in interval]
islice = np.s_[irange[0]:irange[1]] # slice object

num_traces = 10
pslice = np.s_[0:num_traces]
traces_raw = signal[islice, pslice]

times = signal.times[islice]

In [None]:
# Design band-pass filter in frequency band of interest

# Elephant built-in filtering
# import elephant.signal_processing as sigproc
# signal_bp = sigproc.butter(signal, highpass_freq=20, lowpass_freq=30, order=3, filter_function='filtfilt')

# Manual band-pass filtering
Fs = signal.sampling_rate.rescale('Hz').magnitude
Fn = Fs / 2. # Nyquist frequency
hpfreq, lpfreq = 15.0, 33.0
order = 3

assert hpfreq < lpfreq
low, high = hpfreq / Fn, lpfreq / Fn
b, a = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False)

# Check filter stability (otherwise -> NaN values)
filter_stable = np.all(np.abs(np.roots(a))<1)
if not filter_stable:
    raise Exception("Unstable filter!")

# Plot filter response
w, h = scipy.signal.freqz(b, a, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
angles = np.unwrap(np.angle(h))
fax = w * Fn / (np.pi)

fig, axes = plt.subplots(2, 1, sharex=True)
fig.suptitle("Filter response (2*pi = {})".format(Fs))
ax = axes[0]
ax.plot(fax, abs(h), 'b') # 20 * np.log10(abs(h))
ax.set_ylabel('Amplitude [dB]', color='b')

ax = axes[1] # ax2 = ax.twinx()
ax.plot(fax, angles, 'g')
ax.set_ylabel('Angle (radians)', color='g')

# plt.axis('tight')
ax.set_xlim((0, 50))
ax.set_xlabel('Frequency [Hz]')
ax.grid(True)

In [None]:
# Filter signal
data = np.asarray(signal)
signal_bp = scipy.signal.filtfilt(b, a, data, axis=0) # can also use 'lfilter'

In [None]:
# Compute analytic signal - magnitude and phase
from scipy.signal import hilbert

traces_bp = signal_bp[islice, 0:num_traces]
analytic_signal = hilbert(traces_bp, axis=0)
analytic_mag = np.abs(analytic_signal)
analytic_phase = np.angle(analytic_signal)
# NOTE: phases are already wrapped
# analytic_phase = np.unwrap(np.angle(analytic_signal), axis=0) # transform angle in interval (0, 2*pi)

In [None]:
# Plot a bunch of STN Vm signals
fig, axes = plt.subplots(3, 1, figsize=(0.75*page_width,2*ax_height), sharex=True, sharey=False)
fig.suptitle("Analytic signal for GPE Vm")

trace_id = 1
iplot = np.s_[int(7500./rec_dt):int(8500./rec_dt)] # slice object

# Band-pass filtered trace
ax = axes[0]
ax.plot(times[iplot], traces_raw[iplot,trace_id], color='b', label='Vm raw')
ax.plot(times[iplot], traces_bp[iplot,trace_id], color='g', label='Vm bandpass')
ax.set_ylabel('Vm raw & filtered')
ax.grid(True)
# ax.legend()
# ax.set_ylim((-80, 25))

# Magnitude of analytic signal = amplitude envelope
ax = axes[1]
ax.plot(times[iplot], analytic_mag[iplot,trace_id], label='magnitude')
ax.set_ylabel('|analytic| [mV]')
ax.grid(True)

# Phase of analytic signal
ax = axes[2]
ax.plot(times[iplot], analytic_phase[iplot,trace_id], label='phase')
ax.grid(True)
ax.set_ylabel('angle(analytic) [rad]')
ax.set_xlabel('time ({})'.format(times.units))

### Plot Phase

See following matplotlib examples:
- https://matplotlib.org/gallery/pie_and_polar_charts/polar_demo.html
- https://matplotlib.org/api/animation_api.html
- http://tiao.io/posts/notebooks/embedding-matplotlib-animations-in-jupyter-as-interactive-javascript-widgets/

In [None]:
# Sample polar representation of analytic signal at fixed time points in Beta cycle
f_beta_peak = 25.5 # see calculation of f_peak above
phase_trange = times.magnitude
phase_tstart, phase_tstop = phase_trange[0], phase_trange[-1]

# Random point in cycle of main Beta frequency is ok, as long as it's consistent
beta_trigger_t = np.arange(phase_tstart, phase_tstop, 1e3/f_beta_peak)
beta_trigger_i = [int((t-phase_tstart)/rec_dt) for t in beta_trigger_t]

# Phasor vectors for all traces at all trigger times
vec_magnitudes = analytic_mag[beta_trigger_i, :]
vec_phases = analytic_phase[beta_trigger_i, :]

In [None]:
# Plot single time point of Beta phasors
ax = plt.subplot(111, projection='polar')

phases = vec_phases[1,:]
magnitudes = vec_magnitudes[1,:]

ax.grid(True)
ax.set_rmax(np.ceil(max(magnitudes)))
ax.set_rticks(np.linspace(np.ceil(min(magnitudes)),
                          np.ceil(max(magnitudes)),
                          5, endpoint=True))  # Less radial ticks
ax.set_rlabel_position(-90.0)  # Move radial labels away from plotted line

kw = dict(arrowstyle="->", color='g')
for angle, radius in zip(phases, magnitudes):
    ax.annotate("", xy=(angle, radius), xytext=(0, 0), arrowprops=kw)

In [None]:
# Plot evoluation of Beta phasors over time
from matplotlib import rc, animation
from IPython.display import HTML

fig = plt.figure()
ax = plt.subplot(111, projection='polar')

# Artists that will be updated during animation
arrows = []
for i in range(num_traces):
    ln, = ax.plot([], [], 'g-', animated=True)
    arrows.append(ln)

# Animation setup and animate function must return all modified artists
def init_animation():
    ax.grid(True)
    ax.set_rmax(np.ceil(max(magnitudes)))
    ax.set_rticks(np.arange(0, np.ceil(vec_magnitudes.max())+1, 2))
    ax.set_rlabel_position(-90.0)  # Move radial labels away from plotted line
    return arrows


def animate(i):
    for j, arrow in enumerate(arrows):
        xdata = [vec_phases[i, j]] * 2
        ydata = [0.0, vec_magnitudes[i, j]]
        arrow.set_data(xdata, ydata)
    return arrows

        
# Make animation
anim = animation.FuncAnimation(fig, animate, init_func=init_animation,
                               frames=range(vec_magnitudes.shape[0]), blit=True)
# rc('animation', html='jshtml') # set default output to JS animation
# HTML(anim.to_jshtml()) # or convert it excplitly
# rc('animation', html='html5') # set default output to video animation
HTML(anim.to_html5_video()) # or convert explicitly

## CTX - STN

### Coherence

In [None]:
# ctx_vm_mean = neo.AnalogSignal(ctx_vm_signal.sum(axis=1).reshape((-1,1)), units=ctx_vm_signal.units, 
#                               sampling_rate=ctx_vm_signal.sampling_rate, t_start=ctx_vm_signal.t_start)

# Copy signal metadata from CTX Vm signal
ctx_vm_mean = ctx_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(ctx_vm_mean, stn_lfp_signal, freq_res=0.5)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='CTX Vm - STN LFP')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between Cortex mean Vm - STN LFP')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

### Phase-Amplitude Coupling

In [None]:
# from tensorpac import Pac
# estimator = Pac(idpac=(1, 0, 0), fpha=(3, 33, 1, 1), famp=(20, 150, 5, 5),
#                 dcomplex='hilbert', filt='butter')

# Filter the data and extract PAC :
# xpac = estimator.filterfit(fs,
#                            xpha=stn_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            xamp=ctx_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            axis=0, traxis=1, njobs=2)
# xpac = estimator.filterfit(fs, 
#                            xpha=stn_lfp_signal.magnitude.reshape((-1,1)),
#                            xamp=ctx_vm_mean.magnitude.reshape((-1,1)),
#                            axis=0, traxis=1, njobs=2)

In [None]:
# Plot PAC
# estimator.comodulogram(xpac.mean(-1), title='PAC: STN LFP (phase) - CTX Vm (amplitude)',
#                        cmap='Spectral_r', plotas='imshow')

## STN - GPE

In [None]:
# Copy signal metadata from CTX Vm signal
gpe_vm_mean = gpe_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(stn_lfp_signal, gpe_vm_mean, freq_res=0.5)

### Coherence

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='STN LFP - Avg GPe Vm')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between STN LFP - Average GPe Vm')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

# Save Notebook

In [None]:
# alternative: %notebook -e foo.ipynb
# from IPython.display import Javascript
# script = '''
# require(["base/js/namespace"],function(Jupyter) {
#     Jupyter.notebook.save_checkpoint();
# });
# '''
# Javascript(script)

In [None]:
%%javascript
require(["base/js/namespace"],function(Jupyter) {
    Jupyter.notebook.save_checkpoint();
});
// Jupyter.notebook.kernel.execute("notebook_name = " + "\'"+Jupyter.notebook.notebook_name+"\'");

In [None]:
# import os.path
thisfile = 'synchrony_analysis.ipynb'
# outfile = os.path.join(outputs, 'synchrony_analysis.html')
# NOTE: template comes from ToC2 Notebook Extension
!jupyter nbconvert $thisfile --template=toc2 --output-dir=$outputs