# Setup

In [None]:
%matplotlib notebook
import numpy as np
import scipy.signal
import matplotlib
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from bgcellmodels.common import units
units.set_units_module('quantities')
import quantities as pq

# Jupyter notebook extensions
%load_ext bgcellmodels.extensions.jupyter.skip_cell_extension

## Plotting Options

In [None]:
# Width of the page for calibrating fig_size.
# Approx. 16 for matplotlib backend %inline  8 for %notebook
if matplotlib.get_backend() == 'nbAgg':
    from bgcellmodels.extensions.jupyter import jupyterutil
    jupyterutil.notebook_show_figs_after_exception() # fix bug for notebook backend where figures not shown
    page_width = 10
else:
    page_width = 14
ax_height = 3

# Style of figures (default colors etc.): see https://matplotlib.org/gallery/style_sheets/style_sheets_reference.html
plt.style.use('default')

## Load Data

In [None]:
# All populations:
outputs = "/run/media/luye/Windows7_OS/Users/lkoelman/simdata-win/q9_sweep-channel-densities/2018.08.10_job-781466.sonic-head_DA-depleted-v3_CTX-f0_STN-GPE-gNaP-x2.0"

# Single population:
# outputs = [
# "/home/luye/storage/2018.06.21_job-testmpi6_DA-control_CTX-beta/CTX_2018.06.21_pop-100_dur-50.0_job-testmpi6.mat",
# ]

if isinstance(outputs, str):
    filenames = os.listdir(outputs)
    pop_files = [os.path.join(outputs, f) for f in filenames if f.endswith('.mat')]
    params_file = next((os.path.join(outputs,f) for f in filenames if f.startswith('pop-parameters')), None)
else:
    pop_files = outputs
    params_file = None

pops_segments = {}
read_segment_id = 0

# Read binary files using Neo IO module
for pop_file in pop_files:
    reader = neo.io.get_io(pop_file)
    blocks = reader.read()
    assert len(blocks) == 1, "More than one Neo Block in file."
    pop_label = blocks[0].name

    if len(blocks[0].segments)-1 < read_segment_id:
        raise ValueError("Segment index greater than number of Neo segments"
                         " in file {}".format(pop_file))
    if pop_label in pops_segments:
        raise ValueError("Duplicate population labels in files")
        
    pops_segments[pop_label] = blocks[0].segments[read_segment_id]

In [None]:
# Save all PSDs for comparison in figures
all_psd = {}          # tuple[list[float]: freqs, list[float]: psd)
all_psd_peaks = {}    # tuple[list[float]: peak_freqs, list[float] : peak_psd)
all_psd_sum_subband = {} # tuple[list[tuple[int,int]]: band_limits, float: sum_psd)
all_fpeak = {}        # tuple[list[float]: freqs, list[float]: psd)

all_signals = {}   # neo.AnalogSignal variables
all_mean_rate = {}
all_morgera = {}
all_spectrogram = {}

# Data that will be exported
exported_data = {
    'mean_rate': all_mean_rate,
    'PSD': all_psd,
    'PSD_peaks': all_psd_peaks,
    'PSD_subband_power': all_psd_sum_subband,
    'PSD_input_freq': all_fpeak,
    'spectrogram': all_spectrogram,
    'Morgera_index': all_morgera,
}

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# - Each segment has attributes 'analogsignals' and 'spiketrains'
# - Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

# List all recorded signals
for pop_label, segment in pops_segments.items():
    print("\n{} has following signals:".format(pop_label))
    for signal in segment.analogsignals:
        print("\t- '{:10}\t[{}] - description: {}".format(signal.name, signal.units, signal.description))
    print("\t- {} spiketrains".format(len(segment.spiketrains)))

# Network Parameters

In [None]:
if params_file is not None:
    import pickle
    with open(params_file, 'rb') as pf:
        network_params = pickle.load(pf)
        print(network_params.keys())

## Connection Matrices

In [None]:
import analysis

# Print STN-GPE connection matrix
for pops in [('STN', 'GPE'), ('GPE', 'STN'), ('GPE', 'GPE'), ('STN', 'STN')]:
    stn_gpe_weights = network_params[pops[0]][pops[1]]['conn_matrix']
    analysis.plot_connectivity_matrix(stn_gpe_weights, pop0=pops[0], pop1=pops[1], pop_size=100, seaborn=False)

# Spike Trains

<span style='color:red;font-weight:bold'>WARNING</SPAN>: In rastergram plots, note the number of spike trains plotted (see y-axis). If it is too high you get overlapping marker bars (marker height is larger than row height allocated to one spiketrain). This leads to misleading plots as spiketrains are overlapping which looks like an artificially elevated firing rate.

In [None]:
line_colors = 'crkgbm'
pop_color_map = {k: line_colors[i % len(line_colors)] for i,k in enumerate(sorted(pops_segments.keys()))}

def get_pop_color(pop_label):
    return pop_color_map[pop_label]

In [None]:
# %%skip False

from bgcellmodels.common import signal as spikeprocs

def plot_avg_spikerate(pop_label, cell_ids=None, t_range=None, bin_width=20.0, adaptive=False, fig=None):
    """
    Plot running mean firing rate of population.
    Bin width also determines bin spacing (no sliding window).
    """
    segment = pops_segments[pop_label]
    if cell_ids is None:
        cell_ids = range(len(segment.spiketrains))
    if t_range is None:
        t_range = (0.0, np.round(segment.spiketrains[0].t_stop.magnitude, 3))
    
    min_spikes = 10
    # Adaptive method increases bin adaptively until 'min_spikes' included
    if adaptive:
        mean_rates = spikeprocs.nrn_avg_rate_adaptive([segment.spiketrains[i] for i in cell_ids],
                                                      t_range[0], t_range[1], binwidth=bin_width,
                                                      minsum=min_spikes).as_numpy()
    else:
        mean_rates = spikeprocs.nrn_avg_rate_simple([segment.spiketrains[i] for i in cell_ids],
                                                    t_range[0], t_range[1],
                                                    binwidth=bin_width).as_numpy()
    if fig is None:
        fig, ax = plt.subplots(1, 1, figsize=(page_width,ax_height))
    else:
        nrows = len(fig.axes)+1
        ax = fig.add_subplot(nrows,1,nrows)
        
    t_axis = np.arange(mean_rates.size) * bin_width / 2.0 + t_range[0]
    ax.plot(t_axis, mean_rates, color=get_pop_color(pop_label))
    
    fig.suptitle('{} mean population firing rate (running average)'.format(pop_label))
    return fig, ax

for pop in 'STN', 'GPE', 'CTX', 'STR':
    plot_avg_spikerate(pop, t_range=(4e3, 8e3))
#fig.tight_layout()

In [None]:
num_pops = len(pops_segments)
i_pop = 0

def plot_spiketrain(pop_label, cell_ids, t_range, sharex=None, sharey=None):
    """
    Plot spiketrains for one population.
    
    @param    cell_ids : list(int)
              Cell indices in population that will be visible (y-axis constrainment).
    """
    global i_pop
    segment = pops_segments[pop_label]
    sim_dur = segment.spiketrains[0].t_stop.magnitude
    
    # Don't plot all rastergrams in same figure
    fig_spikes = plt.figure(figsize=(page_width,ax_height))
    ax = plt.subplot(1,1,1, sharex=sharex, sharey=sharey)
    # fig_spikes, ax = plt.subplot(1, 1, figsize=(page_width,ax_height), sharex=sharex)
    fig_spikes.suptitle('{} spiketrains'.format(pop_label))
    
    pop_size = len(segment.spiketrains)
    
    # Plot all spiketrains but constrain y-axis later (so you can pan & zoom)
    for i_train in range(pop_size):
        spiketrain = segment.spiketrains[i_train]
        y = spiketrain.annotations.get('source_id', i_train)
        y_vec = np.ones_like(spiketrain) * y
        ax.plot(spiketrain, y_vec,
                marker='|', linestyle='', 
                snap=True, color=get_pop_color(pop_label))
    
    ax.set_xticks(np.arange(0, sim_dur+1000, 1000), minor=False)
    # ax.set_xticks(np.arange(0, sim_dur+5000, 5000), minor=False) # uncomment for long time range
    ax.set_yticks(np.arange(0, pop_size+5, 5), minor=False)
    ax.grid(True, axis='x', which='major')

    ax.set_xlim(t_range)
    ax.set_ylim((min(cell_ids)-0.5, max(cell_ids)+0.5))
    ax.set_ylabel('{} cell #'.format(pop_label))
    
    i_pop += 1
    return fig_spikes, ax
    
# Choose populations and cells indices to plot
shared_axis = None
for pop_label, segment in pops_segments.items():
    cell_indices = range(20)
    t_interval = (7e3, 9e3)
    fig, shared_axis = plot_spiketrain(pop_label, cell_indices, t_interval,
                                       sharex=shared_axis, sharey=shared_axis)

# plot_spiketrain('STN', range(20), (0.0, 5e3))

## Spike Statistics

In [None]:
import collections
pop_firing_rates = collections.OrderedDict()
for pop_label, segment in pops_segments.items():
    # Can use elephant.spike_train_generation.peak_detection() if only raw voltage signals
    pop_rate = 0.0
    for st in segment.spiketrains:
        pop_rate += elephant.statistics.mean_firing_rate(st).rescale('Hz').magnitude
    pop_rate = pop_rate / len(segment.spiketrains)
    pop_firing_rates[pop_label] = pop_rate
    
    all_mean_rate[pop_label] = pop_rate
    print("Mean firing rate for {} is {}".format(pop_label, pop_rate))

In [None]:
fig, ax = plt.subplots()

index = np.arange(len(pops_segments.keys()))
bar_width = 0.35
opacity = 0.4

ax.bar(index, pop_firing_rates.values(), bar_width, alpha=opacity, color='g')

ax.set_xlabel('Population')
ax.set_ylabel('Mean firing rate (Hz)')
ax.set_title('Mean firing rate for each population')
ax.set_xticks(index)
ax.set_xticklabels(pop_firing_rates.keys())
ax.set_yticks(np.arange(0, int(max(pop_firing_rates.values())+2), 5), minor=False)
ax.set_yticks(np.arange(0, int(max(pop_firing_rates.values())+2), 1.0), minor=True)
ax.grid(True, axis='y', which='major')
# ax.legend()

# Raw Signals

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
stn_vm_signal = signal

In [None]:
def plot_vm_signals(signal, cell_indices, interval, interval_only=True):
    
    rec_dt = signal.sampling_period.magnitude
    irange = [0, signal.shape[0]-1] if interval is None else [int(t/rec_dt) for t in interval]
    times = signal.times[irange[0]:irange[1]]
    
    fig, axes = plt.subplots(len(cell_indices), 1, 
                             figsize=(0.75*page_width,2*ax_height),
                             sharex=True, sharey=True)
    fig.suptitle("{} membrane voltage".format(pop_label))

    # Plot each Vm on separate axis
    for i_ax, i_cell in enumerate(cell_indices):
        try:
            ax = axes[i_ax]
        except TypeError:
            ax = axes
        if 'source_ids' in signal.annotations:
            label = "id {}".format(signal.annotations['source_ids'][i_cell])
        else:
            label = "cell {}".format(i_cell)
        
        if interval_only:
            ax.plot(times, signal[irange[0]:irange[1], i_cell], label=label)
        else:
            ax.plot(signal.times, signal[:, i_cell])
        
        ax.grid(True)
        ax.set_ylim((-90, 25))
        ax.set_xlim((times[0].magnitude, times[-1].magnitude))
        # ax.legend()

        if i_ax == len(cell_indices)-1:
            #ax.set_ylabel("voltage ({})".format(signal.units))
            ax.set_xlabel('time ({})'.format(times.units))

    fig.text(0.06, 0.5, "voltage ({})".format(signal.units), va='center', rotation='vertical')
    fig.subplots_adjust(bottom=0.15) # prevent clipping of xlabel

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = [13200, 13600] # [12.75e3, 14e3]
cell_indices = range(1) # range(num_signals)
plot_vm_signals(stn_vm_signal, cell_indices, interval, interval_only=False)

## STN LFP

In [None]:
# Load each individual cell's LFP contribution
pop_label = 'STN'
segment = pops_segments[pop_label]
lfp_sigs = next((sig for sig in segment.analogsignals if sig.name == 'lfp'), None)

WITHOUT_LFP = lfp_sigs is None
WITH_LFP = not WITHOUT_LFP

if not WITH_LFP:
    lfp_signal = None
else:
    lfp_summed = lfp_sigs.sum(axis=1)

    # Turn it into AnalogSignal object
    lfp_signal = stn_lfp_signal = neo.AnalogSignal(
                                    lfp_summed,
                                    units=lfp_sigs.units, 
                                    sampling_rate=stn_vm_signal.sampling_rate,
                                    t_start=stn_vm_signal.times[0])

In [None]:
if WITH_LFP:
    interval = [12.75e3, 14e3] # None
    rec_dt = lfp_sigs.sampling_period.magnitude
    irange = [0, lfp_sigs.shape[0]] if interval is None else [int(t/rec_dt) for t in interval]
    lfp_times = lfp_signal.times[irange[0]:irange[1]]

    fig, ax = plt.subplots(figsize=(page_width, ax_height))
    ax.plot(lfp_times, lfp_signal[irange[0]:irange[1]],
            label='{} LFP'.format(pop_label))
    # ax.plot(lfp_sigs.times, lfp_sigs[:, 5], label='{} LFP'.format(pop_label))
    # ax.plot(lfp_sigs.times, lfp_sigs[:, 8], 'r.', ms=1, mew=1, label='{} LFP'.format(pop_label))

    # ax.set_ylim((-0.5, 4.0))
    ax.set_xlim((lfp_times[0].magnitude, lfp_times[-1].magnitude))
    ax.set_ylabel('LFP magnitude ({})'.format(lfp_sigs.units))
    ax.set_xlabel('time ({})'.format(lfp_sigs.times.units))
    ax.set_title('LFP for {} population'.format(pop_label))
    ax.grid(True)
    ax.legend()

## GPe Vm

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
gpe_vm_signal = signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = [12.75e3, 14e3] # [2000.0, 6000.0]
cell_indices = range(5) # range(num_signals)
plot_vm_signals(gpe_vm_signal, cell_indices, interval)

# Power Spectra

In [None]:
psd_subbands = [(0,5), (6,12), (13,19), (20,30), (20,25), (25,30)]

In [None]:
# Total power in sub-bands
def calc_subband_power(freqs, psd, sig_label):
    Psum_subbands = []
    for subband in psd_subbands:
        subband_indices, = np.where((freqs >= subband[0]) & (freqs <= subband[1]))
        Psum = np.sum(psd[subband_indices])
        Psum_subbands.append(Psum)
        print("Power in sub-band {} is {} {}".format(
            subband, Psum.magnitude, Psum.units.dimensionality))
    all_psd_sum_subband[sig_label] = (psd_subbands, Psum_subbands)

# Find frequency of highest peak
def calc_peak_frequencies(freqs, psd, sig_label, percentile=80.0):
    """ Calculate PSD peak frequencies. """
    i_peak = next((i for i,p in enumerate(psd) if p == psd.max()))
    f_peak = freqs[i_peak]
    all_fpeak[sig_label] = f_peak.magnitude
    print("PSD peak power occurs at f = {}".format(f_peak))

    # Detect all peaks
    min_height = np.percentile(psd.magnitude, percentile)
    idx_peak, props_peak = scipy.signal.find_peaks(psd.magnitude, height=min_height, distance=int(4.0/freq_res))
    p_peak = props_peak['peak_heights']
    f_peak = freqs.magnitude[idx_peak]
    all_psd_peaks[sig_label] = (f_peak, p_peak)    # tuple[list[float]: peak_freqs, list[float] : peak_psd)

## STN LFP

In [None]:
%%skip $WITHOUT_LFP

pop_label = 'STN'
sig_label = pop_label + '_LFP'

# Computes PSD of all 100 LFP signals at the same time
freq_res = 0.5
Ts = lfp_signal.sampling_period.magnitude
freqs, psd = elephant.spectral.welch_psd(lfp_signal[int(5000/Ts):,:], freq_res=freq_res)
psd = psd.ravel() # we only have one axis so make 1-dimensional
psd_rel = psd[0:int(250/freq_res)] # relevant region of psd

all_psd[sig_label] = (freqs, psd)

In [None]:
%%skip $WITHOUT_LFP
calc_subband_power(freqs, psd, sig_label)
calc_peak_frequencies(freqs, psd_rel, sig_label, percentile=80.0)

In [None]:
%%skip $WITHOUT_LFP

# Plot the PSD
fig, ax = plt.subplots(figsize=(0.5*page_width, ax_height))
ax.plot(freqs, psd, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for LFP of STN cells')
ax.legend()

fig, ax = plt.subplots(figsize=(0.5*page_width, ax_height))
ax.plot(freqs, psd, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.set_ylim((0, 0.4))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD (shared scale)')
ax.legend()

<b> Spectrogram </b>

In [None]:
%%skip $WITHOUT_LFP

# Plot spectrogram using STFT
dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(stn_lfp_signal.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000

# Save spectrogram
df = freqs[1]-freqs[0]
all_spectrogram[sig_label] = (freqs[0:int(50/df)], t, Sxx[:,0:int(50/df)])

In [None]:
%%skip $WITHOUT_LFP

# Spectrogram 1
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, abs(Sxx[:, int(5000.0/t_res):]).max()) # discard first second
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - evolution of STN LFP over time ($nV^2/Hz$)')

# Spectrogram 2 (common axis)
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, 200)
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - evolution of STN LFP over time ($nV^2/Hz$)')

<b> Subband Power Evolution </b>

In [None]:
%%skip $WITHOUT_LFP

dur = stn_vm_signal.t_stop.magnitude

# Total power in sub-bands
Psum_subbands = {}
edges_subbands = {}
all_bands = [(13,19), (20,30), (20,25), (25,30)]
main_bands = [(13,19), (20,30)]
for subband in all_bands:
    bin_indices, = np.where((freqs >= subband[0]) & (freqs <= subband[1]))
    edges_subbands[subband] = (bin_indices[0], bin_indices[-1])
    subband_spectrogram = Sxx[bin_indices, :]
    
    # Summed power in subband over time
    Psum = np.sum(subband_spectrogram, axis=0) # / len(bin_indices)
    Psum_subbands[subband] = Psum
    
# ### Low Beta
# bin_indices, = np.where((freqs >= 13) & (freqs <= 21))
# beta_spectrogram = Sxx[bin_indices, :]
# betapower = np.sum(beta_spectrogram, axis=0) / len(bin_indices)

# # Find time to (1-exp(-1)) * max power
# p_saturation = (1.0-np.exp(-1.0)) * betapower.max()
# i_saturated, = np.where(betapower >= p_saturation)
# t_saturated = t[i_saturated[0]]

# ### High Beta
# bins_beta_high, = np.where((freqs >= 22) & (freqs <= 30))
# beta_high_spectrogram = Sxx[bins_beta_high, :]
# beta_high_power = np.sum(beta_high_spectrogram, axis=0) / len(bins_beta_high)

# # Find time to (1-exp(-1)) * max power
# p_sat_high = (1.0-np.exp(-1.0)) * beta_high_power.max()
# i_sat_high, = np.where(beta_high_power >= p_sat_high)
# t_sat_high = t[i_sat_high[0]]

In [None]:
%%skip $WITHOUT_LFP

# Plot power evolution
fig, ax = plt.subplots(figsize=(0.75*page_width, 1.5*ax_height))

for subband, Psum in Psum_subbands.items():
    linestyle = '-' if subband in main_bands else '--'
    ax.plot(t, Psum, linestyle, label='sum(P) ({} - {} Hz)'.format(
            freqs[edges_subbands[subband][0]],
            freqs[edges_subbands[subband][1]]))

# Figure decoration
# ax.set_ylim((0, Sxx.max()))
ax.set_xlim((0, dur))
ax.set_ylabel('Power ($nV^2/Hz$)')
ax.set_xlabel('time (ms)')
ax.grid(True)
ax.legend()
ax.set_title('Summed Power in sub-bands')

## STN Vm

In [None]:
pop_label = 'STN'
sig_label = 'STN_Vm'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
Ts = vm_sig.sampling_period.magnitude
freq_res = 0.5
freqs, psd = elephant.spectral.welch_psd(vm_sig[int(5000/Ts):,:], freq_res=freq_res)
psd_avg = psd.sum(axis=0) / psd.shape[0]
psd_units = psd.units
psd_rel = psd_avg[0:int(250/freq_res)] # relevant region of psd

# Save PSD
all_psd[sig_label] = (freqs, psd_avg)

In [None]:
# Total power in sub-bands
calc_subband_power(freqs, psd_rel, sig_label)

# Find frequency of highest peak
calc_peak_frequencies(freqs, psd_rel, sig_label, percentile=80.0)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, color='b', label='{} Vm'.format(pop_label))

ax.set_ylabel('Power ({})'.format(psd_avg.units), color='b')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
ax.set_title('Welch PSD for average STN membrane voltages')
# ax.set_yscale('log')
ax.legend(loc='upper right')

if WITH_LFP:
    ax2 = ax.twinx()
    ax2.plot(all_psd['STN_LFP'][0], all_psd['STN_LFP'][1], color='r', label='STN LFP x 10')
    ax2.set_ylabel('Power ({})'.format(all_psd['STN_LFP'][0].units), color='r')
    ax2.legend(loc='center right')

# PSD on shared scale for comparison with other simulations
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, color='b', label='{} Vm'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units), color='b')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.set_ylim((0, 60))
ax.grid(True)
ax.set_title('PSD of STN Vm (shared scale)')

In [None]:
# Compute mean Vm signal
stn_vm_mean = vm_sig.duplicate_with_new_array(vm_sig.sum(axis=1).reshape((-1,1)) / vm_sig.shape[1])

# Plot spectrogram using STFT
dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(stn_vm_mean.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000

# Save spectrogram
df = freqs[1]-freqs[0]
all_spectrogram[sig_label] = (freqs[0:int(50/df)], t, Sxx[:,0:int(50/df)])

In [None]:
# Spectrogram 1
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, abs(Sxx[:, int(5000.0/t_res):]).max()) # discard first second
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram of mean STN Vm ({})'.format(psd_units.dimensionality))

# Spectrogram 2 (common axis)
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, 50e3)
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram of mean STN Vm ({})'.format(psd_units.dimensionality))

## CTX spikes

In [None]:
pop_label = 'CTX'
sig_label = 'CTX_Vm'
segment = pops_segments[pop_label]
spiketrains = segment.spiketrains

<b> Convolve spike times with AP </b>

In [None]:
# Load Pyramidal cell action potential from saved recording
pyramidal_trace = np.loadtxt('pyramidal.dat')
vm = pyramidal_trace[:,1]
tvec = pyramidal_trace[:,0]
dt = np.round(tvec[1] - tvec[0], 4)
ap_t_interval = [4.9, 20.0]
ap_i_interval = [int(t/dt) for t in ap_t_interval]
subsample = 2
ap_range = range(ap_i_interval[0], ap_i_interval[1]+1, subsample)
ap_kernel = vm[ap_range]
ap_baseline = -65.0
ap_kernel -= ap_baseline # center on 0 for convolution

plt.figure(figsize=(6,2))
plt.plot(tvec[ap_range], ap_kernel)
plt.suptitle('AP kernel for cortical neurons')
plt.grid(True)

In [None]:
# Construct AnalogSignal of N channels from N spiketrains
dur = np.round(spiketrains[0].t_stop.magnitude, 4)
dt = 0.05
signal_matrix = np.empty((int(dur/dt)+1, len(spiketrains)))

for i, st in enumerate(spiketrains):
    
    time = np.arange(0, dur + dt, dt)
    spiketimes = st.times
    spike_pulses = np.zeros_like(time)
    spike_pulses[[int(t/dt) for t in spiketimes]] = 1.0
    
    # Convole pulses at spike times with AP kernel
    spike_signal = np.convolve(spike_pulses, ap_kernel, mode='same') + ap_baseline
    signal_matrix[:, i] = spike_signal

ctx_vm_signal = neo.AnalogSignal(signal_matrix, units='mV', sampling_rate=1/dt/pq.ms,
                                t_start=spiketrains[0].t_start, t_stop=spiketrains[0].t_stop)
ctx_vm_mean = ctx_vm_signal.sum(axis=1) / ctx_vm_signal.shape[1]

<b> Average PSD </b>

In [None]:
# Computes PSD of all 100 Vm signals at the same time
freq_res = 0.5
freqs, psd = elephant.spectral.welch_psd(ctx_vm_signal, freq_res=freq_res)
psd_avg = psd.sum(axis=0) / psd.shape[0]
psd_rel = psd_avg[0:int(250/freq_res)] # relevant region of psd

all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
# Total power in sub-bands
calc_subband_power(freqs, psd, sig_label)

# Find frequency of highest peak
calc_peak_frequencies(freqs, psd_rel, sig_label, percentile=80.0)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Mean Welch PSD for CTX membrane voltages')
ax.legend()

## GPe Vm

In [None]:
pop_label = 'GPE'
sig_label = 'GPE_Vm'
segment = pops_segments[pop_label]
vm_sig = gpe_vm_signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freq_res = 0.5
Ts = vm_sig.sampling_period.magnitude
freqs, psd = elephant.spectral.welch_psd(vm_sig[int(5000/Ts):,:], freq_res=freq_res)
psd_avg = psd.sum(axis=0) / psd.shape[0]
psd_rel = psd_avg[0:int(250/freq_res)] # relevant region of psd

all_psd[sig_label] = (freqs, psd_avg)

In [None]:
# Total power in sub-bands
calc_subband_power(freqs, psd_rel, sig_label)

# Find frequency of highest peak
calc_peak_frequencies(freqs, psd_rel, sig_label, percentile=80.0)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(all_psd['STN_Vm'][0], all_psd['STN_Vm'][1], 'r', label='STN avg(Vm)')
ax.plot(freqs, psd_avg, label='{} avg(Vm)'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average GPE membrane voltages')
ax.legend()

# PSD on shared scale for comparison with other simulations
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} avg(Vm)'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.set_ylim((0, 60))
ax.grid(True)
ax.set_title('PSD for GPE Vm (shared scale)')
ax.legend()

<b> Spectrogram </b>

In [None]:
# Mean GPe membrane voltage
gpe_vm_mean = gpe_vm_signal.sum(axis=1) / gpe_vm_signal.shape[1]

# Calculate spectrogram using STFT
dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(gpe_vm_mean.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000

# Save spectrogram
df = freqs[1]-freqs[0]
all_spectrogram[sig_label] = (freqs[0:int(50/df)], t, Sxx[:,0:int(50/df)])

In [None]:
# Spectrogram 1
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
plt.clim(0, abs(Sxx[:, int(5000.0/t_res):]).max()) # discard first second
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('GPe mean(Vm) Power ($nV^2/Hz$)')

# Spectrogram 2 (common axis)
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
plt.pcolormesh(t, freqs, Sxx)
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
# plt.clim(0, 20)
plt.clim(0, 50e3)
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('GPe mean(Vm) Power, shared axis ($nV^2/Hz$)')

In [None]:
dur = gpe_vm_signal.t_stop.magnitude

# Total power in sub-bands
Psum_subbands = {}
edges_subbands = {}
all_bands = [(13,19), (20,30), (20,25), (25,30)]
main_bands = [(13,19), (20,30)]
for subband in all_bands:
    bin_indices, = np.where((freqs >= subband[0]) & (freqs <= subband[1]))
    edges_subbands[subband] = (bin_indices[0], bin_indices[-1])
    subband_spectrogram = Sxx[bin_indices, :]
    
    # Summed power in subband over time
    Psum = np.sum(subband_spectrogram, axis=0) # / len(bin_indices)
    Psum_subbands[subband] = Psum

In [None]:
# Plot power evolution
fig, ax = plt.subplots(figsize=(0.75*page_width, 1.5*ax_height))

for subband, Psum in Psum_subbands.items():
    linestyle = '-' if subband in main_bands else '--'
    ax.plot(t, Psum, linestyle, label='sum(P) ({} - {} Hz)'.format(
            freqs[edges_subbands[subband][0]],
            freqs[edges_subbands[subband][1]]))

# Figure decoration
# ax.set_ylim((0, Sxx.max()))
ax.set_xlim((0, dur))
ax.set_ylabel('Power ($nV^2/Hz$)')
ax.set_xlabel('time (ms)')
ax.grid(True)
ax.legend()
ax.set_title('Summed Power in sub-bands')

# Phase Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## STN Phase

In [None]:
# Calculate analytical signal -> phase
sig_label = 'STN_Vm'
signal = stn_vm_signal

# Take 10 cells within shorter time interval
rec_dt = signal.sampling_period.magnitude
interval= [1e3, 26e3]
irange = [int(t/rec_dt) for t in interval]
islice = np.s_[irange[0]:irange[1]] # slice object

num_traces = 100
pslice = np.s_[0:num_traces]
traces_raw = signal[islice, pslice]

times = signal.times[islice]

### Phase Calculation

In [None]:
# Band-pass filter in frequency band of interest

# Elephant built-in filtering
# import elephant.signal_processing as sigproc
# signal_bp = sigproc.butter(signal, highpass_freq=20, lowpass_freq=30, order=3, filter_function='filtfilt')

# Manual band-pass filtering
Fs = signal.sampling_rate.rescale('Hz').magnitude
Fn = Fs / 2. # Nyquist frequency
hpfreq, lpfreq = 5.0, 15.0
order = 4

assert hpfreq < lpfreq
low, high = hpfreq / Fn, lpfreq / Fn
b, a = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False)
sos = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False, output='sos')
# sos = scipy.signal.butter(order, high, btype='lowpass', analog=False, output='sos')

# Check filter stability (otherwise -> NaN values)
filter_stable = np.all(np.abs(np.roots(a))<1)
if not filter_stable:
    # raise Exception("Unstable filter!")
    print("Filter in b, a form is unstable!")

# Plot filter response
w, h = scipy.signal.freqz(b, a, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
w, h = scipy.signal.sosfreqz(sos, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
angles = np.unwrap(np.angle(h))
fax = w * Fn / (np.pi)

fig, axes = plt.subplots(2, 1, sharex=True)
fig.suptitle("Filter response (2*pi = {})".format(Fs))
ax = axes[0]
ax.plot(fax, abs(h), 'b') # 20 * np.log10(abs(h))
ax.set_ylabel('Amplitude [dB]', color='b')

ax = axes[1] # ax2 = ax.twinx()
ax.plot(fax, angles, 'g')
ax.set_ylabel('Angle (radians)', color='g')

# plt.axis('tight')
ax.set_xlim((0, 50))
ax.set_xlabel('Frequency [Hz]')
ax.grid(True)

In [None]:
# Filter signal
data = np.asarray(signal)
# signal_bp = scipy.signal.filtfilt(b, a, data, axis=0) # can also use 'lfilter'
signal_bp = signal.duplicate_with_new_array(scipy.signal.sosfiltfilt(sos, data, axis=0))

In [None]:
# Subsample to preserve up to 2 x highest frequency
fmax = 2 * lpfreq
fs_old = signal.sampling_rate.rescale('Hz').magnitude 
subsample_factor = int(fs_old / (2 * fmax))
subsample_factor = min(20, subsample_factor)

print("Subsampling with factor {}".format(subsample_factor))
signal_bpss = signal_bp[::subsample_factor, :] # adjusts sampling period property
print("New sampling period is {}".format(signal_bpss.sampling_period))

# Adjust indices of intervals
Ts = signal_bpss.sampling_period.magnitude
irange_ss = [int(t/Ts) for t in interval]
islice_ss = np.s_[irange_ss[0]:irange_ss[1]] # slice object
times_ss = signal_bpss.times[islice_ss]

In [None]:
# Compute analytic signal - magnitude and phase
from scipy.signal import hilbert

traces_bp = signal_bpss[islice_ss, 0:num_traces]
traces_raw = signal[islice, 0:num_traces]

analytic_signal = hilbert(traces_bp, axis=0)
analytic_mag = np.abs(analytic_signal)
analytic_phase = np.angle(analytic_signal)
# NOTE: phases are already wrapped
# analytic_phase = np.unwrap(np.angle(analytic_signal), axis=0) # transform angle in interval (0, 2*pi)

In [None]:
# Check result : plot filtered signals
fig, axes = plt.subplots(3, 1, figsize=(0.75*page_width,2*ax_height), sharex=True, sharey=False)
fig.suptitle("Filtered & Analytic Signal")

trace_id = 1
tplot = (7500., 8500.)
iplot = np.s_[int(tplot[0]/rec_dt):int(tplot[1]/rec_dt)] # slice object
iplot_ss = np.s_[int(tplot[0]/Ts):int(tplot[1]/Ts)]

# Band-pass filtered trace
ax = axes[0]
ax.plot(times[iplot], traces_raw[iplot,trace_id], color='b', label='Vm raw')
ax.plot(times_ss[iplot_ss], traces_bp[iplot_ss, trace_id], color='g', label='Vm bandpass')
ax.set_ylabel('Vm raw & filtered')
ax.grid(True)
ax.legend()
# ax.set_ylim((-80, 25))

# Magnitude of analytic signal = amplitude envelope
ax = axes[1]
ax.plot(times_ss[iplot_ss], analytic_mag[iplot_ss, trace_id], label='magnitude')
ax.set_ylabel('|analytic| [mV]')
ax.grid(True)

# Phase of analytic signal
ax = axes[2]
ax.plot(times_ss[iplot_ss], analytic_phase[iplot_ss, trace_id], label='phase')
ax.grid(True)
ax.set_ylabel('angle(analytic) [rad]')
ax.set_xlabel('time ({})'.format(times.units))

### Phase Variability

In [None]:
# Sample polar representation of analytic signal at fixed time points in Beta cycle
f_beta_peak = all_fpeak['STN_Vm']
print("Calculating phase vectors at multiples of {} Hz".format(f_beta_peak))

phase_trange = times.magnitude
phase_tstart, phase_tstop = phase_trange[0], phase_trange[-1]

# Random point in cycle of main Beta frequency is ok, as long as it's consistent
beta_trigger_t = np.arange(phase_tstart, phase_tstop, 1e3/f_beta_peak)
beta_trigger_i = [int((t-phase_tstart)/Ts) for t in beta_trigger_t]

# Phasor vectors for all traces at all trigger times
vec_magnitudes = analytic_mag[beta_trigger_i, :]
vec_phases = analytic_phase[beta_trigger_i, :]

In [None]:
# Instantaneous standard deviation of phases
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(times_ss[::20], analytic_phase[::20, :].std(axis=1))
ax.plot(beta_trigger_t, vec_phases.std(axis=1), 'r.')
ax.set_ylabel('sigma (radians)')
ax.set_xlabel('time ({})'.format(times.units))
ax.set_title('Standard deviation of STN population phases')
ax.grid(True)

### Phase Differences

In [None]:
# PLV (Phase Locking Value)
# Choose reference signal
ref_id = 0
phase_reference = analytic_phase[:, ref_id]
phase_diffs = np.unwrap(analytic_phase - phase_reference[:,np.newaxis], axis=0) / (2*np.pi)

fig, axes = plt.subplots(5, 1, figsize=(0.75*page_width, 5*ax_height))
for i, ax in enumerate(axes):
    if i == 0:
        ax.set_title('Phase difference with cell {}'.format(ref_id))
    ax.plot(times_ss, phase_diffs[:,(10*i):10*(i+1)])
    ax.set_ylabel('phase difference (normalized)')
    ax.set_xlabel('time ({})'.format(times_ss.units))
    # grid
    ax.set_yticks(np.arange(-10,10), minor=True)
    # ax.set_ylim((-5,5))
    ax.grid(which='both', linestyle=':') # for minor and major ticks

In [None]:
def smooth(x, window_len=11, window='flat'):
    """
    Smooth using window. See https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    """
    if x.ndim != 1:
        raise ValueError, "smooth only accepts 1 dimension arrays."
    if x.size < window_len:
        raise ValueError, "Input vector needs to be bigger than window size."
    if window_len<3:
        return x

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"

    # Pad the signal on both sides
    s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]]
    if window == 'flat': #moving average
        w = np.ones(window_len,'d')
    else:
        w = eval('np.'+window+'(window_len)')

    y = np.convolve(w/w.sum(),s,mode='valid')
    return y

In [None]:
# Smooth the phase differences
window_ms = 500.0
window_len = int(window_ms / signal_bpss.sampling_period.magnitude)

delta_phi_smooth = np.empty((analytic_phase.shape[0]+window_len-1, analytic_phase.shape[1])) # for convolve
#delta_phi_smooth = np.empty(analytic_phase.shape) # for medfilt
for i in range(analytic_phase.shape[1]):
    delta_phi_smooth[:,i] = smooth(phase_diffs[:,i], window_len, window='hanning')
    #delta_phi_smooth[:,i] = scipy.signal.medfilt(phase_diffs[:,i], window_len+1)

In [None]:
# Calculate gradient using 2nd order central differences
d_delta_phi = np.gradient(delta_phi_smooth, 2, axis=0)
avg_ddphi = np.mean(d_delta_phi, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.set_title('Average derivative of phase difference with cell {}'.format(ref_id))

t = np.arange(avg_ddphi.size) * Ts
ax.plot(t, avg_ddphi)
# ax.plot(t, delta_phi_smooth[:,1:5])
ax.set_ylabel(r'$ \frac{d}{dt} \Delta \phi $')
ax.set_xlabel('time ({})'.format(times_ss.units))
ax.grid(True)

# Plot a few derivates to check
# fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
# ax.plot(t, d_delta_phi[:,1:5])
# ax.set_ylabel(r'$ \frac{d}{dt} \Delta \phi $')
# ax.set_xlabel('time ({})'.format(times_ss.units))
# ax.grid(True)

### SPIKE Synchronization

In [None]:
import pyspike
# Convert to PySpike data format
trains = []
for st in pops_segments['STN'].spiketrains:
    trains.append(pyspike.SpikeTrain(st.times.magnitude, (st.t_start, st.t_stop)))

In [None]:
sync_profile = pyspike.spike_sync_profile(*trains)

In [None]:
# Plot spike syncronization profile
x, y = sync_profile.get_plottable_data()
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(x,y)
ax.set_xlim((10e3, 20e3))
ax.set_title('Kreuz SPIKE Synchronization for STN cells')
ax.set_ylabel('synchronization index')
ax.set_xlabel('time (ms)')

### Morgera Synchronization Index

In [None]:
M_values = []
delta_t = 200.0
window = 1000.0
t0_values = np.arange(0, signal.t_stop.magnitude, delta_t)
t1_values = []
for t0 in t0_values:
    interval = [t0, t0+window]
    if interval[1] > signal.t_stop.magnitude:
        break
    irange = [int(t/Ts) for t in interval]
    islice = np.s_[irange[0]:irange[1]] # slice object
    
    # SVD and singular values
    u, s, vh = np.linalg.svd(signal_bpss[islice, :], full_matrices=True)
    lambas = s**2
    sigmas = lambas / lambas.sum()
    C = - 1./np.log(len(sigmas)) * np.sum(sigmas * np.log(sigmas))
    
    t1_values.append(interval[1])
    M_values.append(1 - C)

all_morgera[sig_label] = (zip(t0_values, t1_values), M_values)

In [None]:
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(t1_values, M_values)
ax.set_title('Morgera synchronization index [t-{}, t]'.format(window))
ax.set_ylabel('M (0-1)')
ax.set_xlabel('time (ms)')
ax.grid(True)
ax.set_ylim((0,1))
ax.set_xlim((0, signal.t_stop.magnitude))

## GPe Phase

In [None]:
# Calculate analytical signal -> phase
sig_label = "GPE_Vm"
signal = gpe_vm_signal

# Take 10 cells within shorter time interval
rec_dt = signal.sampling_period.magnitude
interval= [1e3, 26e3]
irange = [int(t/rec_dt) for t in interval]
islice = np.s_[irange[0]:irange[1]] # slice object

num_traces = 100
pslice = np.s_[0:num_traces]
traces_raw = signal[islice, pslice]

times = signal.times[islice]

### Phase Calculation

In [None]:
# Band-pass filter in frequency band of interest

# Elephant built-in filtering
# import elephant.signal_processing as sigproc
# signal_bp = sigproc.butter(signal, highpass_freq=20, lowpass_freq=30, order=3, filter_function='filtfilt')

# Manual band-pass filtering
Fs = signal.sampling_rate.rescale('Hz').magnitude
Fn = Fs / 2. # Nyquist frequency
# hpfreq, lpfreq = 5.0, 15.0
order = 4

assert hpfreq < lpfreq
low, high = hpfreq / Fn, lpfreq / Fn
b, a = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False)
sos = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False, output='sos')

# Check filter stability (otherwise -> NaN values)
filter_stable = np.all(np.abs(np.roots(a))<1)
if not filter_stable:
    # raise Exception("Unstable filter!")
    print("Filter in b, a form is unstable!")

# Plot filter response
w, h = scipy.signal.freqz(b, a, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
w, h = scipy.signal.sosfreqz(sos, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
angles = np.unwrap(np.angle(h))
fax = w * Fn / (np.pi)

fig, axes = plt.subplots(2, 1, sharex=True)
fig.suptitle("Filter response (2*pi = {})".format(Fs))
ax = axes[0]
ax.plot(fax, abs(h), 'b') # 20 * np.log10(abs(h))
ax.set_ylabel('Amplitude [dB]', color='b')

ax = axes[1] # ax2 = ax.twinx()
ax.plot(fax, angles, 'g')
ax.set_ylabel('Angle (radians)', color='g')

# plt.axis('tight')
ax.set_xlim((0, 50))
ax.set_xlabel('Frequency [Hz]')
ax.grid(True)

In [None]:
# Filter signal
data = np.asarray(signal)
# signal_bp = scipy.signal.filtfilt(b, a, data, axis=0) # can also use 'lfilter'
signal_bp = signal.duplicate_with_new_array(scipy.signal.sosfiltfilt(sos, data, axis=0))

In [None]:
# Subsample to preserve up to 2 x highest frequency
fmax = 2 * lpfreq
fs_old = signal.sampling_rate.rescale('Hz').magnitude 
subsample_factor = int(fs_old / (2 * fmax))
subsample_factor = min(20, subsample_factor)

print("Subsampling with factor {}".format(subsample_factor))
signal_bpss = signal_bp[::subsample_factor, :] # adjusts sampling period property
print("New sampling period is {}".format(signal_bpss.sampling_period))

# Adjust indices of intervals
Ts = signal_bpss.sampling_period.magnitude
irange_ss = [int(t/Ts) for t in interval]
islice_ss = np.s_[irange_ss[0]:irange_ss[1]] # slice object
times_ss = signal_bpss.times[islice_ss]

In [None]:
# Compute analytic signal - magnitude and phase
from scipy.signal import hilbert

traces_bp = signal_bpss[islice_ss, 0:num_traces]
traces_raw = signal[islice, 0:num_traces]

analytic_signal = hilbert(traces_bp, axis=0)
analytic_mag = np.abs(analytic_signal)
analytic_phase = np.angle(analytic_signal)
# NOTE: phases are already wrapped
# analytic_phase = np.unwrap(np.angle(analytic_signal), axis=0) # transform angle in interval (0, 2*pi)

In [None]:
# Check result : plot filtered signals
fig, axes = plt.subplots(3, 1, figsize=(0.75*page_width,2*ax_height), sharex=True, sharey=False)
fig.suptitle("Filtered & Analytic Signal")

trace_id = 1
tplot = (7500., 8500.)
iplot = np.s_[int(tplot[0]/rec_dt):int(tplot[1]/rec_dt)] # slice object
iplot_ss = np.s_[int(tplot[0]/Ts):int(tplot[1]/Ts)]

# Band-pass filtered trace
ax = axes[0]
ax.plot(times[iplot], traces_raw[iplot,trace_id], color='b', label='Vm raw')
ax.plot(times_ss[iplot_ss], traces_bp[iplot_ss, trace_id], color='g', label='Vm bandpass')
ax.set_ylabel('Vm raw & filtered')
ax.grid(True)
ax.legend()
# ax.set_ylim((-80, 25))

# Magnitude of analytic signal = amplitude envelope
ax = axes[1]
ax.plot(times_ss[iplot_ss], analytic_mag[iplot_ss, trace_id], label='magnitude')
ax.set_ylabel('|analytic| [mV]')
ax.grid(True)

# Phase of analytic signal
ax = axes[2]
ax.plot(times_ss[iplot_ss], analytic_phase[iplot_ss, trace_id], label='phase')
ax.grid(True)
ax.set_ylabel('angle(analytic) [rad]')
ax.set_xlabel('time ({})'.format(times.units))

### Phase Variability

In [None]:
# Sample polar representation of analytic signal at fixed time points in Beta cycle
f_beta_peak = 17.0 # SETPARAM: all_fpeak['STN_Vm']
print("Calculating phase vectors at multiples of {} Hz".format(f_beta_peak))

phase_trange = times.magnitude
phase_tstart, phase_tstop = phase_trange[0], phase_trange[-1]

# Random point in cycle of main Beta frequency is ok, as long as it's consistent
beta_trigger_t = np.arange(phase_tstart, phase_tstop, 1e3/f_beta_peak)
beta_trigger_i = [int((t-phase_tstart)/Ts) for t in beta_trigger_t]

# Phasor vectors for all traces at all trigger times
vec_magnitudes = analytic_mag[beta_trigger_i, :]
vec_phases = analytic_phase[beta_trigger_i, :]

In [None]:
# Instantaneous standard deviation of phases
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(times_ss[::20], analytic_phase[::20, :].std(axis=1))
ax.plot(beta_trigger_t, vec_phases.std(axis=1), 'r.')
ax.set_ylabel('sigma (radians)')
ax.set_xlabel('time ({})'.format(times.units))
ax.set_title('Standard deviation of STN population phases')
ax.grid(True)

### Phase Differences

In [None]:
# PLV (Phase Locking Value)
# Choose reference signal
ref_id = 0
phase_reference = analytic_phase[:, ref_id]
phase_diffs = np.unwrap(analytic_phase - phase_reference[:,np.newaxis], axis=0) / (2*np.pi)

fig, axes = plt.subplots(5, 1, figsize=(0.75*page_width, 5*ax_height))
for i, ax in enumerate(axes):
    if i == 0:
        ax.set_title('Phase difference with cell {}'.format(ref_id))
    ax.plot(times_ss, phase_diffs[:,(10*i):10*(i+1)])
    ax.set_ylabel('phase difference (normalized)')
    ax.set_xlabel('time ({})'.format(times_ss.units))
    # grid
    ax.set_yticks(np.arange(-10,10), minor=True)
    # ax.set_ylim((-5,5))
    ax.grid(which='both', linestyle=':') # for minor and major ticks

### SPIKE Synchronization

In [None]:
import pyspike
# Convert to PySpike data format
trains = []
for st in pops_segments['GPE'].spiketrains:
    trains.append(pyspike.SpikeTrain(st.times.magnitude, (st.t_start, st.t_stop)))

In [None]:
sync_profile = pyspike.spike_sync_profile(*trains)

In [None]:
# Plot spike syncronization profile
x, y = sync_profile.get_plottable_data()
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(x,y)
ax.set_xlim((10e3, 20e3))
ax.set_title('Kreuz SPIKE Synchronization for STN cells')
ax.set_ylabel('synchronization index')
ax.set_xlabel('time (ms)')

### Morgera Synchronization Index

In [None]:
M_values = []
delta_t = 200.0
window = 1000.0
t0_values = np.arange(0, signal.t_stop.magnitude, delta_t)
t1_values = []
for t0 in t0_values:
    interval = [t0, t0+window]
    if interval[1] > signal.t_stop.magnitude:
        break
    irange = [int(t/Ts) for t in interval]
    islice = np.s_[irange[0]:irange[1]] # slice object
    
    # SVD and singular values
    u, s, vh = np.linalg.svd(signal_bpss[islice, :], full_matrices=True)
    lambas = s**2
    sigmas = lambas / lambas.sum()
    C = - 1./np.log(len(sigmas)) * np.sum(sigmas * np.log(sigmas))
    
    t1_values.append(interval[1])
    M_values.append(1 - C)

all_morgera[sig_label] = (zip(t0_values, t1_values), M_values)

In [None]:
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
ax.plot(t1_values, M_values)
ax.set_title('Morgera synchronization index [t-{}, t]'.format(window))
ax.set_ylabel('M (0-1)')
ax.set_xlabel('time (ms)')
ax.grid(True)
ax.set_ylim((0,1))
ax.set_xlim((0, signal.t_stop.magnitude))

## CTX - STN

In [None]:
if WITHOUT_LFP:
    stn_target_signal = stn_vm_mean
else:
    stn_target_signal = stn_lfp_signal

### Coherence

In [None]:
# ctx_vm_mean = neo.AnalogSignal(ctx_vm_signal.sum(axis=1).reshape((-1,1)), units=ctx_vm_signal.units, 
#                               sampling_rate=ctx_vm_signal.sampling_rate, t_start=ctx_vm_signal.t_start)

# Copy signal metadata from CTX Vm signal
ctx_vm_mean = ctx_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(ctx_vm_mean, stn_target_signal, freq_res=0.5)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='CTX Vm - STN LFP')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between Cortex mean Vm - STN LFP')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

### Phase-Amplitude Coupling

In [None]:
# from tensorpac import Pac
# estimator = Pac(idpac=(1, 0, 0), fpha=(3, 33, 1, 1), famp=(20, 150, 5, 5),
#                 dcomplex='hilbert', filt='butter')

# Filter the data and extract PAC :
# xpac = estimator.filterfit(fs,
#                            xpha=stn_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            xamp=ctx_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            axis=0, traxis=1, njobs=2)
# xpac = estimator.filterfit(fs, 
#                            xpha=stn_lfp_signal.magnitude.reshape((-1,1)),
#                            xamp=ctx_vm_mean.magnitude.reshape((-1,1)),
#                            axis=0, traxis=1, njobs=2)

In [None]:
# Plot PAC
# estimator.comodulogram(xpac.mean(-1), title='PAC: STN LFP (phase) - CTX Vm (amplitude)',
#                        cmap='Spectral_r', plotas='imshow')

## STN - GPE

In [None]:
# Copy signal metadata from CTX Vm signal
gpe_vm_mean = gpe_vm_signal.duplicate_with_new_array(gpe_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(stn_target_signal, gpe_vm_mean, freq_res=0.5)

### Coherence

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='STN LFP - Avg GPe Vm')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between STN LFP - Average GPe Vm')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

# Save Notebook

In [None]:
# alternative: %notebook -e foo.ipynb
# from IPython.display import Javascript
# script = '''
# require(["base/js/namespace"],function(Jupyter) {
#     Jupyter.notebook.save_checkpoint();
# });
# '''
# Javascript(script)

In [None]:
%%javascript
require(["base/js/namespace"],function(Jupyter) {
    Jupyter.notebook.save_checkpoint();
});
// Jupyter.notebook.kernel.execute("notebook_name = " + "\'"+Jupyter.notebook.notebook_name+"\'");

In [None]:
# import os.path
thisfile = 'synchrony_analysis.ipynb'
# outfile = os.path.join(outputs, 'synchrony_analysis.html')
# Export using built-in template:
#!jupyter nbconvert --to html --template=full --output-dir=$outputs $thisfile
# NOTE: template comes from ToC2 Notebook Extension
!jupyter nbconvert $thisfile --template=toc2 --output-dir=$outputs

In [None]:
# Export analysis results
import pickle, os.path
outfile = os.path.join(outputs, 'analysis_results.pkl')
with open(outfile, 'wb') as fout:
    pickle.dump(exported_data, fout)
print("Saved analysis results to file: " + outfile)