# Setup

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from analysis import read_population_segments

## Function Definitions

## Load Data

In [None]:
# sim_outputs = "/home/luye/storage/2018.06.13_pop-100_dur-10s_job-776691_DA-control_CTX-SWA"
sim_outputs = [
    "/home/luye/storage/2018.06.16_pop-20_dur-1000.0_job-testmpi6/STN_2018.06.16_pop-20_dur-1000.0_job-testmpi6.mat"
]
pops_segments = read_population_segments(sim_outputs, False, read_ext=".mat")

# Spike Trains

In [None]:
num_pops = len(pops_segments)

# Plot spikes
fig_spikes, axes_spikes = plt.subplots(num_pops, 1, figsize=(10,14))
fig_spikes.suptitle('Spikes for each population')

i_pop = 0
pop_spike_colors = 'rgcbm'
for pop_label, segment in pops_segments.items():

    ax = axes_spikes[i_pop]
    for i_train, spiketrain in enumerate(segment.spiketrains):
        y = spiketrain.annotations.get('source_id', i_train)
        y_vec = np.ones_like(spiketrain) * y
        ax.plot(spiketrain, y_vec, marker='|', linestyle='', snap=True, color=pop_spike_colors[i_pop % 5])
        ax.set_ylabel(pop_label)

    i_pop += 1

plt.show(block=False)

# Raw Signals

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)

interval = None # [2000.0, 6000.0]
rec_dt = signal.sampling_period.magnitude
irange = [0, signal.shape[0]-1] if interval is None else [int(t/rec_dt) for t in interval]
times = signal.times[irange[0]:irange[1]]

fig, axes = plt.subplots(num_signals, 1, figsize=(10,8))
fig.suptitle("{} membrane voltage".format(pop_label))

for i_cell in range(num_signals):
    ax = axes[i_cell]
    if 'source_ids' in signal.annotations:
        label = "id {}".format(signal.annotations['source_ids'][i_cell])
    else:
        label = "cell {}".format(i_cell)
    
    sig = signal[irange[0]:irange[1], i_cell]

    ax.plot(times, sig, label=label)
    ax.grid(True)
    # ax.legend()
    
    if i_cell == num_signals-1:
        ax.set_ylabel("voltage ({})".format(signal.units))
        ax.set_xlabel('time ({})'.format(times.units))

## STN LFP

In [None]:
# Load each individual cell's LFP contribution
pop_label = 'STN'
segment = pops_segments[pop_label]
lfp_sigs = next((sig for sig in segment.analogsignals if sig.name == 'lfp'))
lfp_summed = lfp_sigs.sum(axis=1)

interval = None # (2000.0, 4000.0)
rec_dt = lfp_sigs.sampling_period.magnitude
irange = [0, lfp_sigs.shape[0]] if interval is None else [int(t/rec_dt) for t in interval]

times = lfp_sigs.times[irange[0]:irange[1]]
lfp_ranged = lfp_summed[irange[0]:irange[1]]

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(times, lfp_ranged, label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 5], label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 8], 'r.', ms=1, mew=1, label='{} LFP'.format(pop_label))

ax.set_ylabel('LFP magnitude ({})'.format(lfp_sigs.units))
ax.set_xlabel('time ({})'.format(times.units))
ax.set_title('LFP for {} population'.format(pop_label))
ax.legend()

# Power Spectrum

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

for pop_label, segment in pops_segments.items():
    print(pop_label, " has following signals:")
    for signal in segment.analogsignals:
        print(signal.name, signal.units, signal.description)

## CTX spikes PSD

## STN Vm PSD

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd = {}
all_psd[pop_label] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(freqs, psd_avg, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average STN membrane voltages')
ax.legend()

## GPe Vm PSD

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(all_psd['STN'][0], all_psd['STN'][1], 'r', label='STN PSD')
ax.plot(freqs, psd_avg, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average GPE membrane voltages')
ax.legend()

# Phase Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## CTX - STN Synchrony

## STN - GPE Synchrony

## CTX - STN Synchrony