# Setup

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from common import units
units.set_units_module('quantities')
from analysis import read_population_segments

## Function Definitions

## Load Data

In [None]:
# All populations:
outputs = "/home/luye/storage/2018.06.18_job-776936_DA-depleted_CTX-beta/"

# Single population:
# outputs = ["/home/luye/storage/2018.06.18_job-776933_DA-control_CTX-SWA/STN_2018.06.18_pop-100_dur-10000.0_job-776933.sonic-head.mat"]

filenames = os.listdir(outputs)
pop_files = [os.path.join(outputs, f) for f in filenames if f.endswith('.mat')]

pops_segments = {}
read_segment_id = 0

# Read binary files using Neo IO module
for pop_file in pop_files:
    reader = neo.io.get_io(pop_file)
    blocks = reader.read()
    assert len(blocks) == 1, "More than one Neo Block in file."
    pop_label = blocks[0].name

    if len(blocks[0].segments)-1 < read_segment_id:
        raise ValueError("Segment index greater than number of Neo segments"
                         " in file {}".format(pop_file))
    if pop_label in pops_segments:
        raise ValueError("Duplicate population labels in files")
        
    pops_segments[pop_label] = blocks[0].segments[read_segment_id]

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# - Each segment has attributes 'analogsignals' and 'spiketrains'
# - Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

# List all recorded signals
for pop_label, segment in pops_segments.items():
    print("\n{} has following signals:".format(pop_label))
    for signal in segment.analogsignals:
        print("\t- '{:10}\t[{}] - description: {}".format(signal.name, signal.units, signal.description))
    print("\t- {} spiketrains".format(len(segment.spiketrains)))

# Spike Trains

In [None]:
num_pops = len(pops_segments)

# Plot spikes
fig_spikes, axes_spikes = plt.subplots(num_pops, 1, figsize=(10,14))
fig_spikes.suptitle('Spikes for each population')

i_pop = 0
pop_spike_colors = 'rgcbm'
for pop_label, segment in pops_segments.items():

    ax = axes_spikes[i_pop]
    for i_train, spiketrain in enumerate(segment.spiketrains):
        y = spiketrain.annotations.get('source_id', i_train)
        y_vec = np.ones_like(spiketrain) * y
        ax.plot(spiketrain, y_vec, marker='|', linestyle='', snap=True, color=pop_spike_colors[i_pop % 5])
        ax.set_ylabel(pop_label)

    i_pop += 1

plt.show(block=False)

# Raw Signals

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
stn_vm_signal = signal

max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)

interval = None # [2000.0, 6000.0]
rec_dt = signal.sampling_period.magnitude
irange = [0, signal.shape[0]-1] if interval is None else [int(t/rec_dt) for t in interval]
times = signal.times[irange[0]:irange[1]]

fig, axes = plt.subplots(num_signals, 1, figsize=(10,8))
fig.suptitle("{} membrane voltage".format(pop_label))

for i_cell in range(num_signals):
    ax = axes[i_cell]
    if 'source_ids' in signal.annotations:
        label = "id {}".format(signal.annotations['source_ids'][i_cell])
    else:
        label = "cell {}".format(i_cell)
    
    sig = signal[irange[0]:irange[1], i_cell]

    ax.plot(times, sig, label=label)
    ax.grid(True)
    # ax.legend()
    
    if i_cell == num_signals-1:
        ax.set_ylabel("voltage ({})".format(signal.units))
        ax.set_xlabel('time ({})'.format(times.units))

## STN LFP

In [None]:
# Load each individual cell's LFP contribution
pop_label = 'STN'
segment = pops_segments[pop_label]
lfp_sigs = next((sig for sig in segment.analogsignals if sig.name == 'lfp'))
lfp_summed = lfp_sigs.sum(axis=1)

interval = None # (2000.0, 4000.0)
rec_dt = lfp_sigs.sampling_period.magnitude
irange = [0, lfp_sigs.shape[0]] if interval is None else [int(t/rec_dt) for t in interval]

lfp_times = lfp_sigs.times[irange[0]:irange[1]]
lfp_ranged = lfp_summed[irange[0]:irange[1]]

# Turn it into AnalogSignal object
lfp_signal = neo.AnalogSignal(lfp_ranged, units=lfp_ranged.units, 
                              sampling_rate=stn_vm_signal.sampling_rate, t_start=lfp_times[0])

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(lfp_signal.times, lfp_signal, label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 5], label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 8], 'r.', ms=1, mew=1, label='{} LFP'.format(pop_label))

ax.set_ylabel('LFP magnitude ({})'.format(lfp_sigs.units))
ax.set_xlabel('time ({})'.format(times.units))
ax.set_title('LFP for {} population'.format(pop_label))
ax.legend()

# Power Spectrum

In [None]:
# Save all PSDs for comparison in figures
all_psd = {}

## STN LFP PSD

In [None]:
# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(lfp_signal, freq_res=0.5)
psd = psd.ravel() # we only have one axis so make 1-dimensional
all_psd['STN_LFP'] = (freqs, psd)

In [None]:
# Plot the PSD
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(freqs, psd, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for LFP of STN cells')
ax.legend()

## STN Vm PSD

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.plot(all_psd['STN_LFP'][0], all_psd['STN_LFP'][1] * 1e3, 'r', label='STN LFP x 10')
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average STN membrane voltages')
ax.legend()

## CTX spikes PSD

__TODO__

See paper Piotr Kiewics (1987) - A method of description of single muscle fiver action potential by an analytical function. They give a function that can be used to approximate an AP.   

In [None]:
# def AP_function(t, V1, V2, V3, RT, DT):
#     for i = range(3):
#         # calculate b and sigma for phase
#         b = 1 # TODO
#         sigma = 1 # TODO
#         segment = (V2 - b*t**2) * np.exp(-t**2 / sigma**2)


def exp2fun(t, td, taur, taud):
    """
    Bi-exponential function.
    """
    tp = td + (taud*taur)/(taud-taur)*np.log(taud/taur)
    f = 1/(np.exp(-(tp-td)/taud) - np.exp(-(tp-td)/taur))
    return f * (np.exp(-(t-td)/taud) - np.exp(-(t-td)/taur)) * (t>=td)
    

def cortex_AP_kernel(t):
    peak_time = 1.0
    vrest = -60.0
    height = 10 - vrest
    kernel = height * exp2fun(t+peak_time, td=0.0, taur=0.5, taud=1.0)
    kernel[t<-2] = 0
    kernel[t>10] = 0
    return kernel

In [None]:
t = np.arange(-50, 50, .05) # ms
peak_time = 1.0
plt.plot(t, 10 * exp2fun(t+peak_time, td=0.0, taur=0.5, taud=1.0))
spike_kernel = cortex_AP_kernel(t)

dur = 100.0
dt = 0.05
time = np.arange(0, dur + dt, dt)
spiketimes = np.arange(0, 100, 20)
spike_pulses = np.zeros_like(time)
spike_pulses[[int(t/dt) for t in spiketimes]] = 1.0

spike_signal = np.convolve(spike_pulses, spike_kernel, mode='same') - 60

plt.figure()
plt.plot(time, spike_pulses)
plt.plot(time, spike_signal)

## GPe Vm PSD

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(10,4))
ax.plot(all_psd['STN_Vm'][0], all_psd['STN_Vm'][1], 'r', label='STN Vm PSD')
ax.plot(freqs, psd_avg, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average GPE membrane voltages')
ax.legend()

# Phase Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## CTX - STN Synchrony

## STN - GPE Synchrony

## CTX - STN Synchrony