In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from itertools import combinations
import copy
import pickle

from scipy.ndimage import gaussian_filter1d
from scipy.signal import butter, filtfilt
from scipy.stats import linregress

#These lines allow us to import functions from my python func with helper functions
import sys

sys.path.insert(0, '/Users/charliehuang/Documents/Photometry_pipeline/data_analysis_helperfuncs')
import behav_data_analysis as bd
import dlc_helper as dh
import behav_helper as bh
import cube_helper as ch
import photom_helper as ph
import statistics_helper as sh

%load_ext autoreload

%autoreload 2
import importlib
importlib.reload(bd)
%config IPCompleter.greedy=True

# Important parameters

In [None]:
BACK_WINDOW = 1000
PRE_MOVE_WINDOW = 140 #0.7 seconds
FORWARD_WINDOW = 1000


p_BACK_WINDOW = int(30*(BACK_WINDOW/200))
p_FORWARD_WINDOW = int(30*(FORWARD_WINDOW/200))
p_PRE_MOVE_WINDOW = int(30*(PRE_MOVE_WINDOW/200))

datapath = '/Users/charliehuang/Documents/python_work/data/Photometry'
manip_folder = '/Photometry_Manipulandum'
photom_addon = '_2C3T4B'
fluor_folder = '/Photometry_Fluorescence'
output_path = datapath + '/Outputs'

misc_pkl_folder = datapath + '/misc_pickles'
oreg_pkl_file = '/outlier_regions_dictionary.pkl'
blacklist_files_m = ['/RR20240320_J_2024-04-26.csv']

with open(misc_pkl_folder + oreg_pkl_file, 'rb') as f:
    loaded_oreg_dic = pickle.load(f) # deserialize using load()
f.close()

# Important Classes and Helpers

## CLASS - Gen Cage

In [None]:
class Mouse:
    def __init__(self, name, mouse_folder):
        self.name = name
        self.mouse_folder = mouse_folder
        self.day_2_session = {}
    def add_session(self, date, day_dic, manip_file, photom_df):
        self.day_2_session[date] = {'day_dic': day_dic, 'manip_file': manip_file, 'photom_df': photom_df}
    def days(self):
        return self.day_2_session.keys()
    
class Cage:
    def __init__(self):
        print("fresh new cage")
        self.name_2_mouse = {}
    def add_mouse(self, mouse=Mouse):
        self.name_2_mouse[mouse.name] = mouse
    def get_mouse(self, mouse_name):
        return self.name_2_mouse[mouse_name]    
    def mice(self):
        return self.name_2_mouse.keys()
    
def full_mouse_name(mouse_ID):
    if mouse_ID in ['G','H','I','J','K']:
        return 'RR20240320_' + mouse_ID
    elif mouse_ID == 'F':
        return 'RR20231109_'+mouse_ID
    else:
        return 'RR20231108_'+mouse_ID

## CLASS - Sessions Cage

In [None]:
pkl_folder = '/Pickles/Manip_BigRun_Pickle'

class sessions_cage:
    """_summary_
    Cage that contains sessions directly keyed by their session name (mouse-date)
    """
    def __init__(self):
        self.sessions = {} #dictionary containing sessions
    def add_sess(self, key, session):
        self.sessions[key] = session
    def show_sessions(self):
        print(self.sessions.keys())
    
def load_pickle_file(pkl_file):
    print(datapath+pkl_folder+pkl_file)
    with open(datapath+pkl_folder+pkl_file, 'rb') as f:
        loaded_session = pickle.load(f) # deserialize using load()
    f.close()
    return loaded_session

In [None]:
sess_cage = sessions_cage()
blacklist = ['RR20240320_J-2024_04_26'] # blacklisted out because NO TRIALS (pushes)
# Loads pickles from folder into the cage
for file in os.listdir(datapath+pkl_folder):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    if key in blacklist:
        continue
    obj = load_pickle_file('/' + file)
    sess_cage.add_sess(key, obj)

In [None]:
ordered_sessions = list(sess_cage.sessions.keys()) #just a useful variable for lots of plotting/parsing
ordered_sessions.sort()
ordered_sessions

# Wrappers

## Manip Wrapper (handles behavior)

In [None]:
def correct_framecount(og_frame_count, mtp):
    """_summary_
        corrects breaks in the framecount (like the framecount jumps back down to 0)
    """
    cands = np.where(np.abs(np.diff(og_frame_count[mtp[0]:mtp[1]])) > 1)[0]
    if len(cands) == 0:
        return og_frame_count
    plt.figure()
    plt.plot(og_frame_count)   
    plt.title('original frame count, mtp: ' + str(mtp)) 
    for i,cand in enumerate(cands):
        adjust = cand + mtp[0]
        new_framesubset = og_frame_count[adjust+1:] - og_frame_count[adjust+1] + og_frame_count[adjust] + 1
        new_frame_count = np.hstack([og_frame_count[0:adjust+1], new_framesubset])
        assert len(og_frame_count) == len(new_frame_count)
        og_frame_count = new_frame_count #update og_frame_count
        plt.figure()
        plt.plot(og_frame_count)   
        plt.title('fix: iter' + str(i)) 
    return og_frame_count

def manip_wrapper(file_m, path, single_trial_vis = None):
    """_summary_
        Run via Big Run Part 1 (refer below). Used for loading in raw behavioral data as well as trial typing
        and making behavior cubes
        
        Loads day_dic. NOTE - several day_dic keys are fronted with "og". That is because not all trials (waves) 
        are used due to photometry related bounds and outliers. Thus later we determine the subset of trials
        and make the non-og versions (which are the ones actually used)
        
    Args:
        file_m (str): file name
        path (str): path to access file
        
    Returns:
        day_dic: dictionary containing most behaviorally relevant data 
    """

    # day_dic['combin_df'], day_dic['col_dic'], day_dic['dlc_df'], day_dic['manip_data'], day_dic['metadata']
    day_dic = {}

    manip_path = path + file_m
    manip_colnames = ['x','y','_','_','_','_','frame_count_1','frame_count_2', 'lick', 'reward', 'robot_state'] #, 'cam_frame']
    manip_df = pd.read_csv(manip_path, sep='\t', lineterminator='\n', header = None)
    manip_data = manip_df.to_numpy()[1:, :]
    day_dic['manip_data'] = manip_data
    combin_df = pd.DataFrame(data=manip_data, columns=manip_colnames)
    col_dic = {col:i for col, i in zip(combin_df.columns,np.arange(combin_df.shape[1]))}

    # uncomment if want to look at frame count
    # plt.plot(combin_df['frame_count_1'])

    # frame mapping from manipulandum to camera
    # using frame count
    ref_dic = {'frame_count_1' : 6}
    manip_trans = dh.determine_manip_trans(manip_data[:,ref_dic['frame_count_1']])

    # OUTPUT
    # other metadata commented out cuz i dont have the dlc data
    metadata  = {'manip_trans_points': manip_trans} #, 'cam_trans_points': cam_trans, 'bodyparts': bodyparts} 

    #identify any frame count jumps
    og_frame_count = combin_df['frame_count_1'].to_numpy()
    
    # FIX THE FRAME COUNT IF JUNCTIONS
    new_frame_count = correct_framecount(og_frame_count, manip_trans) #this is just og_frame_count if there are no junctions
    new_frame_count = new_frame_count - new_frame_count[manip_trans[0]] #set first frame (mtp 0) to zero
    combin_df['frame_count_1'] = new_frame_count

    # OG: from dlc_df load_synced_dfs
    # return combin_df, col_dic, dlc_df, manip_data, metadata
    day_dic['metadata'] = metadata
    day_dic['combin_df'] = combin_df
    day_dic['col_dic'] = col_dic

    print(day_dic['metadata'])
    
    # for debugging only
    # return ch.number_waveforms_modern(day_dic['manip_data'],  day_dic['metadata'], sing_trial=single_trial_vis,
                                                                    #   plot=False, trial_front_half=FORWARD_WINDOW, mini_window=True, vel_threshes=[0.035, 0.015])

    # the rest is original
    # change marker
    # return day_dic
    
    day_dic['og_waves'], day_dic['og_summary'] = ch.number_waveforms_modern(day_dic['manip_data'],  day_dic['metadata'], sing_trial=single_trial_vis,
                                                                      plot=False, trial_front_half=FORWARD_WINDOW, mini_window=True, vel_threshes=[0.035, 0.015])

    # og_rew_waves = ch.number_waveforms_reward(day_dic['manip_data'])

    # change marker-done commented this out 3/27
    # rew_trials = [i for i, wav in enumerate(day_dic['waves']) if wav[2].split('_')[0] == 'rewarded']
    # unrew_trials = [i for i, wav in enumerate(day_dic['waves']) if wav[2].split('_')[0] == 'unrewarded']
    # day_dic['manip_cube_trials'] = {'rewarded': rew_trials, 'unrewarded': unrew_trials}
    
    day_dic['behav_mat'] = day_dic['combin_df'].to_numpy() 
    
    # change marker-done
    day_dic['og_wcube_all'] = bd.gen_manip_cube(day_dic['behav_mat'], day_dic['og_waves'], back_window=BACK_WINDOW, forward_window=FORWARD_WINDOW)
    
    # fixed
    bh.visualize_behav_cube(day_dic['og_wcube_all'], day_dic['combin_df'].columns, BACK_WINDOW, PRE_MOVE_WINDOW, indiv_trials=True, trial_type='all', dlc_flag=False)

    # fixed (doesn't use buffer anymore)
    
    # change marker-needs to be moved to new wrapper
    day_dic['og_manip_dist'] = bh.calc_dist_forcube(day_dic['og_wcube_all'], day_dic['col_dic'], 'x','y', BACK_WINDOW, PRE_MOVE_WINDOW)
    day_dic['og_endpoints'] = bh.det_push_endpoints(day_dic['og_wcube_all'], day_dic['og_manip_dist'], BACK_WINDOW, plot=False)
    return day_dic

## Photom Wrapper

In [None]:
"""_summary_
Helper functions for loading in the raw photometry data.
"""

def find_subfolder_name(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder):
    subfolders = [a for a in os.listdir(datapath+fluor_folder+mouse_folder) if a[0] != '.']
    matched_substring = []
    for subfold in subfolders:
        if subfold.startswith(date):
            matched_substring.append(subfold)
    assert len(matched_substring) == 1
    return '/'+matched_substring[0] + '/Fluorescence.csv'

def photom_wrapper(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder, title=''):
    """_summary_

    Loads in the correct photometry fluorescence csv. 
    Takes care of subtracting background ROI (Ch5) from CH1,2,3,4
    Ch2 is DCN, Ch3 is thalamus, Ch4 is SNr, and Ch5 is the control background ROI.
    
    Returns:
        sig_df: pandas dataframe containing photometry data 
    """
    file_f = find_subfolder_name(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder)
    print(file_f)
    print(datapath, fluor_folder, mouse_folder, file_f)
    flur_df = pd.read_csv(datapath + fluor_folder + mouse_folder + file_f, header=1)
    print('flur_df columns: ', flur_df.columns)

    
    dat = {}
    cntrl_410 = flur_df['CH5-410']
    cntrl_470 = flur_df['CH5-470']
    dat['CH1-470'] = flur_df['CH1-470']
    dat['CH1-410'] = flur_df['CH1-410']
    # Preprocessing: subtract control noise
    for i, reg_name in zip([2,3,4], ['DCN', 'Thal', 'SNr']):
        dat[reg_name + '-470'] = flur_df['CH'+str(i)+'-470']-cntrl_470
        dat[reg_name + '-410'] = flur_df['CH'+str(i)+'-410']-cntrl_410
        
    sig_df = pd.DataFrame.from_dict(dat)
    sig_df.plot(figsize=(20,10))
    plt.title(title)
    plt.legend()
    return sig_df

## Preprocess Wrapper

### helper functions (low pass, high pass)

In [None]:
def apply_butter_lowpass(photom, thresh, sampling_rate=30):
        b,a = butter(4, thresh, btype='low', fs=sampling_rate) #4th order butterpass
        filtered_photom = np.apply_along_axis(lambda x: filtfilt(b, a, x), axis=0, arr=photom)
        return filtered_photom

def apply_butter_highpass(photom, thresh, sampling_rate=30): #not used in current pipeline
    b,a = butter(2, thresh, btype='high', fs=sampling_rate)
    filtered_photom = np.apply_along_axis(lambda x: filtfilt(b, a, x, padtype='even'), axis=0, arr=photom)
    return filtered_photom

### Main Preprocessing Wrappers

In [None]:
def preprocess(photom, refpoint_framecount, combin_df, phot_coldic, parameter_dic, manip_fps=200, photom_fps=30, plotter_title=''):
    """_summary_
        preprocessing pipeline that is called on a per-trial basis by function - photom_cube_generate - below
    """
    
    # STEP -1: Determine photometry frame boundaries
    phot_bounds = [refpoint_framecount-p_BACK_WINDOW, refpoint_framecount+p_FORWARD_WINDOW-1]
    print(phot_bounds)

    # STEP 0: Extracting raw signal (raw signal - control background) 
    raw_sig = photom.loc[phot_bounds[0]:phot_bounds[1]].to_numpy()
    raw_sig_means = np.mean(raw_sig, axis=0)
    raw_sig_keys = list(photom.keys())
    print('temp photom shape', raw_sig.shape)

    # STEP 1: Low Pass Filtering - noise correction
    if parameter_dic['lowpass_threshold_2'] == None:
        lowpass_photom = apply_butter_lowpass(raw_sig, parameter_dic['lowpass_threshold']) #4th order butterworth lowpass
        lowpass_photom_means = np.mean(lowpass_photom, axis=0)
        lowpass_photom_keys = list(photom.keys())
        print('lowpass photom shape', lowpass_photom.shape)
    else: # differentially lowpass 470 and 410 signal
        raw_470 = raw_sig[:,[0,2,4,6]]
        raw_410 = raw_sig[:,[1,3,5,7]]
        arr_470 = apply_butter_lowpass(raw_470, parameter_dic['lowpass_threshold']) #4th order butterworth lowpass
        arr_410 = apply_butter_lowpass(raw_410, parameter_dic['lowpass_threshold_2']) #4th order butterworth lowpass
        lowpass_photom = np.array([arr_470[:,0],arr_410[:,0],arr_470[:,1],arr_410[:,1],arr_470[:,2],arr_410[:,2],arr_470[:,3],arr_410[:,3]]).T
        lowpass_photom_keys = list(photom.keys())
        
        print('lowpass photom shape', lowpass_photom.shape)

    # STEP 2: Motion correction, plot per region
    deltf_intermed = {} #Just aligned 410's to the 470s (reg_410adj)
    CH470_movcor = {} #(470-410)/410 half number of channels
    CH470_410_ratio = {}
    CH470_410_uratio = {}
    regions =  ['CH1','DCN','Thal','SNr']
    for i,reg in enumerate(regions):
        chan_470 = lowpass_photom[:,phot_coldic[reg+'-470']]
        chan_410 = lowpass_photom[:,phot_coldic[reg+'-410']]
        slope, intercept, r_value, p_value, std_err = linregress(x=chan_410, y=chan_470) #from scipy.stats
        chan_410_fitted = intercept + slope * chan_410
        # just shows the adjusted chan 410
        deltf_intermed[reg+'-470'] = chan_470#-np.mean(chan_470)
        deltf_intermed[reg+'-410adj'] = chan_410_fitted#-np.mean(chan_410_fitted)
        
        # shows the delta f/f
        CH470_movcor[reg+'-470_movcorr'] = (chan_470 - chan_410_fitted)/chan_410_fitted
        CH470_410_ratio[reg+'-470_410_ratio'] = chan_470/chan_410_fitted
        CH470_410_uratio[reg+'-470_410_uratio'] = chan_470/chan_410

    # STEP 2.1: Save CH470 (untouched) and CH410 (linearly aligned to ch470)
    deltaf_im_df = pd.DataFrame.from_dict(deltf_intermed)
    deltaf_im_keys = list(deltaf_im_df.keys())
    deltaf_im_np = deltaf_im_df.to_numpy()

    # STEP 2.2: Save movement corrected CH470:  (470-410_a)/410_a
    CH470_movcor_df = pd.DataFrame.from_dict(CH470_movcor)
    CH470_movcor_keys = list(CH470_movcor_df.keys())
    CH470_movcor_np = CH470_movcor_df.to_numpy() #richard comment - call it chan_470_move_cor
    
    # STEP 3: Normalization to premovement period - zscoring
    # uses a premovement window for mean and std
    p_start = parameter_dic['norm_window'][0]
    p_end = parameter_dic['norm_window'][1]

    # STEP 3.1: zscores on CH470_movcor
    means = np.mean(CH470_movcor_np[p_start:p_end,:], axis=0) # find means -pre movement window to 0 (where 0 is back_window_p up)
    stds = np.std(CH470_movcor_np[p_start:p_end,:], axis=0)
    zscores = (CH470_movcor_np-means)/stds #now F/F0
    zscores_keys = ['CH1_zscore', 'DCN_zscore', 'Thal_zscore', 'SNr_zscore']


    # output_dic contains: numpy arrays 
    output_dic = {'raw_sig':raw_sig, 'lowpass_photom':lowpass_photom, 
                  'deltaf_im_np': deltaf_im_np, 'CH470_movcor_np':CH470_movcor_np, 'zscores':zscores}
    
    # output_dic_keys contains: lists (with the names for channels for respective numpy arrays)
    output_dic_keys = {'raw_sig':raw_sig_keys, 'lowpass_photom':lowpass_photom_keys, 
                       'deltaf_im_np': deltaf_im_keys, 'CH470_movcor_np':CH470_movcor_keys, 'zscores':zscores_keys}    
    return output_dic, output_dic_keys

def within_oreg(refpoint_framecount, oreg_list):
    """
    Just returns if a photom frame is inside an outlier region (oreg)
    - used by photom_cube_generate
    """
    for pair in oreg_list:
        # if inside pair's range expanded by forward and backward window
        if pair[0] - p_FORWARD_WINDOW <= refpoint_framecount <= pair[1] + p_BACK_WINDOW: 
            return True
    return False

def photom_cube_generate(photom_df, day_dic, oreg_list, 
                         parameter_dic ={'lowpass_threshold':6, 'norm_window':[p_BACK_WINDOW-p_PRE_MOVE_WINDOW, p_BACK_WINDOW]},
                         title='', waves_override=None):
    """
    Main wrapper for generating a photom cube from photom_df
    """
    
    phot_coldic = {key:i for i,key in enumerate(photom_df.keys())}
    print(phot_coldic)
    waves = day_dic['og_waves']
    if waves_override != None:
        waves = waves_override
        print("Overriding Waves!")
    mats_dic = {}
    rand_mats_dic = {}
    
    # change marker
    trials_used = []
    outlier_trials = []
    output_dic_keys = []
    for trial in range(len(waves)):
        print('TRIAL: ' + str(trial))
        wave = waves[trial]
        refpoint_framecount = int(day_dic['combin_df'].loc[wave[0]]['frame_count_1'])
        
        upper_phot_frame = (photom_df.shape[0])-30*(FORWARD_WINDOW/200)
        lower_phot_frame = 30*(BACK_WINDOW/200)
        random_refpoint_framecount = np.random.randint(lower_phot_frame,upper_phot_frame)
        
        
        # trials_used excludes trials which are outside of the bounds
        # forward frame is out of photom_df
        if refpoint_framecount + 30*(FORWARD_WINDOW/200) > photom_df.shape[0]:
            print("STOPPING trial addition at trial: " + str(trial))
            break
        elif refpoint_framecount - 30*(BACK_WINDOW/200) < 0: #back frame is less than 0
            print("SKIPPING TRIAL: ", trial)
            continue  
        elif within_oreg(refpoint_framecount, oreg_list):  
            print("Trial found inside outlier region")
            outlier_trials.append(trial)
            continue  # elif trial in outlier_trials:#     continue #old code
        else:
            trials_used.append(trial)
            output_dic, output_dic_keys = preprocess(photom_df, refpoint_framecount, day_dic['combin_df'], phot_coldic, parameter_dic, plotter_title=title+' trial: ' + str(trial)) #used to also output outlier
            
            rand_output_dic, rand_output_dic_keys = preprocess(photom_df, random_refpoint_framecount, day_dic['combin_df'], phot_coldic, parameter_dic, plotter_title=title+' trial: ' + str(trial)) #used to also output outlier
            for key in output_dic.keys():
                if key in mats_dic.keys():
                    mats_dic[key].append(output_dic[key])
                    rand_mats_dic[key].append(rand_output_dic[key])
                else:
                    mats_dic[key] = [output_dic[key]]
                    rand_mats_dic[key] = [rand_output_dic[key]]
    cube_dic = {}
    rand_cube_dic = {}
    for key in mats_dic.keys():
        cube_dic[key] = np.dstack(mats_dic[key])    
        rand_cube_dic[key] = np.dstack(rand_mats_dic[key])
    return cube_dic, trials_used, outlier_trials, output_dic_keys, rand_cube_dic


# Which files have no frame count

In [None]:
# files with no framecount - skipped in big run - refer to the nonshort manip bigrun ipynb notebook for the function used to determine this list

noframe_count = ['RR20231108_A_2023-12-06.csv','RR20231108_A_2023-12-12.csv','RR20231108_B_2023-12-12.csv','RR20231108_B_2023-12-06.csv'
                 ,'RR20231108_D_2023-12-06.csv','RR20231108_D_2023-12-12.csv','RR20231108_C_2023-12-12.csv','RR20231108_C_2023-12-06.csv']

# Initial Processing Starts Here

## Part 1 - Behavior

In [None]:
"""_summary_

IMPORTANT - if you loaded in the sess_cage directly (unpickling), don't run this cell
- this is for generating a gen cage from the bottom up (starting with behavior)
- if you want to just do modifications to the photometry pipeline (ie: preprocessing, outliers, etc) start with big run part 2

runs through files in manip_folder and loads behavioral data into day_dic for each session.
NOTE - loads files into a gen Cage, NOT a sess_cage

avoids:
- files in blacklist_files_m (blacklisted cuz no push trials)
- files in noframe_count
"""

files_m = []
mouse_folders = []
dates = []
for m_folder in os.listdir(datapath + manip_folder):
    if m_folder.startswith('.'):
        continue
    files_m.append('/'+m_folder)
files_m.sort()

cage2 = Cage()
for i,manip_file in enumerate(files_m):
    print("&&&&&&&&&&&&&")
    print(i, manip_file)
    if manip_file in blacklist_files_m:
        print("SKIPPING blacklist file")
        continue
    elif manip_file[1:] in noframe_count:
        print("SKIPPING (no frame count)")
        continue
    
    mouse_name = manip_file.split('_')[0]+'_'+manip_file.split('_')[1]
    mouse_folder =  mouse_name + photom_addon
    mouse_folders.append(mouse_folder)

    tempstr = manip_file.split('_')[2]
    date = tempstr.split('.')[0]
    date = date.replace('-','_')
    print("mouse_name: ", mouse_name)
    day_dic = manip_wrapper(manip_file, datapath+manip_folder)
    photom_df = photom_wrapper(mouse_folder, date, title=manip_file)
    
    
    if mouse_name in cage2.name_2_mouse:
        cage2.name_2_mouse[mouse_name].add_session(date, day_dic, manip_file, photom_df)
    else:
        new_mouse = Mouse(mouse_name, mouse_folder)
        new_mouse.add_session(date, day_dic, manip_file, photom_df)
        cage2.add_mouse(new_mouse)

## Part 2 - Generate Photom Cubes

### Declaration of parameter dics for exploring different preprocessing parameters

In [None]:
norm_windows = [[p_BACK_WINDOW-p_PRE_MOVE_WINDOW, p_BACK_WINDOW], [p_BACK_WINDOW-p_PRE_MOVE_WINDOW-15, p_BACK_WINDOW-15]]
labels = ['minus1', 'minus1_alt']
print(norm_windows)

In [None]:
lowps = [2,4,6,12] #low pass thresholds

pdics_list_temp = []
for elem in lowps:
    for i,nelem in enumerate(norm_windows):
        pdics_list_temp.append([elem,i])
pdics_list_temp
pdic_list = []
for pelem in pdics_list_temp:
    pdic = {'lowpass_threshold': pelem[0], 'lowpass_threshold_2':None, 'norm_window': norm_windows[pelem[1]], 'name': str(pelem[0]) + '_' +labels[pelem[1]]}
    pdic_list.append(pdic)
pdic_list

### Note - only run either one of the cells below (not both)

### Big Run Part 2: FOR SESS CAGE (loaded in from pkl)

In [None]:
for sessname in ordered_sessions:
    session = sess_cage.sessions[sessname]
    parts = sessname.split('-')
    oreg_list = loaded_oreg_dic['/' + parts[0] + ' ' + parts[1]]
    for parameter_dic in pdic_list:
        name = parameter_dic['name']
        cube_dic_o, trials_used_o, outliers, cube_dic_keys, rand_cube_dic = photom_cube_generate(session['photom_df'], session['day_dic'], oreg_list, parameter_dic=parameter_dic) 
        sess_cage.sessions[sessname]['cube_dic_lowp_'+name] = cube_dic_o
        sess_cage.sessions[sessname]['rand_cube_dic_lowp_'+name] = rand_cube_dic
        sess_cage.sessions[sessname]['outlier_trials'] = outliers
        sess_cage.sessions[sessname]['photom_trials_used'] = trials_used_o
        sess_cage.sessions[sessname]['cube_dic_keys'] = cube_dic_keys
        
        

### Big Run Part 2: for GEN CAGE

In [None]:
# older code to work with cage2
for mouse_name in cage2.name_2_mouse.keys():
    for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
        for parameter_dic in pdic_list:
            print('SESSION: ', mouse_name, date)
            session = cage2.name_2_mouse[mouse_name].day_2_session[date]
            oreg_list = loaded_oreg_dic[mouse_name + ' ' + date]        
            name = parameter_dic['name']
            cube_dic_o, trials_used_o, outliers, cube_dic_keys, rand_cube_dic = photom_cube_generate(session['photom_df'], session['day_dic'], oreg_list, parameter_dic=parameter_dic) 
            cage2.name_2_mouse[mouse_name].day_2_session[date]['cube_dic_lowp_'+name] = cube_dic_o
            cage2.name_2_mouse[mouse_name].day_2_session[date]['rand_cube_dic_lowp_'+name] = rand_cube_dic
            cage2.name_2_mouse[mouse_name].day_2_session[date]['outlier_trials'] = outliers
            cage2.name_2_mouse[mouse_name].day_2_session[date]['photom_trials_used'] = trials_used_o
            cage2.name_2_mouse[mouse_name].day_2_session[date]['cube_dic_keys'] = cube_dic_keys
            

## Part 3 - Misc Cube Adjustments

In [None]:
"""_summary_
Just makes some adjustments listed below
- og_waves, og_wcube_all, og_endpoints -> waves, wcube_all, endpoints (these are selected from the ogs to NOT include outliers)
    - this just allows waves, wcube_all (the behavior cube) to match the trial-dimension of the photom cube
- creates wave_dic : a dictionary containing indices for waves of different trial types (indices of the list waves)
- new_endpoints : these are endpoints that are determined by just the robot state changing (so when manipulandum itself
    recognized it done)
    - i never really used these, but these are a solid alternative for push endpoints(because it's robot state based and not threshold based)
"""

def det_new_endpoints(day_dic, wave_type='og_waves'):
    waves = day_dic[wave_type]
    robo = day_dic['behav_mat'][:,10]
    new_endpoints = []
    for wave in waves:
        robo_period = robo[wave[0]:wave[0]+FORWARD_WINDOW]
        robo_trans = np.where(np.diff(robo_period) == -1)[0]
        if len(robo_trans) > 0:
            trans_point = robo_trans[0] + wave[0] + 1
            new_endpoints.append(trans_point)
        else:
            new_endpoints.append(0)
    return new_endpoints

for mouse_name in cage2.name_2_mouse.keys():
    for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
        #photom_trials_used now excldues outliers
        session = cage2.name_2_mouse[mouse_name].day_2_session[date]
        photom_trials_used = session['photom_trials_used']
        
        og_waves = session['day_dic']['og_waves']
        og_trial_inds = np.arange(len(og_waves))
        og_wcube_all = session['day_dic']['og_wcube_all']
        
        outlier_trials = session['outlier_trials']
    
        #photom_cube already exlcudes outlier trials and includes photom_trials
    
        #code to adjust og_waves and og_wcube_all
        waves = [wav for i, wav in enumerate(og_waves) if i in photom_trials_used]
        wcube_all = og_wcube_all[:,:,photom_trials_used]
        
        print(mouse_name, date)
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['waves'] = waves
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wcube_all'] = wcube_all

        rew_waves = [i for i, wav in enumerate(waves) if wav[2].split('_')[0] == 'rewarded']
        unrew_waves = [i for i, wav in enumerate(waves) if wav[2].split('_')[0] == 'unrewarded']
        suc_waves = [i for i, wav in enumerate(waves) if wav[2].split('_')[1] == 'success']
        fail_waves = [i for i, wav in enumerate(waves) if wav[2].split('_')[1] == 'failure']
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic'] = {
            'rewarded': rew_waves, 'unrewarded': unrew_waves, 
            'success': suc_waves, 'failure': fail_waves}

        #code to adjust og_manip_dist and og_endpoints
        og_manip_dist = session['day_dic']['og_manip_dist']
        og_endpoints = session['day_dic']['og_endpoints']
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['manip_dist'] = og_manip_dist[:,:,photom_trials_used]
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['endpoints'] = [endpt for i, endpt in enumerate(og_endpoints) if i in photom_trials_used]        
        
        #new endpoints (using robo state)
        new_og_endpoints = det_new_endpoints(session['day_dic'], wave_type='og_waves')
        new_endpoints = det_new_endpoints(session['day_dic'], wave_type='waves')
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['new_og_endpoints'] = new_og_endpoints
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['new_endpoints'] = new_endpoints

## Part 4 - Add rew Photom cubes (newer addition)

In [None]:
"""_summary_
Just like part 2, but now centering our trials around the reward pulse (not movement initiation)
"""

parameter_dic = {'lowpass_threshold': 2,
  'lowpass_threshold_2': None,
  'norm_window': [129, 150],
  'name': '2_minus1'}
for mouse_name in cage2.name_2_mouse.keys():
    for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
        session = cage2.name_2_mouse[mouse_name].day_2_session[date]
        day_dic = session['day_dic']
        rew_wave_list = ch.determine_rew_wavelist(day_dic, FORWARD_WINDOW)
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['rew_waves'] = rew_wave_list
        rew_waves_validate = [day_dic['waves'][i][0] for i in day_dic['wave_dic']['rewarded']]
        assert len(rew_waves_validate) == len(rew_wave_list) #just sanity check
        rew_wave_list_mod = [[wave] for wave in rew_wave_list]
        oreg_list = loaded_oreg_dic[mouse_name + ' ' + date]        
        name = parameter_dic['name']
        cube_dic_o, trials_used_o, outliers, cube_dic_keys, rand_cube_dic = photom_cube_generate(session['photom_df'], session['day_dic'], oreg_list, 
                                                                                                 parameter_dic=parameter_dic, waves_override=rew_wave_list_mod) 
        cage2.name_2_mouse[mouse_name].day_2_session[date]['rew_cube_dic_lowp_'+name] = cube_dic_o
        cage2.name_2_mouse[mouse_name].day_2_session[date]['rew_outlier_trials'] = outliers
        cage2.name_2_mouse[mouse_name].day_2_session[date]['rew_photom_trials_used'] = trials_used_o

    

In [None]:
# Print which sessions do not have any reward cube trials - could be useful downstream
for mouse_name in cage2.name_2_mouse.keys():
    for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
        session = cage2.name_2_mouse[mouse_name].day_2_session[date]
        if len(session['rew_cube_dic_lowp_'+name].keys()) == 0:
            print(mouse_name + ' ' + date)
            
rew_ses_notrials = ['RR20231108_B-2023_12_07', 'RR20240320_H-2024_04_19',
                    'RR20240320_H-2024_04_23', 'RR20240320_H 2024_04_24',
                    'RR20240320_H 2024_04_26']

## Part 5 - Pickle sessions from Big Run 

In [None]:
pkl_sav_folder = '/Pickles'

def serialize_sessions_from_cage(folder, cage):
    """_summary_

    Pickle a gen cage (ie: cage2)
    """
    for mouse_name in cage.name_2_mouse.keys():
        for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
            print('SESSION: ', mouse_name, date)
            session = cage.name_2_mouse[mouse_name].day_2_session[date]
            
            fname = mouse_name+'-'+date+'.pkl'
            print(fname)
            # break
            with open(folder+fname, 'wb') as f:  # open a text file
                pickle.dump(session, f) # serialize the list
            f.close()

def serialize_sess_cage(folder, cage):
    """_summary_

    (Re)pickle sess_cage
    """
    
    for sessname in cage.sessions.keys():
        session = cage.sessions[sessname]
        fname = '/'+sessname+'.pkl'
        print(fname)
        # break
        with open(folder+fname, 'wb') as f:  # open a text file
            pickle.dump(session, f) # serialize the list
        f.close()
        
#example code to load in pkl
# with open(pickle_folder+example_ss, 'rb') as f:
#     loaded_obj = pickle.load(f) # deserialize using load()
# f.close()

### only run one of the below saving cells

In [None]:
# Saving a sessions cage object (use if loaded in sess_cage)
pickle_folder = datapath+pkl_sav_folder+'/Manip_BigRun_Pickle'
serialize_sess_cage(pickle_folder, sess_cage)

In [None]:
# Saving a cage object (use if generating from scratch)
pickle_folder = datapath+pkl_sav_folder
serialize_sessions_from_cage(pickle_folder, cage2)

## Part 5.5 Compress cage and pickle

In [None]:
def gen_compressed_cage(sess_cage, session_list, mode='wheel'):
    comp_sess_cage = sessions_cage()
    if mode == 'manip':
        cube_dic_name = 'cube_dic_lowp_2_minus1'
        rew_cube_dic_name = 'rew_cube_dic_lowp_2_minus1'
        rand_cube_dic_name = 'rand_cube_dic_lowp_2_minus1'
        keys_keep = ['photom_df', cube_dic_name, rew_cube_dic_name, rand_cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'manip_file']
        day_dic_keys_keep = ['metadata', 'col_dic', 'waves', 'wcube_all', 'wave_dic', 'manip_dist', 'endpoints', 'new_endpoints']
    elif mode == 'wheel':
        cube_dic_name = 'cube_dic_lowp_2_minus1_alt'
        rand_cube_dic_name = 'rand_cube_dic_lowp_2_minus1_alt'
        keys_keep = ['photom_df', cube_dic_name, rand_cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'dlc_file', 'rad_file']
        day_dic_keys_keep = ['waves','wcube_all','wave_dic','trial_defs','wheel_trans','stride_stance_dic','hand_peaks_troughs','foot_peaks_troughs']
    for sessname in session_list:
        if sessname in blacklist:
            continue
        session = sess_cage.sessions[sessname]
        day_dic = session['day_dic']
        day_dic_keep = dict((k, day_dic[k]) for k in day_dic_keys_keep)
        session_keep = dict((k, session[k]) for k in keys_keep)
        session_keep.update({'day_dic':day_dic_keep})
        comp_sess_cage.add_sess(sessname, session_keep)
    return comp_sess_cage

In [None]:
# the function gen_compressed_cage is located in the "HELPER - Compress Cage" section under Important Classes and Helpers
# refer there to see which attributes of sessions are kept and not kept

compressed_manip_pkl_folder = datapath + pkl_sav_folder + '/Compressed_Manip'
compressed_manip_cage = gen_compressed_cage(sess_cage, ordered_sessions, mode='manip')
serialize_sess_cage(compressed_manip_pkl_folder, compressed_manip_cage)

# Big run Visualizations

## Control: front half back half trial typing

In [None]:
for mouse_name in cage2.name_2_mouse.keys():
    for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
        session = cage2.name_2_mouse[mouse_name].day_2_session[date]    
        waves = session['day_dic']['waves']
        front_half = np.arange(int(len(waves)/2))
        back_half = np.arange(int(len(waves)/2), len(waves))
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['front_half'] = front_half
        cage2.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['back_half'] = back_half

## Vis Part 1: Plotting EACH Session

### Helpers

In [None]:
def plot_cubedics(parameter_dic, cube_dic_type, 
                  ylim_dic = {'raw_sig': [-1.5,1.5], 'lowpass_photom': [-1.5,1.5], 'deltaf_im_np': [-1,1], 'CH470_movcor_np': [-0.005,0.005], 'CH470_410_ratio_np': [0.998,1.002],
                              'zscores': [-4,4], 'zscores_ratio': [-4,4], 'f_f0': [0.998,1.002],
                             'CH470_410_uratio_np' : [0.998,1.002], 'zscores_uratio': [-4,4], 'f_f0_u': [0.998,1.002]
                             },
                  care_about = [True,True,True], heatmap=True,
                  save_subfolder='', save_genfolder = output_path + '/Preprocessing', save_label='', no_ylim=False) :
    """
    Summary
    -----
    Plots and (optionally) saves plots of ch1, dcn, thal, snr - averaged across TRIALS
    NOTE - if save_subfolder left as '', no saving occurs
    NOTE - currently has two breaks in the for loop (put there during debugging), remove these if want to use this and plot ALL sessions
    
    Parameters
    -----
    parameter_dic: dictionary of preprocessing params. should use a parameter_dic created in Big Run Part 2
    
    cube_dic_type: string selected from ['raw_sig', 'lowpass_photom', 'deltaf_im_np', 'deltaf_np', 'zscores']

        # select_from: 'raw_sig', 'lowpass_photom' 'deltaf_im_np' 'CH470_movcor_np' 'zscores's
    
    care_about = if we want to plot [all_sess, rewarded, unrewarded]

    """
    lowp_thresh_used = parameter_dic['lowpass_threshold']
    addon = '_lowp_'+str(lowp_thresh_used)
    norm_window = parameter_dic['norm_window'] 
    if no_ylim:
        ylim=None
    else:
        ylim = ylim_dic[cube_dic_type]
    for mouse_name in cage2.name_2_mouse.keys():
        for date in cage2.name_2_mouse[mouse_name].day_2_session.keys():
            print(mouse_name, date)
            session = cage2.name_2_mouse[mouse_name].day_2_session[date]    
            cube_dic = session['cube_dic'+addon]
            cube_dic_keys = session['cube_dic_keys'][cube_dic_type]
            # print("Cube dic keys: ")
            # print(cube_dic_keys)
            col_dic = {key: i for i,key in enumerate(cube_dic_keys)}
            
            if len(list(cube_dic.keys())) == 0:
                print("NO WAVES")
                continue
            cube = cube_dic[cube_dic_type]     
            waves = session['day_dic']['waves']
            rewarded_waves = session['day_dic']['wave_dic']['rewarded']
            unrewarded_waves = session['day_dic']['wave_dic']['unrewarded']
            # print(len(waves), len(rewarded_waves), len(unrewarded_waves))
            
            #all waves
            all_title = mouse_name + '_' + date + ' ' + str(cube.shape[2]) + ' trials'
            save_title = mouse_name + '-' + date + '-' + cube_dic_type + '-' + save_label + '.jpg'
            save_path = save_genfolder + save_subfolder
            save_flag = save_subfolder != ''
            
            norm_window = parameter_dic['norm_window']
            
            if care_about[0]:
                ph.visualize_cube(cube, col_dic, time_offset = BACK_WINDOW/200, title=all_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=save_title, heatmap=heatmap)
            #rewarded
            if len(rewarded_waves) > 0 and care_about[1]:
                rew_photom_cube = cube[:,:,rewarded_waves]
                rew_title = 'Rewarded: ' + mouse_name + '_' + date + ' ' + str(len(rewarded_waves)) + ' trials '
                save_title = save_title[:-5] + '-Rew.jpg'
                ph.visualize_cube(rew_photom_cube, col_dic, time_offset = BACK_WINDOW/200, title=rew_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=save_title, heatmap=heatmap)
            #unrewarded
            if len(unrewarded_waves) > 0 and care_about[2]:
                unrew_photom_cube = cube[:,:,unrewarded_waves]
                unrew_title = 'Unrewarded: ' + mouse_name + '_' + date + ' ' + str(len(unrewarded_waves)) + ' trials ' 
                save_title = save_title[:-5] + '-Unrew.jpg'
                ph.visualize_cube(unrew_photom_cube, col_dic, time_offset = BACK_WINDOW/200, title=unrew_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=save_title, heatmap=heatmap)

            break
        break

## VP1 - pipeline

In [None]:
#'raw_sig', 'lowpass_photom' 'deltaf_im_np' 'CH470_movcor_np' 'zscores's
cube_dic_type = 'zscores'
plot_cubedics(parameter_dic, cube_dic_type, save_subfolder='', save_label='', care_about=[True,False,False], no_ylim=False)

## Vis Part 2: parsing + avg cubes

In [None]:
"""
Parsing ordered_sessions (from sess_cage) to get groups_dic
groups_dic: a dictionary mapping time_zone ('early','mid','late') to a list of corresponding session names (to use to access sess_cage)
"""

ghikj_early_days = ['2024_04_15','2024_04_16','2024_04_17']
ghikj_mid_days = ['2024_04_19','2024_04_22']
ghik_late_days = ['2024_04_24','2024_04_25','2024_04_26']
j_late_days = ['2024_04_23','2024_04_24','2024_04_25']

abcd_early_days = ['2023_12_05','2023_12_07','2023_12_08']
f_early_days = ['2024_01_15','2024_01_16','2024_01_17']
abcd_late_days = ['2023_12_13','2023_12_14','2023_12_15']
f_late_days = ['2024_01_24','2024_01_25','2024_01_26']
f_mid_days = ['2024_01_19','2024_01_22']

early_sessions, late_sessions, mid_sessions = [],[],[]
for ses in ordered_sessions:
    date = ses.split('-')[1]
    mouse_ID = ses.split('-')[0][-1] 
    
    if date in abcd_early_days or date in f_early_days or date in ghikj_early_days:
        early_sessions.append(ses)
    elif date in abcd_late_days or date in f_late_days:
        late_sessions.append(ses)
    elif mouse_ID == 'J' and date in j_late_days:
        late_sessions.append(ses)
    elif mouse_ID != 'J' and date in ghik_late_days:
        late_sessions.append(ses)
    else:
        if mouse_ID in ['G','H','I','K','J'] and date in ghikj_mid_days:
            mid_sessions.append(ses)
        elif mouse_ID == 'F' and date in f_mid_days:
            mid_sessions.append(ses)
        elif mouse_ID not in ['G','H','I','K','J','F']:
            mid_sessions.append(ses)

groups_dic = {'early':early_sessions, 'mid': mid_sessions, 'late': late_sessions}
print(groups_dic)

In [None]:
def gen_behav_cube_lis(session_list, parameter_dic, sess_cage, cube_dic_type = 'zscores', trial_type=None):
    """
    Returns
    -----
    list of manip_velocity cubes
    """
    cube_list = []
    ses_skip = []
    cube_dic_keys = None
    for sessname in session_list:
        # parts = sessname.split('-')
        # session = cage2.name_2_mouse['/' + parts[0]].day_2_session[parts[1]] #used before loading in the cube
        session = sess_cage.sessions[sessname] #mod: May 30 2024
        manip_dist = session['day_dic']['manip_dist']
        cube = np.diff(manip_dist, axis=0) #velocity = derivative of manip_dist over time
        # manip_vel_mean = np.mean(manip_vel, axis=2) #average over trials
        if trial_type != None:
            wave_inds = session['day_dic']['wave_dic'][trial_type]
            if len(wave_inds) == 0:
                ses_skip.append(sessname)
                print(sessname)
            cube = cube[:,:,wave_inds]
        cube_list.append(cube)
    return cube_list, cube_dic_keys, ses_skip


def gen_cube_list(session_list, parameter_dic, sess_cage, cube_dic_type = 'zscores', trial_type=None):
    """
    Parameters
    -----
    cube_dic_type: string selected from ['raw_sig', 'lowpass_photom', 'deltaf_im_np', 'deltaf_np', 'zscores']
    """
    cube_list = []
    lowp_thresh_used = parameter_dic['lowpass_threshold']
    # addon = '_lowp_'+str(lowp_thresh_used)
    addon = '_lowp_' + str(parameter_dic['name'])
    ses_skip = []
    for sessname in session_list:
        parts = sessname.split('-')
        # session = cage2.name_2_mouse['/' + parts[0]].day_2_session[parts[1]] #used before loading in the cube
        session = sess_cage.sessions[sessname] #mod: May 30 2024
        cube_dic = session['cube_dic'+addon]
        cube = cube_dic[cube_dic_type]
        if trial_type != None:
            # print(session['day_dic']['wave_dic'])
            wave_inds = session['day_dic']['wave_dic'][trial_type]
            if len(wave_inds) == 0:
                ses_skip.append(sessname)
                print(sessname)
            cube = cube[:,:,wave_inds]
        cube_dic_keys = session['cube_dic_keys'][cube_dic_type]
        cube_list.append(cube)
    return cube_list, cube_dic_keys, ses_skip

def multi_cube_plot(session_list, cube_list, cube_params, cube_dic_type, parameter_dic, title_addon='',
                    save_subfolder='', save_genfolder = output_path + '/Preprocessing', save_label='',
                   ylim_dic = {'raw_sig': [-1.5,1.5], 'lowpass_photom': [-1.5,1.5], 'deltaf_im_np': [-1,1], 'deltaf_np': [-0.005,0.005], 'zscores': [-4,4]},
                   behavior_flag = False):
    # master_cubelist_params = ['DCN average zscore','Thal average zscore','SNr average zscore']
    col_dic = {elem: i for i, elem in enumerate(cube_params)}    
    master_cubelist = []
    daycube_list = [np.nanmean(daycube, axis=2) for daycube in cube_list] #list of cubes averaged across trials
    mouse_ids, grouped_daycubes = sh.only_group_by_mice(session_list, daycube_list) #makes a list of cubes per mouse (sublists in the list grouped_daycubes)
    sesscube_list = [np.dstack(daycube_sublist) for daycube_sublist in grouped_daycubes]
    mastercube_list = [np.nanmean(sesscube, axis=2) for sesscube in sesscube_list]
    mastercube = np.dstack(mastercube_list)
    title=cube_dic_type + '-' + str(parameter_dic) + '-' + title_addon

    save_title = '/' + save_label + '-' + cube_dic_type + '.jpg'
    save_path = save_genfolder + save_subfolder
    save_flag = save_subfolder != ''
    norm_window=parameter_dic['norm_window']
    if not behavior_flag:
        ph.visualize_cube(mastercube, col_dic, BACK_WINDOW/200, title=title, norm_window=norm_window,
                      save_flag=save_flag, save_path=save_path, save_title=save_title,
                      plot_3D=False, xlabel='Time (s)', ylabel='Z-Score', ylim=ylim_dic[cube_dic_type])
    else:
        new_norm_win = [(val/30)*200 for val in parameter_dic['norm_window']]
        ph.visualize_master_behavcube(mastercube, new_norm_win, 200, 5)
        # def visualize_master_behavcube(cube, norm_window, frame_rate, time_offset, title= '', save_flag=False, save_path = '', save_title = '',heatmap=True):
    return mastercube

## VP2 - SAVING

### Saving - NO trial types

In [None]:
cube_dic_type = 'zscores'
param_dic_manip = {'lowpass_threshold': 2,
  'lowpass_threshold_2': None,
  'norm_window': [129, 150],
  'name': '2_minus1'}

cube_list, cube_dic_keys, ses_skip = gen_cube_list(ordered_sessions, param_dic_manip, sess_cage, cube_dic_type=cube_dic_type)
master_cube_all = multi_cube_plot(ordered_sessions,cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,
                                  save_genfolder=output_path + '/Manip_Photom_figures', save_subfolder='/allses',save_label='alltrials')

b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_manip, sess_cage, cube_dic_type=cube_dic_type)
b_master_cube_all = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_manip, save_subfolder='',
                                    behavior_flag=True)
subfolder = '/allses'
# ph.save_cube(master_cube_all,  subfolder, 'alltrials', cube_type = cube_dic_type)
# ph.save_cube(b_master_cube_all,  subfolder, 'alltrials_behav', cube_type = cube_dic_type, behav=True)

time_z_cubes = []
for time_z in groups_dic.keys():
    cube_list, cube_dic_keys, cube_mean_vals = gen_cube_list(groups_dic[time_z], param_dic_manip, sess_cage, cube_dic_type=cube_dic_type)
    cube = multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,
                           save_genfolder=output_path + '/Manip_Photom_figures', save_subfolder='/'+time_z,save_label='alltrials')
    
    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_manip, sess_cage, cube_dic_type=cube_dic_type)
    b_master_cube_all = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,save_subfolder='',
                                        behavior_flag=True)
    subfolder = '/' + time_z
    # ph.save_cube(cube,  subfolder, 'alltrials', cube_type = cube_dic_type)
    # ph.save_cube(b_master_cube_all,  subfolder, 'alltrials_behav', cube_type = cube_dic_type, behav=True)

### Saving - YES trial types

In [None]:
cube_dic_type = 'zscores'
param_dic_manip = {'lowpass_threshold': 2,
  'lowpass_threshold_2': None,
  'norm_window': [129, 150],
  'name': '2_minus1'}
for trial_type in ['rewarded','unrewarded']:
    cube_list, cube_dic_keys, ses_skip = gen_cube_list(ordered_sessions, param_dic_manip, sess_cage, cube_dic_type=cube_dic_type,
                                                       trial_type=trial_type)
    master_cube_all = multi_cube_plot(ordered_sessions,cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,
                                      title_addon='alltime_' + trial_type,
                                      save_genfolder=output_path + '/Manip_Photom_figures', save_subfolder='/allses',save_label=trial_type)

    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_manip, sess_cage, cube_dic_type=cube_dic_type,
                                                        trial_type=trial_type)
    b_master_cube_all = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_manip, save_subfolder='',
                                        behavior_flag=True, title_addon='alltime_' + trial_type)
    subfolder = '/allses'
    # ph.save_cube(master_cube_all,  subfolder, trial_type, cube_type = cube_dic_type)
    # ph.save_cube(b_master_cube_all,  subfolder, trial_type + '_behav', cube_type = cube_dic_type, behav=True)

    time_z_cubes = []
    for time_z in groups_dic.keys():
        cube_list, cube_dic_keys, cube_mean_vals = gen_cube_list(groups_dic[time_z], param_dic_manip, sess_cage, cube_dic_type=cube_dic_type,
                                                                 trial_type=trial_type)
        cube = multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,
                               title_addon=time_z + '_' + trial_type,
                               save_genfolder=output_path + '/Manip_Photom_figures', save_subfolder='/'+time_z,save_label=trial_type)
        
        b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_manip, sess_cage, cube_dic_type=cube_dic_type,
                                                                    trial_type=trial_type)
        b_master_cube_all = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_manip,save_subfolder='',
                                            behavior_flag=True, title_addon=time_z + '_' + trial_type)
        subfolder = '/' + time_z
        # ph.save_cube(cube,  subfolder, trial_type, cube_type = cube_dic_type)
        # ph.save_cube(b_master_cube_all,  subfolder, trial_type + '_behav', cube_type = cube_dic_type, behav=True)