In [None]:
#Importing some modules 

import os 
from datetime import datetime
from datetime import date

%matplotlib inline

%load_ext autoreload
%autoreload
 
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from numpy import genfromtxt
from numpy.linalg import norm
import pandas as pd
from numba import jit

from heapq import nsmallest
from scipy.signal import chirp, find_peaks, peak_widths
from scipy.ndimage import gaussian_filter1d

from os import walk
from ast import literal_eval
import re 
import scipy
from scipy.ndimage import gaussian_filter
# import more_itertools as mit
import random
from datetime import date

import xml.etree.ElementTree as ET

# start_time = datetime.now().strftime("%H:%M:%S")
# Date = date.today.strftime("%d%m%Y")

from scipy.spatial.distance import pdist, squareform
from scipy.stats import zscore, bootstrap, sem
from scipy.signal import savgol_filter as sg_filter
from scipy.signal import gaussian
from scipy.ndimage import gaussian_filter1d
from scipy.stats import wilcoxon, pearsonr
from scipy import stats
from scipy.optimize import minimize
# from sklearn.prepocessing import normalize

import copy
from sklearn.neighbors import KernelDensity
from sklearn.decomposition import FastICA
from sklearn.decomposition import PCA
from statsmodels.stats.weightstats import ztest


### Edgar's code of prepocessing data and assembly analysis

In [None]:
#GENERAL FUNCTIONS 
def generate_folder_path(dirs):
    for i in dirs:
        try:
            os.makedirs(i)
            print("Directory " , i ,  " Created ")
        except FileExistsError:
            print("Directory " , i ,  " already exists")  
            
def load_stim_file(dp_StimF): 
    os.chdir(dp_StimF)
    StimFs = []
    StimFs.append(genfromtxt('StimF1.csv', delimiter=',').astype(int))
    return StimFs[0]


def load_s2p_files(dp_s2p, neuropil_correction):
    """ this function loads the suite2p output files 
    #can refine this to load only a subset of files 
    https://mouseland.github.io/suite2p/_build/html/outputs.html 
    Parameters: s2p datapath, neuropil correction value (if 0, do no correction) 
    Returns:F, Spks, ops,iscell,stat,Fneu """ 
    
    os.chdir(dp_s2p)
    F = np.load('F.npy')
    Spks = np.load('Spks.npy')
    ops = np.load('ops.npy', allow_pickle=True).item() #from the suite2p github 
    iscell = np.load('iscell.npy') #first col is binary yes or no and second col is prob. classifier that is cell 
    stat = np.load('stat.npy', allow_pickle=True)
    Fneu = np.load('Fneu.npy')

    if neuropil_correction > 0: 
        F = F-Fneu*neuropil_correction
        for idx,i in enumerate(F): 
            F[idx] = i-np.min(i)
    return F, Spks, ops,iscell,stat,Fneu

def pre_process_imaging(iscell, F, stat, FOVsizeum, mode):
    """ 'Gives you dff for cells of interest and stim triggers 
    Paramters: iscell, F, stat, FOVsizeum, dp_StimF, mode (median or 10% median dff)
    Returns: FNc, iscell_list, x,y, StimFs
    
    '"""
    iscell_list = get_curated_cells(iscell)
    Fc = F[iscell_list] 
    if mode == 'median':
        FNc = dff_median(Fc)
    elif mode == '10': 
        FNc = dff_10percent(Fc)
    x, y = get_cell_centroids(stat, iscell_list)
    xa = [i*(FOVsizeum/512) for i in x]
    yb = [i*(FOVsizeum/512) for i in y]
    return FNc, iscell_list, xa,yb, x,y

def get_cell_centroids(stat, index_list):
    #this function finds the x and y centroids from the stat file from suite2p 
    x,y = zip(*[(stat[i]['med'][1], stat[i]['med'][0]) for i in index_list])
    return x,y

def get_curated_cells(iscell): 
    return np.where(iscell[:,0] == 1)[0]

def dff_10percent(traces): 
    a = np.empty_like(traces) 
    k = int(len(traces[0])/10)
    for idx,i in enumerate(traces): 
        bsl = np.median(i[np.argpartition(i, k)[:k]])
        a[idx] = (i-bsl)/bsl
    return a
def dff_median(traces): 
    a = np.empty_like(traces) 
    k = int(len(traces[0])/10)
    for idx,i in enumerate(traces): 
        bsl = np.median(i)#[np.argpartition(i, k)[:k]])
        a[idx] = (i-bsl)/bsl
    return a

def grab_file_info(dp): 
    """ This function gets the relevant information of a file from the name 
        
        Parameters: datapath 

        Returns: Mouse_ID, Date

    """ 
    f = []
    for (dirpath, dirnames, filenames) in walk(dp):
        f.extend(filenames)
        break
    Mouse_ID = f[0].split('.')[0][6:9]
    Date = f[0].split('.')[0][:6]
    return Mouse_ID, Date
def sorted_nicely( l ): 
    """ Sort the given iterable in the way that humans expect.""" 
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)

def behaviour_trial_function_1Map(data, start_trial):
    
    
    behav_trial_lengths = [len(i) for i in data[start_trial:]]
    behav_trials = []
    for i in data[start_trial:]: 
        this_trial = i
        this_trial_list = []
        for u in this_trial: 
            this_trial_list.append(u[0])
        behav_trials.append(this_trial_list)
        
    return behav_trial_lengths, behav_trials

def get_imaging_trial_lengths(Triggers):
    
    img_trial_lengths = []
    for idx, i in enumerate(Triggers[:-1]):
        img_trial_lengths.append(Triggers[idx+1] - Triggers[idx])
        
    return img_trial_lengths

def re_sample_behaviour(behav_trials, img_trial_lengths, y_start, y_end): 
    re_sampled_behav = []
    for idx, i in enumerate(behav_trials): 
        re_sample = scipy.signal.resample(i, img_trial_lengths[idx])
        re_sample[re_sample < y_start] = y_start
        re_sample[re_sample > y_end] = y_end
        re_sample[0:10], re_sample[-10:] = y_start, y_end
        re_sampled_behav.append(re_sample)
    return re_sampled_behav

def re_sample_speed(behav_trials, img_trial_lengths): 
    re_sampled_behav = []
    for idx, i in enumerate(behav_trials): 
        re_sample = scipy.signal.resample(i, img_trial_lengths[idx])
#         re_sample[re_sample < y_start] = y_start
#         re_sample[re_sample > y_end] = y_end
#         re_sample[0:10], re_sample[-10:] = y_start, y_end
        re_sampled_behav.append(re_sample)
    return re_sampled_behav

def normalizedata(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

# @jit(nopython=True)
def interpolation(arr_3d,):
    result=np.zeros_like(arr_3d)
    for i in range(arr_3d.shape[0]):
        for j in range(arr_3d.shape[1]):
            arr=arr_3d[i,j,:]
            # If all elements are nan then cannot conduct linear interpolation.
            if np.sum(np.isnan(arr))==arr.shape[0]:
                result[i,j,:]=arr
            else:
                # If the first elemet is nan, then assign the value of its right nearest neighbor to it.
                if np.isnan(arr[0]):
                    arr[0]=arr[~np.isnan(arr)][0]
                # If the last element is nan, then assign the value of its left nearest neighbor to it.
                if np.isnan(arr[-1]):
                    arr[-1]=arr[~np.isnan(arr)][-1]
                # If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
                for k in range(arr.shape[0]):
                    if np.isnan(arr[k]):
                        x=k
                        x1=x-1
                        x2=x+1
                        # Find left neighbor whose value is not nan.
                        while x1>=0:
                            if np.isnan(arr[x1]):
                                x1=x1-1
                            else:
                                y1=arr[x1]
                                break
                        # Find right neighbor whose value is not nan.
                        while x2<arr.shape[0]:
                            if np.isnan(arr[x2]):
                                x2=x2+1
                            else:
                                y2=arr[x2]
                                break
                        # Calculate the slope and intercept determined by the left and right neighbors.
                        slope=(y2-y1)/(x2-x1)
                        intercept=y1-slope*x1
                        # Linear interpolation and assignment.
                        y=slope*x+intercept
                        arr[x]=y
                result[i,j,:]=arr
    return result

def behaviour_pre_processing(dp_behav, StimFs, y_start, y_end): 
    os.chdir(dp_behav)
    Mouse_ID, Date = grab_file_info(dp_behav)
    behaviour_data = fully_sorted_data(dp_behav, Date)    
    behav_trial_lengths, behav_trials = behaviour_trial_function_1Map(behaviour_data, 1)
    img_trial_lengths = get_imaging_trial_lengths(StimFs)
    re_sampled_behav = re_sample_behaviour(behav_trials, img_trial_lengths, y_start, y_end)

    behav_speed = []
    for t in behav_trials:
        speed_list = []
        for i in range(20,len(t)):
            speed_list.append((np.max(t[i-20:i])-np.min(t[i-20:i]))*0.5/(20/55))
        speed_list = [speed_list[0]]*20 + speed_list
        behav_speed.append(speed_list)
    re_sampled_speed = re_sample_speed(behav_speed, img_trial_lengths)


    return re_sampled_behav, re_sampled_speed, behav_trial_lengths

def fully_sorted_data(dp_behav, Date):
    os.chdir(dp_behav)
    file_names = [i for i in os.listdir(dp_behav) if os.path.isfile(os.path.join(dp_behav,i)) and Date in i[:6]]
    file_names2 = []
    for x in sorted_nicely(file_names):
        file_names2.append(x)
    data  = []
    for i in file_names2:
        with open(i, 'r') as f:
            this_data = literal_eval('[' + ''.join(f.readlines()) + ']')
        data.append(this_data[0])
    return data 


#PLACE CELL FUNCTIONS 

def new_rate_map(traces, speed, y_start, y_end, binsize, triggers, trial_n, re_sampled_behav):
    tracesc = np.array(traces, copy=True)  
    bins = np.arange(y_start, y_end+binsize, binsize)
    master = np.zeros([len(trial_n), len(tracesc), len(bins)] )
    for tdx, t in enumerate(trial_n): 
        bt = re_sampled_behav[t]
        st = speed[t]
        tt = tracesc[:,triggers[t]:triggers[t+1]]
        dig = np.digitize(bt,bins)
        tt[:, np.where(st < 5)[0]] = np.nan
        master[tdx] = np.array([tt[:,dig == i].mean(axis=1) for i in range(0,len(bins))]).transpose()
    return interpolation(master)

def new_rate_map_shuffle(traces, speed, y_start, y_end, binsize, triggers, trial_n, re_sampled_behav):
    tracesc = np.array(traces, copy=True)  
    bins = np.arange(y_start, y_end+binsize, binsize)
    master = np.zeros([len(trial_n), len(tracesc), len(bins)] )
    for tdx, t in enumerate(trial_n): 
        bt = re_sampled_behav[t]
        st = speed[t]
        tt = np.roll(tracesc[:,triggers[t]:triggers[t+1]], random.randint(5*30, 100000))
        dig = np.digitize(bt,bins)
        tt[:, np.where(st < 5)[0]] = np.nan
        master[tdx] = np.array([tt[:,dig == i].mean(axis=1) for i in range(0,len(bins))]).transpose()
    return interpolation(master)

def peaksort(alist):
    return np.argmax(alist)

def place_cell_dombeck(trace):
    #trace = gaussian_filter(Map1_mean_traces[221],sigma = 3)
    max_bins = 80
    min_bins = 4
    df_thresh = 0.1

#     trace = trace-np.min(trace)
    peak = np.max(trace)
    minn = np.median(sorted(trace)[:int(len(trace)/4)])
    greater_indicies = [udx for udx, u in enumerate(trace) if u > (peak-minn)/4]
    fields = [list(group) for group in mit.consecutive_groups(greater_indicies)]
    big_enough_fields = [i for i in fields if max_bins >len(i) > min_bins]
    percent_20 = []
    for i in big_enough_fields: 
        if len([c for c in [trace[u] for u in i] if c > df_thresh]) > 1: 
            percent_20.append(i)
    thresholded_field = []
    for i in percent_20: 
        if np.mean([trace[u] for u in i]) > np.mean([c for cdx, c in enumerate(trace) if cdx not in i])*3: 
            thresholded_field.append(i)
    return thresholded_field

def calculate_map_stability(rate_map): 
    corr = []
    for idx, i in enumerate(rate_map): 
        remaing = rate_map[(np.arange(rate_map.shape[0]) != idx),:]
        corr.append(np.corrcoef(i, np.mean(remaing,axis=0))[1,0])
    return np.mean(corr)

def my_z_score(dist, value):
    mean = np.mean(dist)
    std = np.std(dist)
    return (value-mean)/std

def my_place_cell_func(smoothed_trace, shuffled_peaks): 
    return my_z_score(shuffled_peaks,np.max(smoothed_trace))


def calculate_map_stability(rate_map, Sigma): 
    corr = []
    for idx, i in enumerate(rate_map): 
        remaing = rate_map[(np.arange(rate_map.shape[0]) != idx),:]
        corr.append(np.corrcoef(gaussian_filter(i,sigma = Sigma), gaussian_filter(np.mean(remaing,axis=0),sigma = Sigma))[1,0])
    return corr

def FWHM(X,Y):
    half_max = max(Y) / 2
    #find when function crosses line half_max (when sign of diff flips)
    #take the 'derivative' of signum(half_max - Y[])
    d = np.sign(half_max - np.array(Y[0:-1])) - np.sign(half_max - np.array(Y[1:]))
    #plot(X[0:len(d)],d) #if you are interested
    #find the left and right most indexes
    left_idx = np.where(d > 0)[0]
    right_idx = np.where(d < 0)[-1]
    return left_idx,right_idx # X[right_idx], X[left_idx] #return the difference (full width)

def PC_FWHM_field_width(trace):
    bsl = np.median(sorted(trace)[:int(len(trace)/5)]) #median of 20% lowest values
    trace = trace - bsl
    r_trace = trace[::-1]
    try: 
        r = np.argmax(trace) + next(xdx for xdx, x in enumerate(trace[np.argmax(trace):]) if x < np.max(trace)/2)
    except: r = 200
    try:
        l = np.argmax(trace) - next(xdx for xdx, x in enumerate(r_trace[np.argmax(r_trace):]) if x < np.max(trace)/2)
    except: l = 0
    return r-l

# def new_rate_map(traces, speed, y_start, y_end, binsize, triggers, trial_n, re_sampled_behav):
#     tracesc = np.array(traces, copy=True)  
#     bins = np.arange(y_start, y_end+binsize, binsize)
#     master = np.zeros([len(trial_n), len(tracesc), len(bins)] )
#     for tdx, t in enumerate(trial_n): 
#         print(tdx)
#         bt = re_sampled_behav[t]
#         st = speed[t]
#         tt = tracesc[:,triggers[t]:triggers[t+1]]
#         dig = np.digitize(bt,bins)
#         tt[:, np.where(st < 5)[0]] = np.nan
#         master[tdx] = np.array([tt[:,dig == i].mean(axis=1) for i in range(0,len(bins))]).transpose()
#     return interpolation(master)

def position_and_speed_img_vec(FNc_behav, re_sampled_behav, re_sampled_speed, StimFs):
    flat_list_a = [x for xs in re_sampled_behav for x in xs]
    flat_list_b = [x for xs in re_sampled_speed for x in xs]

    a = np.array(flat_list_a)
    b = np.array(flat_list_b)
    missed_start = np.empty(StimFs[0]) * np.nan
    missed_end = np.empty(FNc_behav.shape[1]-StimFs[-1]) * np.nan
    behav =  np.concatenate ([missed_start, a, missed_end])
    speed = np.concatenate ([missed_start, b, missed_end])
    return behav, speed


# Defining some functions
def mpPDF(var, q, pts):
    """
    Creates a Marchenko-Pastur Probability Density Function

    Input
    ----------
    var: Variance (float)
    q: T/N where T is the number of rows and N the number of columns (n_neurons/n_samples)
    pts: int for the number of points used to build PDF
    Returns
    -------
    pandas series Marchenko-Pastur PDF

    Source
    ------
    Adapted from https://medium.com/swlh/an-empirical-view-of-marchenko-pastur-theorem-1f564af5603d

    Ref
    ---
    Marchenko and Pastur 1967 - https://www.researchgate.net/publication/303008084_Distribution_of_eigenvalues_for_some_sets_of_random_matrices
    """
    # Marchenko-Pastur pdf
    # q=T/N
    # Adjusting code to work with 1 dimension arrays
    if isinstance(var, np.ndarray):
        if var.shape == (1,):
            var = var[0]
    eMin, eMax = var * (1 - (1. / q) ** .5) ** 2, var * (1 + (1. / q) ** .5) ** 2
    eVal = np.linspace(eMin, eMax, pts)
    pdf = q / (2 * np.pi * var * eVal) * ((eMax - eVal) * (eVal - eMin)) ** .5
    pdf = pd.Series(pdf, index=eVal)
    return pdf


def fitKDE(obs, bWidth=.25, kernel='gaussian', x=None):
    """
    Fit kernel to a series of obs, and derive the prob of obs x is the array of values
        on which the fit KDE will be evaluated. It is the empirical PDF

    Input
    ----------
    obs: array of observations to fit (eigenvalues)
    bWidth: The bandwidth of the kernel (0.25 default)
    kernel: string of which kerne to use ([‘gaussian’|’tophat’|’epanechnikov’|’exponential’|’linear’|’cosine’]), ('gaussian' default)
    x: array of values on which fit KDE is evaluated
    Returns
    -------
    pandas series Emperical PDF

    Source
    ------
    Adapted from https://medium.com/swlh/an-empirical-view-of-marchenko-pastur-theorem-1f564af5603d

    Ref
    ---
    Marchenko and Pastur 1967 - https://www.researchgate.net/publication/303008084_Distribution_of_eigenvalues_for_some_sets_of_random_matrices
    """

    if len(obs.shape) == 1:
        obs = obs.reshape(-1, 1)
    kde = KernelDensity(kernel=kernel, bandwidth=bWidth).fit(obs)
    if x is None:
        x = np.unique(obs).reshape(-1, 1)
    if len(x.shape) == 1:
        x = x.reshape(-1, 1)
    logProb = kde.score_samples(x)  # log(density)
    pdf = pd.Series(np.exp(logProb), index=x.flatten())
    return pdf


def errPDFs(var, eVal, q, bWidth, pts=1000):
    """
    Fit error of Empirical PDF (uses Marchenko-Pastur PDF)

    Input
    ----------
    var: Variance (float)
    eVal: array of eigenvalues
    q: T/N where T is the number of rows and N the number of columns (n_neurons/n_samples)
    bWidth: The bandwidth of the kernel
    pts: int for the number of points used to build PDF
    Returns
    -------
    A float that is the sum squared error

    Source
    ------
    Adapted from https://medium.com/swlh/an-empirical-view-of-marchenko-pastur-theorem-1f564af5603d

    Ref
    ---
    Marchenko and Pastur 1967 - https://www.researchgate.net/publication/303008084_Distribution_of_eigenvalues_for_some_sets_of_random_matrices
    """

    # Fit error
    pdf0 = mpPDF(var, q, pts)  # theoretical pdf
    pdf1 = fitKDE(eVal, bWidth, x=pdf0.index.values)  # empirical pdf
    sse = np.sum((pdf1 - pdf0) ** 2)
    return sse


def findMaxEval(eVal, q, bWidth):
    """
    Finds max random eigenvalue by fitting Marchenko’s dist (i.e) everything else larger than
        this, is a signal eigenvalue

    Input
    ----------
    eVal: array of eigenvalues to fit
    q: T/N where T is the number of rows and N the number of columns (n_neurons/n_samples)
    bWidth: The bandwidth of the kernel

    Returns
    -------
    Tuple containing eMax - (max eigenvalue for random eigenvalues) and var - variance attributed to noise

    Source
    ------
    Adapted from https://medium.com/swlh/an-empirical-view-of-marchenko-pastur-theorem-1f564af5603d

    Ref
    ---
    Marchenko and Pastur 1967 - https://www.researchgate.net/publication/303008084_Distribution_of_eigenvalues_for_some_sets_of_random_matrices
    """

    out = minimize(lambda *x: errPDFs(*x), .5, args=(eVal, q, bWidth),
                   bounds=((1E-5, 1 - 1E-5),))
    if out['success']:
        var = out['x'][0]
    else:
        var = 1
    eMax = var * (1 + (1. / q) ** .5) ** 2
    return eMax, var


### ancillary functions

In [None]:
def normalize(x):
    '''
    normalize a vector x
    
    Input:
    x: numpy.ndarry shape: (n, )
    '''
    return (x-np.min(x))/(np.max(x)-np.min(x))

def first_index(x, y, condition='>', order='forward'): 
    '''
    get the first index that meets the condition
    
    Inputs:
    x: numpy.ndarry shape: (n, )
    y: numpy.ndarry shape: (n, )
    condition: should be one of '>','>=','<','<=','=='
    order: choose from 'forward' and 'backward', the order we look through the array
    '''
    
    if order == 'forward':
        for i in range(len(x)):
            if condition == '>':
                if x[i] > y[i]:
                    return i
            elif condition == '>=':
                if x[i] >= y[i]:
                    return i
            elif condition == '<':
                if x[i] < y[i]:
                    return i
            elif condition == '<=':
                if x[i] <= y[i]:
                    return i
            else:
                if x[i] == y[i]:
                    return i
        return len(x)-1
    else:
        for i in range(len(x)-1, -1, -1):
            if condition == '>':
                if x[i] > y[i]:
                    return i
            elif condition == '>=':
                if x[i] >= y[i]:
                    return i
            elif condition == '<':
                if x[i] < y[i]:
                    return i
            elif condition == '<=':
                if x[i] <= y[i]:
                    return i
            else:
                if x[i] == y[i]:
                    return i
        return 0



### process data

In [None]:
def get_sparse_spike(dffs, nbins=15, std=1):
    '''
    get spike train from dffs.
    detect one spike as long as its dff value exceed 2 std above mean
    
    Inputs:
    dffs: numpy.ndarry cell_num * frames
    nbins: gaussian smooth bins
    std: gaussian smooth std
    
    Return:
    spike: spike train
    '''# dffs: [cell_id, frames]
    dffs = copy.deepcopy(dffs)
#     dffs = gaussian_filter1d(dffs, sigma=std)
    spike = np.where(dffs>(np.mean(dffs, axis=1)+2*np.std(dffs, axis=1)).reshape(dffs.shape[0], 1), 1, 0)
    return spike

def get_rest(velocity, activity, rest_speed=2):  
    '''
    get rest periods v<2cm/s
    
    Input:
    velocity: speed over [frames] 
    activity: dF/F over [cell_id, frames]
    '''
    return activity[:, np.where(velocity < rest_speed)[0]]

def gaussian_smooth(dffs, nbins=15, std=1): 
    '''
    smooth inputs with a gaussian kernel
    
    Inputs:
    dffs: to be smooted, dF/F over [cell_id, frame]
    nbins: length of gaussian kernel
    std: std of gaussian kernel
    
    Return:
    smooth_dffs: dffs after smoothing
    '''
    cell_num = dffs.shape[0]
    bins = int((nbins-1)/2)
    gaussian_win = gaussian(nbins, std=std)
    gaussian_win /= np.sum(gaussian_win)
    smooth_dffs = np.zeros(dffs.shape)
    dffs_com = np.concatenate([np.zeros((cell_num, bins)), dffs, np.zeros((cell_num, bins+1))], axis=1)
    for t in range(dffs.shape[1]):
        smooth_dffs[:, t] = (dffs_com[:, t:t+2*bins+1] @ gaussian_win.reshape(nbins, 1)).reshape(cell_num)
    return smooth_dffs

def pos_dff(dffs, locations, speed, pos_bins=2):
    '''
    get average activity at each position bin (only consider running periods)
    
    Inputs:
    dffs: [cell_id, frame]
    locations: location over [frame]
    speed: speed over [frame]
    pos_bins: bin size of position
    
    '''
    cell_num = dffs.shape[0]
    loc_max = np.max(locations)
    loc_min = np.min(locations)
    bins_n = int((loc_max - loc_min)/pos_bins)+1
    dff_pos = np.zeros((cell_num, bins_n))

    bins_num = np.zeros(bins_n)
    for k in range(dffs.shape[1]):
        if speed[k] >= 2:
            dff_pos[:, int((locations[k]-loc_min)/pos_bins)] += dffs[:, k]
            bins_num[int((locations[k]-loc_min)/pos_bins)] += 1
            bins_num[bins_num==0] = 1
    for cell in range(cell_num):
        dff_pos[cell, :] = dff_pos[cell, :] / bins_num
    return dff_pos


### HSE detection and plots

In [None]:
def detect_HSE(activity, spike, pcs, least_pc_num=5, min_win=150, max_win=1000, 
               bin_size=1000/30, upper_bound=3, lower_bound=1, min_peak_dis=20, mov_win=1000):  
    '''
    detect high synchrony events
    
    Input:
    activity: dF/F traces [cell, frame]
    spike: spike train [cell, spike]
    pcs: id of all place cells
    least_pc_num: the least activated pc numbers if detected as a HSE
    min_win: minimum window size of a HSE (ms)
    max_win: maximum window size of a HSE (ms)
    bin_size: time length of a time bin in activity and spike (ms)
    upper_bound, lower_bound: detect HSE when population activity exceed mean+upper_bound*std, boundary is mean+lower_bound*std
    min_peak_dis: detect two peaks as distinct when their peak distance exceeds min_peak_dis
    mov_win: the length of moving window (bins)
    '''
    population = zscore(np.mean(activity, axis=0))
    mean = np.zeros(population.shape)
    std = np.zeros(population.shape)
    for i in range(population.shape[0]):
        mean[i] = np.mean(population[i-mov_win//2:i+mov_win//2])
        std[i] = np.std(population[i-mov_win//2:i+mov_win//2])

#     plt.plot(population)
#     plt.plot(mean+upper_bound*std)
#     plt.plot(mean+lower_bound*std)
#     plt.plot(mean)
#     plt.show()

    start = None
    peak_all = []
    t = 0
    while t < activity.shape[1]:
        start = first_index(population[t:], mean[t:]+lower_bound*std[t:], '>=') + t
        exceed1 = first_index(population[start:], mean[start:]+upper_bound*std[start:], '>=') + start
        exceed2 = first_index(population[exceed1+1:], mean[exceed1+1:]+upper_bound*std[exceed1+1:], '<=') + exceed1 + 1
        end = first_index(population[exceed2:], mean[exceed2:]+lower_bound*std[exceed2:], '<=') + exceed2
        if exceed1 == exceed2:
            if population[exceed1] >= mean[exceed1]+upper_bound*std[exceed1]:
                peak_all.append(int(exceed1))
        else:
            sign = np.sign(population[exceed1+1: exceed2] - population[exceed1: exceed2-1])
            if len(sign) == 1:
                if population[exceed1+1]>=mean[exceed1+1]+upper_bound*std[exceed1+1]:
                    peak_all.append(int(exceed1)+1)
            else:
                peaks = np.where((sign[1:]-sign[:len(sign)-1])<0)[0] + 1 + exceed1 + 1
                for peak in peaks:
                    if population[peak]>=mean[peak]+upper_bound*std[peak]:
                        peak_all.append(peak)
        t = end + 1
    
    peak_all = np.sort(np.array(list(peak_all), dtype=np.int64))
    peak_all = peak_all[np.where((peak_all[1:]-peak_all[:len(peak_all)-1])>min_peak_dis)[0]+1]
    
    HSE_events = []
    HSE_peaks = []
    max_win_bins = int(max_win/bin_size)+1
    for peak in peak_all:
        assert population[peak] >=  mean[peak]+upper_bound*std[peak]
        start = first_index(population[: peak], mean[:peak]+lower_bound*std[:peak], '<=', order='backward')
        end = first_index(population[peak: ], mean[peak:]+lower_bound*std[peak:], '<=') + peak
        if end - start >= int(min_win/bin_size):
            if np.where(np.mean(spike[pcs[:], start:end], axis=1)>0)[0].shape[0] >= least_pc_num:
                if end - start > max_win_bins:
                    HSE_events.append([peak-max_win_bins//2, peak+max_win_bins//2])
                    HSE_peaks.append(peak)
                else:
                    HSE_events.append([start, end])
                    HSE_peaks.append(peak)

    return HSE_events, HSE_peaks, mean+lower_bound*std, mean+upper_bound*std, max_win_bins  # HSE of place cell


In [None]:
# heatmap
def plot_heatmap(aver_all, normalized, aver_center=None, lower=None, upper=None,
                 xlabel='Frames', ylabel1='Mean dF/F (z-score)', ylabel2='Cell ID', title='Map 1'):
    win_bins = aver_all.shape[0]
    if aver_center is not None:
        win = len(aver_center)
        left = win_bins // 2 - int(win / 2)
        right = win_bins // 2 - (win - int(win / 2))
    fig = plt.figure(figsize=(4, 8))
    gs = fig.add_gridspec(3, 1)
    ax1 = fig.add_subplot(gs[0, 0], )
    ax2 = fig.add_subplot(gs[1:, 0], sharex=ax1)
    ax1.plot(np.arange(win_bins), aver_all)
    if aver_center is not None:
        ax1.plot(np.arange(left, left + win), aver_center)
    if lower is not None:
        ax1.plot(lower)
    if upper is not None:
        ax1.plot(upper)
    im = ax2.imshow(normalized, aspect='auto', interpolation='None', cmap='cividis', )
    ax1.set_title(title, fontsize=20)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.set_ylabel(ylabel1, fontsize=15)
    ax2.set_ylabel(ylabel2, fontsize=15)
    ax2.set_xlabel(xlabel, fontsize=15)
    ax1.tick_params(axis='both', which='major', labelsize=12)
    ax2.tick_params(axis='both', which='major', labelsize=12)
    fig.align_ylabels()
    cbar = plt.colorbar(im, ax=[ax1, ax2], )
    cbar.ax.set_ylabel('Normalized dF/F', fontsize=15)
    plt.show()


def HSE_heatmap(HSE_events, HSE_peaks, activity, lower, upper, gaussian_bins=7, gaussian_std=10,
                win_bins=100, win=30, sort='value', plot_all=False, peaks=None, pcs=None, if_rank=False,
                plot_aver=False,
                if_zscore=False, plot_cells=None, cell_order=None):  # activity: after smooth and zscore
    #     activity = zscore(gaussian_filter(activity, sigma=gaussian_std))
    mean_hse_mean = np.zeros(win_bins)
    #     population = zscore(np.mean(activity, axis=0))
    if not if_zscore:
        population = zscore(np.mean(zscore(activity, axis=1), axis=0))
    else:
        population = zscore(np.mean(activity, axis=0))

    if plot_cells is not None:
        activity = activity[plot_cells[:], :]
    hse_mean = np.zeros((activity.shape[0], win_bins))
    cell_num = activity.shape[0]

    ## HSE_peaks
    num = 0
    good_hse = 0
    for peak in HSE_peaks:
        if activity[:, peak - win_bins // 2:peak + win_bins // 2].shape[1] == win_bins:
            left = peak - win_bins // 2
            right = peak + win_bins // 2
            hse_mean += activity[:, left:right]
            mean_hse_mean += population[left:right]
            num += 1
            normalized = np.zeros(activity[:, left:right].shape)
            for cell in range(cell_num):
                normalized[cell, :] = normalize(activity[cell, left:right])
            if cell_order is None:
                if sort == 'value':
                    # sort by peak value
                    cell_rank = np.argsort(
                        np.max(normalized[:, win_bins // 2 - win // 2:win_bins // 2 + win // 2], axis=1))
                elif sort == 'peak':
                    if peaks is None:
                        peaks = np.argsort(
                            np.max(normalized[:, win_bins // 2 - win // 2:win_bins // 2 + win // 2], axis=1))
                    cell_rank = np.argsort(peaks)
                else:  # sort == 'pc_peak'
                    # pc sort by field location
                    non_pl = np.setdiff1d(np.arange(cell_num), pcs)
                    cell_rank = np.concatenate(
                        [non_pl[np.argsort(peaks[non_pl[:]])[:]], pcs[np.argsort(peaks[pcs[:]])]])
                if if_rank:
                    normalized = normalized[cell_rank[::-1], :]

            else:
                if if_rank:
                    normalized = normalized[cell_order, :]

            if plot_all:
                plot_heatmap(population[left:right], normalized, lower=lower[left:right], upper=upper[left:right],
                             aver_center=population[peak - win // 2:peak + win // 2])
                if good_HSE(population, peak, lower, upper):
                    print('This is a good HSE')
                    good_hse += 1
                else:
                    print('This is not a good HSE')

    hse_mean /= num
    mean_hse_mean /= num
    normalized = np.zeros(hse_mean.shape)
    for cell in range(cell_num):
        normalized[cell, :] = normalize(hse_mean[cell, :])
    if cell_order is None:
        if sort == 'value':
            # sort by peak value
            cell_rank = np.argsort(np.max(normalized[:, win_bins // 2 - win // 2:win_bins // 2 + win // 2], axis=1))
        elif sort == 'peak':
            if peaks is None:
                peaks = np.argmax(normalized[:, win_bins // 2 - win // 2:win_bins // 2 + win // 2], axis=1)
            cell_rank = np.argsort(peaks)
        else:  # sort == 'pc_peak'
            # pc sort by field location
            non_pl = np.setdiff1d(np.arange(cell_num), pcs)
            cell_rank = np.concatenate([non_pl[np.argsort(peaks[non_pl[:]])[:]], pcs[np.argsort(peaks[pcs[:]])]])
    else:
        cell_rank = cell_order
    if if_rank:
        normalized = normalized[cell_rank[::-1], :]

    if plot_aver:
        plot_heatmap(mean_hse_mean, normalized, aver_center=None)

    return [mean_hse_mean, lower, upper, normalized], cell_rank[::-1], good_hse


def plot_pre_behav_post(list1, list2, list3, rank):
    '''
    Plot three plots on pre, hebav and post periods, each consists HSE mean of means, heatmap of HSE mean over [cell, frames]
    
    Input:
    list1, list2, list3: [mean of means, lower_bound, upper_bound, heatmap] of pre, behav, post
    rank: the order of cells
    
    '''
    win_bins = list1[0].shape[0]
    fig = plt.figure(figsize=(12, 8))
    gs = fig.add_gridspec(3, 3)
    ax1 = fig.add_subplot(gs[0, 0], )
    ax2 = fig.add_subplot(gs[1:, 0], sharex=ax1)
    ax3 = fig.add_subplot(gs[0, 1], sharey=ax1)
    ax4 = fig.add_subplot(gs[1:, 1], sharex=ax3)
    ax5 = fig.add_subplot(gs[0, 2], sharey=ax1)
    ax6 = fig.add_subplot(gs[1:, 2], sharex=ax5)

    ax1.plot(np.arange(win_bins), list1[0])
    ax1.plot(np.ones(win_bins) * (list1[1]))
    ax1.plot(np.ones(win_bins) * (list1[2]))
    im = ax2.imshow(list1[3][rank[:], :], aspect='auto', interpolation='None', cmap='cividis', )
    ax1.set_title('Pre', fontsize=20)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.set_ylabel('Mean dF/F (z-score)', fontsize=15)
    ax2.set_ylabel('Cell ID', fontsize=15)
    ax2.set_xlabel('Frames', fontsize=15)
    ax1.tick_params(axis='both', which='major', labelsize=12)
    ax2.tick_params(axis='both', which='major', labelsize=12)
    fig.align_ylabels()

    #     cbar = plt.colorbar(im, ax=[ax1,ax2], )
    #     cbar.ax.set_ylabel('Normalized dF/F',fontsize = 15)

    ax3.plot(np.arange(win_bins), list2[0])
    ax3.plot(np.ones(win_bins) * (list2[1]))
    ax3.plot(np.ones(win_bins) * (list2[2]))
    im = ax4.imshow(list2[3][rank[:], :], aspect='auto', interpolation='None', cmap='cividis', )
    ax3.set_title('Behav', fontsize=20)
    plt.setp(ax3.get_xticklabels(), visible=False)
    ax4.set_xlabel('Frames', fontsize=15)
    ax3.tick_params(axis='both', which='major', labelsize=12)
    ax4.tick_params(axis='both', which='major', labelsize=12)
    plt.setp(ax4.get_yticklabels(), visible=False)

    #     fig.align_ylabels()
    #     cbar = plt.colorbar(im, ax=[ax3,ax4], )
    #     cbar.ax.set_ylabel('Normalized dF/F',fontsize = 15)

    ax5.plot(np.arange(win_bins), list3[0])
    ax5.plot(np.ones(win_bins) * (list3[1]))
    ax5.plot(np.ones(win_bins) * (list3[2]))
    im = ax6.imshow(list3[3][rank[:], :], aspect='auto', interpolation='None', cmap='cividis', )
    ax5.set_title('Post', fontsize=20)
    plt.setp(ax5.get_xticklabels(), visible=False)
    ax6.set_xlabel('Frames', fontsize=15)
    ax5.tick_params(axis='both', which='major', labelsize=12)
    ax6.tick_params(axis='both', which='major', labelsize=12)
    plt.setp(ax6.get_yticklabels(), visible=False)

    #     fig.align_ylabels()
    cbar = plt.colorbar(im, ax=[ax1, ax2, ax3, ax4, ax5, ax6], )
    cbar.ax.set_ylabel('Normalized dF/F', fontsize=15)

    plt.show()
    plt.close()



def HSE_counts_overtime(pre_HSE_peaks, post_HSE_peaks, total_time, bin_size=1800):  
    '''
    Plots HSE number as a function of time window
    
    Inputs:
    pre_HSE_peaks: HSE peaks in pre
    post_HSE_peaks: HSE peaks in post
    total_time: in frames
    bin_size: frame size of one time window
    '''
    # the center/peak of each HSE events
    nbins = total_time // bin_size + 1
    pre = np.zeros(nbins[0])
    post = np.zeros(nbins[1])
    for peak in pre_HSE_peaks:
        pre[peak // bin_size] += 1
    for peak in post_HSE_peaks:
        post[peak // bin_size] += 1
    fig = plt.figure(figsize=(12, 4))
    gs = fig.add_gridspec(1, 1)
    ax1 = fig.add_subplot(gs[0, 0], )
    ax1.plot(np.arange(nbins[0]), pre)
    ax1.plot(np.arange(nbins[0]), np.ones(nbins[0])*np.mean(pre))
    ax1.plot(np.arange(nbins[1]) + nbins[0], post)
    ax1.plot(np.arange(nbins[1]) + nbins[0], np.ones(nbins[1])*np.mean(post))
    #     ax1.set_title('Map 1', fontsize = 20)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.set_ylabel('HSE counts', fontsize=15)
    ax1.xaxis.set_minor_locator(mticker.FixedLocator((nbins[0] // 2, nbins[1] // 2 + nbins[0])))
    ax1.xaxis.set_minor_formatter(mticker.FixedFormatter(("Pre", "Post")))
    plt.setp(ax1.xaxis.get_minorticklabels(), size=20, va="center")
    ax1.tick_params("x", which="minor", pad=25, left=False)
    ax1.tick_params(axis='both', which='major', labelsize=12)
    plt.show()
    plt.close()
    
def plot_PC_corr(pre_dffs, post_dffs, pc_peak, slid_bins=4, dis_bin=2, order='circulate'):
    '''
    plot correlation between place cells' dF/F as a function of peak distance, time periods could be pre/post or HSE
    
    Input: 
    pre_dffs: [cell_id, frame] dF/F of all place cells in pre
    post_dffs: [cell_id, frame] dF/F of all place cells in post
    pc_peak: the peaks of all place cells
    dis_bin: bin size of distance between cell peaks
    slid_bins: bins number of a sliding window
    '''
    # dffs: [cell_id, activity]  pl_peak:[place field center(cm)]  one-to-one
    pc_num = len(pc_peak)
    pre_corr = np.corrcoef(pre_dffs)
    post_corr = np.corrcoef(post_dffs)
    
    if order == 'circulate':
        nbins = int(((np.max(pc_peak)-np.min(pc_peak))//dis_bin+1)//2+1)
    elif order == 'sequence':
        nbins = int(((np.max(pc_peak)-np.min(pc_peak))//dis_bin+1))
    pre_all = [[] for i in range(nbins)]
    post_all = [[] for i in range(nbins)]
    for i in range(len(pc_peak)):
        for j in range(i):
            idx = int(abs(pc_peak[i]-pc_peak[j])//dis_bin)
            if order == 'circulate':
                idx = idx if idx <= nbins-idx else nbins-idx
            pre_all[idx].append(pre_corr[i, j])
            post_all[idx].append(post_corr[i, j])
    
    pre_all = np.array(pre_all, dtype=np.ndarray)
    post_all = np.array(post_all, dtype=np.ndarray)
    for k in range(nbins):
        pre_all[k] = np.array(pre_all[k])
        post_all[k] = np.array(post_all[k])
    pre_func = np.zeros(nbins)
    post_func = np.zeros(nbins)
    pre_bootstrap_l = np.zeros(nbins)
    pre_bootstrap_h = np.zeros(nbins)
    post_bootstrap_l = np.zeros(nbins)
    post_bootstrap_h = np.zeros(nbins)
    #### compute mean with a sliding window on each bin
    #### compute 95% bootstrap
    
    for k in range(slid_bins//2):
        data = []
        for item in pre_all[:k+slid_bins//2]:
            data.extend(item)
        data = np.array(data)
        pre_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        pre_bootstrap_h[k], pre_bootstrap_l[k] = bs.confidence_interval
        
        data = []
        for item in post_all[:k+slid_bins//2]:
            data.extend(item)
        data = np.array(data)
        post_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        post_bootstrap_h[k], post_bootstrap_l[k] = bs.confidence_interval
    
    for k in range(nbins-slid_bins//2, nbins, 1):
        data = []
        for item in pre_all[k-slid_bins//2:]:
            data.extend(item)
        data = np.array(data)
        pre_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        pre_bootstrap_h[k], pre_bootstrap_l[k] = bs.confidence_interval
        
        data = []
        for item in post_all[k-slid_bins//2:]:
            data.extend(item)
        data = np.array(data)
        post_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        post_bootstrap_h[k], post_bootstrap_l[k] = bs.confidence_interval
        
    for k in range(slid_bins//2, nbins-slid_bins//2):
        data = []
        for item in pre_all[k-slid_bins//2:k+slid_bins//2]:
            data.extend(item)
        data = np.array(data)
        pre_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        pre_bootstrap_h[k], pre_bootstrap_l[k] = bs.confidence_interval
        
        data = []
        for item in post_all[k-slid_bins//2:k+slid_bins//2]:
            data.extend(item)
        data = np.array(data)
        post_func[k] = np.mean(data)
        bs = bootstrap((data, ), np.mean, method='percentile', vectorized=False)
        post_bootstrap_h[k], post_bootstrap_l[k] = bs.confidence_interval
        
    pre_func = gaussian_filter1d(pre_func, sigma=1)
    post_func = gaussian_filter1d(post_func, sigma=1)

    fig = plt.figure(figsize = (12,4))
    gs = fig.add_gridspec(1,5)
    ax1 = fig.add_subplot(gs[0, :2],)
    ax2 = fig.add_subplot(gs[0, 3:],)
    ax1.plot(np.arange(0, dis_bin*(pre_func.shape[0]), dis_bin), pre_func, label='Pre')
    ax1.fill_between(x=np.arange(0, dis_bin*pre_func.shape[0], dis_bin), y1=pre_bootstrap_l,
                     y2=pre_bootstrap_h, alpha=0.2)
    ax1.plot(np.arange(0, dis_bin*pre_func.shape[0], dis_bin), post_func, label='Post')
    ax1.fill_between(x=np.arange(0, dis_bin*pre_func.shape[0], dis_bin), y1=post_bootstrap_l,
                     y2=post_bootstrap_h, alpha=0.2)
    ax1.legend()
    ax1.set_xlabel('Run PF \n peak distance(mm)')
    ax1.set_ylabel('Offline Pairwise \n correlation coefficient')
    
    for k in range(len(pre_func)):
        ax2.scatter(pre_func[k], post_func[k], color='blue', s=10)
    ax2.scatter(np.mean(pre_func), np.mean(post_func), color='orange', s=10)
    ax2.plot(np.mean(pre_func)*np.ones(2), [np.mean(post_func)-sem(post_func), np.mean(post_func)+sem(post_func)], color='orange')
    ax2.plot([np.mean(pre_func)-sem(pre_func), np.mean(pre_func)+sem(pre_func)], np.mean(post_func)*np.ones(2), color='orange')
    ax2.plot(np.linspace(min(np.min(pre_func), np.min(post_func)), max(np.max(pre_func), np.max(post_func)), 100), 
             np.linspace(min(np.min(pre_func), np.min(post_func)), max(np.max(pre_func), np.max(post_func)), 100), linestyle='--')
    ax2.set_xlabel('Pre offline synchrony \n (mean corrected coefficient)')
    ax2.set_ylabel('Post offline synchrony \n (mean corrected coefficient)')
    plt.show()

### detect modulated cell in HSE

In [None]:
def test_mod_cell(HSE_peaks, cell_dff, hse_win=30, win_size=100, mov_win=500, exceed_percent=0.05):
    '''
    test if a cell is modulated by HSE or not
    
    Input:
    HSE_peaks: list, peaks of HSEs
    cell_dff: dF/F of the cell to be tested
    hse_win: window length of HSE
    win_size: window length of testing period
    mov_win: compute baseline within the moving window length
    exceed_percent: test a cell as positive/negative if exceed_percent more/less than base line
    
    Output:
    label, baseline: label=1, positive modulated; -1, negative modulated; 0, not modulated
    '''
    hse_means = np.zeros(len(HSE_peaks))
    com_means = np.zeros(len(HSE_peaks))
    com_medians = np.zeros(len(HSE_peaks))
    for t in range(len(HSE_peaks)):
        hse_means[t] = np.mean(cell_dff[HSE_peaks[t]-hse_win//2:HSE_peaks[t]+hse_win//2])
        com_means[t] = np.mean(cell_dff[HSE_peaks[t] - win_size // 2:HSE_peaks[t] - win_size // 2 + hse_win])
        com_medians[t] = np.median(cell_dff[HSE_peaks[t]-hse_win//2:HSE_peaks[t]-hse_win//2])
    
    wil_score = wilcoxon(hse_means, com_means)
    baseline = np.mean(com_means)
    if wil_score[1] > 0.05:
#         print('p value > 0.05')
        return 0, baseline
    elif np.mean(hse_means) > (1 + exceed_percent) * baseline:
#         print('positive')
        return 1, baseline
    elif np.mean(hse_means) < (1 - exceed_percent) * baseline:
#         print('negative')
        return -1, baseline
    else:
        return 0, baseline


def find_mod_cell(HSE_peaks, dff, hse_win=30, win_size=100, title='', plot_all=False, plot_aver=False, exceed_percent=0.05):
    '''
    use test_mod_cell to test all the cells and make plots
    
    Input:
    HSE_peaks: peak of all HSEs
    dff: [cell_id, frame] dF/F
    hse_win: window length of HSEs
    win_size: window length of peri-HSEs
    plot_all: True if make plots of all cells, else False
    plot_aver: True if make plots of average, else False
    
    Output:
    pos_mod_cell: list, cell id of all positive modulated cell
    neg_mod_cell: list, cell id of all negative modulated cell
    '''
    pos_mod_cell = []
    neg_mod_cell = []
    pos_hse_mean = np.zeros((len(HSE_peaks), win_size))
    neg_hse_mean = np.zeros((len(HSE_peaks), win_size))
    
    for cell in range(dff.shape[0]):
        label, baseline = test_mod_cell(HSE_peaks, dff[cell, :], hse_win, win_size, exceed_percent=exceed_percent)
        if label==1:
            pos_mod_cell.append(cell)
#             print('mod cell')
    
        elif label==-1:
            neg_mod_cell.append(cell)
#             print('not mod cell')
        
    for cell in pos_mod_cell:
        cell_hse = np.zeros((len(HSE_peaks), win_size))
        for k in range(len(HSE_peaks)):
            cell_hse[k, :] = normalize(dff[cell, HSE_peaks[k]-win_size//2:HSE_peaks[k]+win_size//2])
        pos_hse_mean += cell_hse
        cell_hse = cell_hse[np.argsort(np.max(cell_hse[:, win_size//2-hse_win//2:win_size//2+hse_win//2], axis=1)), :]
        cell_mean = np.mean(cell_hse, axis=0)
        if plot_all:
            plot_heatmap(cell_mean, cell_hse, ylabel2='HSE number', title=title+'Positive', 
                            lower=baseline*np.ones(len(cell_mean)), upper=baseline*np.ones(len(cell_mean)), 
                            aver_center=cell_mean[win_size//2-hse_win//2:win_size//2+hse_win//2])
        
    for cell in neg_mod_cell:
        cell_hse = np.zeros((len(HSE_peaks), win_size))
        for k in range(len(HSE_peaks)):
            cell_hse[k, :] = normalize(dff[cell, HSE_peaks[k]-win_size//2:HSE_peaks[k]+win_size//2])
        neg_hse_mean += cell_hse
        cell_hse = cell_hse[np.argsort(np.max(cell_hse[:, win_size//2-hse_win//2:win_size//2+hse_win//2], axis=1)), :]
        cell_mean = np.mean(cell_hse, axis=0)
        if plot_all:
            plot_heatmap(cell_mean, cell_hse, ylabel2='HSE number', title=title+'Negitive', 
                            lower=baseline*np.ones(len(cell_mean)), upper=baseline*np.ones(len(cell_mean)), 
                            aver_center=cell_mean[win_size//2-hse_win//2:win_size//2+hse_win//2])
        
    pos_hse_mean /= len(pos_mod_cell)
    neg_hse_mean /= len(neg_mod_cell)
    if plot_aver:
        plot_heatmap(np.mean(pos_hse_mean, axis=0), pos_hse_mean, ylabel2='HSE number', title='Positive')
        plot_heatmap(np.mean(neg_hse_mean, axis=0), neg_hse_mean, ylabel2='HSE number', title='Negative')
    
    
    return pos_mod_cell, neg_mod_cell
    
    
def plot_mod_cell_num(pre_num, post_num):
    '''
    make bar plot of positive and negative modulated cells in pre and post HSEs
    
    Input:
    pre_num: [pos_num, neg_num]
    post_num: [pos_num, neg_num]
    '''
    
    labels = ['Pre', 'Post']
    x = np.arange(1, len(labels)+1)  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, [pre_num[0], post_num[0]], width, label='Positive')
    rects2 = ax.bar(x + width/2, [pre_num[1], post_num[1]], width, label='Negative')

    ax.set_ylabel('Cell counts')
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.xaxis.set_minor_locator(mticker.FixedLocator((1, 2)))
    ax.xaxis.set_minor_formatter(mticker.FixedFormatter(labels))
    plt.setp(ax.xaxis.get_minorticklabels())
    ax.tick_params(axis="x", which="minor")
    ax.legend()

    ax.bar_label(rects1, padding=3)
    ax.bar_label(rects2, padding=3)

    fig.tight_layout()

    plt.show()

#### Tereda's way to detect modulated cell for each HSE

In [None]:
## Tereda's
def get_reactivated_cell(peak, dff, hse_win, win, cell_type=None, pass_corr=0.5): 
    '''
    detect reactivated cell for each HSE
    
    Input:
    peak: peak time of a HSE
    dff: dF/F of one HSE [frames]
    hse_win: window size of a HSE
    win: peri-HSE window size
    
    Return:
    1 for positive modulated cell; -1 for negative modulated cell; 0 for not modulated cell
    '''
    pos_react_cell = []
    neg_react_cell = []
    pca = PCA(n_components=2).fit(copy.deepcopy(dff[:, peak-hse_win//2:peak+hse_win//2]))  # +- 3s
    pca1 = pca.components_[0]
    for cell_id in range(dff.shape[0]):
        cell_dff = copy.deepcopy(dff[cell_id, peak-hse_win//2:peak+hse_win//2])
        if np.isnan(pearsonr(cell_dff, pca1)[0]):
            continue
        if pearsonr(cell_dff, pca1)[0] > pass_corr:
            pos_react_cell.append(cell_id)
        elif pearsonr(cell_dff, pca1)[0] < -pass_corr:
            neg_react_cell.append(cell_id)
    
    pos_react_cell = np.array(pos_react_cell, dtype=np.int64)
    neg_react_cell = np.array(neg_react_cell, dtype=np.int64)
#     As = reactivated_cell[cell_type[reactivated_cell]==1]
#     Xs = reactivated_cell[cell_type[reactivated_cell]==-1]
#     plt.pie([len(As), len(Xs), len(reactivated_cell)-len(As)-len(Xs)], labels=['A', 'X', 'Other'], autopct='%1.2f%%')
#     plt.show()
    
    hse_reac_pos = dff[pos_react_cell, peak-win//2:peak+win//2]
    hse_reac_neg = dff[neg_react_cell, peak-win//2:peak+win//2]
    for cell in range(len(pos_react_cell)):
        hse_reac_pos[cell, :] = normalize(hse_reac_pos[cell, :])
    for cell in range(len(neg_react_cell)):
        hse_reac_neg[cell, :] = normalize(hse_reac_neg[cell, :])
    hse_reac_pos = hse_reac_pos[np.argsort(np.argmax(hse_reac_pos[:, win//2-hse_win//2:win//2+hse_win//2], axis=1)), :]
    hse_reac_neg = hse_reac_neg[np.argsort(np.argmax(hse_reac_neg[:, win//2-hse_win//2:win//2+hse_win//2], axis=1)), :]
    hse_reac = np.concatenate((hse_reac_pos, hse_reac_neg), axis=0)
    plot_heatmap(np.mean(hse_reac, axis=0), hse_reac)
    plot_heatmap(np.mean(hse_reac_pos, axis=0), hse_reac_pos)
    plot_heatmap(np.mean(hse_reac_neg, axis=0), hse_reac_neg)
    
#     hse_reac = dff[As, peak-win//2:peak+win//2]
#     for cell in range(len(As)):
#         hse_reac[cell, :] = normalize(hse_reac[cell, :])
#     rank = np.argsort(np.argmax(hse_reac[:, win//2-hse_win//2:win//2+hse_win//2], axis=1))
#     hse_reac = hse_reac[rank, :]
#     plot_heatmap(np.mean(hse_reac, axis=0), hse_reac)
    
#     hse_reac = dff[Xs, peak-win//2:peak+win//2]
#     for cell in range(len(Xs)):
#         hse_reac[cell, :] = normalize(hse_reac[cell, :])
#     rank = np.argsort(np.argmax(hse_reac[:, win//2-hse_win//2:win//2+hse_win//2], axis=1))
#     hse_reac = hse_reac[rank, :]
#     plot_heatmap(np.mean(hse_reac, axis=0), hse_reac)
    
    return hse_reac_pos, hse_reac_neg#, As, Xs


### reactivation analysis

In [None]:
### Grosmark 

def reactivation_strength(ICA_matrix, activity_matrix):
    '''
    Compute reactivation strength matrix
    
    ICA_matrix: [cell_id, assembly_id] weights
    activity_matrix: [cell_id, frames] dF/F
    '''
    cell_num, assem_num = ICA_matrix.shape
    R = np.zeros((assem_num, activity_matrix.shape[1]))  ## reactivation strength
    for i in range(ICA_matrix.shape[1]): #there is probably a better way to code this but fuck it 
        weighted = (activity_matrix.transpose()*ICA_matrix[:,i]).transpose()
        R[i] = np.sum(weighted, axis=0)**2
    return R


def PCC_score(x, ICA_matrix, activity_matrix):  # xth cell
    '''
    compute PCC score for one cell
    
    Input:
    x: id of cell
    ICA_matrix: [cell_id, assembly_id] weights
    activity_matrix: [cell_id, frames] dF/F
    '''
    assert type(x) == int
    R1 = reactivation_strength(ICA_matrix, activty_matrix)
    R2 = reactivation_strength(ICA_matrix[np.arange(ICA_matrix.shape[0]) != x, :],
                               activity_matrix[np.arange(ICA_matrix.shape[0]) != x, :])
    return np.mean(R1 - R2)


def assembly_activation_strength(ICA_matrix, pre_hse_matrix, post_hse_matrix, hse_win, plot_num=None):
    '''
    compute activation strength for each assembly and make plots
    
    Inputs:
    ICA_matrix: [cell_id, assembly_id] weights
    pre_hse_matrix: [cell_id, frames] dF/F of all HSEs in pre
    post_hse_matrix: [cell_id, frames] dF/F of all HSEs in post
    hse_win: window length of HSE
    plot_num: if int, make the plot of plot_num(th) assembly; if 'plot_average', make the average plot of all assemblies
    '''
    # hse_matrix: 1s length
    pre_assembly_strength = zscore(reactivation_strength(ICA_matrix, pre_hse_matrix), axis=1)

    post_assembly_strength = zscore(reactivation_strength(ICA_matrix, post_hse_matrix), axis=1)

    peri_win = pre_hse_matrix.shape[1]

    if type(plot_num) == int:
        fig = plt.figure(figsize=(12, 4))
        gs = fig.add_gridspec(1, 2)
        ax1 = fig.add_subplot(gs[0, 0], )
        pre_strength_smooth = gaussian_filter1d(pre_assembly_strength[plot_num, :], sigma=5)
        post_strength_smooth = gaussian_filter1d(post_assembly_strength[plot_num, :], sigma=5)
        ax1.plot(pre_strength_smooth, label='Pre')
        ax1.plot(post_strength_smooth, label='Post')
        ax1.set_title('Rum ensemble ' + str(plot_num), fontsize=20)
        plt.setp(ax1.get_xticklabels(), visible=False)
        ax1.set_ylabel('Run ensemble reactivation', fontsize=15)
        ax1.set_xlabel('Peri-HSE time(s)')
        ax1.tick_params(axis='y', which='major', labelsize=12)
        ax1.legend()
        #         ax1.xaxis.set_minor_locator(mticker.FixedLocator((0, post_hse_matrix.shape[1]//2, post_hse_matrix.shape[1])))
        #         ax1.xaxis.set_minor_formatter(mticker.FixedFormatter((-0.5, 0, 0.5)))
        #         plt.setp(ax1.xaxis.get_minorticklabels(), size=20, va="center")
        #         ax1.tick_params("x",which="minor",pad=25, left=False)

        ax2 = fig.add_subplot(gs[0, 1], )
        ax2.scatter(np.max(pre_strength_smooth[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2]),
                    np.max(post_strength_smooth[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2]),
                    marker='+')
        ax2.plot(np.arange(0, 2, 0.1), np.arange(0, 2, 0.1), linestyle='--')
        ax2.set_ylabel('Post within-HSE \n run ensemble reactivation', fontsize=15)
        ax2.tick_params(axis='both', which='major', labelsize=12)
        ax2.set_xlabel('Pre within-HSE \n run ensemble reactivation', fontsize=15)

        plt.show()
    elif plot_num == 'plot average':
        fig = plt.figure(figsize=(12, 4))
        gs = fig.add_gridspec(1, 2)
        ax1 = fig.add_subplot(gs[0, 0], )
        pre_population = gaussian_filter1d(np.mean(pre_assembly_strength, axis=0), sigma=5)
        post_population = gaussian_filter1d(np.mean(post_assembly_strength, axis=0), sigma=5)
        ax1.plot(np.arange(peri_win), pre_population, label='Pre')
        ax1.plot(np.arange(peri_win), post_population, label='Post')
#         ax1.plot(np.arange(peri_win // 2 - hse_win // 2, peri_win // 2 + hse_win // 2),
#                  pre_population[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2])
#         ax1.plot(np.arange(peri_win // 2 - hse_win // 2, peri_win // 2 + hse_win // 2),
#                  post_population[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2])
        ax1.set_title('Rum ensemble ' + str(plot_num), fontsize=20)
        plt.setp(ax1.get_xticklabels(), visible=False)
        ax1.legend()
        ax1.set_ylabel('Run ensemble reactivation', fontsize=15)
        ax1.tick_params(axis='y', which='major', labelsize=12)
        ax1.legend()
        #         ax1.xaxis.set_minor_locator(mticker.FixedLocator((0, post_hse_matrix.shape[1]//2, post_hse_matrix.shape[1])))
        #         ax1.xaxis.set_minor_formatter(mticker.FixedFormatter((-0.5, 0, 0.5)))
        #         plt.setp(ax1.xaxis.get_minorticklabels(), size=20, va="center")
        #         ax1.tick_params("x",which="minor",pad=50, left=False)

        ax2 = fig.add_subplot(gs[0, 1], )
        ax2.scatter(np.max(pre_population[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2]),
                    np.max(post_population[peri_win // 2 - hse_win // 2:peri_win // 2 + hse_win // 2]), marker='+')
        ax2.plot(np.arange(0, 2, 0.1), np.arange(0, 2, 0.1), linestyle='--')
        ax2.set_ylabel('Post within-HSE \n run ensemble reactivation', fontsize=15)
        ax2.tick_params(axis='both', which='major', labelsize=12)
        ax2.set_xlabel('Pre within-HSE \n run ensemble reactivation', fontsize=15)
        plt.show()


### Assembly analysis (Edgar's code)

In [None]:
def EB_assembly_analysis(data):
    # How to use

    #1) Create spike matrix 

    bin_size = 0.025 #in seconds 

    # bin_size=0.025
    matrix = binning(data, 25) #deconvolved data #this is the normal spike matrix (np.array with shape(neurons,time bins))
    print('binning finished')
    matrix = gaussian_filter1d(matrix, sigma=1)
    print('Smoothing finished')
    #we can consider smoothing it 
#     s_matrix = np.zeros_like(matrix)
#     for i in range(matrix.shape[0]): 
#         s_matrix[i] = gaussian_filter1d(matrix[i], sigma = 1)
    #if you'd like to use the smoothed one just change matrix to s_matrix (or assign matrix as s_matrix)

    #Next z-score such that we have null-mean and unity-variance 
    z_matrix = stats.zscore(matrix, axis=1) #this is the z-scored spike matrix (z-scoring is on full matrix on purpose! - relative to baseline would require different math)
    print('zscore finished')
    covariance_matrix = np.cov(z_matrix) #this is the covariance matrix


    #2) Find number of cell assemblies using Marchenko Pastur Theorem

    # Getting Eigenvalues and Eigenvectors
    eig_vals, eig_vecs = np.linalg.eig(covariance_matrix) ####
    print('Got eigenvalues and eigenvectors')
    #determining q -> Ncol/NRow of spike matrix for estimating Randomness using Marchenko-Pastur
    q = z_matrix.shape[1]/z_matrix.shape[0]

    # Getting Max Eigenvalues and calculating variance attributed to noise
    eMax0, var0 = findMaxEval(eig_vals, q, bWidth=0.01)

    #determining number of significant assemblies 
    sig_eigs = eig_vals[np.where(eig_vals > eMax0)]
    print('Got number of significant assemblies')
    
    #3) Extraction of cell assembly patterns and estimation of cell assembly activity

    X = z_matrix #just assigning matrix to new variable in case ICA transforms the array 

    #Lopes dos Santos et al. 2013 says to first reduce PCA and run ICA in that reduced space.
    #I tried that and it gives essentially the same result but sometimes inversts the sign of the weight. 
    #I suspect that the sklearn ICA function already factors this in when you give it the number of components to pull out. 
    #So I just run ICA straight as it seems to give me reasonable results

    # #uncomment if like to run PCA first 
    # pca = PCA(n_components=len(sig_eigs), svd_solver='arpack', random_state=101)
    # X = pca.fit_transform(X)

    #we run fastICA (Hyvarinen and Oja 2000 - implemented by sklearn A. Hyvarinen and E. Oja, Independent Component Analysis: Algorithms and Applications, Neural Networks, 13(4-5), 2000, pp. 411-430.) 
    #this attributes wieghts to each neuron for the number of determined cell assemblies 
    transformer = FastICA(n_components=len(sig_eigs), whiten=True, random_state=101) 
    X_transformed = transformer.fit_transform(X) 
    print('fastICA finished')

#     #Next we calculate the assembly strength by multiplying the spike matrix by each neuron's ICA weight for each assembly pattern
#     #We then sum these products to obtain the time resolved activation strength for each assembly
#     assembly_strength = np.zeros((X_transformed.shape[1], z_matrix.shape[1])) #this will be time resolved assembly strength (np.array with shape(assemblies, time_bins))
#     for i in range(X_transformed.shape[1]): #there is probably a better way to code this but fuck it 
#         weighted = (z_matrix.transpose()*X_transformed[:,i]).transpose()
#         assembly_strength[i] = np.sum(weighted, axis=0)
    
    assembly_strength = reactivation_strength(X_transformed, z_matrix)
    print('Computed assembly strength')
    
    fig = plt.figure(figsize=(12, 8))
    gs = fig.add_gridspec(2, 2)
    ax1 = fig.add_subplot(gs[0, 0], )
    ax2 = fig.add_subplot(gs[0, 1], )
    ax3 = fig.add_subplot(gs[1, 0], )
    ax4 = fig.add_subplot(gs[1, 1], )

    show_z = np.zeros(z_matrix.shape)
    for i in range(show_z.shape[0]):
        show_z[i, :] = normalize(z_matrix[i, :])
    im = ax1.imshow(show_z, aspect='auto', interpolation='None', cmap='cividis')
    ax1.set_title('Z-scored Spike Matrix' , fontsize=20)
    ax1.set_ylabel('Neuron', fontsize=15)
    ax1.set_xlabel('Time(0.025s bins)', fontsize=15)
    cbar = plt.colorbar(im, ax=ax1, )
    cbar.ax.set_ylabel('Nomalized(Z score)', fontsize=15)

    im = ax2.imshow(covariance_matrix, aspect='auto', interpolation='None', cmap='cividis')
    ax2.set_title('Covariance Matrix' , fontsize=20)
    ax2.set_ylabel('Neuron', fontsize=15)
    ax2.set_xlabel('Neuron', fontsize=15)
    cbar = plt.colorbar(im, ax=ax2, )
    cbar.ax.set_ylabel('Covariance', fontsize=15)

    x = np.arange(0, np.max(eig_vals)+1, 0.2)
    eigs = np.zeros(x.shape)
    for k in eig_vals:
        eigs[int(k/0.2)] += 1
    eigs /= np.sum(eigs)
    ax3.bar(x, eigs, color='orange', label='empirical data')
    ax3.plot(np.ones(100)*(1+np.sqrt(1/q))**2, np.linspace(0, 0.3, 100), linestyle='--', label='max lamda')
    ax3.legend()
    ax3.set_title('Marchenko-Pastur Theorem' , fontsize=20)
    ax3.set_xlabel('\lambda', fontsize=15)
    ax3.set_ylabel('prob[\lambda]', fontsize=15)
    
    im = ax4.imshow(X_transformed.T, aspect='auto', interpolation='None', cmap='cividis')
    ax4.set_title('Assembly Contribution' , fontsize=20)
    ax4.set_xlabel('Neuron', fontsize=15)
    ax4.set_ylabel('Assembly', fontsize=15)
    cbar = plt.colorbar(im, ax=ax4, )
    cbar.ax.set_ylabel('ICA weight', fontsize=15)
    
    
    ## spike train from Spks_behav
    ## Spks_behav

    #plot assembly 

    one = 0
    two = 1
    three = 2

    window = np.argmax(assembly_strength[0,:])
    size = 100

    df = pd.DataFrame(X_transformed)
    res = pd.DataFrame(df.values.argsort(0), columns=df.columns)\
            .iloc[len(df.index): -4: -1]
    color_seq = [0]*len(matrix)
    for udx, u in enumerate(matrix): 
        if udx in np.array(res.iloc[:,one]): 
            color_seq[udx] = 'C0'
        elif udx in np.array(res.iloc[:,two]): 
            color_seq[udx] = 'C1'
        elif udx in np.array(res.iloc[:,three]): 
            color_seq[udx] = 'C2'
        else:
            color_seq[udx] = 'gray'
        
    fig = plt.figure(figsize = (8,8))

    gs = fig.add_gridspec(3,3)
    ax2 = fig.add_subplot(gs[2, :2],)
    ax1 = fig.add_subplot(gs[:2, :2], sharex=ax2)

    plt.subplots_adjust(wspace=0.3, hspace=0.3)


    ax5 = fig.add_subplot(gs[2, 2],)
    ax3 = fig.add_subplot(gs[0, 2],sharex = ax5, sharey=ax5)
    ax4 = fig.add_subplot(gs[1, 2],sharex = ax5, sharey=ax5)
    



    for udx, u in enumerate(matrix):
        ax1.scatter([idx for idx, i in enumerate(matrix[udx,window-size:window+size]) if i > 0], [udx]*len([idx for idx, i in enumerate(matrix[udx,window-size:window+size]) if i > 0]), s = 3, color = color_seq[udx])

    ax2.plot(assembly_strength[one,window-size:window+size])
    ax2.plot(assembly_strength[two,window-size:window+size])
    ax2.plot(assembly_strength[three,window-size:window+size])

    ax3.hlines(y=[i+1 for i in range(len(X_transformed[:,one]))], xmin=0, xmax=[i for i in X_transformed[:,one]], color='C0', alpha = 0.7)
    ax3.plot([i for i in X_transformed[:,one]], range(len(X_transformed)) , "o", color = 'C0', markersize = 3)

    ax4.hlines(y=[i+1 for i in range(len(X_transformed[:,two]))], xmin=0, xmax=[i for i in X_transformed[:,two]], color='C1', alpha = 0.7)
    ax4.plot([i for i in X_transformed[:,two]], range(len(X_transformed)), "o", color = 'C1', markersize = 3)

    ax5.hlines(y=[i+1 for i in range(len(X_transformed[:,three]))], xmin=0, xmax=[i for i in X_transformed[:,three]], color='C2', alpha = 0.7)
    ax5.plot([i for i in X_transformed[:,three]], range(len(X_transformed)), "o", color = 'C2', markersize = 3)

    ax5.set_xlim(-1,1)

    # ax5.set_xticklabels(np.arange(-1,1.5,0.5))

    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax3.get_xticklabels(), visible=False)
    plt.setp(ax4.get_xticklabels(), visible=False)

    ax5.set_yticks(np.arange(0,25,5))
    ax1.set_yticks(np.arange(0,25,5))


    ax1.set_ylabel('Neuron', fontsize = 10)
    ax2.set_ylabel('Assembly \n Strength (z)', fontsize = 10)
    ax2.set_xlabel('Time ({} s bins)'.format(bin_size), fontsize = 10)

    ax3.set_ylabel('Neuron', fontsize = 10)
    ax4.set_ylabel('Neuron', fontsize = 10)
    ax5.set_ylabel('Neuron', fontsize = 10)


    ax3.set_title('Assembly 1', fontsize = 10)
    ax4.set_title('Assembly 2', fontsize = 10)
    ax5.set_title('Assembly 3', fontsize = 10)

    plt.show()
    
    return X_transformed, assembly_strength

def binning(data, nbins):
    '''
    bin data
    data:[cells, activity]
    '''
    bin_num = int(data.shape[1]/nbins)
    binned = np.zeros((data.shape[0], bin_num))
    for k in range(bin_num):
        binned[:, k] = np.sum(data[:, k*nbins:(k+1)*nbins], axis=1)
    return binned

In [None]:
## loading data

In [None]:
# get spike train, smoothing and zscoring
pre_spike_train = get_sparse_spike(Spks_pre, nbins=15, std=2)
# pre_smooth = gaussian_smooth(FNc_pre, nbins=15, std=1)
pre_smooth = gaussian_filter1d(Spks_pre, sigma=2)
pre_zscore = zscore(pre_smooth, axis=1)

post_spike_train = get_sparse_spike(Spks_post, nbins=15, std=2)
# post_smooth = gaussian_smooth(FNc_post, nbins=15, std=1)
post_smooth = gaussian_filter1d(Spks_post, sigma=2)
post_zscore = zscore(post_smooth, axis=1)


# rest = get_rest(, Spks_behave)
# rest_spike_train = get_sparse_spike(rest, nbins=15, std=1)
# rest_smooth = gaussian_smooth(rest, nbins=15, std=1)
# rest_zscore = zscore(rest_smooth, axis=1)

In [None]:
# get HSE events
pre_HSE, pre_HSE_peaks, pre_lower, pre_upper, pre_max_bins = detect_HSE(pre_zscore, pre_spike_train, pcs, upper_bound=3)
post_HSE, post_HSE_peaks, post_lower, post_upper, post_max_bins = detect_HSE(post_zscore, post_spike_train, pcs, upper_bound=3)
# rest_HSE_events, rest_HSE_peaks, rest_lower, rest_upper = detect_HSE(rest_zscore, rest_spike_train, pcs)

In [None]:
pre_hse_dff = None
post_hse_dff = None

for event in pre_HSE:
    if pre_hse_dff is None:
        pre_hse_dff = pre_zscore[:, event[0]:event[1]]
    else:
        pre_hse_dff = np.concatenate((pre_hse_dff, pre_zscore[:, event[0]:event[1]]), axis=1)

for event in post_HSE:
    if post_hse_dff is None:
        post_hse_dff = post_zscore[:, event[0]:event[1]]
    else:
        post_hse_dff = np.concatenate((post_hse_dff, post_zscore[:, event[0]:event[1]]), axis=1)

In [None]:
plot_PC_corr(pre_zscore[pcs, :], post_zscore[pcs, :], PeakMap1[pcs], dis_bin=2)

In [None]:
# sort by peak location
list_pre, rank = HSE_heatmap(pre_HSE, pre_HSE_peaks, pre_zscore, pre_lower, pre_upper, sort='peak', peaks=PeakMap1, pcs=pcs, win_bins=500, plot_aver=False, plot_all=True, if_rank=True)
list_post = HSE_heatmap(post_HSE, post_HSE_peaks, post_zscore, post_lower, post_upper, sort='peak', peaks=PeakMap1, win_bins=500, plot_aver=False, plot_all=True, if_rank=True)[0]

In [None]:
HSE_counts_overtime(pre_HSE_peaks, post_HSE_peaks, np.array([pre_zscore.shape[1], post_zscore.shape[1]]))

In [None]:
pre_mod_cell, pre_not_mod_cell = find_mod_cell(pre_HSE_peaks, pre_zscore, title='Pre ', win_size=300, 
                                               plot_aver=False, plot_all=True, exceed_percent=0.5)
post_mod_cell, post_not_mod_cell = find_mod_cell(post_HSE_peaks, post_zscore, title='Post ', win_size=300, 
                                                 plot_aver=False, plot_all=True, exceed_percent=0.5)

In [None]:
pre_mod_cell, pre_not_mod_cell = find_mod_cell(pre_HSE_peaks, pre_zscore, title='Pre ', win_size=300, 
                                               plot_aver=True, plot_all=False, exceed_percent=0.5)
post_mod_cell, post_not_mod_cell = find_mod_cell(post_HSE_peaks, post_zscore, title='Post ', win_size=300, 
                                                 plot_aver=True, plot_all=False, exceed_percent=0.5)

In [None]:
plot_mod_cell_num([len(pre_mod_cell), len(pre_not_mod_cell)], [len(post_mod_cell), len(post_not_mod_cell)])
# pre+, pre-, post+, post-

In [None]:
for peak in pre_HSE_peaks:
    get_reactivated_cell(peak, pre_zscore, 30, 200)

In [None]:
ICA_weight = EB_assembly_analysis(Spks_behav)

In [None]:
for i in range(ICA_weight.shape[1]):
    assembly_activation_strength(ICA_weight, list_pre[-1], list_post[-1], hse_win=30, plot_num=i) 