# Setup

In [None]:
%matplotlib inline
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

import os
import neo.io
import elephant

# analysis.py module in same folder
from common import units
units.set_units_module('quantities')
import quantities as pq

## Plotting Options

In [None]:
# Width of the page for calibrating fig_size.
# Approx. 16 for matplotlib backend %inline  8 for %notebook
page_width = 14
ax_height = 3

# Style of figures (default colors etc.): see https://matplotlib.org/gallery/style_sheets/style_sheets_reference.html
plt.style.use('default')

## Load Data

In [None]:
# All populations:
outputs = "/home/luye/storage/2018.06.18_job-776933_DA-control_CTX-SWA"

# Single population:
# outputs = [
# "/home/luye/storage/2018.06.21_job-testmpi6_DA-control_CTX-beta/CTX_2018.06.21_pop-100_dur-50.0_job-testmpi6.mat",
# ]

if isinstance(outputs, str):
    filenames = os.listdir(outputs)
    pop_files = [os.path.join(outputs, f) for f in filenames if f.endswith('.mat')]
else:
    pop_files = outputs

pops_segments = {}
read_segment_id = 0

# Read binary files using Neo IO module
for pop_file in pop_files:
    reader = neo.io.get_io(pop_file)
    blocks = reader.read()
    assert len(blocks) == 1, "More than one Neo Block in file."
    pop_label = blocks[0].name

    if len(blocks[0].segments)-1 < read_segment_id:
        raise ValueError("Segment index greater than number of Neo segments"
                         " in file {}".format(pop_file))
    if pop_label in pops_segments:
        raise ValueError("Duplicate population labels in files")
        
    pops_segments[pop_label] = blocks[0].segments[read_segment_id]

# Save all PSDs for comparison in figures
all_psd = {}
all_signals = {}

In [None]:
# The recordings are saved in Neo format. See:
# http://neo.readthedocs.io/en/latest/
# http://neo.readthedocs.io/en/latest/api_reference.html#neo.core.AnalogSignal
# - Each segment has attributes 'analogsignals' and 'spiketrains'
# - Each quantity (e.g. AnalogSignal.signal) has attributes magnitude, units, dimensionality

# List all recorded signals
for pop_label, segment in pops_segments.items():
    print("\n{} has following signals:".format(pop_label))
    for signal in segment.analogsignals:
        print("\t- '{:10}\t[{}] - description: {}".format(signal.name, signal.units, signal.description))
    print("\t- {} spiketrains".format(len(segment.spiketrains)))

# Spike Trains

<span style='color:red;font-weight:bold'>WARNING</SPAN>: In rastergram plots, note the number of spike trains plotted (see y-axis). If it is too high you get overlapping marker bars (marker height is larger than row height allocated to one spiketrain). This leads to misleading plots as spiketrains are overlapping which looks like an artificially elevated firing rate.

In [None]:
num_pops = len(pops_segments)
pops_per_fig = 2

# Plot spikes
fig_spikes, axes_spikes = None, None # plt.subplots(num_pops, 1, figsize=(10,14))

i_pop = 0
pop_spike_colors = 'rgcbm'
for pop_label, segment in pops_segments.items():
    # Don't plot all rastergrams in same figure
    if i_pop % pops_per_fig == 0:
        fig_spikes, axes_spikes = plt.subplots(pops_per_fig, 1, figsize=(page_width,2*ax_height))
        fig_spikes.suptitle('Spikes for each population')
    ax = axes_spikes[i_pop % pops_per_fig]
    
    # Which cells to plot (all cells causes overlapping markers and misleading plot)
    # cell_ids = range(len(segment.spiketrains))
    cell_ids = range(20)
    
    for i_train in cell_ids:
        spiketrain = segment.spiketrains[i_train]
        y = spiketrain.annotations.get('source_id', i_train)
        y_vec = np.ones_like(spiketrain) * y
        ax.plot(spiketrain, y_vec, marker='|', linestyle='', snap=True, color=pop_spike_colors[i_pop % 5])
        ax.set_ylabel('{} cell #'.format(pop_label))

    i_pop += 1

plt.show(block=False)

# Raw Signals

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))
stn_vm_signal = signal

In [None]:
# Choose plot interval and cell indices
max_num_plot = 10
num_signals = min(signal.shape[1], max_num_plot)
interval = [5e3, 10e3] # [2000.0, 6000.0]
rec_dt = signal.sampling_period.magnitude
irange = [0, signal.shape[0]-1] if interval is None else [int(t/rec_dt) for t in interval]
times = signal.times[irange[0]:irange[1]]

fig, axes = plt.subplots(num_signals, 1, figsize=(0.75*page_width,2*ax_height))
fig.suptitle("{} membrane voltage".format(pop_label))

# Plot a bunch of STN Vm signals
for i_cell in range(num_signals):
    ax = axes[i_cell]
    if 'source_ids' in signal.annotations:
        label = "id {}".format(signal.annotations['source_ids'][i_cell])
    else:
        label = "cell {}".format(i_cell)
    
    sig = signal[irange[0]:irange[1], i_cell]

    ax.plot(times, sig, label=label)
    ax.grid(True)
    # ax.legend()
    
    if i_cell == num_signals-1:
        ax.set_ylabel("voltage ({})".format(signal.units))
        ax.set_xlabel('time ({})'.format(times.units))

## STN LFP

In [None]:
# Load each individual cell's LFP contribution
pop_label = 'STN'
segment = pops_segments[pop_label]
lfp_sigs = next((sig for sig in segment.analogsignals if sig.name == 'lfp'))
lfp_summed = lfp_sigs.sum(axis=1)

interval = None # [5e3, 10e3]
rec_dt = lfp_sigs.sampling_period.magnitude
irange = [0, lfp_sigs.shape[0]] if interval is None else [int(t/rec_dt) for t in interval]

lfp_times = lfp_sigs.times[irange[0]:irange[1]]
lfp_ranged = lfp_summed[irange[0]:irange[1]]

# Turn it into AnalogSignal object
lfp_signal = stn_lfp_signal = neo.AnalogSignal(lfp_ranged, units=lfp_ranged.units, 
                              sampling_rate=stn_vm_signal.sampling_rate, t_start=lfp_times[0])

In [None]:
fig, ax = plt.subplots(figsize=(page_width, ax_height))
ax.plot(lfp_signal.times, lfp_signal, label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 5], label='{} LFP'.format(pop_label))
# ax.plot(lfp_sigs.times, lfp_sigs[:, 8], 'r.', ms=1, mew=1, label='{} LFP'.format(pop_label))

ax.set_ylim((-0.5, 3))
ax.set_ylabel('LFP magnitude ({})'.format(lfp_sigs.units))
ax.set_xlabel('time ({})'.format(times.units))
ax.set_title('LFP for {} population'.format(pop_label))
ax.grid(True)
ax.legend()

# Power Spectrum

## STN LFP

In [None]:
# Computes PSD of all 100 LFP signals at the same time
freqs, psd = elephant.spectral.welch_psd(lfp_signal, freq_res=0.5)
psd = psd.ravel() # we only have one axis so make 1-dimensional
all_psd['STN_LFP'] = (freqs, psd)

In [None]:
# Plot the PSD
fig, ax = plt.subplots(figsize=(0.5*page_width, ax_height))
ax.plot(freqs, psd, label='{} PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for LFP of STN cells')
ax.legend()

In [None]:
# Plot spectrogram using STFT
dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(stn_lfp_signal.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))

plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
plt.clim(0, abs(Sxx).max())
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - evolution of STN LFP over time ($nV^2/Hz$)')

## STN Vm

In [None]:
pop_label = 'STN'
segment = pops_segments[pop_label]
vm_sig = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.plot(all_psd['STN_LFP'][0], all_psd['STN_LFP'][1] * 1e3, 'r', label='STN LFP x 10')
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average STN membrane voltages')
ax.legend()

## CTX spikes

### Convolve spike times with AP

In [None]:
# Load Pyramidal cell action potential from saved recording
pyramidal_trace = np.loadtxt('pyramidal.dat')
vm = pyramidal_trace[:,1]
tvec = pyramidal_trace[:,0]
dt = np.round(tvec[1] - tvec[0], 4)
ap_t_interval = [4.9, 20.0]
ap_i_interval = [int(t/dt) for t in ap_t_interval]
subsample = 2
ap_range = range(ap_i_interval[0], ap_i_interval[1]+1, subsample)
ap_kernel = vm[ap_range]
ap_baseline = -65.0
ap_kernel -= ap_baseline # center on 0 for convolution

plt.figure(figsize=(6,2))
plt.plot(tvec[ap_range], ap_kernel)
plt.suptitle('AP kernel for cortical neurons')
plt.grid(True)

In [None]:
pop_label = 'CTX'
segment = pops_segments[pop_label]
spiketrains = segment.spiketrains

In [None]:
# Construct AnalogSignal of N channels from N spiketrains
dur = np.round(spiketrains[0].t_stop.magnitude, 4)
dt = 0.05
signal_matrix = np.empty((int(dur/dt)+1, len(spiketrains)))

for i, st in enumerate(spiketrains):
    
    time = np.arange(0, dur + dt, dt)
    spiketimes = st.times
    spike_pulses = np.zeros_like(time)
    spike_pulses[[int(t/dt) for t in spiketimes]] = 1.0
    
    # Convole pulses at spike times with AP kernel
    spike_signal = np.convolve(spike_pulses, ap_kernel, mode='same') + ap_baseline
    signal_matrix[:, i] = spike_signal

ctx_vm_signal = neo.AnalogSignal(signal_matrix, units='mV', sampling_rate=1/dt/pq.ms,
                                t_start=spiketrains[0].t_start, t_stop=spiketrains[0].t_stop)
ctx_vm_mean = ctx_vm_signal.sum(axis=1) / ctx_vm_signal.shape[1]

### Average PSD

In [None]:
# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(ctx_vm_signal, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]
all_psd[pop_label + '_Vm'] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(freqs, psd_avg, label='{} Vm'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Mean Welch PSD for CTX membrane voltages')
ax.legend()

## GPe Vm

In [None]:
pop_label = 'GPE'
segment = pops_segments[pop_label]
vm_sig = gpe_vm_signal = next((sig for sig in segment.analogsignals if sig.name == 'Vm'))

# Computes PSD of all 100 Vm signals at the same time
freqs, psd = elephant.spectral.welch_psd(vm_sig, freq_res=0.5)
psd_avg = psd.sum(axis=0) / psd.shape[0]

all_psd[pop_label] = (freqs, psd_avg)

In [None]:
fig, ax = plt.subplots(figsize=(page_width*0.5, ax_height))
ax.plot(all_psd['STN_Vm'][0], all_psd['STN_Vm'][1], 'r', label='Avg STN Vm PSD')
ax.plot(freqs, psd_avg, label='Avg {} Vm PSD'.format(pop_label))
ax.set_ylabel('Power ({})'.format(psd_avg.units))
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 50))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Welch PSD for average GPE membrane voltages')
ax.legend()

In [None]:
# Plot spectrogram using STFT

# Mean GPe membrane voltage
gpe_vm_mean = gpe_vm_signal.sum(axis=1) / psd.shape[1]

dt = 0.05
fs = 1/dt*1e3
freq_res = 1.0
nperseg = int(fs/freq_res) # determines frequency resolution
t_res = 20.0 # ms
noverlap = nperseg - int(t_res/dt)
freqs, t, Sxx = scipy.signal.spectrogram(gpe_vm_mean.ravel(), 1/dt, window='hanning',
                                         nperseg=nperseg, noverlap=noverlap, scaling='density')
freqs = freqs * 1000
fig, ax = plt.subplots(figsize=(0.75*page_width, ax_height))

plt.pcolormesh(t, freqs, Sxx)
# plt.imshow(Sxx, cmap='jet', aspect='auto', vmax=abs(Sxx).max(), vmin=Sxx.min())
f_max = 50
plt.ylim((0, f_max))
plt.colorbar()
plt.clim(0, abs(Sxx).max())
plt.ylabel('frequency (Hz)')
plt.xlabel('time (ms)')
plt.suptitle('Spectrogram - Power of mean GPe Vm over time ($nV^2/Hz$)')

# Phase Relationships

For suitable measures and implementations, see
- bookmarks/neuroscience/signal_processing
- google `measure + site:github.com`
- ask Amir for his mutual information and related measures
- see Beta-related and other neurophysiology articles

For example, we can use following measures

- __Coherence__ : linear relationship betwee two signals by frequency component
    + see `welch_cohere` in [elephant.spectral](http://elephant.readthedocs.io/en/latest/reference/spectral.html)


- __Phase-Amplitude Coupling__ (PAC)
    + see `hilbert` in [elephant.signal_processing](http://elephant.readthedocs.io/en/latest/reference/signal_processing.html) to do band-pass filter + Hilbert transform
    + see `comodulogram` in [pactools](https://pactools.github.io/auto_examples/plot_comodulogram.html)

## CTX - STN

### Coherence

In [None]:
# ctx_vm_mean = neo.AnalogSignal(ctx_vm_signal.sum(axis=1).reshape((-1,1)), units=ctx_vm_signal.units, 
#                               sampling_rate=ctx_vm_signal.sampling_rate, t_start=ctx_vm_signal.t_start)

# Copy signal metadata from CTX Vm signal
ctx_vm_mean = ctx_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(ctx_vm_mean, stn_lfp_signal, freq_res=0.5)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='CTX Vm - STN LFP')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between Cortex mean Vm - STN LFP')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')

### Phase-Amplitude Coupling

In [None]:
# from tensorpac import Pac
# estimator = Pac(idpac=(1, 0, 0), fpha=(3, 33, 1, 1), famp=(20, 150, 5, 5),
#                 dcomplex='hilbert', filt='butter')

# Filter the data and extract PAC :
# xpac = estimator.filterfit(fs,
#                            xpha=stn_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            xamp=ctx_vm_signal.magnitude[int(1e3/dt):int(10e3/dt), 0:10],
#                            axis=0, traxis=1, njobs=2)
# xpac = estimator.filterfit(fs, 
#                            xpha=stn_lfp_signal.magnitude.reshape((-1,1)),
#                            xamp=ctx_vm_mean.magnitude.reshape((-1,1)),
#                            axis=0, traxis=1, njobs=2)

In [None]:
# Plot PAC
# estimator.comodulogram(xpac.mean(-1), title='PAC: STN LFP (phase) - CTX Vm (amplitude)',
#                        cmap='Spectral_r', plotas='imshow')

## STN - GPE

In [None]:
# Copy signal metadata from CTX Vm signal
gpe_vm_mean = gpe_vm_signal.duplicate_with_new_array(ctx_vm_signal.sum(axis=1).reshape((-1,1)))

# Coherence between averaged cortical Vm and STN LFP
# N x CTX signals -> 1 x LFP signal
freqs, coherence, phase_lag = elephant.spectral.welch_cohere(stn_lfp_signal, gpe_vm_mean, freq_res=0.5)

### Coherence

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(0.75*page_width, 2*ax_height))
ax = axes[0]
ineg = len(freqs)/2 # plot messed up because positive frequency axis comes before negative part
ax.plot(freqs[:ineg], coherence[:ineg], label='STN LFP - Avg GPe Vm')
ax.set_ylabel('Coherence (unitless)')
# ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Coherence between STN LFP - Average GPe Vm')
ax.legend(loc='upper right')

# Plot phase
ax = axes[1]
ax.plot(freqs[:ineg], phase_lag[:ineg], 'g-', label='phase')
ax.set_ylabel('Phase (rad)')
ax.set_xlabel('frequency ({})'.format(freqs.units))
ax.set_xlim((0, 100))
ax.grid(True)
# ax.set_yscale('log')
ax.set_title('Phase lag > 0 means CTX leads STN')
ax.legend(loc='upper right')