In [4]:
import pandas as pd; pd.set_option('display.max_columns', 30); pd.set_option('display.max_rows', 100)
import numpy as np
from cmlreaders import CMLReader, get_data_index
from ptsa.data.filters import ButterworthFilter, ResampleFilter, MorletWaveletFilter
import xarray as xarray
import sys
import os
import matplotlib.pyplot as plt
from pylab import *
from copy import copy
from scipy import stats
from scipy.stats import zscore
import seaborn as sns       
import pickle
plt.rcParams['pdf.fonttype'] = 42; plt.rcParams['ps.fonttype'] = 42 # fix fonts for Illustrator
sys.path.append('/home1/john/johnModules')
from brain_labels import HPC_labels, ENT_labels, PHC_labels, temporal_lobe_labels,\
                         MFG_labels, IFG_labels, nonHPC_MTL_labels
from general import *
from SWRmodule import *
import statsmodels.formula.api as smf
from ripples_HFA_SME import ripple_analysis_SME
from ripples_HFA_analysis import ripple_HFA_analysis

In [5]:
class ripple_analysis_SCE(ripple_analysis_SME):
    
    def __init__(self, exp, df, sub_selection, sr_factor=2, ripple_bin_start_end=[100,1700], HFA_bins=[400,1100], 
                 ripple_sampling_time=2.0, pre_encoding_time=-700, encoding_time=2300, bin_size=100,
                smoothing_triangle=5, samples=100, regions_selected=['ca1', 'dg']):

        super().__init__(exp, df, sub_selection, ripple_bin_start_end, HFA_bins, ripple_sampling_time
                            , pre_encoding_time, encoding_time, bin_size,
                            smoothing_triangle, samples, regions_selected)
        
        self.psth_start = self.pre_encoding_time
        self.sr_factor = sr_factor
        
    def remove_subject_sessions(self):
        
        super().remove_subject_sessions()
    
    def load_data_from_cluster(self, selected_period, region_name='HPC'):
        
        super().load_data_from_cluster(selected_period, region_name)
        
    def getStartArray(self):
        
        super().getStartArray()
        
    def semantic_clustering(self):
        
        self.num_sessions = 0 
        
        self.counter_delta = 0 
        
        # select which serialpositions you're looking at (since curious if 1-6 show all the SCE)
        serialpos_select = np.arange(1,13) 

        # EF1208, what is this?
        remove_chaining = 0 # 2022-07-19 trying a control to see if SCE still exists after removing recalls that begin with SP 1+2 in a row
        
        # these values are all for subject-level SCE v. avg_recalls analysis
        if self.sub_selection == 'whole':
            min_SCE_trials = 20 # minimum SCE trials in session to include in SCE v. avg_recalls plot
        else:
            min_SCE_trials = 10
            
        stats_bin = self.ripple_bin_start_end[1]-self.ripple_bin_start_end[0] # only using 1 bin for encoding 

        self.adj_semantic_encoding_array = []
        self.rem_semantic_encoding_array = []
        self.rem_unclustered_encoding_array = []
        self.last_recall_encoding_array = [] # the last word remembered on each list (no transitions)...but make sure it's not an intrusion or repeat too!
        self.forgot_encoding_array = []
        self.sub_name_array0 = []; self.sess_name_array0 = []; self.elec_name_array0 = []
        self.sub_name_array1 = []; self.sess_name_array1 = []; self.elec_name_array1 = []
        self.sub_name_array2 = []; self.sess_name_array2 = []; self.elec_name_array2 = []
        self.sub_name_array3 = []; self.sess_name_array3 = []; self.elec_name_array3 = []
        self.sess_name_array4 = [] # forgot why I keep the others but leaving them 2022-06-10
        self.sub_name_array5 = []; self.sess_name_array5 = []; self.elec_name_array5 = []

        # for clustered v. unclustered subject-level analysis (need to record at session-level though for mixed model)
        self.sess_sessions = []
        self.sess_delta = []
        self.sess_subjects = []
        self.sess_recall_num = []
        self.sess_clust_num = []
        self.sess_prop_semantic = []
        
        session_names = np.unique(self.session_name_array)
        
        # EF1208, converting to numpy for parallel indexing 
        for sess in session_names:
            
            self.num_sessions += 1
            
            # for each session will get a clustered and unclustered start array to get the delta SCE
            start_arrayC = []; start_arrayU = []
            
            # and also the proportion of semantically clustered recalls
            temp_corr = []; temp_sem_key = []
            
            # select list numbers within the session, either 12 or 25
            sess_list_nums = np.unique(self.list_num_key_np[self.session_names_np==sess]) 
             
            for ln in sess_list_nums:
                
                list_elec_array = np.unique(self.electrode_array_np[(self.session_names_np==sess) & (self.list_num_key_np==ln)])
                
                for elec in list_elec_array:
                
                    # only select data from a given session, given list number, and given electrode 
                    # list_ch_idxs is a boolean array, with the number of True elements equal to the list length (12)
                    list_ch_idxs = (self.session_names_np==sess) & (self.list_num_key_np==ln) & (self.electrode_array_np==elec)
                    list_ch_encoding_array = self.start_array_np[list_ch_idxs]  # start time of ripples
                    list_ch_cats = self.cat_array_np[list_ch_idxs] # semantic category of presented words
                    list_ch_corr = self.word_correct_array_np[list_ch_idxs] # binary array, whether or not the word was correctly recalled 
                    list_ch_semantic_key = self.semantic_array_np[list_ch_idxs] # list of containing semantic category of recalled items
                    list_ch_recall_positions = self.recall_position_np[list_ch_idxs] # list containing encoded position of recalled words
                    
                    # remove ones starting with serialpos 1->2 as a control (or just 1 if it's len 1)
                    # EF1208, why are we doing this?
                    if remove_chaining == 1:
                        if len(list_ch_recall_positions[0])==1:
                            if list_ch_recall_positions[0][0]==1: # if 1st serialpos
                                continue # get out of this loop if only one recall and it's serialpos 1
                        elif len(list_ch_recall_positions[0])>0:
                            if ((list_ch_recall_positions[0][0]==1)&(list_ch_recall_positions[0][1]==2)):     
                                continue # get out of loop if recalls are serialpos 1->2 (no matter what)            

                    for i_recall_type, recall_type in enumerate(list_ch_semantic_key[0]): # all 12 lists have same values so just take 1st one
                        recall_position = list_ch_recall_positions[0][i_recall_type] # ditto re: taking 1st
                        if recall_position in serialpos_select: # so can select by serialpos (e.g. 1:6 or 7:12)
                            if recall_type == 'A': # adjacent semantic and adjacent in time 
        #                     if recall_type in ['A','C']: # adjacent AND remote semantic
                                # note the -1 since recall positions are on scale of 1-12
                                self.adj_semantic_encoding_array = superVstack(self.adj_semantic_encoding_array,list_ch_encoding_array[recall_position-1])
                                self.sub_name_array0.append(sess[0:6])
                                self.sess_name_array0.append(sess)
                                self.elec_name_array0.append(elec)
                            elif recall_type == 'C': # remote semantic, remote in time but from the same semantic category 
                                self.rem_semantic_encoding_array = superVstack(self.rem_semantic_encoding_array,list_ch_encoding_array[recall_position-1])
                                self.sub_name_array1.append(sess[0:6])
                                self.sess_name_array1.append(sess)
                                self.elec_name_array1.append(elec)
                            elif ( (recall_type == 'D') ): # & (recall_position>0) ): # remote unclustered
                                self.rem_unclustered_encoding_array = superVstack(self.rem_unclustered_encoding_array,list_ch_encoding_array[recall_position-1])
                                self.sub_name_array2.append(sess[0:6])
                                self.sess_name_array2.append(sess)  
                                self.elec_name_array2.append(elec)
                            elif ( (recall_type == 'Z') ): #& (recall_position>0) ): # last word of list & was actually a recalled word
                                self.last_recall_encoding_array = superVstack(self.last_recall_encoding_array,list_ch_encoding_array[recall_position-1])
                                self.sub_name_array3.append(sess[0:6])
                                self.sess_name_array3.append(sess)
                                self.elec_name_array3.append(elec)
                            else:
                                self.sess_name_array4.append(sess[0:6])
                                
                        # I also want to do an SCE v. avg_recall_num by subject analysis
                        # to do this will combine A/C as SCE group and combine D/Z as control  
                        
                        # Creating clustered vs unclustered conditioned 
                        # start_arracyC will be of shape N x ripple_start:ripple_end, where N is the number of clustered recalls
                        # and ripple_start:ripple_end is the timestep range of interest for ripples
                        if recall_position in serialpos_select: # so can select by serialpos (e.g. 1:6 or 7:12)
                            if recall_type in ['A','C']: # adjacent semantic or remote semantic
                                # note the -1 since recall positions are on scale of 1-12
                                start_arrayC = superVstack(start_arrayC,list_ch_encoding_array[recall_position-1])
                            elif ( (recall_type in ['D','Z']) & (recall_position>0) ): # remote unclustered or dead end (>0 means recalled word)
                                start_arrayU = superVstack(start_arrayU,list_ch_encoding_array[recall_position-1])
                                
                                
                    # unpack semantic clustering key to trial level (only need to do once for one electrode)
                    if elec == list_elec_array[0]:
                        for word in range(sum(list_ch_idxs)): 
                            if (word+1) in list_ch_recall_positions[0]: # serial positions are 1-indexed so add 1 to check in list_ch_recall_positions
                                temp_corr.append(1)
                                # use index from serialpos to get clustering classification
                                if ((sess== 'R1108J-2')&(ln==25)): # single mistake shwos up
                                    if word == 8:
                                        temp_sem_key.append('A')
                                    elif word == 9:
                                        temp_sem_key.append('Z')
                                else: 
                                    temp_sem_key.append(list_ch_semantic_key[0][list_ch_recall_positions[0].index(word+1)])
                            else:
                                temp_corr.append(0)
                                temp_sem_key.append('')               
                                
                    # make forgotten array to plot along with SCE too which is easy enough 
                    forgotten_words = 1-np.array(list_ch_corr)
                    if sum(forgotten_words)>0: # R1065 a whiz
                        self.forgot_encoding_array = superVstack(self.forgot_encoding_array,np.array(list_ch_encoding_array)[findInd(forgotten_words),:])
                        self.sub_name_array5.extend(np.tile(sess[0:6],int(sum(forgotten_words))))
                        self.sess_name_array5.extend(np.tile(sess,int(sum(forgotten_words))))
                        self.elec_name_array5.extend(np.tile(elec,int(sum(forgotten_words))))
                        
            self.start_arrayC = start_arrayC
            self.start_arrayU = start_arrayU

            if ( (len(start_arrayC)>min_SCE_trials) & (len(start_arrayU)>min_SCE_trials) & (len(start_arrayC)!=1500) ): # last one in there for a len(1) start_arrayC

                # back at session-level record the delta, sub, sess, and avg_recall_num for *all* trials
                self.sess_sessions.append(sess)
                self.sess_subjects.append(sess[0:6])   

                # can just use list_elec_array to select only one electrode we know exists for this session (altho should be irrelevant when we average anyyway)
                sess_word_correct_array = self.word_correct_array[((self.electrode_array==list_elec_array[0]) & (self.session_name_array==sess))]
                self.sess_recall_num.append(12*sum(sess_word_correct_array)/len(sess_word_correct_array)) # *12 to convert to recall_num/list
                
                # while temp_sem_key is NOT in the same order as sess_word_correct_array, can use it anyway to get proportion of clustered recalls
        #         if sum(temp_corr) == sum(sess_word_correct_array):
                self.sess_prop_semantic.append(sum([trial in ['A','C'] for trial in temp_sem_key])/sum(temp_corr))
        #         else:
        #             print('temp_sem_key must be messed up for '+sess+'!!')
        #             print(sum(temp_corr))
        #             print(sum(sess_word_correct_array))

                # calculate single delta ripples value for clust v. not clust for this session
                
                # obtain ripple start times in desired time range
                # and create histogram for ripples associated with clustered + 
                # unclustered words
                self.start_arrayC = start_arrayC
                self.start_arrayU = start_arrayU
                self.binned_stats_arrayC = binBinaryArray(start_arrayC,stats_bin,self.sr_factor)
                self.binned_stats_arrayU = binBinaryArray(start_arrayU,stats_bin,self.sr_factor)
                delta = np.mean(self.binned_stats_arrayC)-np.mean(self.binned_stats_arrayU)
                if delta == 0:
                    print(sess)
                    print(np.sum(list_ch_encoding_array))
                self.sess_delta.append(np.mean(self.binned_stats_arrayC)-np.mean(self.binned_stats_arrayU)) # the diff. in ripple rate for clustered and unclustered recalls
                
        trial_nums = [len(self.sub_name_array0),len(self.sub_name_array1),len(self.sub_name_array2),len(self.sub_name_array3),len(self.sess_name_array4)]
        

In [3]:
# init variables
df = get_data_index("r1") # all RAM subjects
exp = 'catFR1' # 'FR1' 'catFR1' 'RepFR1'
ripple_bin_start_end = [100, 1700]
region_name = 'AMY'
regions_selected = ['ca1']
rs_str = ''
num_regions = len(regions_selected)
for i, region in enumerate(regions_selected):
    rs_str += region
    if i + 1 < len(regions_selected):
        rs_str += '_'         
SCE = ripple_analysis_SCE(exp=exp, df=df, sub_selection='first half', ripple_bin_start_end=ripple_bin_start_end, regions_selected=regions_selected)
SCE.remove_subject_sessions()
SCE.load_data_from_cluster(selected_period='encoding', region_name=region_name)

  ((self.df.subject!='R1239E') | (self.df.session!=0)) # some correlated noise (can see in catFR1 problem sessions ppt)


catFR1
2023-04-27_19-59-25: DF Exception: Sub: R1004D, Sess: 0, FileNotFoundError, [Errno 2] No such file or directory: '/scratch/john/[100, 1700]/ENCODING/SWR_catFR1_R1004D_0_HPC_encoding_soz_in_hamming.p', file: ripples_HFA_analysis.py, line no: 167
2023-04-27_19-59-25: DF Exception: Sub: R1024E, Sess: 0, FileNotFoundError, [Errno 2] No such file or directory: '/scratch/john/[100, 1700]/ENCODING/SWR_catFR1_R1024E_0_HPC_encoding_soz_in_hamming.p', file: ripples_HFA_analysis.py, line no: 167
2023-04-27_19-59-25: DF Exception: Sub: R1032D, Sess: 0, FileNotFoundError, [Errno 2] No such file or directory: '/scratch/john/[100, 1700]/ENCODING/SWR_catFR1_R1032D_0_HPC_encoding_soz_in_hamming.p', file: ripples_HFA_analysis.py, line no: 167
2023-04-27_19-59-25: DF Exception: Sub: R1032D, Sess: 1, FileNotFoundError, [Errno 2] No such file or directory: '/scratch/john/[100, 1700]/ENCODING/SWR_catFR1_R1032D_1_HPC_encoding_soz_in_hamming.p', file: ripples_HFA_analysis.py, line no: 167
2023-04-27_19

IndexError: too many indices for array

In [6]:
SCE.loaded_files

0

In [17]:
SCE.getStartArray()
SCE.select_idxs_numpy()
SCE.semantic_clustering()

Total # of ripples: 5971
(9459, 800)
(9459, 800)
(9459,)




R1361C-2
0
R1426N-0
0
R1448T-0
0
R1456D-3
0
R1465D-2
0
R1482J-0
0
R1482J-1
0
R1482J-2
0
R1501J-6
0
R1501J-8
0


In [8]:
SCE.sess_delta

[0.2055921052631579,
 0.16596889952153104,
 0.24944125159642405,
 0.0921474358974359,
 -0.08220502901353965,
 0.2597402597402596,
 0.07121091679915209,
 0.06304347826086953,
 0.3130951047384338,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [61]:
SCE.sess_delta

[0.2055921052631579,
 0.16596889952153104,
 0.24944125159642405,
 0.0,
 0.0921474358974359,
 -0.08220502901353965,
 0.2597402597402596,
 0.07121091679915209,
 0.06304347826086953,
 0.3130951047384338,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [89]:
np.unique(SCE.list_num_key_np[SCE.session_names_np=='R1015J-0']) 
SCE.session_names_np

array(['R1015J-0', 'R1015J-0', 'R1015J-0', ..., 'R1015J-0', 'R1015J-0',
       'R1015J-0'], dtype='<U9')

In [59]:
SCE.list_num_key[SCE.session_name_array=='R1015J-0']
SCE.session_name_array=='R1015J-0'

False

In [None]:
def stats_SCE_(self):
        
        save_plot = 1
        save_data_df = 1

        sess_df = pd.DataFrame({'delta_ripple_rate':self.sess_delta,'avg_recall_num':self.sess_recall_num,
                                'session':self.sess_sessions,'subject':self.sess_subjects,'prop_semantic':self.sess_prop_semantic})
        print('Mixed model of ripple_rate ~ avg_recall_num at session-level')        
        vc = {'session':'0+session'}
        sig_bin_model = smf.mixedlm("delta_ripple_rate ~ avg_recall_num", sess_df, groups="subject", vc_formula=vc,re_formula="avg_recall_num")
        bin_model1 = sig_bin_model.fit(reml=True, method='nm',maxiter=2000) 
        print(bin_model1.summary())
        print(bin_model1.pvalues)
        print(bin_model1.params)
        print(bin_model1.bse_fe)

        # want to plot at subject-level even though DF above is at subject_level
        sns.set(rc={'figure.figsize':(11.7,8.27),"font.size":20,"axes.titlesize":20,"axes.labelsize":20},style="white")

        if location_selected == 'dg':
            min_rate = 0.1
            plot_color = [(0.8,0,0)]
            lmplot_color = (0.8,0,0,0)
            errorbar_color = (1,0.5,0.5)
            line_color = 'darkred'
        elif location_selected == 'ca1':
            min_rate = 0.1 # for CA1 and DG I require a FR minimum for the two bins being compared since I do further stats on them
            plot_color = [(0,0,0.8)]
            lmplot_color = (0,0,0.8,0)
            errorbar_color = (0.5,0.5,1)
            line_color = 'darkblue'
        else:
            lmplot_color = (1,0,1)
            line_color = (1,0,1)    

        # convert session-level df to subject-level for plot
        sub_delta = []
        sub_recall_num = []
        sub_prop_semantic = []
        for sub in np.unique(self.sess_subjects):
            sub_delta.append(np.mean(np.array(self.sess_delta)[np.array(self.sess_subjects) == sub]))
            sub_recall_num.append(np.mean(np.array(self.sess_recall_num)[np.array(self.sess_subjects) == sub]))
            sub_prop_semantic.append(np.mean(np.array(self.sess_prop_semantic)[np.array(self.sess_subjects) == sub]))

        set_pubfig()
        deltaRR_df = pd.DataFrame({'avg_recall_num':sub_recall_num,'ripple_rates':sub_delta,'prop_semantic':sub_prop_semantic})
        fig, ax = subplots(figsize=(4,4))
        sns.regplot(ax=ax,x='avg_recall_num',y='ripple_rates',data=deltaRR_df,
                scatter_kws = {'color': lmplot_color},line_kws = {'color': line_color})

        if self.exp == 'FR1':
            gca().set(ylim=(-0.31,0.31),xlim=(0.5,7.5))
            xticks(np.arange(1,7.5+0.01,1))
        elif self.exp == 'catFR1':
            gca().set(ylim=(-0.35,0.35),xlim=(0.5,9.5))
            xticks(np.arange(1,9.5+0.01,1))
        gca().tick_params(labelsize=12)

        plot_corr = stats.pearsonr(sub_recall_num,sub_delta)

        annotate('Correlation: '+str(np.round(plot_corr[0],3)),(3.5,-0.25)) 
        # annotate('p-values: '+str(np.round(plot_corr[1],3)),(3.5,-0.29)) # use mixed model p-value instead:
        # annotate('$\it{P}$-value: '+str(np.round(bin_model1.pvalues[1],3)),(3.5,-0.29))
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)   
        tight_layout()

        ylabel(u'SCE Δ ripple rate (Hz)')
        xlabel('Average recalls/list')
        if type(location_selected)==list:
            location_selected = location_selected[0]+location_selected[1]
            
        if save_plot == 1:
            if type(location_selected)==list:
                location_selected = 'ca1dg'
            fname = '/home1/efegghi/SWR/figures/subject-level_plots/SCE_v_recall_'+ self.exp+'_'+ self.region_name+'_'+\
                location_selected+'_'+ self.filter_type+'_'+ self.sub_selection+'.pdf'
            savefig(fname,format='pdf',transparent=True)
        if save_data_df == 1:
            if type(location_selected)==list:
                location_selected = 'ca1dg'
            fname = '/home1/efegghi/SWR/figures/subject-level_plots/SCE_v_recall_'+ self.exp+'_'+ self.region_name+'_'+\
                location_selected+'_'+ self.filter_type+'_'+ self.sub_selection+'.p'
            with open(fname,'wb') as f:
                pickle.dump({'sess_df':sess_df}, f)
        
    def plot_figs_woah(self):
    
        location_selected = ''
    
         # (make sure you run semantic clustering for catFR cell first)
        save_fig = 0
        plot_SE = 0 # if you only want to do stats making this 0 will be much quicker since it doesn't do the ME model at each bin
                    # this must be 1 if plot_ME_mean is 1
        plot_ME_mean = 0 # 0 for typical PSTH; 1 for ME mean; 2 for average across sub averages

        # get PSTH for two of the types
        plot_three = [5,4,6] ## for clust v. unclust model: [5,4,6] # SCE contrast: [4,6]; SME contrast: # [5,4]
        # 0 is adjacent_semantic, 1 is remote_semantic, 
        # 2 is remote unclustered, 3 is dead end, 
        # 4 is combined remote unclustered & dead end
        # 5 is not recalled
        # 6 is combined clustering (adjacent semantic and remote semantic)

        bin_size = 100 # in ms
        smoothing_triangle = 5 # triangular smoothing window width
        pad = int(np.floor(smoothing_triangle/2)) # factor is how many points from middle does triangle go?
        text_height = 0.21

        ripple_bin_start_end = [100,1700]
        encoding_range = range(int((-self.pre_encoding_time+ripple_bin_start_end[0])/self.sr_factor),int((-self.pre_encoding_time+ripple_bin_start_end[1])/self.sr_factor)) 
        # comes to 400:1200 for 100 to 1700 of 1500 len timeseries

        ME_start_array = []
        ME_sub_name_array = []
        ME_session_name_array = []
        ME_indicator = [] # keeps track of trial type in comparison defined above (e.g. clustered v. not)

        # for recalled and then forgotten words
        for i_array,array_num in enumerate(plot_three):
            if array_num == 0:
                temp_start_array = self.adj_semantic_encoding_array
        #         label = 'List words lead to \\textbf{semantic transition}'        
                label = 'List words lead to \\textbf{adjacent semantic}'
                plot_color = (1,0.33,0)
            elif array_num == 1:
                temp_start_array = self.rem_semantic_encoding_array
                label = 'List words lead to \\textbf{remote semantic}'
                plot_color = (0,0.66,1)
            elif array_num == 2:
                temp_start_array = self.rem_unclustered_encoding_array
                label = 'List words lead to \\textbf{remote unclustered}'
                plot_color = (0,0.3,0)
            elif array_num == 3:
                temp_start_array = self.last_recall_encoding_array
                label = 'List words lead to \\textbf{dead end}'
                plot_color = (0,0,0.3)
            elif array_num == 5:
                temp_start_array = self.forgot_encoding_array
                label = 'List words later \\textbf{not recalled}'
                plot_color = (.66,0.33,0)
                if location_selected == '':
                    text_height = 0.125        
            elif array_num == 4:
                temp_start_array = superVstack(self.rem_unclustered_encoding_array,self.last_recall_encoding_array)
                label = 'List words lead to \\textbf{unclustered recalls}'
                plot_color = (0.5,0.5,0.5)
                self.sub_name_array4 = self.sub_name_array2+self.sub_name_array3
                self.sess_name_array4 = self.sess_name_array2+self.sess_name_array3
                self.elec_name_array4 = self.elec_name_array2+self.elec_name_array3
                if location_selected == '':
                    text_height = 0.125
            elif array_num == 6:
                temp_start_array = superVstack(self.adj_semantic_encoding_array,self.rem_semantic_encoding_array)
                label = 'List words lead to \\textbf{clustered recalls}'
                if location_selected == 'ca1':
                    plot_color = (0,0.66,1)
                    text_height = 0.21
                elif location_selected == 'dg':
                    plot_color = (1,0.33,0.66)
                    text_height = 0.21
                else:
                    plot_color = (1,0,1)
                    text_height = 0.125
                self.sub_name_array6 = self.sub_name_array0+self.sub_name_array1
                self.sess_name_array6 = self.sess_name_array0+self.sess_name_array1
                self.elec_name_array6 = self.elec_name_array0+self.elec_name_array1        
            else:
                print('not using this array_num, pick another my guy')
                break
                
            sub_name_array = eval('self.sub_name_array'+str(array_num))
            sess_name_array = eval('self.sess_name_array'+str(array_num))
            elec_name_array = eval('self.elec_name_array'+str(array_num))
            
            # update accumulating totals for ME model
            ME_start_array = superVstack(ME_start_array,temp_start_array)
            ME_sub_name_array = np.concatenate((ME_sub_name_array,sub_name_array))
            ME_session_name_array = np.concatenate((ME_session_name_array,sess_name_array))
            ME_indicator.extend(np.repeat(i_array,len(temp_start_array)))

            if array_num == plot_three[0]:
                # for plot
                subplots(1,1,figsize=(5,3.75))
                text(-600,text_height,label,usetex=True,size=16,color=plot_color) 
                text(-600,text_height-0.03,'Number of trials: '+str(temp_start_array.shape[0]),color=plot_color,size=12)
            elif array_num == plot_three[1]:
                text(-600,text_height-0.08,label,usetex=True,size=16,color=plot_color) 
                text(-600,text_height-0.11,'Number of trials: '+str(temp_start_array.shape[0]),color=plot_color,size=12)
            else:
                text(-600,0.05,label,usetex=True,size=16,color=plot_color) 
                text(-600,0.02,'Number of trials: '+str(temp_start_array.shape[0]),color=plot_color,size=12)        
                
            # make a PSTH based on start times of SWRs

            PSTH,bin_centers = fullPSTH(temp_start_array,bin_size,smoothing_triangle,self.sr,self.pre_encoding_time)

            # get binned start array (trials X bins)
            binned_start_array = binBinaryArray(temp_start_array,bin_size,self.sr_factor)

            print('done making binned start_array with shape:')
            print(binned_start_array.shape)

            # note that output is the net ± distance from mean
            if plot_SE == True:
                mean_plot,SE_plot = getMixedEffectMeanSEs(binned_start_array,sub_name_array,sess_name_array) #,elec_name_array)
                print('SEs created!')
                
            if plot_ME_mean == 1:
                PSTH = triangleSmooth(mean_plot,smoothing_triangle) # replace PSTH with means from ME model (after smoothing as usual)
            elif plot_ME_mean == 2: # average across subs first
                temp_means = []
                for sub in np.unique(sub_name_array):
                    temp_means = superVstack(temp_means,np.mean(binned_start_array[np.array(sub_name_array)==sub],0))
                PSTH = triangleSmooth(np.mean(temp_means,0),smoothing_triangle)
                SE_sub_level = np.std(temp_means,0)/np.sqrt(len(temp_means))
            # plot
            xr = bin_centers #np.arange(psth_start,psth_end,binsize)
            xr = xr[pad:-pad]
            if pad > 0:
                binned_start_array = binned_start_array[:,pad:-pad] # remove edge bins
                PSTH = PSTH[pad:-pad] 
                if plot_SE == True:
                    SE_plot = SE_plot[:,pad:-pad]
            
            plot(xr,PSTH,color=plot_color)
            if plot_SE == True:
                fill_between(xr, PSTH-SE_plot[0,:], PSTH+SE_plot[0,:], alpha = 0.3, facecolor=plot_color)
                
        xticks(np.arange(self.pre_encoding_time+pad*100,self.encoding_time-pad*100+1,500),
            np.arange((self.pre_encoding_time+pad*100)/1000,(self.encoding_time-pad*100)/1000+1,500/1000))
        xlabel('Time from word presentation (s)',fontsize=14)
        ylabel('Ripple rate (events/s)',fontsize=14)
        ax = plt.gca()
        ylim = 0.5
        ax.set_ylim(0,ylim)
        ax.set_xlim(self.pre_encoding_time,self.encoding_time)
        plot([0,0],[ax.get_ylim()[0],ax.get_ylim()[1]],linewidth=1,linestyle='-',color=(0,0,0))
        plot([1600,1600],[ax.get_ylim()[0],ax.get_ylim()[1]],linewidth=1,linestyle='--',color=(0.7,0.7,0.7))
        xticks(fontsize=12)
        yticks(fontsize=12)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)   
        tight_layout()

        # cell below uses comp_str so grab it regardless
        if type(location_selected)==list:
            location_selected = 'ca1dg'
            
        if len(plot_three)==3:
            comp_str = str(plot_three[0])+'-'+str(plot_three[1])+'-'+str(plot_three[2])
        elif len(plot_three)==2:
            comp_str = str(plot_three[0])+'-'+str(plot_three[1])
                
        if save_fig == 1:
            # get strings for path name for save and loading cluster data

            path_name = '/home1/efeghhi/SWR/figures/ENCODING/'
            fn = os.path.join(path_name,
                'semantic_clustering_'+self.exp+'_'+self.region_name+'_'+\
                location_selected+'_'+self.filter_type+'_'+self.sub_selection+comp_str+'_'+str(plot_ME_mean)+'.pdf')
            plt.savefig(fn,transparent=True)  

In [104]:
a = np.array([0,1,2])
a[[1, 1, 0]]

array([1, 1, 0])

In [None]:
Sess R1221P-0,  delta 0.2467105263157895
Sess R1221P-1,  delta 0.46949760765550236
Sess R1221P-2,  delta -0.012771392081736888
Sess R1239E-1,  delta 0.0
Sess R1269E-0,  delta 0.10683760683760685
Sess R1269E-2,  delta -0.017730496453900735
Sess R1278E-0,  delta 0.3084415584415585
/home1/efeghhi/.conda/envs/env1/lib/python3.7/site-packages/ipykernel_launcher.py:383: RuntimeWarning: invalid value encountered in double_scalars
Sess R1278E-10,  delta -0.014573396926338078
Sess R1332M-0,  delta -0.17463768115942035
Sess R1332M-1,  delta 0.41111235577461613
Sess R1361C-2,  delta 0.0
Sess R1426N-0,  delta 0.0
Sess R1448T-0,  delta 0.0
Sess R1456D-3,  delta 0.0
Sess R1465D-2,  delta 0.0
Sess R1482J-0,  delta 0.0
Sess R1482J-1,  delta 0.0
Sess R1482J-2,  delta 0.0
Sess R1501J-6,  delta 0.0
Sess R1501J-8,  delta 0.0

In [34]:
if not True:
    print("Hi")