# Setup

In [None]:
%matplotlib inline
import numpy as np
import scipy.signal
import matplotlib
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from bgcellmodels.common import units, analysis
units.set_units_module('quantities')
import quantities as pq

# Jupyter notebook extensions
%load_ext bgcellmodels.extensions.jupyter.skip_cell_extension

## Plotting Options

In [None]:
# Width of the page for calibrating fig_size.
# Approx. 16 for matplotlib backend %inline  8 for %notebook
if matplotlib.get_backend() == 'nbAgg':
    from bgcellmodels.extensions.jupyter import jupyterutil
    jupyterutil.notebook_show_figs_after_exception() # fix bug for notebook backend where figures not shown
    page_width = 10
else:
    page_width = 14
ax_height = 3

# Style of figures (default colors etc.): see https://matplotlib.org/gallery/style_sheets/style_sheets_reference.html
plt.style.use('default')

## Support Functions

In [None]:
def make_slice(interval, t_start, Ts):
    """
    Make slice object for addressing interval (a,b) in signal starting
    at t_start and recorded with sampling time Ts.
    """
    irange = [int((t-tstart)/Ts) for t in interval]
    return np.s_[irange[0]:irange[1]] # slice object

## Load Data

In [None]:
# All populations:
outputs = "/home/luye/storage/BBB_LuNetStnGpe/LuNetStnGpe_2018.11.09_17.42.20_job-CALDD_StnGpe_template_syn-V10"

# Single population:
# outputs = [
# "/home/luye/storage/2018.06.21_job-testmpi6_DA-control_CTX-beta/CTX_2018.06.21_pop-100_dur-50.0_job-testmpi6.mat",
# ]

if isinstance(outputs, str):
    filenames = os.listdir(outputs)
    # SETPARAM: choose time segment to load if split over files
    file_filter = lambda f: f.endswith('.mat') and '-2500' in f
    pop_files = [os.path.join(outputs, f) for f in filenames if file_filter(f)]
    params_file = next((os.path.join(outputs,f) for f in filenames if f.startswith('pop-parameters')), None)
else:
    pop_files = outputs
    params_file = None

pops_segments = {}
read_segment_id = 0

# Read binary files using Neo IO module
for pop_file in pop_files:
    reader = neo.io.get_io(pop_file)
    blocks = reader.read()
    assert len(blocks) == 1, "More than one Neo Block in file."
    pop_label = blocks[0].name

    if len(blocks[0].segments)-1 < read_segment_id:
        raise ValueError("Segment index greater than number of Neo segments"
                         " in file {}".format(pop_file))
    if pop_label in pops_segments:
        raise ValueError("Duplicate population labels in files")
        
    pops_segments[pop_label] = blocks[0].segments[read_segment_id]

# Extract some signal metadata
pop_labels = pops_segments.keys()
sim_dur = pops_segments['STN'].spiketrains[0].t_stop.magnitude

In [None]:
# Save all PSDs for comparison in figures
all_psd = {}          # tuple[list[float]: freqs, list[float]: psd)
all_psd_peaks = {}    # tuple[list[float]: peak_freqs, list[float] : peak_psd)
all_psd_sum_subband = {} # tuple[list[tuple[int,int]]: band_limits, float: sum_psd)
all_fpeak = {}        # tuple[list[float]: freqs, list[float]: psd)

all_signals = {}   # neo.AnalogSignal variables
all_mean_rate = {}
all_morgera = {}
all_spectrogram = {}
all_sigmean = {}

# Data that will be exported (subset of dicts above)
exported_data = {
    'mean_rate': all_mean_rate,
    'PSD': all_psd,
    'PSD_peaks': all_psd_peaks,
    'PSD_subband_power': all_psd_sum_subband,
    'PSD_input_freq': all_fpeak,
    'spectrogram': all_spectrogram,
    'Morgera_index': all_morgera,
}

all_cell_phase_vecs = exported_data.setdefault('cell_phase_vecs', {})
all_pop_phase_vecs = exported_data.setdefault('pop_phase_vecs', {})

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# - Each segment has attributes 'analogsignals' and 'spiketrains'
# - Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

# List all recorded signals
for pop_label, segment in pops_segments.items():
    print("\n{} has following signals:".format(pop_label))
    for signal in segment.analogsignals:
        print("\t- '{:10}\t[{}] - description: {}".format(signal.name, signal.units, signal.description))
    print("\t- {} spiketrains".format(len(segment.spiketrains)))

# Network Parameters

In [None]:
if params_file is not None:
    import pickle
    with open(params_file, 'rb') as pf:
        network_params = pickle.load(pf)
        print("Found parameters for following populations:\n"
              "{}\n".format(network_params.keys()))

for pre_pop, pre_params in network_params.iteritems():
    if not isinstance(pre_params, dict):
        print("Ignoring entry: '{}'\n".format(pre_pop))
        continue
    for post_pop, pre_post_params in pre_params.iteritems():
        if 'conpair_pvals' in pre_post_params:
            print("Parameters for connection {} -> {}:\n"
                  "{}\n{}\n".format(pre_pop, post_pop,
                                    pre_post_params['conpair_pnames'], 
                                    pre_post_params['conpair_pvals'][:,2:].max(axis=0)))

## Connection Matrices

In [None]:
from bgcellmodels.common import analysis

# Print STN-GPE connection matrix
for pre_pop, pre_params in network_params.iteritems():
    if not isinstance(pre_params, dict) or pre_pop=='CTX':
        continue
    for post_pop, pre_post_params in pre_params.iteritems():
        if 'conn_matrix' in pre_post_params:
            adjacency_mat = pre_post_params['conn_matrix']
            divx = max(3, adjacency_mat.shape[1]/10)
            divy = max(3, adjacency_mat.shape[0]/10)
            analysis.plot_connectivity_matrix(adjacency_mat, 
                                              pop0=pre_pop, pop1=post_pop,
                                              px=divx, py=divy,
                                              title='{} -> {}'.format(pre_pop, post_pop))

# for pops in [('STN', 'GPE'), ('GPE', 'STN'), ('GPE', 'GPE'), ('STN', 'STN')]:
#     stn_gpe_weights = network_params[pops[0]][pops[1]]['conn_matrix']
#     analysis.plot_connectivity_matrix(stn_gpe_weights, pop0=pops[0], pop1=pops[1], pop_size=100, seaborn=False)

# Spike Trains

<span style='color:red;font-weight:bold'>WARNING</SPAN>: In rastergram plots, note the number of spike trains plotted (see y-axis). If it is too high you get overlapping marker bars (marker height is larger than row height allocated to one spiketrain). This leads to misleading plots as spiketrains are overlapping which looks like an artificially elevated firing rate.

In [None]:
line_colors = 'crkgbm'
pop_color_map = {k: line_colors[i % len(line_colors)] for i,k in enumerate(sorted(pops_segments.keys()))}

def get_pop_color(pop_label):
    return pop_color_map[pop_label]

def get_pop_order(pop_label):
    if pop_label.startswith('CTX'):
        return 0
    elif pop_label.startswith('STR'):
        return 1
    elif pop_label.startswith('GPE'):
        return 2
    elif pop_label.startswith('STN'):
        return 3
    else:
        return 100

## Spike Statistics

In [None]:
import collections
pop_firing_rates = collections.OrderedDict()
for pop_label in sorted(pops_segments.keys(), key=get_pop_order):
    segment = pops_segments[pop_label]
    # Can use elephant.spike_train_generation.peak_detection() if only raw voltage signals
    pop_rate = 0.0
    for st in segment.spiketrains:
        pop_rate += elephant.statistics.mean_firing_rate(st).rescale('Hz').magnitude
    pop_rate = pop_rate / len(segment.spiketrains)
    pop_firing_rates[pop_label] = pop_rate
    
    all_mean_rate[pop_label] = pop_rate
    print("Mean firing rate for {} is {}".format(pop_label, pop_rate))

In [None]:

# Also plot target rates in DD and normal conditions
invivo_rates = { # (upper, lower) or (median-std, median+std)
    'CTX': {
        'CTL.Bergman2015-Li2012': [1.0, 5.0],
    },
    'STR.MSN': {
        'DD.Rat.KitaKita2011': [3.8, 9.4], # Kita & Kita (2011): baseline, peak is 140
        'CTL.Rat.KitaKita2011': [0.0, 1.65], # Kita Kita (2011)
    },
    'STR.FSI': {
        'CTL.Berke2004': [10.0, 100.0], # Berke (2004, 2008)
    },
    'GPE.proto': {
        'DD.Monkey.Nambu2014': [18.7, 63.7], # Nambu (2014): mean 41.2
        'CTL.Rat.Stoptask.Mallet2016': [41.3, 53.5],
        'CTL.Rat.Sleep-Act.Mallet2008': [32.4, 35.0],
        'DD.Rat.Sleep-Act.Mallet2008': [13.6, 14.6],
        # Nambu(2014): mean 65.2
    },
    'GPE.arky': {
        'CTL.Rat.Mallet2016': [7.0, 10.8], # rat, Mallet (2016)
    },
    'STN': {
        'CTL.Monkey.Bergman2015': [20.0, 25.0], # Bergman book
        'CTL.Rat.Sleep-Act.Mallet2008': [11.1, 16.6],
        'DD.Rat.Sleep-Act.Mallet2008': [25.8, 37.4],
        'DD.Human.Sharott2018': [33.0, 34.0],
    }
}

fig, ax = plt.subplots()

model_pops = [k for k in pop_firing_rates.keys() if 'surrogate' not in k]
index = np.arange(len(model_pops))
bar_width = 0.35
opacity = 0.4

rate_indicators = []
rate_labels = []
for i, pop_label in enumerate(model_pops):
    # Draw in vivo rates
    for j, src_rates in enumerate(invivo_rates.get(pop_label, {}).items()):
        label, bounds = src_rates
        line_kwargs = {'color': 'rgbmk'[j], 'linewidth': 1.0, 'linestyle': ':' if label.startswith('DD') else '-'}
        lines = ax.vlines(i-.02+j*.02, bounds[0], bounds[1], **line_kwargs)
        ax.hlines(bounds, i-bar_width/4, i+bar_width/4, **line_kwargs)
        rate_indicators.append(lines)
        rate_labels.append('{} - {}'.format(pop_label, label))
    
    # Draw simulation rates
    pop_rate = pop_firing_rates[pop_label]
    color = 'green' if (bounds[0] <= pop_rate <= bounds[1]) else 'red' # red bar if not within in-vivo range
    ax.bar(i, pop_rate, bar_width, alpha=opacity, color=color)


ax.set_xlabel('Population')
ax.set_ylabel('Mean firing rate (Hz)')
ax.set_title('Mean firing rate for each population')
ax.set_xticks(index)
ax.set_xticklabels(model_pops)
max_rate = max(pop_firing_rates.values())+10
ax.set_yticks(np.arange(0, max_rate+5, 5), minor=False)
ax.set_yticks(np.arange(0, max_rate+2.5, 2.5), minor=True)
ax.set_ylim((0, max_rate))
ax.grid(True, axis='y', which='major')
# ax.legend()

# Draw separate legend
fig, ax = plt.subplots()
fig.legend(rate_indicators, rate_labels, loc='upper center', frameon=False)
ax.axis('off')

## Running Spikerates

In [None]:
# %%skip False

from bgcellmodels.common import signal as spikeprocs

def plot_avg_spikerate(pop_label, cell_ids=None, t_range=None, bin_width=20.0, adaptive=False, fig=None):
    """
    Plot running mean firing rate of population.
    Bin width also determines bin spacing (no sliding window).
    """
    segment = pops_segments[pop_label]
    if cell_ids is None:
        cell_ids = range(len(segment.spiketrains))
    if t_range is None:
        t_range = (0.0, np.round(segment.spiketrains[0].t_stop.magnitude, 3))
    
    min_spikes = 10
    # Adaptive method increases bin adaptively until 'min_spikes' included
    if adaptive:
        mean_rates = spikeprocs.nrn_avg_rate_adaptive([segment.spiketrains[i] for i in cell_ids],
                                                      t_range[0], t_range[1], binwidth=bin_width,
                                                      minsum=min_spikes).as_numpy()
    else:
        mean_rates = spikeprocs.nrn_avg_rate_simple([segment.spiketrains[i] for i in cell_ids],
                                                    t_range[0], t_range[1],
                                                    binwidth=bin_width).as_numpy()
    if fig is None:
        fig, ax = plt.subplots(1, 1, figsize=(page_width,ax_height))
    else:
        nrows = len(fig.axes)+1
        ax = fig.add_subplot(nrows,1,nrows)
        
    t_axis = np.arange(mean_rates.size) * bin_width + t_range[0]
    ax.plot(t_axis, mean_rates, color=get_pop_color(pop_label))
    ax.grid(True)
    
    fig.suptitle('{} mean population firing rate (running average)'.format(pop_label))
    return fig, ax

for pop in pop_labels:
    plot_avg_spikerate(pop, bin_width=50)
#fig.tight_layout()

## Raster Plots

In [None]:
# SETPARAM: region of interest (ROI) for plotting
ROI_INTERVAL = (0.0, 2.5e3)

In [None]:
num_pops = len(pops_segments)
i_pop = 0

def plot_spiketrain(pop_label, cell_ids, t_range, sharex=None, sharey=None,
                    figsize=(page_width,ax_height), plot_compact=False):
    """
    Plot spiketrains for one population.
    
    @param    cell_ids : list(int)
              Cell indices in population that will be visible (y-axis constrainment).
    """
    global i_pop
    
    segment = pops_segments[pop_label]
    sim_dur = segment.spiketrains[0].t_stop.magnitude
    
    # Don't plot all rastergrams in same figure
    fig_spikes = plt.figure(figsize=figsize)
    ax = plt.subplot(1,1,1, sharex=sharex, sharey=sharey)
    # fig_spikes, ax = plt.subplot(1, 1, figsize=(page_width,ax_height), sharex=sharex)
    if not plot_compact:
        fig_spikes.suptitle('{} spiketrains'.format(pop_label))
    
    pop_size = len(segment.spiketrains)
    
    # Plot all spiketrains but constrain y-axis later (so you can pan & zoom)
    # for i_train in range(pop_size):
    #     y = spiketrain.annotations.get('source_id', i_train)
    # Only plot selected spike trains
    y_vals = np.empty(len(cell_ids))
    for j, i_train in enumerate(cell_ids):
        spiketrain = segment.spiketrains[i_train]
        y_vals[j] = i_train
        y_vec = np.ones_like(spiketrain) * y_vals[j]
        ax.plot(spiketrain, y_vec,
                marker='|', linestyle='', 
                snap=True, color=get_pop_color(pop_label))
    
    
    # ax.set_xticks(np.arange(0, sim_dur+5000, 5000), minor=False) # uncomment for long time range
    if plot_compact:
        ax.set_yticks(np.arange(min(y_vals), max(y_vals)+1, 1), minor=False)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    else:
        ax.set_yticks(np.arange(min(y_vals), max(y_vals)+5, 5), minor=False)
        ax.set_xticks(np.arange(0, sim_dur+1000, 1000), minor=False)
        ax.set_ylabel('{} cell #'.format(pop_label))
        
    ax.set_xlim(t_range)
    ax.set_ylim((min(y_vals)-0.5, max(y_vals)+0.5))
    ax.grid(True, axis='x', which='major')

    i_pop += 1
    return fig_spikes, ax

In [None]:
# Choose populations and cells indices to plot
shared_axis = None
for pop_label in sorted(pop_labels, key=get_pop_order):
    segment = pops_segments[pop_label]
    max_num_plot = 100
    num_spiketrains = min(len(segment.spiketrains), max_num_plot)
    # SETPARAM: choose cells to plot
    cell_indices = (0 + np.arange(num_spiketrains, dtype=int)) % len(segment.spiketrains)
    t_interval = ROI_INTERVAL # SETPARAM: time interval
    fig_height = ax_height*num_spiketrains/20.0
    # NOTE: don't use sharey when different population sizes
    fig, shared_axis = plot_spiketrain(pop_label, cell_indices, t_interval,
                                       sharex=shared_axis, sharey=None,
                                       figsize=(page_width, fig_height), plot_compact=False)

# plot_spiketrain('STN', range(20), (0.0, 5e3))

# Raw Signals

In [None]:
def plot_vm_signals(signal, cell_indices, interval, interval_only=True):
    
    rec_dt = signal.sampling_period.magnitude
    tstart = signal.t_start.magnitude
    irange = [0, signal.shape[0]-1] if interval is None else [int((t-tstart)/rec_dt) for t in interval]
    times = signal.times[irange[0]:irange[1]]
    
    fig, axes = plt.subplots(len(cell_indices), 1, 
                             figsize=(0.75*page_width,2*ax_height),
                             sharex=True, sharey=True)
    fig.suptitle("{} membrane voltage".format(pop_label))

    # Plot each Vm on separate axis
    for i_ax, i_cell in enumerate(cell_indices):
        try:
            ax = axes[i_ax]
        except TypeError:
            ax = axes
        if 'source_ids' in signal.annotations:
            label = "id {}".format(signal.annotations['source_ids'][i_cell])
        else:
            label = "cell {}".format(i_cell)
        
        if interval_only:
            ax.plot(times, signal[irange[0]:irange[1], i_cell], label=label)
        else:
            ax.plot(signal.times, signal[:, i_cell])
        
        ax.grid(True)
        ax.set_ylim((-90, 25))
        ax.set_xlim((times[0].magnitude, times[-1].magnitude))
        # ax.legend()

        if i_ax == len(cell_indices)-1:
            #ax.set_ylabel("voltage ({})".format(signal.units))
            ax.set_xlabel('time ({})'.format(times.units))

    fig.text(0.06, 0.5, "voltage ({})".format(signal.units), va='center', rotation='vertical')
    fig.subplots_adjust(bottom=0.15) # prevent clipping of xlabel

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
stn_vm_signal = signal

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = ROI_INTERVAL # SETPARAM: plot interval
cell_indices = [i+10 for i in range(num_signals)]
plot_vm_signals(stn_vm_signal, cell_indices, interval, interval_only=False)

## GPe Vm

In [None]:
pop_label = 'GPE.proto'
segment = pops_segments[pop_label]
gpe_vm_signal = signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = ROI_INTERVAL # [2000.0, 6000.0]
cell_indices = range(5) # range(num_signals)
plot_vm_signals(gpe_vm_signal, cell_indices, interval)

## CTX Artificial

Convolve spike times with stereotypical AP to obtain artificial voltage signal.

In [None]:
pop_label = 'CTX'
segment = pops_segments['CTX']
spiketrains = segment.spiketrains[0:100] # select subset of spiketrains

In [None]:
# Load Pyramidal cell action potential from saved recording
pyramidal_trace = np.loadtxt('../../KlmnNetMorpho/analysis/pyramidal.dat')
vm = pyramidal_trace[:,1]
tvec = pyramidal_trace[:,0]
dt = np.round(tvec[1] - tvec[0], 4)
ap_t_interval = [4.9, 20.0]
ap_i_interval = [int(t/dt) for t in ap_t_interval]
subsample = 2
ap_range = range(ap_i_interval[0], ap_i_interval[1]+1, subsample)
ap_kernel = vm[ap_range]
ap_baseline = -65.0
ap_kernel -= ap_baseline # center on 0 for convolution

plt.figure(figsize=(6,2))
plt.plot(tvec[ap_range], ap_kernel)
plt.suptitle('AP kernel for cortical neurons')
plt.grid(True)

In [None]:
# Construct AnalogSignal of N channels from N spiketrains
dur = np.round(spiketrains[0].duration.magnitude, 4)
tstart = np.round(spiketrains[0].t_start.magnitude, 4)
tstop = np.round(spiketrains[0].t_stop.magnitude, 4)
dt = stn_vm_signal.sampling_period.magnitude
signal_matrix = np.empty((int(dur/dt)+1, len(spiketrains)))

# Convolution operation for each spiketrain
for i, st in enumerate(spiketrains):
    
    time = np.arange(tstart, tstop + dt, dt)
    spiketimes = st.times.magnitude
    spike_pulses = np.zeros_like(time)
    spike_pulses[[int((t-tstart)/dt) for t in spiketimes]] = 1.0
    
    # Convole pulses at spike times with AP kernel
    spike_signal = np.convolve(spike_pulses, ap_kernel, mode='same') + ap_baseline
    signal_matrix[:, i] = spike_signal

# Resulting Neo signals
ctx_vm_signal = neo.AnalogSignal(signal_matrix, units='mV', sampling_rate=1/dt/pq.ms,
                                t_start=spiketrains[0].t_start, t_stop=spiketrains[0].t_stop)

# ctx_vm_mean = ctx_vm_signal.sum(axis=1) / ctx_vm_signal.shape[1]
ctx_vm_mean = signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)) / ctx_vm_signal.shape[1])
all_sigmean[pop_label + '_Vm'] = ctx_vm_mean

In [None]:
# Plot artificial signal to verify
max_num_plot = 10
num_signals = min(ctx_vm_signal.shape[1], max_num_plot)
interval = ROI_INTERVAL # [12.75e3, 14e3]
cell_indices = range(num_signals)
plot_vm_signals(ctx_vm_signal, cell_indices, interval, interval_only=False)

# Power Spectra

## PSD

In [None]:
def calc_psd(pop_label, t_start=1000.0, vm_sig=None):
    """
    Calculate PSD from membrane voltages of population.
    
    It computes the PSD for the individual Vm traces and then
    averages the resulting PSDs.
    """
    if vm_sig is None:
        segment = pops_segments[pop_label]
        vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

    # Computes PSD of all 100 Vm signals at the same time
    Ts = vm_sig.sampling_period.magnitude # ms
    fs = signal.sampling_rate.rescale('Hz').magnitude
    vm_segment = vm_sig[int(t_start/Ts):,:]
    dF = max(0.5, fs / vm_segment.shape[0]) # dF for max window size is finest possible
    if dF != 0.5:
        print("Adjusted frequency resolution to data size: dF = {}".format(dF))
        
    freqs, psd = elephant.spectral.welch_psd(vm_segment, freq_res=dF)
    psd_avg = psd.sum(axis=0) / psd.shape[0]
    psd_rel = psd_avg[0:int(250/dF)] # relevant region of psd

    # Save PSD
    all_psd[pop_label+'_Vm'] = (freqs, psd_avg)
    
    # Save units for other plotting functions
    global psd_units
    psd_units = psd.units
    
    return freqs, psd_avg


def plot_psd(freqs, psd_avg, pop_label):
    """
    Plots PSD on relative and absolute axis for comparison with others.
    """
    fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
    ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))

    ax.set_ylabel('Power ({})'.format(psd_avg.units.dimensionality))
    ax.set_xlabel('frequency ({})'.format(freqs.units.dimensionality))
    ax.set_xlim((0, 50))
    ax.grid(True)
    ax.set_title('Welch PSD for mean {} Vm'.format(pop_label))
    # ax.set_yscale('log')
    # ax.legend(loc='best')

    # PSD on shared scale for comparison with other simulations
    fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
    ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
    ax.set_ylabel('Power ({})'.format(psd_avg.units.dimensionality))
    ax.set_xlabel('frequency ({})'.format(freqs.units.dimensionality))
    ax.set_xlim((0, 50))
    ax.set_ylim((0, 10))
    ax.grid(True)
    ax.set_title('Welch PSD for mean {} Vm (absolute y-axis)'.format(pop_label))

### STN Vm

In [None]:
pop_label = 'STN'
# sig_label = 'STN_Vm'
# segment = pops_segments[pop_label]
# vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

freqs, psd_avg = calc_psd(pop_label)

plot_psd(freqs, psd_avg, pop_label)

### CTX spikes

In [None]:
pop_label = 'CTX'
# segment = pops_segments[pop_label]

freqs, psd_avg = calc_psd(pop_label, vm_sig=ctx_vm_signal)

plot_psd(freqs, psd_avg, pop_label)

### GPE.proto

In [None]:
pop_label = 'GPE.proto'
# segment = pops_segments[pop_label]

freqs, psd_avg = calc_psd(pop_label)

plot_psd(freqs, psd_avg, pop_label)

## Spectrogram (STFT)

In [None]:
def calc_spectrogram(pop_label, signal=None):
    """
    Calculate spectrogram (STFT) for mean signal.
    
    The resulting time axis is missing 'nperseg' values on each side since
    each PSD sample is calculated on the interval [-nperseg/2, +nperseg/2]
    around it.
    
    @return    freqs, t, Sxx
               Sxx has t-dimension along axis 0 and f-dimension along axis 1.
    """
    if signal is None:
        segment = pops_segments[pop_label]
        signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
    
    # Compute mean Vm signal
    sigmean = signal.duplicate_with_new_array(signal.sum(axis=1).reshape((-1,1)) / signal.shape[1])
    all_sigmean[pop_label + '_Vm'] = sigmean

    # Plot spectrogram using STFT
    dt = signal.sampling_period.magnitude
    fs = 1/dt*1e3
    freq_res = 1.0
    nperseg = int(fs/freq_res) # determines frequency resolution
    t_res = 20.0 # ms
    noverlap = nperseg - int(t_res/dt)
    freqs, t, Sxx = scipy.signal.spectrogram(sigmean.ravel(), 1/dt, window='hanning',
                                             nperseg=nperseg, noverlap=noverlap, scaling='density')
    freqs = freqs * 1000
    t = t + signal.t_start.rescale('ms').magnitude

    # Save spectrogram
    df = freqs[1]-freqs[0]
    all_spectrogram[pop_label + '_Vm'] = (freqs[0:int(50/df)], t, Sxx[:,0:int(50/df)])
    
    return freqs, t, Sxx

def plot_spectrogram(freqs, t, Sxx, pop_label, f_max=100.0, P_max=None):
    # Spectrogram 1
    fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))
    # cmap = 'viridis' / 'plasma' / 'jet'
    plt.pcolormesh(t, freqs, Sxx, cmap='viridis')
    # plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
    plt.ylim((0, f_max))
    plt.colorbar()
    # plt.clim(0, 20)
    if P_max is None:
        P_max = abs(Sxx[:, int(1000.0/(t[1]-t[0])):]).max() # discard first second
    plt.clim(0, P_max)
    plt.ylabel('frequency (Hz)')
    plt.xlabel('time (ms)')
    ax.set_title('Spectrogram of mean {} Vm ({})'.format(pop_label, psd_units.dimensionality))

### STN

In [None]:
freqs, t, Sxx = calc_spectrogram('STN')
plot_spectrogram(freqs, t, Sxx, 'STN', f_max=50.0)

### GPE.proto

In [None]:
freqs, t, Sxx = calc_spectrogram('GPE.proto')
plot_spectrogram(freqs, t, Sxx, 'GPE.proto')

### STR.MSN

In [None]:
freqs, t, Sxx = calc_spectrogram('STR.MSN')
plot_spectrogram(freqs, t, Sxx, 'STR.MSN')

# Signal Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## CTX Phase

In [None]:
# Calculate analytical signal -> phase
sig_label = 'CTX_Vm'
signal = ctx_vm_signal
sigmean = ctx_vm_mean

# Take 10 cells within shorter time interval
phase_tstart = tstart = np.round(signal.t_start.magnitude, 4)
phase_tstop = tstop = np.round(signal.t_stop.magnitude, 4)
rec_dt = signal.sampling_period.magnitude
interval= [tstart, tstop]
irange = [int((t-tstart)/rec_dt) for t in interval]
islice = np.s_[irange[0]:irange[1]] # slice object

num_traces = 100
pslice = np.s_[0:num_traces]
traces_raw = signal[islice, pslice]

times = signal.times[islice]

### Filtering

In [None]:
# Band-pass filter in frequency band of interest

# Elephant built-in filtering
# import elephant.signal_processing as sigproc
# signal_bp = sigproc.butter(signal, highpass_freq=20, lowpass_freq=30, order=3, filter_function='filtfilt')

# Manual band-pass filtering
Fs = signal.sampling_rate.rescale('Hz').magnitude
Fn = Fs / 2. # Nyquist frequency

# WARNING: phase of filter must be NEUTRAL at target frequency!!! -> CENTER on target
hpfreq, lpfreq = 16.0, 24.0 # SETPARAM: passband of BP filter
order = 4

assert hpfreq < lpfreq
low, high = hpfreq / Fn, lpfreq / Fn
b, a = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False)
sos = scipy.signal.butter(order, [low, high], btype='bandpass', analog=False, output='sos')
# sos = scipy.signal.butter(order, high, btype='lowpass', analog=False, output='sos')

# Check filter stability (otherwise -> NaN values)
filter_stable = np.all(np.abs(np.roots(a))<1)
if not filter_stable:
    # raise Exception("Unstable filter!")
    print("Filter in b, a form is unstable!")

# Plot filter response
w, h = scipy.signal.freqz(b, a, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
w, h = scipy.signal.sosfreqz(sos, np.linspace(0, np.pi, 2**np.ceil(np.log2(Fn))))
angles = np.unwrap(np.angle(h))
fax = w * Fn / (np.pi)

fig, axes = plt.subplots(2, 1, sharex=True)
fig.suptitle("Filter response (2*pi = {})".format(Fs))
ax = axes[0]
ax.plot(fax, abs(h), 'b') # 20 * np.log10(abs(h))
ax.set_ylabel('Amplitude [dB]', color='b')

ax = axes[1] # ax2 = ax.twinx()
ax.plot(fax, angles, 'g')
ax.set_ylabel('Angle (radians)', color='g')

# plt.axis('tight')
ax.set_xlim((0, 50))
ax.set_xlabel('Frequency [Hz]')
ax.grid(True)

In [None]:
# Filter signal
data = np.asarray(signal)
# signal_bp = scipy.signal.filtfilt(b, a, data, axis=0) # can also use 'lfilter'
signal_bp = signal.duplicate_with_new_array(scipy.signal.sosfiltfilt(sos, data, axis=0))
sigmean_bp = signal.duplicate_with_new_array(scipy.signal.sosfiltfilt(sos, np.asarray(sigmean), axis=0))

In [None]:
# Subsample signals for efficiency

# Traces: subsample to preserve up to 2 x highest frequency
fmax = 2 * lpfreq
fs_old = signal.sampling_rate.rescale('Hz').magnitude 
subsample_factor = int(fs_old / (2 * fmax))
subsample_factor = min(20, subsample_factor)

print("Subsampling with factor {}".format(subsample_factor))
signal_bpss = signal_bp[::subsample_factor, :] # adjusts 'sampling period' attribute
print("New sampling period is {}".format(signal_bpss.sampling_period))

# Adjust indices of intervals
Ts = signal_bpss.sampling_period.magnitude
irange_ss = [int((t-tstart)/Ts) for t in interval]
islice_ss = np.s_[irange_ss[0]:irange_ss[1]] # slice object
times_ss = signal_bpss.times[islice_ss]

# Reference signal: must have small enough time resolution to look up spike phases
ss_factor_ref = int(0.25 / signal.sampling_period.magnitude) # 0.25 ms resolution: 1/200 of 20 Hz beta
print("Subsampling with factor {}".format(ss_factor_ref))
sigmean_bpss = sigmean_bp[::ss_factor_ref, :]
print("New sampling period is {}".format(sigmean_bpss.sampling_period))
Ts_ref = sigmean_bpss.sampling_period.magnitude

### Hilbert Transform

In [None]:
# Compute analytic signal - magnitude and phase
from scipy.signal import hilbert

traces_bp = signal_bpss[islice_ss, 0:num_traces]
traces_raw = signal[islice, 0:num_traces]

analytic_signal = hilbert(traces_bp, axis=0)
analytic_mag = np.abs(analytic_signal)
analytic_phase = np.angle(analytic_signal)

analsig_sigmean = hilbert(sigmean_bpss, axis=0)
analmag_sigmean = np.abs(analsig_sigmean)
analphase_sigmean = np.angle(analsig_sigmean)

# NOTE: phases are already wrapped
# analytic_phase = np.unwrap(np.angle(analytic_signal), axis=0) # transform angle in interval (0, 2*pi)

In [None]:
# Check result : plot filtered signals
def plot_analytic_signal(signal, bandpass, magnitude, phase, interval):
    fig, axes = plt.subplots(3, 1, figsize=(0.75*page_width,2*ax_height), sharex=True, sharey=False)
    fig.suptitle("Vmean: Filtered & Analytic Signal")

    tplot = (tstop-1000, tstop) # SETPARAM: plot interval
    iplot = make_slice(tplot, tstart, rec_dt)
    iplot_ss = make_slice(tplot, tstart, Ts)
    iplot_ref = make_slice(tplot, tstart, Ts_ref)

    # Band-pass filtered trace
    ax = axes[0]
    ax.plot(signal.times[iplot], signal[iplot], color='b', label='Vref')
    ax.plot(bandpass.times[iplot_ref], bandpass[iplot_ref] + signal.max(), color='g', label='Vref_BP')
    ax.set_ylabel('Vm (filtered)')
    ax.grid(True)
    ax.legend()
    # ax.set_ylim((-80, 25))

    # Magnitude of analytic signal = amplitude envelope
    ax = axes[1]
    ax.plot(bandpass.times[iplot_ref], magnitude[iplot_ref], label='magnitude')
    ax.set_ylabel('|analytic| [mV]')
    ax.grid(True)

    # Phase of analytic signal
    ax = axes[2]
    ax.plot(bandpass.times[iplot_ref], phase[iplot_ref], label='phase')
    ax.grid(True)
    ax.set_ylabel('angle(analytic) [rad]')
    ax.set_xlabel('time ({})'.format(times.units))

In [None]:
plot_ival = (tstop-1000, tstop) # SETPARAM: plot interval
plot_analytic_signal(sigmean, sigmean_bpss, analmag_sigmean, analphase_sigmean, plot_ival)

In [None]:
# Analytic signal for one representative trace
tid = 1
plot_analytic_signal(traces_raw[:,tid], traces_bp[:,tid], 
                     analytic_mag[:,tid], analytic_phase[:,tid],
                     plot_ival)

### Phase Vectors

In [None]:
# Load the burst intervals from the simulation config file
from bgcellmodels.common import fileutils
sim_config = fileutils.parse_json_file(os.path.join(outputs, 'sim_config.json'), nonstrict=True)
burst_intervals = sim_config['CTX']['burst_intervals']

In [None]:
# Normalize and sample the complex analytic signal
# analsig_norm = analsig_sigmean / np.sqrt(analsig_sigmean.real**2 + analsig_sigmean.imag**2)
# analsig_norm = analsig_sigmean / np.sqrt(analsig_sigmean * analsig_sigmean.conjugate())
analsig_norm = analsig_sigmean / np.abs(analsig_sigmean)

def calc_mean_phase_vectors(spiketrains):
    """
    Calculate mean phase vector of spikes with reference to given analytic signal
    (e.g. BP filtered + Hilbert transformed).
    """
    # Gather spiketimes that fall within intervals of cortical beta bursts
    spikes_during = [] # list of numpy array
    for i, st in enumerate(spiketrains):
        spiketimes = st.magnitude
        mask = np.zeros_like(spiketimes, dtype=bool)
        for ival in burst_intervals:
            mask = mask | ((spiketimes > ival[0]) & (spiketimes <= ival[1]))
        spikes_during.append(spiketimes[mask])


    Ts = sigmean_bpss.sampling_period.magnitude
    t_start = sigmean_bpss.t_start.magnitude
    mean_phase_vecs = []
    for i, spiketimes in enumerate(spikes_during):
        analsig_indices = np.round((spiketimes-t_start)/ Ts).astype(int) # indices into analytic signal
        if analsig_indices.size > 0:
            mean_phase_vecs.append(np.mean(analsig_norm[analsig_indices])) # mean of normalized phase vectors
        else:
            mean_phase_vecs.append(np.array([0 + 0j]))
    mean_phase_vecs = np.array(mean_phase_vecs)
    pop_phase_vec = np.mean(mean_phase_vecs)
    return mean_phase_vecs, pop_phase_vec


def plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label):
    """
    Plot mean phase vectors for individual neurons and whole population
    in Polar coordinates.
    """
    # Save phase vectors for export
    exported_data['cell_phase_vecs'][pop_label] = mean_phase_vecs
    exported_data['pop_phase_vecs'][pop_label] = pop_phase_vec
    
    ax = plt.subplot(111, projection='polar')

    vec_angs = np.angle(mean_phase_vecs)
    vec_lens = np.abs(mean_phase_vecs)
    rmax = (vec_lens.max() // 0.1 + 1) * 0.1

    # ax.plot(vec_angs, vec_lens, 'ro')
    ax.vlines(vec_angs, 0, vec_lens, color='g', linewidth=1, snap=True)
    ax.vlines(np.angle(pop_phase_vec), 0, np.abs(pop_phase_vec), color='r', linewidth=3)
    ax.plot(vec_angs, np.zeros_like(vec_angs)+rmax, 'go', markersize=5)

    # Format axes
    ax.grid(True)
    ax.set_rticks(np.arange(0.1, 1.1, 0.1)) # less radial ticks
    # ax.set_rticks(np.arange(0.1, int(vec_lens.max()/0.1+1)*0.1, 0.1)) # less radial ticks
    ax.set_rmax(rmax)
    # ax.set_rlabel_position(-90.0)  # Move radial labels away from plotted line
    ax.set_title('Mean angle and vector length of {} neurons'.format(pop_label), va='bottom')

    # kw = dict(arrowstyle="->", color='g')
    # for angle, radius in zip(phases, magnitudes):
    #     ax.annotate("", xy=(angle, radius), xytext=(0, 0), arrowprops=kw)

## Phase Vector Length

To test the degree of phase locking to the reference signal (mean CTX Vm), we compare the STN spike times to the analytic signal of the reference.

See method description Sharott et al. (2016) and (2018)

- Hilbert transform on reference signal
- for each spike of each neuron: save the phase vector at spike time (angle + magnitude of Hilbert transform)
- for mean vector length: 
    + it does not make sense to use the magnitude of the Hilbert transform since this is for the reference signal. 
    + It makes more sense to normalize each vector to length one (scale each component by vectorn norm). 
    + As described in [this nature article](https://www.nature.com/articles/srep35135#methods-and-materials): *"The circular mean of the spike phases was calculated by taking the weighted sum of the cosine and sine of the angles, finally resulting in, the mean angle and mean vector length (R) over the number of spikes"*

### STN

In [None]:
pop_label = 'STN'
segment = pops_segments['STN']
spiketrains = segment.spiketrains[:] # select subset of spiketrains

# Calculate neuron and population phase vectors
mean_phase_vecs, pop_phase_vec = calc_mean_phase_vectors(spiketrains)

# Show in polar plot
plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label)

### GPE.proto

In [None]:
pop_label = 'GPE.proto'
spiketrains = pops_segments[pop_label].spiketrains[:] # select subset of spiketrains

# Calculate neuron and population phase vectors
mean_phase_vecs, pop_phase_vec = calc_mean_phase_vectors(spiketrains)

# Show in polar plot
plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label)

### GPE.arky

In [None]:
pop_label = 'GPE.arky'
spiketrains = pops_segments[pop_label].spiketrains[:] # select subset of spiketrains

# Calculate neuron and population phase vectors
mean_phase_vecs, pop_phase_vec = calc_mean_phase_vectors(spiketrains)

# Show in polar plot
plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label)

### STR.MSN

In [None]:
pop_label = 'STR.MSN'
spiketrains = pops_segments[pop_label].spiketrains[:] # select subset of spiketrains

# Calculate neuron and population phase vectors
mean_phase_vecs, pop_phase_vec = calc_mean_phase_vectors(spiketrains)

# Show in polar plot
plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label)

### STR.FSI

In [None]:
pop_label = 'STR.FSI'
spiketrains = pops_segments[pop_label].spiketrains[:] # select subset of spiketrains

# Calculate neuron and population phase vectors
mean_phase_vecs, pop_phase_vec = calc_mean_phase_vectors(spiketrains)

# Show in polar plot
plot_phase_vectors(mean_phase_vecs, pop_phase_vec, pop_label)

## Coherence

### STN

In [None]:
# ctx_vm_mean = neo.AnalogSignal(ctx_vm_signal.sum(axis=1).reshape((-1,1)), units=ctx_vm_signal.units, 
#                               sampling_rate=ctx_vm_signal.sampling_rate, t_start=ctx_vm_signal.t_start)

# Copy signal metadata from CTX Vm signal
# ctx_vm_mean = ctx_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(all_sigmean['CTX_Vm'],
                                                             all_sigmean['STN_Vm'],
                                                             freq_res=0.5)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='CTX Vm - STN LFP')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between Cortex mean Vm - STN LFP')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

# Synaptic - Ionic Currents

In [None]:
# Make trigger signal for Beta based on zero-crossings of Beta phase
phase_pos = np.array(analphase_sigmean.ravel() >= 0, dtype=float)
phase_zero_idx = np.where(np.diff(phase_pos) == 1)[0] + 1 # zero crossings in UP direction
phase_zero_times = sigmean_bpss.times.magnitude[phase_zero_idx] # use Hilbert transformed signal

In [None]:
# FIXME: temporary variables if Hilbert not calculated
phase_pos = None
phase_zero_idx = None
phase_zero_times = None

In [None]:
# Function definitions for plotting
import re, collections

def get_synapse_index(trace_name):
    matches = re.search(r'(?P<index>\d+)$', trace_name)
    syn_index = matches.group('index')
    assert syn_index is not None
    return int(syn_index)


def sorted_signals(segment, trace_name):
    return sorted([sig for sig in segment.analogsignals if sig.name.startswith(trace_name)],
                  key=lambda sig: get_synapse_index(sig.name))


def plot_signal_interval(ax, signal, interval, channels, **kwargs):
    """
    Plot neo.AnalogSignal in given interval.
    """
    rec_dt = signal.sampling_period.magnitude
    tstart = signal.t_start.magnitude
    sim_dur = signal.t_stop.magnitude
    irange = [int((t-tstart)/rec_dt) for t in interval]
    times = signal.times[irange[0]:irange[1]]
    ax.plot(times, signal[irange[0]:irange[1], channels], **kwargs)

    
def plot_synapse_traces(pop_label, max_ncell, max_nsyn, interval, interval_only=True,
                       channel_gates=None, channel_currents=None, beta_phase=True,
                       trace_names=None):
    """
    Plot Synaptic dynamics for specific cells and their recorded synapses.
    
    Typically both the cells in the population and their synapses have
    been sampled.
    
    
    @pre    If arguments 'channel_gates' and 'channel_currents' are given,
            the signals they refer to are recorded from the same cells
            as the synaptic currents and conductances. I.e. the indices
            into channel-related signals are the same as those into
            the synapse-related signals.
    """
    segment = pops_segments[pop_label]
    # NOTE: signals are akwardly ordered: one signal is the i-th synapse for all recorded cells
    default_tracenames = 'gAMPA', 'gNMDA', 'iGLU', 'gGABAA', 'gGABAB', 'iGABA'
    if trace_names is None:
        trace_names = default_tracenames
    all_synaptic_sigs = {tn: sorted_signals(segment, tn) for tn in trace_names}
    trace_groups = {k:'GABA' if ('GABA' in k) else 'GLU' for k in all_synaptic_sigs}
    signal = next((sigs[0] for sigs in all_synaptic_sigs.values() if len(sigs)>0), None)
    if signal is None:
        print("No synaptic traces for population {}".format(pop_label))
        return

    # index in all_synaptic_sigs of the synaptic traces that exist (are recorded)
    # existing_traces = [i for i,sigs in enumerate(all_synaptic_sigs.values()) if len(sigs)>0]
    existing_traces = [sig_name for sig_name, sigs in all_synaptic_sigs.items() if len(sigs)>0]
    num_ax_per_cell = len(existing_traces)
    num_cell = min(signal.shape[1], max_ncell)
    
    # Select synapses to plot.
    selected_synapses = {'GLU': [], 'GABA': []} # trace suffixes of selected synapses
    for group in selected_synapses.keys():
        tname = next((n for n in existing_traces if trace_groups[n] == group), None)
        if tname is None:
            continue
        for j_sig, sig in enumerate(all_synaptic_sigs[tname]):
            if j_sig >= max_nsyn:
                break
            selected_synapses[group].append(get_synapse_index(sig.name))

    # Get signal time data
    rec_dt = signal.sampling_period.magnitude
    tstart = signal.t_start.magnitude
    tstop = signal.t_stop.magnitude
    if interval is None:
        interval = (tstart, tstop)
    irange = [int((t-tstart)/rec_dt) for t in interval]
    times = signal.times[irange[0]:irange[1]]

    # Make the figure
    num_axes = num_cell * num_ax_per_cell
    fig, axes = plt.subplots(num_axes, 1, 
                             figsize=(0.75*page_width, num_axes*ax_height),
                             sharex=True)
    # fig.suptitle("{} synapse dynamics".format(pop_label))

    # For each cell we try to plot all synapses of one type on one axis
    for i_cell in xrange(num_cell):
        # One axis for all iGLU, one for all iGABA, one for each conductance type.
        # This makes a maximum of 6 axes per cell
        for i_plotted, tracename in enumerate(existing_traces):
            i_ax = (i_cell * num_ax_per_cell) + i_plotted
            try:
                ax = axes[i_ax]
            except TypeError:
                ax = axes
            
            # Plot all synapses for this axis (same conductance or current)
            for j_sig, sig in enumerate(all_synaptic_sigs[tracename]):
                i_syn = get_synapse_index(sig.name)
                if i_syn in selected_synapses[trace_groups[tracename]]:
                    label = None # if j_sig>0 else ax_tracename
                    if interval_only:
                        ax.plot(times, sig[irange[0]:irange[1], i_cell], label=label)
                    else:
                        ax.plot(signal.times, sig[:, i_cell], label=label)
            
            # Plot Beta trigger signal (zero phase)
            if beta_phase:
                ymin, ymax = ax.get_ylim()
                ax.vlines(phase_zero_times, ymin, ymax, label='$\phi$ = 0',
                          colors='black', linestyle='dashed', linewidths=0.5)
            
            # NOTE: cell index -> see recorder._get_current_segment() -> should save source_ids/channel_ids
            # TODO: plot spike times or Vm of source_indices in same plot
            ax_r = None
            
            # Plot channel gating vars if we are plotting conductance
            if channel_gates and tracename.startswith('g'):
                ax_r = ax.twinx()
                gating_sigs = [sig for sig in segment.analogsignals if sig.name in channel_gates]
                for k, csig in enumerate(gating_sigs):
                    color, style = analysis.pick_line_options('red', 'broken', k)
                    plot_signal_interval(ax_r, csig, interval, i_cell, label=csig.name,
                                         color=color, linestyle=style)
                ax_r.legend()
                ax_r.set_ylabel('open')
            
            # Plot channel currents if we are plotting synaptic currents
            if channel_currents and tracename.startswith('i'):
                ax_r = ax.twinx()
                curr_sigs = [sig for sig in segment.analogsignals if sig.name in channel_currents]
                for k, csig in enumerate(curr_sigs):
                    color, style = analysis.pick_line_options('red', 'broken', k)
                    plot_signal_interval(ax_r, csig, interval, i_cell, label=csig.name,
                                         color=color, linestyle=style)
                ax_r.legend()
                ax_r.set_ylabel('current ($mA/cm^2$)')
            
            # Annotation and axes
            ax.grid(True, axis='y')
            if tracename.startswith('i'):
                ax.set_ylabel('current (nA)')
            else:
                ax.set_ylabel('conductance (uS)')
            if i_plotted == 0:
                ax.set_title('{} cell {}'.format(pop_label, i_cell))
            # ax.set_xlabel('time (ms)')
            ax.set_ylim((ymin, ymax))
            ax.set_xlim((times[0].magnitude, times[-1].magnitude))
            ax.legend()

## STN Synapse Dynamics

In [None]:
# Plot synaptic conductances and currents for selected populations
t_interval = ROI_INTERVAL # SETPARAM: time interval
stn_segment = pops_segments['STN']

plot_synapse_traces('STN', max_ncell=5, max_nsyn=2,
                    channel_gates=[sig.name for sig in stn_segment.analogsignals if (
                                   sig.name.startswith('STN_CaT'))],
                    trace_names=['i_NR2A', 'i_GLU', 'iGABA'],
                    interval=t_interval, interval_only=True)

## STN Balance Exc/Inh

In [None]:
# We can get a rough estimate of the balance of excitation / inhibition
# by integrating the recorded synaptic currents for both excitatory and
# inhibitory synapses, and then multiplying each with the ratio of actual
# over recorded synapses.

def sum_total_current(currents_traces, cell_idx):
    """
    Sum synaptic currents for all recorded synapses of a given cell.
    """
    itot = 0.0
    for current, traces in currents_traces.items():
        # traces is all signals (synapses) recorded from the same synapse type
        for sig in traces:
            isyn = sig.magnitude[:, cell_idx]
            itot += isyn.sum()
    return itot


def calc_exc_inh_ratio(pop_label, exc_currents, inh_currents, cell_idx=0):
    """
    @param    cell_idx : int
              which cell to use out of all recorded cells
    """
    segment = pops_segments[pop_label]
    
    def afferents_total_current(afferents_currents):
        """
        @param    afferents_currents : dict[str, list(str)]
                  Map afferent population labels to recorded synaptic current names
        """
        itot = 0.0 # total current for all given afferents
        for afferent_pop in afferents_currents.keys():
            # Afferent currents, by current type
            aff_traces = {curr: sorted_signals(segment, curr) for curr in afferents_currents[afferent_pop]}
            num_rec_aff = len(aff_traces.values()[0]) # number of synapses for afferent population
            itot_aff = sum_total_current(aff_traces, cell_idx) # sum total of all recorded currents for afferent

            # Sum number of afferents (matrix column), should be same for each cell
            num_syn_aff = sum(network_params[afferent_pop][pop_label]['conn_matrix'][:, 0] > 0)

            # Multiply total currents by ratio of recorded to afferent synapses
            itot += itot_aff * num_syn_aff / num_rec_aff
        return itot
    
    itot_exc = afferents_total_current(exc_currents) # excitatory currents are negative by convention
    itot_inh = afferents_total_current(inh_currents)
    
    ratio = -1.0 * itot_exc / itot_inh
    print("{}: Ratio of integrated current EXC / INH = {:.2f}".format(pop_label, ratio))
    return ratio

In [None]:
exc_currents = {'CTX': ('i_NR2A', 'i_GLU')}
inh_currents = {'GPE.all': ('iGABA',)}

stn_exh_inh_ratio = calc_exc_inh_ratio('STN', exc_currents, inh_currents)

## GPe Synapse Dynamics

In [None]:
# Plot synaptic conductances and currents for selected populations
t_interval = ROI_INTERVAL # SETPARAM: time interval

plot_synapse_traces('GPE.proto', max_ncell=5, max_nsyn=2,
                    interval=t_interval, interval_only=True,
                    trace_names=['iGLU', 'iGABA', 'iGABAA']) # iGABAA is from STR.MSN

# Save Notebook

In [None]:
%%javascript
require(["base/js/namespace"],function(Jupyter) {
    Jupyter.notebook.save_checkpoint();
});
// Jupyter.notebook.kernel.execute("notebook_name = " + "\'"+Jupyter.notebook.notebook_name+"\'");

In [None]:
# import os.path
thisfile = 'lunet_stn-gpe-str_analysis.ipynb'
# outfile = os.path.join(outputs, 'synchrony_analysis.html')
# Export using built-in template:
#!jupyter nbconvert --to html --template=full --output-dir=$outputs $thisfile
# NOTE: template comes from ToC2 Notebook Extension
!jupyter nbconvert $thisfile --template=toc2 --output-dir=$outputs

In [None]:
# Export analysis results
import pickle, os.path
outfile = os.path.join(outputs, 'analysis_results.pkl')
with open(outfile, 'wb') as fout:
    pickle.dump(exported_data, fout)
print("Saved analysis results to file: " + outfile)