In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import math
import plotly.express as px
import plotly.graph_objects as go
import scipy
import cv2
import random
from itertools import combinations
import copy
import pickle

from scipy.ndimage import gaussian_filter1d
from scipy.signal import butter, filtfilt
from scipy.stats import linregress

#These lines allow us to import functions from my python func with helper functions
import sys

sys.path.insert(0, '/Users/charliehuang/Documents/Photometry_pipeline/data_analysis_helperfuncs')
import behav_data_analysis as bd
import dlc_helper as dh
import wheel_helper as wh
import photom_helper as ph
import statistics_helper as sh
import stride_stance_helper as stsh

%load_ext autoreload

%autoreload 2
import importlib
importlib.reload(bd)
%config IPCompleter.greedy=True

# Important Parameters

In [None]:
#not used yet
WARP_LENGTH = 1000

BACK_WINDOW = 500
PRE_MOVE_WINDOW = 70 # 0.7 seconds
FORWARD_WINDOW = 500

p_BACK_WINDOW = int(30*(BACK_WINDOW/100))
p_FORWARD_WINDOW = int(30*(FORWARD_WINDOW/100))
p_PRE_MOVE_WINDOW = int(30*(PRE_MOVE_WINDOW/100))

datapath = '/Users/charliehuang/Documents/python_work/data/Photometry'
manip_folder = '/Photometry_Manipulandum'

photom_addon = '_2C3T4B'
fluor_folder = '/Photometry_Fluorescence'

arduino_folder = '/Photometry_Wheel'
radians_folder = '/radians'
dlc_folder = '/DLC'
rename_dic = {'A':'G', 'B':'H', 'C':'I', 'D':'J', 'E':'K'}
output_path = datapath + '/Outputs'

# Important Classes and Helpers

## Class - Gen Cage

In [None]:
class Mouse:
    def __init__(self, name, mouse_folder):
        self.name = name
        self.mouse_folder = mouse_folder
        self.day_2_session = {}
    def add_session(self, date, day_dic, rad_file, dlc_file, photom_df):
        self.day_2_session[date] = {'day_dic': day_dic, 'rad_file': rad_file, 'dlc_file': dlc_file, 'photom_df': photom_df}
    
class Cage:
    def __init__(self):
        print("fresh new cage")
        # self.date_2_mouse = {}
        self.name_2_mouse = {}
    def add_mouse(self, mouse=Mouse):
        self.name_2_mouse[mouse.name] = mouse
    def get_mouse(self, mouse_name):
        return self.name_2_mouse[mouse_name]    
    

def full_mouse_name(mouse_ID):
    if mouse_ID in ['G','H','I','J','K']:
        return 'RR20240320_' + mouse_ID
    elif mouse_ID == 'F':
        return 'RR20231109_'+mouse_ID
    else:
        return 'RR20231108_'+mouse_ID

## CLASS - Sessions Cage

In [None]:
pkl_folder = '/Pickles/Wheel_BigRun_Pickle'

class sessions_cage:
    def __init__(self):
        self.sessions = {}
    def add_sess(self, key, session):
        self.sessions[key] = session
    def show_sessions(self):
        print(self.sessions.keys())
    
def load_pickle_file(pkl_file):
    print(datapath+pkl_folder+pkl_file)
    with open(datapath+pkl_folder+pkl_file, 'rb') as f:
        loaded_session = pickle.load(f) # deserialize using load()
    f.close()
    return loaded_session

In [None]:
sess_cage = sessions_cage()
for file in os.listdir(datapath + pkl_folder):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    obj = load_pickle_file('/' + file)
    sess_cage.add_sess(key, obj)

In [None]:
ordered_sessions = list(sess_cage.sessions.keys())
ordered_sessions.sort()
ordered_sessions

# Wrappers

## Radians Wrapper

In [None]:
def interpolate_frame_count(radians, wheel_trans, photom_df_length):
    """_summary_

    Uses wheel transition points (wheel_trans) to interpolate the frame count for the radians
    """
    
    conv_rate = (wheel_trans[1]-wheel_trans[0])/(photom_df_length)
    print('HERE: wheel 2 photom rate: ', conv_rate)
    radians_inds = np.arange(radians.shape[0])
    radians_inds[:wheel_trans[0]] = 0
    radians_inds[wheel_trans[1]+1:] = 0
    rad_subset = radians_inds[wheel_trans[0]:wheel_trans[1]+1]
    temp = (rad_subset - wheel_trans[0]) * (1/conv_rate)
    radians_inds[wheel_trans[0]:wheel_trans[1]+1] = temp.astype(int)
    
    return radians_inds

# radians_path = datapath + arduino_folder + radians_folder
# ex_file = '/radians_RR20231108_C_2024-02-08-125142-0000_flippedDLC_resnet50_wheel_behaviorSep12shuffle1_500000_filtered.csv'
def radians_wrapper(file_r, dlc_file, wheel_trans, photom_df_length, path = datapath+arduino_folder+radians_folder):
    """_summary_
        Run via Big Run Part 1 (refer below). Used for loading in raw behavioral data as well as trial typing
        and making behavior cubes
        
        Loads day_dic. NOTE - several day_dic keys are fronted with "og". That is because not all trials (waves) 
        are used due to photometry related bounds and outliers. Thus later we determine the subset of trials
        and make the non-og versions (which are the ones actually used)
        
    Args:
        file_m (str): file name
        path (str): path to access file
        
    Returns:
        day_dic: dictionary containing most behaviorally relevant data 
    """
    day_dic = {}
    rd_df = pd.read_csv(path+file_r, index_col = 0)
    day_dic['radians_df'] = rd_df
    # plt.figure()
    # plt.plot(rd_df['radians_interp'])
    # plt.title(file_r[:31])
    dlc_df, bodyparts = dh.gen_dlc_df(datapath+arduino_folder+dlc_folder+dlc_file)
    # print(rd_df.columns)
    
    # uncomment soon
    day_dic['wheel_trans'] = wheel_trans
    radians_inds = interpolate_frame_count(day_dic['radians_df']['radians_interp'], wheel_trans, photom_df_length)
    combin_dic = {'radians' : rd_df['radians'],'radians_likelihood' : rd_df['radians_likelihood'],
                  'radians_interp' : rd_df['radians_interp'], 'frame_count_1' : radians_inds}
    combin_df_temp = pd.DataFrame.from_dict(combin_dic)
    
    combin_df = pd.concat([combin_df_temp, dlc_df], axis=1)
    
    day_dic['combin_df'] = combin_df
    day_dic['og_waves'] = det_waves(day_dic['radians_df']['radians_interp'], wheel_trans, combin_df, photom_df_length, title=file_r[:31])
    day_dic['og_wcube_all'] = bd.gen_manip_cube(np.expand_dims(day_dic['radians_df']['radians_interp'], axis=1), day_dic['og_waves'], back_window=BACK_WINDOW, forward_window=FORWARD_WINDOW)
    
    return day_dic

def det_waves(radians, wheel_trans, combin_df, photom_length, title=''):
    """_summary_
    My thresholding function (uses helper functions from wheel_helper) for movement initiation- trial typing
    """
    
    lookrange = [0, len(radians)]
    spec_dic = {"thresh": 0.002, "filter_param" : 3, "order" : 1, "clus_max_interval" : 300, "clus_min_range": 20}
    trans_points, cbounds = wh.compute_threshed_disp(radians, lookrange, spec_dic, plot=False)
    
    trial_cbounds = evaluate_cbounds(cbounds, wheel_trans, radians, combin_df, photom_length)
    plt.figure(figsize=(20, 8))
    plt.plot(radians)
    plt.title(title + ' now with evaluated cbounds')
    for tc in trial_cbounds:
        plt.axvline(tc[0], c='r', alpha=0.7)
    plt.xlim(lookrange)
    plt.axvline(wheel_trans[0], c='0.3')
    plt.axvline(wheel_trans[1], c='0.3')
    # plt.ylim([0,5])
    return trial_cbounds

def evaluate_cbounds(cbounds, wheel_trans, radians, combin_df, photom_length):
    """_summary_
    evaluates potential waves on being within the bounds of photom_df
    as well as
    the amount of displacement from pre and post movement
    """
    
    trial_cbounds = []
    min_meanval_change_thresh = 0.05
    baseline_diffs_thresh = 0.001
    for cbound in cbounds:
        cb = cbound[0]
        cbe = cbound[1]
        assert cbe > cb
        if combin_df.loc[cb]['frame_count_1'] - p_BACK_WINDOW < 0 or combin_df.loc[cbe]['frame_count_1'] + p_FORWARD_WINDOW > photom_length:
            continue
        #old bounds if statement
        # if cb - BACK_WINDOW < wheel_trans[0] or cbe + FORWARD_WINDOW > wheel_trans[1]:
        pre, post = radians[cb-BACK_WINDOW:cb], radians[cbe: cbe+FORWARD_WINDOW]
        pre_mean, post_mean = np.mean(pre), np.mean(post)
        pre_diffs, post_diffs = np.diff(pre), np.diff(post)
        pre_abs_diff_mean, post_abs_diff_mean = np.mean(np.abs(pre_diffs)), np.mean(np.abs(post_diffs))
        if abs(post_mean - pre_mean) >= min_meanval_change_thresh:
            if pre_abs_diff_mean < baseline_diffs_thresh and post_abs_diff_mean < baseline_diffs_thresh:
                trial_cbounds.append(cbound)    
                
    return trial_cbounds

## Photom Wrapper

In [None]:
"""_summary_
Helper functions for loading in the raw photometry data.
"""
def find_subfolder_name(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder):
    subfolders = [a for a in os.listdir(datapath+fluor_folder+mouse_folder) if a[0] != '.']
    matched_substring = []
    for subfold in subfolders:
        if subfold.startswith(date):
            matched_substring.append(subfold)
    assert len(matched_substring) == 1, 'bish pls'
    return '/'+matched_substring[0] + '/Fluorescence.csv'

def photom_wrapper(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder, title=''):
    """_summary_

    Loads in the correct photometry fluorescence csv. 
    Takes care of subtracting background ROI (Ch5) from CH1,2,3,4
    Ch2 is DCN, Ch3 is thalamus, Ch4 is SNr, and Ch5 is the control background ROI.
    
    Returns:
        sig_df: pandas dataframe containing photometry data 
    """
    file_f = find_subfolder_name(mouse_folder, date, datapath=datapath, fluor_folder=fluor_folder)
    print('file f: ', file_f)
    print(datapath, fluor_folder, mouse_folder, file_f)
    flur_df = pd.read_csv(datapath + fluor_folder + mouse_folder + file_f, header=1)
    print('flur_df columns: ', flur_df.columns)

    #Ch2 is DCN, Ch3 is thalamus, Ch4 is SNr, and Ch5 is the control background ROI.
    dat = {}
    cntrl_410 = flur_df['CH5-410']
    cntrl_470 = flur_df['CH5-470']
    dat['CH1-470'] = flur_df['CH1-470']
    dat['CH1-410'] = flur_df['CH1-410']
    # Preprocessing: subtract control noise
    for i, reg_name in zip([2,3,4], ['DCN', 'Thal', 'SNr']):
        dat[reg_name + '-470'] = flur_df['CH'+str(i)+'-470']-cntrl_470
        dat[reg_name + '-410'] = flur_df['CH'+str(i)+'-410']-cntrl_410
        
    sig_df = pd.DataFrame.from_dict(dat)
    sig_df.plot(figsize=(20,10))
    plt.title(title)
    plt.legend()
    return sig_df

## Preprocess Wrapper

### helper functions (low pass, high pass)

In [None]:
def apply_butter_lowpass(photom, thresh, sampling_rate=30):
    b,a = butter(4, thresh, btype='low', fs=sampling_rate)
    filtered_photom = np.apply_along_axis(lambda x: filtfilt(b, a, x), axis=0, arr=photom)
    return filtered_photom

def apply_butter_highpass(photom, thresh, sampling_rate=30):
    b,a = butter(2, thresh, btype='high', fs=sampling_rate)
    filtered_photom = np.apply_along_axis(lambda x: filtfilt(b, a, x, padtype='even'), axis=0, arr=photom)
    return filtered_photom

### main wrappers - preprocess and load photometry cubes

In [None]:
def preprocess(photom, refpoint_framecount, combin_df, phot_coldic, parameter_dic, manip_fps=100, photom_fps=30, plotter_title=''):
    """_summary_
        preprocessing pipeline that is called on a per-trial basis by function - photom_cube_generate - below
    """
    
    # STEP -1: Determine photometry frame boundaries
    phot_bounds = [refpoint_framecount-p_BACK_WINDOW, refpoint_framecount+p_FORWARD_WINDOW-1]
    print(phot_bounds)

    # STEP 0: Extracting raw signal (raw signal - control background) 
    raw_sig = photom.loc[phot_bounds[0]:phot_bounds[1]].to_numpy()
    raw_sig_means = np.mean(raw_sig, axis=0)
    raw_sig_keys = list(photom.keys())
    print('temp photom shape', raw_sig.shape)

    # STEP 1: Low Pass Filtering - noise correction
    if parameter_dic['lowpass_threshold_2'] == None:
        lowpass_photom = apply_butter_lowpass(raw_sig, parameter_dic['lowpass_threshold']) #4th order butterworth lowpass
        lowpass_photom_means = np.mean(lowpass_photom, axis=0)
        lowpass_photom_keys = list(photom.keys())
        print('lowpass photom shape', lowpass_photom.shape)
    else: # differentially lowpass 470 and 410 signal
        raw_470 = raw_sig[:,[0,2,4,6]]
        raw_410 = raw_sig[:,[1,3,5,7]]
        arr_470 = apply_butter_lowpass(raw_470, parameter_dic['lowpass_threshold']) #4th order butterworth lowpass
        arr_410 = apply_butter_lowpass(raw_410, parameter_dic['lowpass_threshold_2']) #4th order butterworth lowpass
        lowpass_photom = np.array([arr_470[:,0],arr_410[:,0],arr_470[:,1],arr_410[:,1],arr_470[:,2],arr_410[:,2],arr_470[:,3],arr_410[:,3]]).T
        lowpass_photom_keys = list(photom.keys())
        
        print('lowpass photom shape', lowpass_photom.shape)

    # STEP 2: Motion correction, plot per region
    deltf_intermed = {} #Just aligned 410's to the 470s (reg_410adj)
    CH470_movcor = {} #(470-410)/410 half number of channels
    CH470_410_ratio = {}
    CH470_410_uratio = {}
    regions =  ['CH1','DCN','Thal','SNr']
    for i,reg in enumerate(regions):
        chan_470 = lowpass_photom[:,phot_coldic[reg+'-470']]
        chan_410 = lowpass_photom[:,phot_coldic[reg+'-410']]
        slope, intercept, r_value, p_value, std_err = linregress(x=chan_410, y=chan_470) #from scipy.stats
        chan_410_fitted = intercept + slope * chan_410
        # just shows the adjusted chan 410
        deltf_intermed[reg+'-470'] = chan_470#-np.mean(chan_470)
        deltf_intermed[reg+'-410adj'] = chan_410_fitted#-np.mean(chan_410_fitted)
        
        # shows the delta f/f
        CH470_movcor[reg+'-470_movcorr'] = (chan_470 - chan_410_fitted)/chan_410_fitted
        CH470_410_ratio[reg+'-470_410_ratio'] = chan_470/chan_410_fitted
        CH470_410_uratio[reg+'-470_410_uratio'] = chan_470/chan_410

    # STEP 2.1: Save CH470 (untouched) and CH410 (linearly aligned to ch470)
    deltaf_im_df = pd.DataFrame.from_dict(deltf_intermed)
    deltaf_im_keys = list(deltaf_im_df.keys())
    deltaf_im_np = deltaf_im_df.to_numpy()

    # STEP 2.2: Save movement corrected CH470:  (470-410_a)/410_a
    CH470_movcor_df = pd.DataFrame.from_dict(CH470_movcor)
    CH470_movcor_keys = list(CH470_movcor_df.keys())
    CH470_movcor_np = CH470_movcor_df.to_numpy() #richard comment - call it chan_470_move_cor
    
    # STEP 3: Normalization to premovement period - zscoring
    # uses a premovement window for mean and std
    p_start = parameter_dic['norm_window'][0]
    p_end = parameter_dic['norm_window'][1]

    # STEP 3.1: zscores on CH470_movcor
    means = np.mean(CH470_movcor_np[p_start:p_end,:], axis=0) # find means -pre movement window to 0 (where 0 is back_window_p up)
    stds = np.std(CH470_movcor_np[p_start:p_end,:], axis=0)
    zscores = (CH470_movcor_np-means)/stds #now F/F0
    zscores_keys = ['CH1_zscore', 'DCN_zscore', 'Thal_zscore', 'SNr_zscore']


    # output_dic contains: numpy arrays 
    output_dic = {'raw_sig':raw_sig, 'lowpass_photom':lowpass_photom, 
                  'deltaf_im_np': deltaf_im_np, 'CH470_movcor_np':CH470_movcor_np, 'zscores':zscores}
    
    # output_dic_keys contains: lists (with the names for channels for respective numpy arrays)
    output_dic_keys = {'raw_sig':raw_sig_keys, 'lowpass_photom':lowpass_photom_keys, 
                       'deltaf_im_np': deltaf_im_keys, 'CH470_movcor_np':CH470_movcor_keys, 'zscores':zscores_keys}    
    return output_dic, output_dic_keys

def within_oreg(refpoint_framecount, oreg_list):
    """
    Just returns if a photom frame is inside an outlier region (oreg)
    - used by photom_cube_generate
    """
    for pair in oreg_list:
        # if inside pair's range expanded by forward and backward window
        if pair[0] - p_FORWARD_WINDOW <= refpoint_framecount <= pair[1] + p_BACK_WINDOW: 
            return True
    return False

def photom_cube_generate(photom_df, day_dic, oreg_list, 
                         parameter_dic ={'lowpass_threshold':6, 'norm_window':[p_BACK_WINDOW-p_PRE_MOVE_WINDOW, p_BACK_WINDOW]},
                         title='', phot_coldic_override=None):
    """
    Main wrapper for generating a photom cube from photom_df
    """
    
    phot_coldic = {key:i for i,key in enumerate(photom_df.keys())}
    if phot_coldic_override != None:
        print("USING OVERRIDE on photcoldic!")
        phot_coldic = phot_coldic_override
    print(phot_coldic)
    waves = day_dic['og_waves']
    mats_dic = {}
    rand_mats_dic = {}
    
    # change marker
    trials_used = []
    outlier_trials = []
    for trial in range(len(waves)):
        print('TRIAL: ' + str(trial))
        wave = waves[trial]
        refpoint_framecount = int(day_dic['combin_df'].loc[wave[0]]['frame_count_1'])
        
        upper_phot_frame = (photom_df.shape[0])-30*(FORWARD_WINDOW/100)
        lower_phot_frame = 30*(BACK_WINDOW/100)
        random_refpoint_framecount = np.random.randint(lower_phot_frame,upper_phot_frame)
        
        
        # trials_used excludes trials which are outside of the bounds
        # forward frame is out of photom_df
        if refpoint_framecount + 30*(FORWARD_WINDOW/100) > photom_df.shape[0]:
            print("STOPPING trial addition at trial: " + str(trial))
            break
        elif refpoint_framecount - 30*(BACK_WINDOW/100) < 0: #back frame is less than 0
            print("SKIPPING TRIAL: ", trial)
            continue  
        elif within_oreg(refpoint_framecount, oreg_list):  
            print("Trial found inside outlier region")
            outlier_trials.append(trial)
            continue  # elif trial in outlier_trials:#     continue #old code
        else:
            trials_used.append(trial)
            output_dic, output_dic_keys = preprocess(photom_df, refpoint_framecount, day_dic['combin_df'], phot_coldic, parameter_dic, plotter_title=title+' trial: ' + str(trial)) #used to also output outlier
            
            rand_output_dic, rand_output_dic_keys = preprocess(photom_df, random_refpoint_framecount, day_dic['combin_df'], phot_coldic, parameter_dic, plotter_title=title+' trial: ' + str(trial)) #used to also output outlier
            for key in output_dic.keys():
                if key in mats_dic.keys():
                    mats_dic[key].append(output_dic[key])
                    rand_mats_dic[key].append(rand_output_dic[key])
                else:
                    mats_dic[key] = [output_dic[key]]
                    rand_mats_dic[key] = [rand_output_dic[key]]
    cube_dic = {}
    rand_cube_dic = {}
    for key in mats_dic.keys():
        cube_dic[key] = np.dstack(mats_dic[key])    
        rand_cube_dic[key] = np.dstack(rand_mats_dic[key])
    # return photom_cube, trials_used
    return cube_dic, trials_used, outlier_trials, output_dic_keys, rand_cube_dic


# load in light on data, also blacklist

In [None]:
# Loading in light on database - written like this cuz there are two files

wheel_lightonlightoff = '/Wheel_Vid_LightOnOFF_clean_052024.csv'
light_df  = pd.read_csv(datapath + arduino_folder + wheel_lightonlightoff)
light_dic = {}
light_blacklist = []
for i in range(light_df.shape[0]):
    sess = light_df.loc[i]['Session']
    ssplit = sess.split('_')
    mouse_key = ssplit[1]
    new_mouse_key = rename_dic[mouse_key]
    mouse_name = '/' + ssplit[0]+'_'+new_mouse_key
    date = ssplit[2]
    date = date.replace('-','_')
    print(mouse_name, mouse_key, date)

    if np.isnan(light_df.loc[i]['First Frame Light On']):
        light_blacklist.append(mouse_name + '-'+date)
    else:
        light_dic[mouse_name + '-'+date] =  [int(light_df.loc[i]['First Frame Light On']), int(light_df.loc[i]['Last Frame Light On'])]

wheel_lightonlightoff = '/Wheel_Vid_LightOnOFF_clean.csv'
light_df2  = pd.read_csv(datapath + arduino_folder + wheel_lightonlightoff)
light_dic2 = {}
light_blacklist = []
for i in range(light_df2.shape[0]):
    sess = light_df2.loc[i]['Session']
    ssplit = sess.split('_')
    mouse_key = ssplit[1]
    mouse_name = '/' + ssplit[0]+'_'+mouse_key
    date = ssplit[2]
    date = date[0:4]+'_'+date[4:6]+'_'+date[6:]
    date = date.replace('-','_')
    print(mouse_name, mouse_key, date)

    if np.isnan(light_df2.loc[i]['First Frame Light On']):
        light_blacklist.append(mouse_name + '-'+date)
    else:
        light_dic2[mouse_name + '-'+date] =  [int(light_df2.loc[i]['First Frame Light On']), int(light_df2.loc[i]['Last Frame Light On'])]

light_dic.update(light_dic2) #combines the two dictionaries

# light_blacklist = ['/RR20231109_F-2024_02_06'] # this is the result

# Initial Processing Starts Here

## Part 1 - Behavior

In [None]:
files_r = []
mouse_folders = []
for radians_file in os.listdir(datapath+arduino_folder+radians_folder):
    if radians_file.startswith('.'):
        continue
    files_r.append('/'+radians_file)
files_r.sort()

In [None]:
"""_summary_

IMPORTANT - if you loaded in the sess_cage directly (unpickling), don't run this cell
- this is for generating a gen cage from the bottom up (starting with behavior)
- if you want to just do modifications to the photometry pipeline (ie: preprocessing, outliers, etc) start with big run part 2

runs through files in manip_folder and loads behavioral data into day_dic for each session.
NOTE - loads files into a gen Cage, NOT a sess_cage

avoids:
- files in light_blacklist

"""
cage = Cage()

for radians_file in files_r:
    dlc_file = '/' + radians_file[9:]
    split_file = radians_file.split('_')
    mouse_name = '/' + split_file[1]+'_'+split_file[2]
    date = split_file[3][:10]
    date = date.replace('-','_')
    session_name = mouse_name + '-' +date
    if session_name in light_blacklist:
        print("skipping: " + str(session_name) + " because no light on light off data")
        continue 
    wheel_trans = light_dic[session_name]
    
    mouse_folder =  mouse_name + photom_addon
    mouse_folders.append(mouse_folder)
    print(date, mouse_name)
    
    
    # Part 1: load in the photom df
    photom_df = photom_wrapper(mouse_folder, date, title=mouse_name + ' ' + date)    
    
    # Part 2: load in the behavior
    day_dic = radians_wrapper(radians_file, dlc_file, wheel_trans, photom_df.shape[0], path =  datapath+arduino_folder+radians_folder)
    
    if mouse_name in cage.name_2_mouse:
        cage.name_2_mouse[mouse_name].add_session(date, day_dic, radians_file, dlc_file, photom_df)
    else:
        new_mouse = Mouse(mouse_name, mouse_folder)
        new_mouse.add_session(date, day_dic, radians_file, dlc_file, photom_df)
        cage.add_mouse(new_mouse)

## Part 2 - photom cubes

### Declaration of parameter dics for exploring different preprocessing parameters

In [None]:
norm_windows = [[p_BACK_WINDOW-p_PRE_MOVE_WINDOW, p_BACK_WINDOW], [p_BACK_WINDOW-p_PRE_MOVE_WINDOW-15, p_BACK_WINDOW-15]]
labels = ['minus1', 'minus1_alt']
print(norm_windows)

In [None]:
lowps = [1,2,4,6,12]
pdics_list_temp = []
for elem in lowps:
    for i,nelem in enumerate(norm_windows):
        pdics_list_temp.append([elem,i])
pdics_list_temp
pdic_list = []
for pelem in pdics_list_temp:
    pdic = {'lowpass_threshold': pelem[0], 'lowpass_threshold_2': None, 'norm_window': norm_windows[pelem[1]], 'name': str(pelem[0]) + '_' +labels[pelem[1]]}
    pdic_list.append(pdic)

pdic_list

### Note - only run either one of the cells below (not both)

### Big Run Part 2: FOR SESS CAGE (loaded in from pkl)

In [None]:
for sessname in ordered_sessions:
    session = sess_cage.sessions[sessname]
    parts = sessname.split('-')
    oreg_list = [] # NOTE - I have NOT done outlier detection for the wheel photom dataset
    pc_override = None
    if sessname == 'RR20240320_G-2024_05_07':
        templis = ['CH1-470', 'CH1-410', 'Thal-470', 'Thal-410', 'DCN-470', 'DCN-410','SNr-470', 'SNr-410']
        pc_override = {key: i for i,key in enumerate(templis)}
        print(pc_override)
    for parameter_dic in pdic_list:
        name = parameter_dic['name']
        cube_dic_o, trials_used_o, outliers, cube_dic_keys, rand_cube_dic = photom_cube_generate(session['photom_df'], session['day_dic'], oreg_list, parameter_dic=parameter_dic, phot_coldic_override=pc_override) 
        sess_cage.sessions[sessname]['cube_dic_lowp_'+name] = cube_dic_o
        sess_cage.sessions[sessname]['rand_cube_dic_lowp_'+name] = rand_cube_dic
        sess_cage.sessions[sessname]['outlier_trials'] = outliers
        sess_cage.sessions[sessname]['photom_trials_used'] = trials_used_o
        sess_cage.sessions[sessname]['cube_dic_keys'] = cube_dic_keys

### Big Run Part 2: for GEN CAGE

In [None]:
for mouse_name in cage.name_2_mouse.keys():
    for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
        print('SESSION: ', mouse_name, date)
        session = cage.name_2_mouse[mouse_name].day_2_session[date]
        oreg_list = [] # NOTE - I have NOT done outlier detection for the wheel photom dataset
        for parameter_dic in pdic_list:
            name = parameter_dic['name']
            #return cube_dic, trials_used, outlier_trials, output_dic_keys, rand_cube_dic
            cube_dic_o, trials_used_o, outliers, cube_dic_keys, rand_cube_dic = photom_cube_generate(session['photom_df'], session['day_dic'], oreg_list, parameter_dic=parameter_dic) 
            cage.name_2_mouse[mouse_name].day_2_session[date]['cube_dic_lowp_'+name] = cube_dic_o
            cage.name_2_mouse[mouse_name].day_2_session[date]['rand_cube_dic_lowp_'+name] = rand_cube_dic
            cage.name_2_mouse[mouse_name].day_2_session[date]['outlier_trials'] = outliers
            cage.name_2_mouse[mouse_name].day_2_session[date]['photom_trials_used'] = trials_used_o
            cage.name_2_mouse[mouse_name].day_2_session[date]['cube_dic_keys'] = cube_dic_keys


## Part 3 - Misc Cube Adjustments

In [None]:
"""_summary_
Just makes some adjustments listed below
- og_waves, og_wcube_all, og_endpoints -> waves, wcube_all, endpoints (these are selected from the ogs to NOT include outliers)
    - this just allows waves, wcube_all (the behavior cube) to match the trial-dimension of the photom cube
- note diff from manip part 3 - we'll make wave_dic later cuz no trial typing yet
"""
for mouse_name in cage.name_2_mouse.keys():
    for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
        session = cage.name_2_mouse[mouse_name].day_2_session[date]
        og_waves = session['day_dic']['og_waves']
        og_trial_inds = np.arange(len(og_waves))
        og_wcube_all = session['day_dic']['og_wcube_all']
        photom_trials_used = session['photom_trials_used']
        #code to adjust og_waves and og_wcube_all
        waves = [wav for i, wav in enumerate(og_waves) if i in photom_trials_used]
        wcube_all = og_wcube_all[:,:,photom_trials_used]
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['waves'] = waves
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wcube_all'] = wcube_all

## Part 3.5 - Trial Definitions (using stride and stance)

In [None]:
# NOTE - I just copied code from my photometry_wheel_analysis notebook into a py file and wrote a wrapper pipeline function that 
# mods sess_cage in place
stsh.stride_stance_pipeline(sess_cage, ordered_sessions)
stsh.trial_type_pipeline(sess_cage, ordered_sessions)

## Part 4 - pickle sessions from Big Run 

In [None]:
pkl_sav_folder = '/Pickles'

def serialize_sessions_from_cage(folder, cage):
    """_summary_

    Pickle a gen cage 
    """
    for mouse_name in cage.name_2_mouse.keys():
        for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
            print('SESSION: ', mouse_name, date)
            session = cage.name_2_mouse[mouse_name].day_2_session[date]
            fname = mouse_name+'-'+date+'.pkl'
            print(fname)
            # break
            with open(folder+fname, 'wb') as f:  # open a text file
                pickle.dump(session, f) # serialize the list
            f.close()
            
def serialize_sess_cage(folder, cage):
    """_summary_
    (Re)pickle sess_cage
    """
    for sessname in cage.sessions.keys():
        session = cage.sessions[sessname]
        fname = '/'+sessname+'.pkl'
        print(fname)
        # break
        with open(folder+fname, 'wb') as f:  # open a text file
            pickle.dump(session, f) # serialize the list
        f.close()

### only run one of the below saving cells

In [None]:
# Saving a sessions cage object (use if loaded in sess_cage)
pickle_folder = datapath+pkl_sav_folder + '/Wheel_BigRun_Pickle'
serialize_sess_cage(pickle_folder, sess_cage)

In [None]:
pickle_folder = datapath+pkl_sav_folder+'/Wheel_BigRun_Pickle'
serialize_sessions_from_cage(pickle_folder, cage)

## Part 5: Compress and save pickle

In [None]:
def gen_compressed_cage(sess_cage, session_list, mode='wheel'):
    comp_sess_cage = sessions_cage()
    if mode == 'manip':
        cube_dic_name = 'cube_dic_lowp_2_minus1'
        rew_cube_dic_name = 'rew_cube_dic_lowp_2_minus1'
        rand_cube_dic_name = 'rand_cube_dic_lowp_2_minus1'
        keys_keep = ['photom_df', cube_dic_name, rew_cube_dic_name, rand_cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'manip_file']
        day_dic_keys_keep = ['metadata', 'col_dic', 'waves', 'wcube_all', 'wave_dic', 'manip_dist', 'endpoints', 'new_endpoints']
    elif mode == 'wheel':
        cube_dic_name = 'cube_dic_lowp_2_minus1_alt'
        rand_cube_dic_name = 'rand_cube_dic_lowp_2_minus1_alt'
        keys_keep = ['photom_df', cube_dic_name, rand_cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'dlc_file', 'rad_file']
        day_dic_keys_keep = ['waves','wcube_all','wave_dic','trial_defs','wheel_trans','stride_stance_dic','hand_peaks_troughs','foot_peaks_troughs']
    for sessname in session_list:
        session = sess_cage.sessions[sessname]
        day_dic = session['day_dic']
        day_dic_keep = dict((k, day_dic[k]) for k in day_dic_keys_keep)
        session_keep = dict((k, session[k]) for k in keys_keep)
        session_keep.update({'day_dic':day_dic_keep})
        comp_sess_cage.add_sess(sessname, session_keep)
    return comp_sess_cage

In [None]:
# the function gen_compressed_cage is located in the "HELPER - Compress Cage" section under Important Classes and Helpers
# refer there to see which attributes of sessions are kept and not kept

compressed_wheel_cage = gen_compressed_cage(sess_cage, ordered_sessions, mode='wheel')
compressed_wheel_pkl_folder = datapath + pkl_sav_folder + '/Compressed_Wheel'
serialize_sess_cage(compressed_wheel_pkl_folder, compressed_wheel_cage)

# PASS 2: Big Run Visualizations

## (optional) adding front half back half trial types

In [None]:
for mouse_name in cage.name_2_mouse.keys():
    for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
        session = cage.name_2_mouse[mouse_name].day_2_session[date]    
        session['day_dic']['wave_dic'] = {}
        waves = session['day_dic']['waves']
        
        wlen = len(waves)
        wave_inds = np.arange(wlen)
        front_half = wave_inds[:int(wlen/2)]
        back_half = wave_inds[int(wlen/2):]
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['front_half'] = front_half
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['back_half'] = back_half    
        
        rand_wave_inds = np.random.permutation(wave_inds)
        fh = rand_wave_inds[:int(wlen/2)]
        bh = rand_wave_inds[int(wlen/2):]
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['rand_front_half'] = fh
        cage.name_2_mouse[mouse_name].day_2_session[date]['day_dic']['wave_dic']['rand_back_half'] = bh  
        

## Vis Part 1: Plotting EACH Session

### Helpers

In [None]:
def plot_cubedics(parameter_dic, cube_dic_type, 
                  ylim_dic = {'raw_sig': [-1.5,1.5], 'lowpass_photom': [-1.5,1.5], 'deltaf_im_np': [-1,1], 'CH470_movcor_np': [-0.005,0.005], 'CH470_410_ratio_np': [0.998,1.002],
                              'zscores': [-4,4], 'zscores_ratio': [-4,4], 'f_f0': [0.998,1.002],
                             'CH470_410_uratio_np' : [0.998,1.002], 'zscores_uratio': [-4,4], 'f_f0_u': [0.998,1.002]
                             },
                  care_about = [True,True,True], heatmap=True,
                  save_subfolder='', save_genfolder = output_path + '/Preprocessing', save_label='', no_ylim=False) :
    """
    Summary
    -----
    Plots and (optionally) saves plots of ch1, dcn, thal, snr - averaged across TRIALS
    NOTE - if save_subfolder left as '', no saving occurs
    NOTE - currently has two breaks in the for loop (put there during debugging), remove these if want to use this and plot ALL sessions
    
    Parameters
    -----
    parameter_dic: dictionary of preprocessing params. should use a parameter_dic created in Big Run Part 2
    
    cube_dic_type: string selected from ['raw_sig', 'lowpass_photom', 'deltaf_im_np', 'deltaf_np', 'zscores']

    # select_from: 'raw_sig', 'lowpass_photom' 'deltaf_im_np' 'CH470_movcor_np' 'CH470_410_ratio_np' 'zscores' 'zscores_ratio' 'f_f0'

    """
    lowp_thresh_used = parameter_dic['lowpass_threshold']
    addon = '_lowp_'+parameter_dic['name']
    norm_window = parameter_dic['norm_window'] 
    if no_ylim:
        ylim=None
    else:
        ylim = ylim_dic[cube_dic_type]
    for mouse_name in cage.name_2_mouse.keys():
        for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
            print(mouse_name, date)
            session = cage.name_2_mouse[mouse_name].day_2_session[date]    
            wave_dic = session['day_dic']['wave_dic']
            cube_dic = session['cube_dic'+addon]
            cube_dic_keys = session['cube_dic_keys'][cube_dic_type]
            # print("Cube dic keys: ")
            # print(cube_dic_keys)
            col_dic = {key: i for i,key in enumerate(cube_dic_keys)}
            
            if len(list(cube_dic.keys())) == 0:
                print("NO WAVES")
                continue
            cube = cube_dic[cube_dic_type]     

            print(np.max(cube), np.min(cube))
            #all waves
            all_title = mouse_name + '_' + date + ' ' + str(cube.shape[2]) + ' trials'
            save_title = mouse_name + '-' + date + '-' + cube_dic_type + '-' + save_label + '.jpg'
            save_path = save_genfolder + save_subfolder
            save_flag = save_subfolder != ''
            if care_about[0]:
                ph.visualize_cube(cube, col_dic, time_offset = BACK_WINDOW/100, title=all_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=save_title, heatmap=heatmap)
            if care_about[1]:
                front_half_trials = wave_dic['front_half']
                front_save =  mouse_name + '-' + date + '-' + cube_dic_type + '-' + save_label + '-front' + '.jpg'
                if len(front_half_trials) == 0:
                    continue
                front_cube = cube[:,:,front_half_trials]
                front_title = 'FRONT trials: ' + mouse_name + '_' + date + ' ' + str(front_cube.shape[2]) + ' trials'
                ph.visualize_cube(front_cube, col_dic, time_offset = BACK_WINDOW/100, title=front_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=front_save, heatmap=heatmap)
            if care_about[2]:
                back_half_trials = wave_dic['back_half']
                back_save =  mouse_name + '-' + date + '-' + cube_dic_type + '-' + save_label + '-back' + '.jpg'
                all_title = mouse_name + '_' + date + ' ' + str(cube.shape[2]) + ' trials'
                back_cube = cube[:,:,back_half_trials]
                back_title = 'FRONT trials: ' + mouse_name + '_' + date + ' ' + str(back_cube.shape[2]) + ' trials'
                ph.visualize_cube(back_cube, col_dic, time_offset = BACK_WINDOW/100, title=back_title, ylim=ylim, norm_window=norm_window, save_flag=save_flag, save_path=save_path, save_title=back_save, heatmap=heatmap)
            
            break
        break

## VP1 pipeline

In [None]:
# select_from:raw_sig', 'lowpass_photom' 'deltaf_im_np' 'CH470_movcor_np'  'zscores' 
cube_dic_type = 'zscores'
parameter_dic = None #TEMP - Please fill this in!!! - ie: pdic_list[0]
plot_cubedics(parameter_dic, cube_dic_type, save_subfolder='', save_label='', care_about=[True,True,True], no_ylim=False)

## VP 2: parsing and helpers

In [None]:
"""
Parsing ordered_sessions (from sess_cage) to get groups_dic
groups_dic: a dictionary mapping time_zone ('early','mid','late') to a list of corresponding session names (to use to access sess_cage)
"""

abcdf_ez_erly = ['2024_02_05', '2024_02_06'] #f exception
abcdf_ez_late = ['2024_02_08', '2024_02_09']
abcdf_hd_erly = ['2024_02_12', '2024_02_13']
abcdf_hd_late = ['2024_02_15', '2024_02_16']
abcdf_time_list = [abcdf_ez_erly,abcdf_ez_late,abcdf_hd_erly,abcdf_hd_late]

ghijk_ez_erly = ['2024_04_29', '2024_04_30']
ghijk_ez_late = ['2024_05_02', '2024_05_03']
ghijk_hd_erly = ['2024_05_06', '2024_05_07']
ghijk_hd_late = ['2024_05_09', '2024_05_10']
ghijk_time_list = [ghijk_ez_erly,ghijk_ez_late,ghijk_hd_erly,ghijk_hd_late]

exceptions_dic = {'F_ez_erly': ['2024_02_05', '2024_02_07'], 'K_ez_erly': ['2024_04_30', '2024_05_01'], 'K_hd_late': ['2024_05_09', '2024_05_11']}

# ordered_sessions = []
# for mouse_name in cage.name_2_mouse.keys():
#     for date in cage.name_2_mouse[mouse_name].day_2_session.keys():
#         ordered_sessions.append(mouse_name[1:] + '-' + date)

groups_dic = {'ez_erly': [], 'ez_late': [], 'hd_erly':[], 'hd_late':[]}
for i, timezone in enumerate(list(groups_dic.keys())):
    for mouse_ID in ['A','B','C','D','F','G','H','I','J','K']:
        mouse_name = full_mouse_name(mouse_ID)
        if mouse_ID + '_' + timezone in exceptions_dic.keys():
            days = exceptions_dic[mouse_ID + '_' + timezone]
            sessnames = [mouse_name + '-' + date for date in days]
            groups_dic[timezone] += sessnames
        else:
            if mouse_ID in ['A','B','C','D','F']:
                days = abcdf_time_list[i]
            else:
                days = ghijk_time_list[i]
            sessnames = [mouse_name + '-' + date for date in days]
            groups_dic[timezone] += sessnames
                

In [None]:
# Helper Plotters
def gen_behav_cube_lis(session_list, parameter_dic, sess_cage, cube_dic_type = 'zscores', trial_type=None, derivative=True):
    """
    Returns
    -----
    list of behavior cubes
    """
    cube_list = []
    ses_skip = []
    cube_dic_keys = None
    for sessname in session_list:
        # parts = sessname.split('-')
        # session = cage2.name_2_mouse['/' + parts[0]].day_2_session[parts[1]] #used before loading in the cube
        session = sess_cage.sessions[sessname] #mod: May 30 2024
        cube = session['day_dic']['wcube_all']
        if derivative:
            cube = np.diff(cube, axis=0) #velocity = derivative of manip_dist over time
        
            
        # manip_vel_mean = np.mean(manip_vel, axis=2) #average over trials
        if trial_type != None:
            wave_inds = session['day_dic']['wave_dic'][trial_type]
            if len(wave_inds) == 0:
                ses_skip.append(sessname)
                print(sessname)
            cube = cube[:,:,wave_inds]
        cube_list.append(cube)
    return cube_list, cube_dic_keys, ses_skip

def gen_cube_list(session_list, parameter_dic, sess_cage, cube_dic_type = 'zscores', trial_type=None):
    """
    Parameters
    -----
    cube_dic_type: string selected from ['raw_sig', 'lowpass_photom', 'deltaf_im_np', 'deltaf_np', 'zscores']
    """
    cube_list = []
    addon = '_lowp_'+parameter_dic['name']
    front_addon = ''
    # if random:
    #     front_addon = 'rand_'
    ses_skip = []
    for sessname in session_list:
        parts = sessname.split('-')
        # session = cage.name_2_mouse['/' + parts[0]].day_2_session[parts[1]]
        session = sess_cage.sessions[sessname]
        cube_dic = session[front_addon + 'cube_dic'+addon]
        cube = cube_dic[cube_dic_type]
        
        if trial_type != None:
            # print(session['day_dic']['wave_dic'])
            wave_inds = session['day_dic']['wave_dic'][trial_type]
            
            if len(wave_inds) == 0:
                ses_skip.append(sessname)
                print(sessname)
            cube = cube[:,:,wave_inds]
        cube_dic_keys = session['cube_dic_keys'][cube_dic_type]
        cube_list.append(cube)
    return cube_list, cube_dic_keys, ses_skip

def multi_cube_plot(session_list, cube_list, cube_params, cube_dic_type, parameter_dic, title_addon='',
                    save_subfolder='', save_genfolder = output_path + '/Preprocessing', save_label='',
                   ylim_dic = {'raw_sig': [-1.5,1.5], 'lowpass_photom': [-1.5,1.5], 'deltaf_im_np': [-1,1], 'deltaf_np': [-0.005,0.005], 'zscores': [-4,4]},
                   behavior_flag = False):
    # master_cubelist_params = ['DCN average zscore','Thal average zscore','SNr average zscore']
    col_dic = {elem: i for i, elem in enumerate(cube_params)}    
    master_cubelist = []
    daycube_list = [np.nanmean(daycube, axis=2) for daycube in cube_list] #list of cubes averaged across trials
    mouse_ids, grouped_daycubes = sh.only_group_by_mice(session_list, daycube_list) #makes a list of cubes per mouse (sublists in the list grouped_daycubes)
    sesscube_list = [np.dstack(daycube_sublist) for daycube_sublist in grouped_daycubes]
    mastercube_list = [np.nanmean(sesscube, axis=2) for sesscube in sesscube_list]
    mastercube = np.dstack(mastercube_list)
    title=cube_dic_type + '-' + str(parameter_dic) + '-' + title_addon

    save_title = '/' + save_label + '-' + cube_dic_type + '.jpg'
    save_path = save_genfolder + save_subfolder
    save_flag = save_subfolder != ''
    norm_window=parameter_dic['norm_window']
    if not behavior_flag:
        ph.visualize_cube(mastercube, col_dic, BACK_WINDOW/100, title=title, norm_window=norm_window,
                      save_flag=save_flag, save_path=save_path, save_title=save_title,
                      plot_3D=False, xlabel='Time (s)', ylabel='Z-Score', ylim=ylim_dic[cube_dic_type])
    else:
        new_norm_win = [(val/30)*100 for val in parameter_dic['norm_window']]
        ph.visualize_master_behavcube(mastercube, new_norm_win, 100, 5)
        # def visualize_master_behavcube(cube, norm_window, frame_rate, time_offset, title= '', save_flag=False, save_path = '', save_title = '',heatmap=True):
    return mastercube


In [None]:
# example run of gen_cube_list
# cube_dic_type = 'zscores'
# param_dic_wheel = {'lowpass_threshold': 2, 'lowpass_threshold_2': None, 'norm_window': [105, 135],'name': '2_minus1_alt'}
# cube_list, cube_dic_keys, ses_skip = gen_cube_list(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type='good')

## VP2 - SAVING

### Saving - NO TRIAL TYPE

In [None]:
cube_dic_type = 'zscores'
param_dic_wheel = {'lowpass_threshold': 2, 'lowpass_threshold_2': None, 'norm_window': [114, 135],'name': '2_minus1_alt'}

cube_list, cube_dic_keys, ses_skip = gen_cube_list(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type)
master_cube_all = multi_cube_plot(ordered_sessions,cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                                  save_genfolder = output_path + '/Wheel_Photom_figures', save_subfolder='/allses', save_label='alltrials')

b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type)
b_master_cube_all = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                                    save_genfolder = output_path + '/Wheel_Behav_figures', save_subfolder='/allses', save_label='alltrials_deriv', 
                                    behavior_flag=True)

b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, derivative=False)
b_master_cube_all_disp = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                                    save_genfolder = output_path + '/Wheel_Behav_figures', save_subfolder='/allses', save_label='alltrials_disp',
                                    behavior_flag=True)

subfolder = '/Wheel' + '/allses'
ph.save_cube(master_cube_all,  subfolder, 'alltrials', cube_type = cube_dic_type)
ph.save_cube(b_master_cube_all,  subfolder, 'alltrials_behav', cube_type = cube_dic_type, behav=True)
ph.save_cube(b_master_cube_all_disp,  subfolder, 'alltrials_behav_disps', cube_type = cube_dic_type, behav=True)

time_z_cubes = []
for time_z in groups_dic.keys():
    cube_list, cube_dic_keys, cube_mean_vals = gen_cube_list(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type)
    cube = multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                           save_genfolder = output_path + '/Wheel_Photom_figures', save_subfolder='/'+time_z, save_label='alltrials')
    
    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type)
    b_master_cube_all = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel,
                            save_genfolder = output_path + '/Wheel_Behav_figures', save_subfolder='/'+time_z, save_label='alltrials_deriv',
                            behavior_flag=True)
    
    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, derivative=False)
    b_master_cube_all_disp = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                            save_genfolder = output_path + '/Wheel_Behav_figures', save_subfolder='/'+time_z, save_label='alltrials_disp',
                            behavior_flag=True)
    
    subfolder = '/Wheel' + '/' + time_z
    ph.save_cube(cube,  subfolder, 'alltrials', cube_type = cube_dic_type)
    ph.save_cube(b_master_cube_all,  subfolder, 'alltrials_behav', cube_type = cube_dic_type, behav=True)
    ph.save_cube(b_master_cube_all_disp,  subfolder, 'alltrials_behav_disps', cube_type = cube_dic_type, behav=True)

### SAVING - YES TRIAL TYPE

In [None]:
"""_summary_
you could commment out the behavior stuff (lines with variables named 'b_...') if you just want to vis/save photom results
"""
cube_dic_type = 'zscores'
param_dic_wheel = {'lowpass_threshold': 2, 'lowpass_threshold_2': None, 'norm_window': [114, 135],'name': '2_minus1_alt'}

for trial_type in ['good','bad']:
    # photom cubes
    cube_list, cube_dic_keys, ses_skip = gen_cube_list(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type=trial_type)
    master_cube_all = multi_cube_plot(ordered_sessions,cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, 
                                      save_genfolder = output_path + '/Wheel_Photom_figures', save_subfolder='/allses', save_label=trial_type)

    # behavior cubes - derivative of rads
    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type=trial_type)
    b_master_cube_all = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, save_subfolder='',
                                        behavior_flag=True)

    # behavior cubes - not derivative of rads (keep in mind, there's no normalization rn - these just be raw values)
    b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(ordered_sessions, param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type=trial_type, derivative=False)
    b_master_cube_all_disp = multi_cube_plot(ordered_sessions,b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, save_subfolder='',
                                        behavior_flag=True)

    subfolder = '/Wheel' + '/allses'
    ph.save_cube(master_cube_all,  subfolder, trial_type, cube_type = cube_dic_type)
    ph.save_cube(b_master_cube_all,  subfolder, trial_type + '_behav', cube_type = cube_dic_type, behav=True)
    ph.save_cube(b_master_cube_all_disp,  subfolder, trial_type + '_behav_disps', cube_type = cube_dic_type, behav=True)

    time_z_cubes = []
    # traversing through time zones
    for time_z in groups_dic.keys():
        cube_list, cube_dic_keys, cube_mean_vals = gen_cube_list(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type=trial_type)
        cube = multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel,
                               save_genfolder = output_path + '/Wheel_Photom_figures', save_subfolder='/'+time_z, save_label=trial_type)
        
        b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, trial_type=trial_type)
        b_master_cube_all = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel,save_subfolder='',
                                            behavior_flag=True)
        
        b_cube_list, b_cube_dic_keys, b_ses_skip = gen_behav_cube_lis(groups_dic[time_z], param_dic_wheel, sess_cage, cube_dic_type=cube_dic_type, derivative=False, trial_type=trial_type)
        b_master_cube_all_disp = multi_cube_plot(groups_dic[time_z],b_cube_list,cube_dic_keys, cube_dic_type, param_dic_wheel, save_subfolder='',
                                        behavior_flag=True)
        
        subfolder = '/Wheel' + '/' + time_z
        ph.save_cube(cube,  subfolder, trial_type, cube_type = cube_dic_type)
        ph.save_cube(b_master_cube_all,  subfolder, trial_type+'_behav', cube_type = cube_dic_type, behav=True)
        ph.save_cube(b_master_cube_all_disp,  subfolder, trial_type+'_behav_disps', cube_type = cube_dic_type, behav=True)

In [None]:
# # EXAMPLE RUN OF CONTROL PLOT - FRONT HALF VS BACK HALF
# parameter_dic = pdic_list[0] 

# cube_dic_type='zscores'
# subfold_name = ''

# # select_from:raw_sig', 'lowpass_photom' 'deltaf_im_np' 'CH470_movcor_np' 'CH470_410_ratio_np' 'zscores' 'zscores_ratio' 'f_f0', CH470_410_uratio_np, zscores_uratio', 'f_f0_u'
# #groups_dic keys: 'ez_erly', 'ez_late', 'hd_erly', 'hd_late'

# for time_z in groups_dic.keys():
#     print(time_z)
#     cube_list,cube_dic_keys,ses_skip = gen_cube_list(groups_dic[time_z], parameter_dic, cube_dic_type=cube_dic_type, random=False)
#     multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, parameter_dic, title_addon=time_z.upper(), save_subfolder=subfold_name, save_label=time_z + '-front_half-')
#     cube_list,cube_dic_keys,ses_skip = gen_cube_list(groups_dic[time_z], parameter_dic, cube_dic_type=cube_dic_type, random=True)
#     multi_cube_plot(groups_dic[time_z],cube_list,cube_dic_keys, cube_dic_type, parameter_dic, title_addon=time_z.upper() + '  -RAND-', save_subfolder=subfold_name, save_label=time_z + '-front_half-')
    

# (ARCHAIC) Generate radians pipeline

## Wrappers

In [None]:
def generate_radians(dlc_df,title=''):
    # STEP 1
    from circle_fit import taubinSVD
    # point_coordinates = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]])

    def determine_circle_params(dlc_df):
        marker_dat  =[]
        for marker in ['singlet', 'doublet', 'triplet', 'quadruplet']:
            pc = dlc_df[[marker+'_x', marker+'_y']].to_numpy()
            pc_nonan = pc[~np.isnan(pc).any(axis=1)]
            marker_dat.append(pc_nonan)
            
        marker_dat = np.vstack(marker_dat)

        xc, yc, r, sigma = taubinSVD(marker_dat)
        plt.figure()
        plt.plot(marker_dat[:,0], marker_dat[:,1], 'o')
        plt.plot(xc, yc, 'o', c='r')
        ax = plt.gca()
        ax.set_aspect('equal', adjustable='box')
        plt.title('all markers and estimated circle')
        print("coords: ", xc, yc)
        print("radius: ", r)
        return xc, yc, r

    xc, yc, r = determine_circle_params(dlc_df)

    #STEP 2
    markers_full = ['singlet_x', 'singlet_y', 'doublet_x', 'doublet_y', 'triplet_x', 'triplet_y', 'quadruplet_x', 'quadruplet_y']
    marker_dic = {key: ind for ind, key in enumerate(markers_full)}
    mdlc = dlc_df[['singlet_x', 'singlet_y', 'doublet_x', 'doublet_y', 'triplet_x', 'triplet_y', 'quadruplet_x', 'quadruplet_y']].to_numpy()
    markers = ['singlet', 'doublet', 'triplet', 'quadruplet']

    def marker_helper(dlc_df, row, prev_marker, markers_full = ['singlet_x', 'singlet_y', 'doublet_x', 'doublet_y', 'triplet_x', 'triplet_y', 'quadruplet_x', 'quadruplet_y']):
        
        #gets you the row's (no repeats) list of markers present
        nonan_markers = remove_dup([marker.split('_')[0] for marker in markers_full if not np.isnan(dlc_df.iloc[row][marker])])
        # print(nonan_markers)
        tempdic = {mark: ind for ind, mark in enumerate(markers)}
        # print("prev_marker " , prev_marker)
        # print(nonan_markers)
        if len(nonan_markers) == 0: #case 0: no markers -> return np.nan
            return np.nan
        elif len(nonan_markers) == 1: #case 1: only one marker -> return that marker
            return tempdic[nonan_markers[0]]
        elif not np.isnan(prev_marker): #edge case for the first one    
            for i in range(len(nonan_markers)-1,-1,-1):
                if markers.index(nonan_markers[i]) == prev_marker:
                    return tempdic[nonan_markers[i]]
                elif abs(prev_marker - markers.index(nonan_markers[i])) == 1:
                    return tempdic[nonan_markers[i]]
                elif prev_marker == 0 and nonan_markers[i] == 'quadruplet':
                    return tempdic[nonan_markers[i]]
                elif prev_marker == 3 and nonan_markers[i] == 'singlet':
                    return tempdic[nonan_markers[i]]
        return np.nan

    #just removes duplicates, and maintains the same order- that's all this is lmao
    def remove_dup(seq):
        seen = set()
        seen_add = seen.add
        return [x for x in seq if not (x in seen or seen_add(x))]
                    
                    
    prev_marker = np.nan
    markers_func = []
    for i in range(dlc_df.shape[0]):
        row_oi = dlc_df.iloc[i]
        marker_touse = marker_helper(dlc_df, i, prev_marker)
        prev_marker = marker_touse
        markers_func.append(marker_touse)
        
    # STEP 3
    centers_sub  = (mdlc - np.array([[xc, yc]] * 4).flatten())
    ax = plt.gca()
    ax.set_aspect('equal', adjustable='box')
    plt.plot(centers_sub[:,6], centers_sub[:,7], 'o')
    ref_vec = np.array([-1,0])
    plt.plot(ref_vec, 'o', c='r')
    plt.plot(centers_sub[4000,6], centers_sub[4000,7], 'o', c='g')
    plt.plot(centers_sub[6000,6], centers_sub[6000,7], 'o', c='r')

    plt.figure(figsize=(15,8))

    rads_dic = {}

    for i, mark in enumerate(markers):
        print(i*2,(i+1)*2)
        singlets = np.dot(centers_sub[:,i*2:(i+1)*2], ref_vec)
        ref_vec_magnitude = np.linalg.norm(ref_vec)
        vectors_magnitude = np.linalg.norm(centers_sub[:,i*2:(i+1)*2], axis=1)
        if i == 0:
            print('abe: ', singlets)
        singlets/=(ref_vec_magnitude * vectors_magnitude)
        if i == 0:
            print('abe2: ', singlets)
        
        
        singlets = np.arccos(singlets)
        
        if i == 0:
            print('abe3: ', singlets)
        
        rads_dic[mark] = singlets
        if mark == 'quadruplet':
            print('pinpoint ' , singlets[6000])
            
    for m in markers:
        plt.plot(rads_dic[m], label=m)
    plt.legend()
    plt.title(title)

    rads = []
    for i, m in enumerate(markers_func):
        if np.isnan(m):
            rads.append(m)
        else:
            rads.append(rads_dic[markers[m]][i])    
    plt.figure(figsize=(15,8))
    plt.plot(rads)
    plt.title(title)

    m_prev = np.nan
    prev_val = 0
    trads = 0 #this is the accumulator
    tracker = []
    likelihoods = []
    cases = []

    for i, m_ind in enumerate(markers_func):
        if np.isnan(m_ind):
            cases.append(0)
            # print('nan case: ', i)
            tracker.append(np.nan)  
            likelihoods.append(0)
        else:
            # print('smt case: ', i)
            likelihoods.append(dlc_df.iloc[int(i)][markers[m_ind]+'_likelihood'])
            if m_ind == m_prev or m_prev == np.nan: #if continuity
                cases.append(1)
                trads += rads[i]-prev_val
                tracker.append(trads)
                prev_val = rads[i]
                m_prev = m_ind
                
            elif m_ind != m_prev: #if transition point, add 0 to the accumulator (heuristic)
                cases.append(2)
                trads += 0
                prev_val = rads[i]
                m_prev = m_ind
                tracker.append(trads)

    plt.plot(tracker)
    plt.title(title)

    tracker = np.array(tracker)
    non_nan_index = np.where(~np.isnan(tracker))[0][0]
    non_nan_value = tracker[non_nan_index]
    tracker = tracker - non_nan_value
    plt.figure(figsize=(15,8))
    plt.plot(tracker)
    plt.title(title)

    # OUTPUT
    dlc_df['radians'] = tracker
    dlc_df['radians_likelihood'] = likelihoods

    dh.interpolate_radians(dlc_df, 0.5, non_nan_index, len_thresh=1e5)
    plt.figure(figsize=(15,8))
    plt.plot(dlc_df['radians_interp'])
    plt.title(title)
    
# Example use case: 
# dlc_df, bodyparts = dh.gen_dlc_df(datapath+wheel_base+dlc_base_new+'/'+file)
# generate_radians(dlc_df, title=mouse_name+ ' ' + date)