# Read Me
    Functions to be mergered into Analysis.py in Braingeneerspy package

# Import Library

In [95]:
import numpy as np
import matplotlib.pyplot as plt
import scipy 
import statsmodels.api as sm
import math

import jdc  # for %%add_to 'class' magic

# Functions

## Computation

In [96]:
class SpikeData():
    def __init__(self, spike_times):
        'initiate the spike time train for a recording'
        self.train = spike_times
    
    def interspike_intervals(self):
        'get interspike interval for each neuron'
        return [np.diff(ts) for ts in self.train]
    
    def skewness(self):
        'Skewness of interspike interval distribution.'
        intervals = self.interspike_intervals()
        return [scipy.stats.skew(intl) for intl in intervals]

In [97]:
%%add_to SpikeData
def log_histogram(self, bin_num=300):
    intervals = self.interspike_intervals()
    ret_hist = []
    ret_logbins = []
    for ts in intervals:
        log_bins = np.logspace(np.log10(min(ts)),np.log10(max(ts)), bin_num+1)
        hist, _ = np.histogram(ts, log_bins)
        ret_hist.append(hist)
        ret_logbins.append(log_bins)
    return ret_hist, ret_logbins

In [98]:
%%add_to SpikeData
def culmulative_moving_average(self, hist):
    '''
    The culmulative moving average for a histogram. Return a list of cma. 
    '''
    ret = []
    for h in hist:
        cma = 0
        cma_list = []
        for i in range(len(h)):
            cma = (cma * i + h[i]) / (i+1)
            cma_list.append(cma)
        ret.append(cma_list)
    return ret

In [99]:
%%add_to SpikeData
def max_cma(self, hist):
    '''
    Return the maximum cma and its index for a histogram.
    '''
    max_list = []
    max_idx = []
    for h in hist:
        cma = 0
        cma_list = []
        for i in range(len(h)):
            cma = (cma * i + h[i]) / (i+1)
            cma_list.append(cma)
        max_list.append(max(cma_list))
        max_idx.append(np.argmax(cma_list))
    return max_list, max_idx

In [100]:
%%add_to SpikeData
def isi_threshold_cma(self, hist, bins, coef=1):
    '''
    Calculate interspike interval threshold from cumulative moving average[1]. Return threshold.  
    The threshold is the corresponding bin that has the max cma in the interspike interval histogram.
    Histogram and bins are default to logarithm. 
    [1] Kapucu, Fikret Emre, et al. Frontiers in computational neuroscience 6 (2012): 38.
    '''
    isi_thr = []
    for n in range(len(hist)):
        h = hist[n]   
        max_idx = 0
        cma = 0
        cma_list = []
        for i in range(len(h)):
            cma = (cma * i + h[i]) / (i+1)
            cma_list.append(cma)
        max_idx = np.argmax(cma_list)     
        thr = (bins[n][max_idx+1]) * coef
        isi_thr.append(thr)
    
    return isi_thr

In [101]:
%%add_to SpikeData
def lowess_smooth(hist, bins):
    '''smooth the logISIH with local linear regression.'''
    lowess = sm.nonparametric.lowess
    yest_sm = lowess(hist, bins[0:len(bins)-1], frac=12/300, it=3, return_sorted = False)
    return yest_sm

In [102]:
%%add_to SpikeData
def peaks_and_valleys(yest_sm):
    '''find the first two peaks and the valley between them. Return the value of peaks and valley, and their indexes'''
    peaks_idx, _ = scipy.signal.find_peaks(yest_sm, distance=2)
    peaks = sorted(yest_sm[peaks_idx])
    peak_1 = peaks[-1]
    peak_2 = peaks[-2]
    peak_1_idx = np.where(yest_sm == peak_1)[0][0]
    peak_2_idx = np.where(yest_sm == peak_2)[0][0]
    peaks_idx = sorted([peak_1_idx, peak_2_idx])
    valley = min(yest_sm[peaks_idx[0]: peaks_idx[1]])
    valley_idx = np.where(yest_sm == valley)[0][0]
    return [(peak_1, peak_1_idx), (peak_2, peak_2_idx), (valley, valley_idx)]

In [103]:
%%add_to SpikeData
def void_parameter(peak_1, peak_2, valley):
    '''Calculate void parameter.'''
    return 1 - (valley/math.sqrt(peak_1*peak_2))

In [104]:
%%add_to SpikeData
def isi_threshold_pm(self, hist, bins):
    '''
    Calculate interspike interval threshold from peaks and local minima [1]. Return threshold. 
    [1] Pasquale, Valentina, et al. Journal of computational neuroscience 29.1 (2010): 213-229.
    '''
    yest_sm = lowess_smooth(hist, bins)
    pk_1, pk_2, vly = peaks_and_valleys(yest_sm)
    vp = void_parameter(pk_1[0], pk_2[0], vly[0])
    return bins_n[vly[1]+1], vp

## Burst Detection

In [105]:
def burst_detection(spike_times, burst_threshold, spike_num_thr=3):
    '''
    Detect burst from spike times with a interspike interval threshold (burst_threshold) and a spike number threshold (spike_num_thr).
    The interspike interval threshold can be calculated from cumulative moving average function or peaks and minimum function
    or set manually. 
    Spike number threshold is set manually.
    Returns:
        spike_num_list -- a list of [burst start idx, number of spikes in this burst]
        burst_set -- a list of spike times in all the bursts. 
    '''
    spike_num_burst = 1
    spike_num_list = []    # [burst start idx, number of spikes in this burst]
    for i in range(len(spike_times)-1):
        if spike_times[i+1] - spike_times[i] <= burst_threshold:
            spike_num_burst += 1 
        else:
            if spike_num_burst >= spike_num_thr:
                spike_num_list.append([i-spike_num_burst+1, spike_num_burst]) 
                spike_num_burst = 1
            else:
                spike_num_burst = 1
    burst_set = []
    for loc in spike_num_list:
        for i in range(loc[1]):
            burst_set.append(spike_times[loc[0]+i])
    return spike_num_list, burst_set

## Visualization

In [106]:
%%add_to SpikeData
def logISI_histogram(isi, logbins):
    '''logISIH for each neuron using subplot.'''
    num1 = len(isi)
    num2 = len(logbins)
    if num1 != num2:
        return "Error: Input data length inconsistent!"
    
    col = 5
    if len(isi)%col != 0:
        row = int(len(isi)//col + 1)
    else:
        row = int(len(isi)//col)
        
    fig, axs = plt.subplots(row, col, figsize=(12*row, 8*col))
    for r in range(row):
        for c in range(col):
            if col*r+c < len(isi):
                axs[r, c].hist(isi[col*r+c], logbins[col*r+c], rwidth=0.9)
                axs[r, c].set_xscale('log')
                axs[r, c].set_xlabel('ISI, time (s)', fontsize=16)
                axs[r, c].set_ylabel('Number of Intervals', fontsize=16)
            else:
                axs[r, c].set_axis_off()
            
    plt.show()

In [107]:
%%add_to SpikeData
def smoothed_LogISIH(y_values, bins):
    '''plot smoothed result with principal peaks and local minima using subplot.'''
    num1 = len(y_values)
    num2 = len(bins)
    if num1 != num2:
        return "Error: Input data length inconsistent!"
    
    col = 5
    if len(y_values)%col != 0:
        row = int(len(y_values)//col + 1)
    else:
        row = int(len(y_values)//col)
        
    yest_sm = []
    peaks_valley = []
    for i in range(len(y_values)):
        y = lowess_smooth(y_values[i], bins[i])
#         print(i, y[0:20])
        peaks_valley.append(peaks_and_valleys(y))
        yest_sm.append(y)
    
    fig, axs = plt.subplots(row, col, figsize=(12*row, 8*col))
    for r in range(row):
        for c in range(col):
            if col*r+c < len(y_values):
                x_bins = bins[col*r+c]
                x_bins = x_bins[0:len(x_bins)-1]
                axs[r, c].scatter(x_bins, y_values[col*r+c])
                axs[r, c].plot(x_bins, yest_sm[col*r+c], linewidth=2, color='green', label='lowess: statsmodel')
                for i in peaks_valley[col*r+c]:
                    axs[r, c].plot(x_bins[i[1]], i[0], '*', color='magenta', markersize=16, linewidth=4)
                axs[r, c].set_xscale('log')
                axs[r, c].set_xlabel('ISI, time (s)', fontsize=16)
                axs[r, c].set_ylabel('Number of Intervals', fontsize=16)
                axs[r, c].legend(fontsize=16)
            else:
                axs[r, c].set_axis_off()
            
    plt.show()

In [108]:
def labeled_raster(spike_times, burst_set, burst_st, cut, start_point, end_point):   
    '''Plot raster with burst start point, bursts and every spike time. 'spike_times' is the list for one neuron'''
    if cut == True:
        st = start_point
        en = end_point
        spike_times = [spk for spk in spike_times if spk > st and spk < en]
        burst_st = [spk for spk in burst_st if spk > st and spk < en]
        burst_set = [spk for spk in burst_set if spk > st and spk < en]
        
    fig, axs = plt.subplots(1, 1, figsize=(36, 12))
    axs.eventplot(burst_st, linelengths=1, linewidth=2, color='blue', alpha=0.4)
    axs.eventplot(burst_set, linelengths=0.7, linewidth=0.8, color='red', alpha=1)
    axs.eventplot(spike_times, linelengths=0.3, linewidth=0.8, color='green', alpha=1)
    
    axs.set_xlabel('Time (s)', fontsize=16)
    axs.set_ylabel('Unites', fontsize=16)
    axs.tick_params(labelsize=16)
    
    plt.show()