In [None]:
%reload_ext autoreload
%autoreload 2

import os, pickle, sys
import numpy as np
import pandas as pd
import main_funcs as mfun
import utils_funcs as utils
import plot_funcs as pfun
import matplotlib.pyplot as plt
from scipy import stats
import glob
from scipy.signal import resample
from scipy.cluster.hierarchy import linkage, leaves_list
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning) 
warnings.filterwarnings("ignore", category= FutureWarning) 
warnings.filterwarnings("ignore", category= DeprecationWarning)
warnings.filterwarnings("ignore", category= UserWarning)

pfun.set_figure()


## Parameters
fRate = 1000/30
pre_frames, post_frames, analysisWindowDur, simulationDur = pfun.set_analysisParams ()
responsiveness_test_duration = 1000.0 #in ms
simulationDur_ms = 350.0 # in ms 
simulationDur  = int(np.ceil(simulationDur_ms/fRate))
shutterLength  = int(np.round(simulationDur_ms/fRate))
tTypes = ['All','Both', 'onlyOpto', 'onlyVis']

In [None]:
# Loop through all recordings - MUST BE RUN TO UPDATE THE INFO FILE
# For each recording one file saved, each index is a cell 
# Define the parent directory
base_dir = r"C:\Users\Huriye\Documents\code\clapfcstimulation\analysis"  

# Loop through folders starting with "2025"
for ind, folder_name in enumerate(os.listdir(base_dir)):
   folder_path = os.path.join(base_dir, folder_name)
   if os.path.isdir(folder_path) and folder_name.startswith("2025"):
      pkl_path = os.path.join(folder_path, "extracted_variables.pkl")
      savepathname = folder_path + '\\'
      filenamePAQ_analysis = [f for f in glob.glob(savepathname + 'extracted_variables.pkl')] # paq analysis file
      if ind>0: #try:
         print(str(ind) + ': Creating: ' +savepathname)
         #Create a huge dictionary with all cells and parameters for each cell
         pathname = savepathname
         print('Creating the dict for:' + pathname)

         ########## Organise stimuli times 
         if  True: #(info.recordingList.paqExtraction[ind] ==1) :
            paqData = pd.read_pickle (pathname +'paq-data.pkl')
            paqRate = paqData['rate']
            # Get the stim start times 
            frame_clock    = utils.paq_data (paqData, 'prairieFrame', threshold_ttl=True, plot=False)
            optoStimTimes  = utils.paq_data (paqData, 'optoLoopback', threshold_ttl=True, plot=False)
            if   (len(optoStimTimes)>500):

               print('Opto stim times is a lot , check it out : ' + str(ind))
            else: # the frame_clock is slightly longer in paq as there are some up to a sec delay from
               # microscope to PAQI/O software.  
               optoStimTimes = utils.stim_start_frame (paq=paqData, stim_chan_name='optoLoopback',
                                                   frame_clock=None,stim_times=None, plane=0, n_planes=1)
               visStimTimes = utils.stim_start_frame (paq=paqData, stim_chan_name='maskerLED',
                                                   frame_clock=None,stim_times=None, plane=0, n_planes=1)
               shutterTimes = utils.shutter_start_frame (paq=paqData, stim_chan_name='shutterLoopback',
                                                   frame_clock=None,stim_times=None, plane=0, n_planes=1)
               # Lets organise it more for analysis friendly format
               trialStartTimes = np.unique(np.concatenate((optoStimTimes,visStimTimes),0))
               # final check if there is a problematic stim start
               first_ind = np.where(np.diff(trialStartTimes)>30) # should be at least one sec between stim starts
               first_ind = np.concatenate(([0], first_ind[0]+1))
               trialStartTimes = np.array(trialStartTimes)
               trialStartTimes = trialStartTimes[first_ind]
               trialTypes = []

               for t in trialStartTimes:
                  optoexist =  np.any(optoStimTimes== t)
                  visexist  =  np.any( visStimTimes == t)
                  if  optoexist  & visexist: 
                     trialTypes += ['Both']
                  elif optoexist &~ visexist:
                     trialTypes += ['onlyOpto']
                  elif ~optoexist & visexist:
                     trialTypes += ['onlyVis']
                  else:
                     trialTypes += ['CHECK']
               trialStartTimes = shutterTimes
               #t = [idx for idx, t_type in enumerate(trialTypes) if t_type=='Both']
            
            ########## Organise calcium imaging traces 
            pvals ={} 
            pvalsWilcoxon = {}
            dffTrace ={} 
            dffTrace_mean ={}
            dffTrace_median ={}
            dffAfterStim1500ms_mean ={}
            dffTrace_normalised = {}
            dffTrace_mean_normalised ={}
            
            if True:
               imData = pd.read_pickle (pathname +'imaging-data.pkl')
               fluR      = imData['flu']
               n_frames  = imData['n_frames']
               flu_raw   = imData['flu_raw']

               # Lets put nans for stimulated frames
               frameTimes = np.full((1,fluR.shape[1] ), False) # create a full false array
               for sT in shutterTimes:
                  frameTimes[:,sT:(sT+shutterLength)] = True
               fluR[:, frameTimes[0,:]] = np.nan

               # clean detrended traces
               flu = utils.clean_traces(fluR)
               flu_normalised = mfun.norm_to_zero_one (flu)

               ### Look for responsiveness for 4 trial types
               for t in tTypes:
                  if t =='All':
                     trialInd = np.transpose(list(range(len(trialStartTimes))))
                  elif t =='onlyVis':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyVis']
                  elif t =='onlyOpto':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyOpto']
                  elif t =='Both':  
                     # Get indices of each type
                     visual_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'onlyVis']
                     opto_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'Both']         
                     # Determine number to sample from each
                     half_n = min(len(visual_inds), len(opto_inds)) // 2
                     # Randomly sample from each group
                     selected_visual = np.random.choice(visual_inds, size=half_n, replace=False)
                     selected_opto = np.random.choice(opto_inds, size=half_n, replace=False)
                     # Combine and sort
                     trialInd = np.sort(np.concatenate([visual_inds, selected_opto]))
                     #trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyVis']
                     tBoth = trialInd
                                    
                  if len(trialInd)>1:
                     _, _, pval = utils.test_responsive  (flu, frame_clock, trialStartTimes[trialInd], pre_frames=int(np.ceil(responsiveness_test_duration/fRate)), 
                                                               post_frames=int(np.ceil(responsiveness_test_duration/fRate)), offset=simulationDur,
                                                               testType ='ttest')
                     pvals[t] = pval
                     _, _, pval = utils.test_responsive  (flu, frame_clock, trialStartTimes[trialInd], pre_frames=int(np.ceil(responsiveness_test_duration/fRate)), 
                                                               post_frames=int(np.ceil(responsiveness_test_duration/fRate)), offset=simulationDur,
                                                               testType ='wilcoxon')
                     pvalsWilcoxon[t] = pval
               nCell = len(flu)
               print('number of cell: ' + str(nCell))
               ### Get dff values for 4 trial types
               for indx, t in enumerate(tTypes) :
                  if t =='All':
                     trialInd = np.transpose(list(range(len(trialStartTimes))))
                  elif t =='onlyVis':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyVis']
                  elif t =='onlyOpto':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyOpto']
                  elif t =='Both':  
                     # # Get indices of each type
                     # visual_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'onlyVis']
                     # opto_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'onlyVis']         
                     # # Determine number to sample from each
                     # half_n = min(len(visual_inds), len(opto_inds)) // 2
                     # # Randomly sample from each group
                     # selected_visual = np.random.choice(visual_inds, size=half_n, replace=False)
                     # selected_opto = np.random.choice(opto_inds, size=half_n, replace=False)
                     # # Combine and sort
                     # trialInd = np.sort(np.concatenate([selected_visual, selected_opto]))
                     trialInd = tBoth
                  
                  if len(trialInd)>1:
                     dffTrace[t]      = utils.flu_splitter(flu, trialStartTimes[trialInd], pre_frames, post_frames) # Cell x time x trial
                     dffTrace_mean[t] = np.mean(dffTrace[t],2) # Cell x time
                     dffTrace_median[t] = np.median(dffTrace[t],2) # Cell x time
                     dffAfterStim1500ms_mean[t] = np.nanmean(dffTrace_mean[t][:, (pre_frames+simulationDur): (pre_frames + analysisWindowDur)],1)
                     
                     dffTrace_normalised[t] = utils.flu_splitter(flu_normalised, trialStartTimes[trialInd], pre_frames, post_frames) # Cell x time x trial
                     dffTrace_mean_normalised[t] = np.mean (dffTrace_normalised[t],2) # Cell x time ( mean - 2)

            ########## Organise pupil  traces 
            pupilTrace ={}
            pupilTrace_mean ={}
            pupilTraceVer ={}
            pupilTrace_meanVer ={}

            if True:
               pupilData = pd.read_pickle (pathname +'pupil-data.pkl')
               pupilrawh = pupilData['horizontalDis']
               pupilrawv = pupilData['verticalDis']
               pupilQualityh = pupilData['horizontallikelihood']
               pupilQualityv = pupilData['verticallikelihood']

               for indx, t in enumerate(tTypes) :
                  if t =='All':
                     trialInd = np.transpose(list(range(len(trialStartTimes))))
                  elif t =='onlyVis':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyVis']
                  elif t =='onlyOpto':
                     trialInd = [idx for idx, t_type in enumerate(trialTypes) if t_type=='onlyOpto']
                  elif t =='Both':  
                     # # Get indices of each type
                     # visual_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'onlyVis']
                     # opto_inds = [idx for idx, t_type in enumerate(trialTypes) if t_type == 'onlyVis']         
                     # # Determine number to sample from each
                     # half_n = min(len(visual_inds), len(opto_inds)) // 2
                     # # Randomly sample from each group
                     # selected_visual = np.random.choice(visual_inds, size=half_n, replace=False)
                     # selected_opto = np.random.choice(opto_inds, size=half_n, replace=False)
                     # # Combine and sort
                     # trialInd = np.sort(np.concatenate([selected_visual, selected_opto]))
                     trialInd = tBoth
                  
                  if len(trialInd)>1:
                     pupilTrace[t]  = utils.trace_splitter(pupilrawh, trialStartTimes[trialInd], pre_frames, post_frames) # Cell x time x trial
                     pupilTrace_mean[t]  = np.mean(pupilTrace[t],2) # Cell x time

                     pupilTraceVer[t]  = utils.trace_splitter(pupilrawv, trialStartTimes[trialInd], pre_frames, post_frames) # Cell x time x trial
                     pupilTrace_meanVer[t]  = np.mean(pupilTraceVer[t],2) # Cell x tim

            #lets get recording info
            animalID = []
            x_coordinate = []
            y_coordinate = []
            z_coordinate = []
            stimuliFamilarity = []
            dataQuality =[]
            recData = []
            recID   = []
            cellID  = []

            if len(imData)>0:
               imStats = imData['stat']
               for idx, cell_flu in enumerate(flu): # from suite2p website: med: (y,x) center of cell
                  x_coordinate += [0]#[np.round(imStats[idx]['med'][1] *512/558.9) + info.recordingList['x-coordinate'] [ind]]
                  y_coordinate += [0]#[np.round(imStats[idx]['med'][0] *512/558.9) + info.recordingList['y-coordinate'] [ind]]
                  z_coordinate += [0]#[info.recordingList['z-coordinate'] [ind]]
                  animalID     += [25000] # [info.recordingList['animalID'] [ind]]
                  stimuliFamilarity += [21]# [info.recordingList['stimuliFamiliarity'] [ind]]
                  dataQuality  += [1]#[info.recordingList['dataQuality'] [ind]]
                  recData  += [0]#[info.recordingList['recordingDate'] [ind]]
                  recID    += [0]#[info.recordingList['recordingID'] [ind]]
                  cellID   += [idx]

            #save outputs for population analysis
            savename = pathname + '\\extracted_variablesTH.pkl'
            with open(savename, 'wb') as f:  # Python 3: open(..., 'wb')
               pickle.dump([trialStartTimes, trialTypes, pvals, dffTrace, 
                           dffTrace_mean, dffAfterStim1500ms_mean,
                           x_coordinate,y_coordinate,z_coordinate,
                           animalID,stimuliFamilarity,pvalsWilcoxon, dataQuality,
                           recData, recID, cellID, pupilTrace, pupilTrace_mean, 
                           pupilQualityh, pupilTraceVer, pupilTrace_meanVer, pupilQualityv,
                           dffTrace_normalised,dffTrace_mean_normalised, dffTrace_median], f)

      #except:
      #   print('FAILED: ' + pathname)
   
print('All should be done!!')

In [None]:
## Merge all datapoints & save them in a dictionary


# Initialize empty or None types for variables to be aggregated.
dff_traceBoth, dff_traceVis, dff_traceOpto = None, None, None
dff_traceBoth_median, dff_traceVis_median, dff_traceOpto_median = None, None, None
dff_traceBoth_normalised, dff_traceVis_normalised, dff_traceOpto_normalised = None, None, None
pupil_traceBoth, pupil_traceVis, pupil_traceOpto = None, None, None
pupilID, x_coordinate, y_coordinate, z_coordinate = [], [], [], []
animalID, stimuliFamilarity, dataQuality, recData, recID, cellID = [], [], [], [], [], []
pvalsBoth, pvalsVis, pvalsOpto = [], [], []
dff_meanBothValue, dff_meanVisValue, dff_meanOptoValue = [], [], []

# Create the main variables for plots by merging the extracted variables from all recordings
ty = 0
indPupil = 0

base_dir = r"C:\Users\Huriye\Documents\code\clapfcstimulation\analysis"  

# Loop through folders starting with "2025"
for ind, folder_name in enumerate(os.listdir(base_dir)):
   folder_path = os.path.join(base_dir, folder_name)
   if os.path.isdir(folder_path) and folder_name.startswith("2025"):
        pathname = folder_path + '\\'
        extData = pd.read_pickle (pathname + 'extracted_variablesTH.pkl')
        sys.stdout.write(f'\rExtraction started for : {ind}')
        sys.stdout.flush() 

        len_cellID   = len(extData[15])
        if len_cellID==0:
            len_cellID=1

        # Get Dff traces MEAN
        dff_trace = extData[4] #4 for mean, 24 for median
        shape = (len_cellID, 240) 
        dff_traceBoth =mfun.update_dff_traces(dff_traceBoth, dff_trace, 'Both', shape)
        dff_traceVis = mfun.update_dff_traces(dff_traceVis, dff_trace, 'onlyVis', shape)
        dff_traceOpto = mfun.update_dff_traces(dff_traceOpto, dff_trace, 'onlyOpto', shape)

        # Get Dff traces MEDIAN
        dff_trace = extData[24] #4 for mean, 24 for median
        shape = (len_cellID, 240) 
        dff_traceBoth_median =mfun.update_dff_traces(dff_traceBoth_median, dff_trace, 'Both', shape)
        dff_traceVis_median = mfun.update_dff_traces(dff_traceVis_median, dff_trace, 'onlyVis', shape)
        dff_traceOpto_median = mfun.update_dff_traces(dff_traceOpto_median, dff_trace, 'onlyOpto', shape)

        # Get Dff traces NORMALISED
        dff_trace = extData[23] #4 for mean, 24 for median
        shape = (len_cellID, 240)
        dff_traceBoth_normalised =mfun.update_dff_traces(dff_traceBoth_normalised, dff_trace, 'Both', shape)
        dff_traceVis_normalised = mfun.update_dff_traces(dff_traceVis_normalised, dff_trace, 'onlyVis', shape)
        dff_traceOpto_normalised = mfun.update_dff_traces(dff_traceOpto_normalised, dff_trace, 'onlyOpto', shape)

        # Get Pupil traces
        indPupil, pupilID, pupil_traceBoth, pupil_traceVis, pupil_traceOpto = mfun.update_pupil_traces(
            extData, len_cellID, indPupil, pupilID, pupil_traceBoth, pupil_traceVis, pupil_traceOpto)
        
        # Get P values

        pvals = extData[2]
        if len(pvals)>0:
            pvalsBoth += pvals['Both'].tolist()
            pvalsVis  += pvals['onlyVis'].tolist()
            pvalsOpto  += pvals['onlyOpto'].tolist()
        else:
            display('Weird')
            pvalsBoth += [5]
            pvalsVis  += [5]
            pvalsOpto += [5]

         
        # Get  cell-related information
        x_coordinate += extData[6]
        y_coordinate += extData[7]
        z_coordinate += extData[8]
        animalID     += extData[9]
        stimuliFamilarity += extData[10]
        dataQuality += extData[12]
        recData  += extData[13]
        recID    += extData[14]
        cellID   += extData[15]

########################################
####################### Organise & save files

fRate = 1000/30.0
pre_frames    = 2000.0# in ms
pre_frames    = int(np.ceil(pre_frames/fRate))
post_frames   = 6000.0 # in ms
post_frames   = int(np.ceil(post_frames/fRate))
analysis_time = 1500.0 # in ms
analysis_time = int(np.ceil(analysis_time/fRate))
simulationDur_ms = 350.0 # in ms
simulationDur = int(np.ceil(simulationDur_ms/fRate))
pupil_resample_num = int(6*5)

### Lets normalise to baseline - MEAN
traces = {  'Vis': dff_traceVis,
            'Opto': dff_traceOpto,
            'Both': dff_traceBoth}

dff_traceVis_normalisedtopre, dff_traceOpto_normalisedtopre, dff_traceBoth_normalisedtopre = (
    mfun.normalize_to_baseline(traces[key], pre_frames) for key in traces)
### Lets do tiny cleaning for imaging traces -MEDIAN
traces = {  'Vis': dff_traceVis_median,
            'Opto': dff_traceOpto_median,
            'Both': dff_traceBoth_median}
dff_traceVis_normalisedtopre_median, dff_traceOpto_normalisedtopre_median, dff_traceBoth_normalisedtopre_median = (
    mfun.normalize_to_baseline(traces[key], pre_frames) for key in traces)

### Lets do tiny cleaning for pupil traces
traces = {  'Vis': pupil_traceVis,
            'Opto': pupil_traceOpto,
            'Both': pupil_traceBoth}
pupil_traceVis, pupil_traceOpto, pupil_traceBoth = (
    mfun.normalize_to_baseline(traces[key], pre_frames) for key in traces)

pupil_traceOpto = resample(pupil_traceOpto, pupil_resample_num, axis=1)
pupil_traceVis = resample(pupil_traceVis, pupil_resample_num, axis=1)
pupil_traceBoth = resample(pupil_traceBoth, pupil_resample_num, axis=1)

#  More version of dff traces
dff_meanBoth = np.nanmean(dff_traceBoth [:, pre_frames:(pre_frames + simulationDur + analysis_time)],axis=1)
dff_meanVis  = np.nanmean(dff_traceVis  [:, pre_frames:(pre_frames + simulationDur + analysis_time)],axis=1)
dff_meanOpto = np.nanmean(dff_traceOpto [:, pre_frames:(pre_frames + simulationDur + analysis_time)],axis=1) 

zdff_traceBoth = stats.zscore (dff_traceBoth, nan_policy='omit')
zdff_traceVis  = stats.zscore (dff_traceVis, nan_policy='omit')
zdff_traceOpto = stats.zscore (dff_traceOpto, nan_policy='omit')

filenameINFO = base_dir + '\\infoForAnalysisTH-readyForSelectingInterestedCells.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((animalID, stimuliFamilarity, dataQuality,
                 recData, recID, cellID, 
                 pvalsBoth, pvalsVis, pvalsOpto,
                 dff_meanVis, dff_meanBoth, 
                 dff_meanOpto, pupilID),f)

filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((dff_traceBoth, dff_traceVis, dff_traceOpto), f)
    
filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_median.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((dff_traceBoth_median, dff_traceVis_median, dff_traceOpto_median), f)
    
filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_normalised.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((dff_traceBoth_normalised, dff_traceVis_normalised, dff_traceOpto_normalised), f)
    
filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_normalisedtoPre.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((dff_traceBoth_normalisedtopre, dff_traceVis_normalisedtopre, dff_traceOpto_normalisedtopre), f)
    
filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_normalisedtoPre_median.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((dff_traceBoth_normalisedtopre_median, dff_traceVis_normalisedtopre_median, dff_traceOpto_normalisedtopre_median), f)
    
filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_position.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((x_coordinate, y_coordinate, z_coordinate ), f)

filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlottingPupil.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((pupil_traceVis, pupil_traceBoth, pupil_traceOpto ), f)
print('\nAll should be done!!')
    


In [None]:
# More stats

# Create the main variables for plots by merging the extracted variables from all recordings
pre_frames, post_frames, analysis_frame, simulationDur = pfun.set_analysisParams ()
pre_analysisWindow = np.arange(pre_frames - analysis_frame, pre_frames)
post_analysisWindow = np.arange((pre_frames+simulationDur), (pre_frames + simulationDur + analysis_frame))

# Create dictionaries to store the values
variance_dict_pre = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
variance_dict_post = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
mi_dict = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
snr_dict = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
abs_dict = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
crosscorr_dict_pre = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
crosscorr_dict_post = {'Both': [], 'onlyVis': [], 'onlyOpto': []}
dtw_dict = {'Both': [], 'onlyVis': [], 'onlyOpto': []}

for ind, folder_name in enumerate(os.listdir(base_dir)):
   folder_path = os.path.join(base_dir, folder_name)
   if os.path.isdir(folder_path) and folder_name.startswith("2025"):
        pathname = folder_path + '\\'
        extData = pd.read_pickle (pathname + 'extracted_variablesTH.pkl')
        sys.stdout.write(f'\rMore stats are calculating for : {ind}')
        sys.stdout.flush() 
        len_cellID = len(extData[15])

        # Loop through the datasets ('Both', 'onlyVis', 'onlyOpto')
        for dataset in ['Both', 'onlyVis', 'onlyOpto']:
            dff_trace = extData[3]
            if dataset in dff_trace:
                flu = dff_trace[dataset]
                # Calculate variance
                variance_value_pre = mfun.variance_cell_rates(flu, pre_analysisWindow)
                variance_dict_pre[dataset] += variance_value_pre.tolist()
                variance_value_post = mfun.variance_cell_rates(flu, post_analysisWindow)
                variance_dict_post[dataset] += variance_value_post.tolist()

                # Calculate cross-correlation
                crosscorr_value_post = mfun.mean_cross_correlation(flu, pre_analysisWindow)
                crosscorr_dict_pre[dataset] += crosscorr_value_post.tolist()

                crosscorr_value_post = mfun.mean_cross_correlation(flu, post_analysisWindow)
                crosscorr_dict_post[dataset] += crosscorr_value_post.tolist()

                # Calculate SNR
                snr_value = mfun.calculate_SNR(flu, pre_analysisWindow, post_analysisWindow)
                snr_dict[dataset] += snr_value.tolist()

                # Calculate Absolute Value
                abs_value = mfun.calculate_absMagnitude(flu, pre_analysisWindow, post_analysisWindow)
                abs_dict[dataset] += abs_value.tolist()

                # Calculate MI (if applicable)
                mi_value = mfun.calculate_MI(flu, pre_analysisWindow, post_analysisWindow)
                mi_dict[dataset] += mi_value.tolist()

                # Calculate DTW
                dtw_value = mfun.calculate_per_cell_dtw(flu, post_analysisWindow)
                dtw_dict[dataset] += dtw_value.tolist()  
            else:
                if 'onlyOpto' in dff_trace:
                    fakeOpto = np.empty(np.shape(dff_trace['Both']))
                else:
                    fakeOpto = np.empty(1)
                fakeOpto[:] = np.nan

                variance_dict_pre[dataset] += fakeOpto.tolist()
                variance_dict_post[dataset] += fakeOpto.tolist()
                crosscorr_dict_pre[dataset] += fakeOpto.tolist()
                crosscorr_dict_post[dataset] += fakeOpto.tolist()
                snr_dict[dataset] += fakeOpto.tolist()
                abs_dict[dataset] += fakeOpto.tolist()
                mi_dict[dataset] += fakeOpto.tolist()
                dtw_dict[dataset] += fakeOpto.tolist() 

filenameINFO = base_dir + '\\infoForAnalysisTH-readyForPlotting_moreStats.pkl'
with open(filenameINFO, 'wb') as f:
    pickle.dump((variance_dict_pre, variance_dict_post, snr_dict, mi_dict,
                 crosscorr_dict_pre,crosscorr_dict_post, abs_dict,dtw_dict), f)
    
print('\nCompleted')

In [None]:
# Supplementary figure : Chrimson in Th mice - Naive population analysis Increase variance in response

interestedCohort = 'Th'
interestedTrainedLevel = 'Naive'

pfun.set_figure()
fig = plt.figure(constrained_layout=False, figsize=(16, 18)) # 8,11 for full A4 page
# figsize=(6.85, 9.05)) for full page # figsize=(3.35, 9.05)) for single page

# set gridspec
gs_visHeatmap  = fig.add_gridspec(ncols=3, nrows=1, bottom=0.65, top=0.95, left=0.05,right=0.46,
                              wspace=0.1, hspace=0.2)
gs_visHeatmapCax  = fig.add_gridspec(ncols=1, nrows=1, bottom=0.65, top=0.95, left=0.48, right=0.50,
                              wspace=0.2, hspace=0.4)
gs_visuaPlots  = fig.add_gridspec(ncols=3, nrows=2, bottom=0.65, top=0.95, left=0.56, right=0.95,
                              wspace=0.3, hspace=0.4)
gs_optoPlots   = fig.add_gridspec(ncols=5, nrows=1, bottom=0.46, top=0.59, left=0.0, right=0.95,
                             width_ratios=[0.8, 0.8, 1, 1, 1], wspace=0.45, hspace=0.4)
gs_bothPlots   = fig.add_gridspec(ncols=5, nrows=1, bottom=0.27, top=0.40, left=0.0, right=0.95,
                             width_ratios=[0.8, 0.8, 1, 1, 1], wspace=0.45, hspace=0.4)

# Panel A: Heatmap for Visual responsive cells
ax_gs_visHeatmap = {xx: fig.add_subplot(gs_visHeatmap[xx]) for xx in range(3)}
cax =  {xx: fig.add_subplot(gs_visHeatmapCax[xx]) for xx in range(1)} 
colorbarlimitsForHeatMap = [-1,1] 
pfun.heatmap_comparison('Visual', 'Visual + Opto', sortType = 'Visual', cohort=interestedCohort, 
                       trainedLevel= interestedTrainedLevel, condition='Sensory',
                       colormapSelection = 'OptoProject', axis=ax_gs_visHeatmap, cbar_ax=cax[0], 
                       savefigname=None, savefigpath=None, colorbarlimits=colorbarlimitsForHeatMap)

# Panel B-G: Visual responsive cells analysis
plotParams = {
    'ylimitsforhist': [0, 50],
    'xlimitsforhist': [-0.75, 0.75],
    'analysis_time': 1500,  # in ms
    'colorbarlimitsForHeatMap': [-2, 2],
    'scatterplotlimits': [-4.5, 4.5],
    'ylimitsforECDF': [0.5, 1.05],
    'xlimitsforABS': [-0.05, 1.2],
    'ylimitsforCV': [0.1, 0.15],
    'faceColors': ['black','red'],
     }

total_num_axis = gs_visuaPlots.get_geometry()[0]* gs_visuaPlots.get_geometry()[1]
ax_gs_visuaPlots = {xx: fig.add_subplot(gs_visuaPlots[xx]) for xx in range(total_num_axis)}
pfun.population_plots('Visual', 'Visual + Opto', sortType = 'Visual', cohort=interestedCohort, 
                       trainedLevel=interestedTrainedLevel, condition='Sensory', plotParams = plotParams,
                       axisAll=ax_gs_visuaPlots, savefigname=None, savefigpath=None)

# Panel: Opto responsive cells analysis
total_num_axis = gs_optoPlots.get_geometry()[0]* gs_optoPlots.get_geometry()[1]
ax_gs_optoPlots = {xx: fig.add_subplot(gs_optoPlots[xx]) for xx in range(total_num_axis)}
plotParams['faceColors'] = ['black', 'blue']
plotParams['ylimitsforhist'] = [0, 100]
pfun.population_plots('Visual', 'Visual + Opto', sortType = 'Visual', cohort=interestedCohort, 
                      trainedLevel=interestedTrainedLevel, condition='Opto', plotParams = plotParams,
                      axisAll=ax_gs_optoPlots, savefigname=None, savefigpath=None)

# Panel: OptoBoosted responsive cells analysis
total_num_axis = gs_bothPlots.get_geometry()[0]* gs_bothPlots.get_geometry()[1]
ax_gs_bothPlots = {xx: fig.add_subplot(gs_bothPlots[xx]) for xx in range(total_num_axis)}
plotParams['faceColors'] = ['black', 'green']
plotParams['ylimitsforhist'] = [0, 20]
pfun.population_plots('Visual', 'Visual + Opto', sortType = 'Visual', cohort=interestedCohort, 
                       trainedLevel=interestedTrainedLevel, condition='Opto-boosted', plotParams = plotParams,
                       axisAll=ax_gs_bothPlots, savefigname=None, savefigpath=None)

# Lets add the labels
axes = [ax_gs_visHeatmap, ax_gs_visuaPlots, ax_gs_optoPlots, ax_gs_bothPlots]
labels = ['A','B','C','D','','E','F','G','','H','','I','J','K','L','','M','N','O','P',
          'Q','R','S','T','U','V','W','X','Y','','Z']
lInd = 0
skip_indices = [68, 70, 75]
for ax, label in zip(axes, labels):
    for key in ax:
        # if 8/10/15, do not add the label
        if lInd not in skip_indices:
            if lInd<3:
                ax[key].text(-0.04, 1.04, labels[lInd], transform=ax[key].transAxes, fontsize=16, 
                            fontweight='bold', va='top', ha='right',
                            bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))
            elif lInd>2:
                ax[key].text(-0.08, 1.1, labels[lInd], transform=ax[key].transAxes, fontsize=16, 
                            fontweight='bold', va='top', ha='right',
                            bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))
            else:
                ax[key].text(-0.04, 1.04, labels[lInd], transform=ax[key].transAxes, fontsize=16, 
                            fontweight='bold', va='top', ha='right',
                            bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))
        lInd += 1

# Add some text in the figure  
axes[1][1].text(2.5, 1.3, 'Sensory responsive neurons', transform=axes[1][0].transAxes, fontsize=16, 
                        fontweight='bold', va='top', ha='right',
                        bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))
              
axes[2][1].text(3.5, 1.2, 'Opto responsive neurons', transform=axes[2][1].transAxes, fontsize=16, 
                        fontweight='bold', va='top', ha='right',
                        bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))

axes[3][1].text(3.5, 1.2, 'Opto-boosted neurons', transform=axes[3][1].transAxes, fontsize=16, 
                        fontweight='bold', va='top', ha='right',
                        bbox=dict(facecolor='none', edgecolor='none', boxstyle='round,pad=0.1'))

        
plt.tight_layout()
savefigpath  =  r'G:\My Drive\Manuscripts\0 CLAStPFC\panels_raw2'
savefigname = 'Supp_Th_NaivePopulationAnalysisRawSorted'
pfun.save_figure(savefigname,savefigpath)
plt.close()