In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import math
import plotly.express as px
import plotly.graph_objects as go
from itertools import combinations
import copy
import pickle
import gc

from scipy.ndimage import gaussian_filter1d
from scipy.signal import butter, filtfilt
from scipy.stats import linregress

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

#These lines allow us to import functions from my python func with helper functions
import sys

sys.path.insert(0, '/Users/charliehuang/Documents/Photometry_pipeline/data_analysis_helperfuncs')
import behav_datanal as bd
import statistics_helper as sh
import metric_helper as mh
import metrics_util as mu
import preprocess_helper as pph

%load_ext autoreload

%autoreload 2
import importlib
importlib.reload(bd)
%config IPCompleter.greedy=True

In [None]:
# Wheel parameters
w_BACK_WINDOW = 500
w_PRE_MOVE_WINDOW = 70
w_FORWARD_WINDOW = 500

# Manip parameters
m_BACK_WINDOW = 1000
m_PRE_MOVE_WINDOW = 140
m_FORWARD_WINDOW = 1000

p_BACK_WINDOW = 150
p_FORWARD_WINDOW = 150
p_PRE_MOVE_WINDOW = 21

# change this to your path
datapath = '/Users/charliehuang/Documents/python_work/data/Photometry'
manip_folder = '/Photometry_Manipulandum'

photom_addon = '_2C3T4B'
fluor_folder = '/Photometry_Fluorescence'

arduino_folder = '/Photometry_Wheel'
radians_folder = '/radians'
dlc_folder = '/DLC'
rename_dic = {'A':'G', 'B':'H', 'C':'I', 'D':'J', 'E':'K'}
output_path = datapath + '/Outputs'
pkl_folder = '/Pickles'
wheel_pkl = pkl_folder + '/Wheel_BigRun_Pickle'
manip_pkl = pkl_folder + '/Manip_BigRun_Pickle'

blacklist = ['RR20240320_J-2024_04_26']

manip_param_dic = {'lowpass_threshold': 2, 'lowpass_threshold_2': None, 'norm_window': [90,120], 'name':  '2_minus1'}
wheel_param_dic = {'lowpass_threshold': 2, 'lowpass_threshold_2': None, 'norm_window': [90,120], 'name':  '2_minus1_alt'}

def full_mouse_name(mouse_ID):
    if mouse_ID in ['G','H','I','J','K']:
        return 'RR20240320_' + mouse_ID
    elif mouse_ID == 'F':
        return 'RR20231109_'+mouse_ID
    else:
        return 'RR20231108_'+mouse_ID

# Sess cage class

In [None]:
class sessions_cage:
    def __init__(self):
        self.sessions = {}
    def add_sess(self, key, session):
        self.sessions[key] = session
    def show_sessions(self):
        print(self.sessions.keys())
    
def load_pickle_file(pkl_file, path):
    print(path+pkl_file)
    with open(path+pkl_file, 'rb') as f:
        loaded_session = pickle.load(f) # deserialize using load()
    f.close()
    return loaded_session
            
def serialize_sess_cage(folder, cage):
    for sessname in cage.sessions.keys():
        session = cage.sessions[sessname]
        fname = '/'+sessname+'.pkl'
        print(fname)
        with open(folder+fname, 'wb') as f:  # open a text file
            pickle.dump(session, f) # serialize the list
        f.close()

# (0.0) Save compressed version wheel/manip sess_cage (Don't run if already have compressed cages)

## Load WHEEL sess cage (don't load both wheel and manip cages at a time)

In [None]:
wheel_sess_cage = sessions_cage()
for file in os.listdir(datapath + wheel_pkl):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    obj = load_pickle_file('/' + file, datapath + wheel_pkl)
    wheel_sess_cage.add_sess(key, obj)
w_ordered_sessions = list(wheel_sess_cage.sessions.keys())
w_ordered_sessions.sort()

## Load MANIP sess cage (don't load both manip and wheel)

In [None]:
manip_sess_cage = sessions_cage()
for file in os.listdir(datapath + manip_pkl):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    obj = load_pickle_file('/' + file, datapath + manip_pkl)
    manip_sess_cage.add_sess(key, obj)
m_ordered_sessions = list(manip_sess_cage.sessions.keys())
m_ordered_sessions.sort()

## COMPRESS CAGE OF CHOICE

In [None]:
class compressed_sessions_cage:
    def __init__(self):
        self.sessions = {}
    def add_sess(self, key, session):
        self.sessions[key] = session
    def show_sessions(self):
        print(self.sessions.keys())
        
# Keeping: photom_df, some cube_dic, outlier_trials, photom_trials_used, cube_dic_keys, dlc_file, rad_file
# from daydic: keeping everything except for radians_df, and combin_df

def gen_compressed_cage(sess_cage, session_list, mode='wheel'):
    comp_sess_cage = compressed_sessions_cage()
    if mode == 'manip':
        cube_dic_name = 'cube_dic_lowp_2_minus1'
        keys_keep = ['photom_df', cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'manip_file']
        day_dic_keys_keep = ['manip_data', 'metadata', 'col_dic', 'og_waves', 'og_summary', 'behav_mat', 'og_wcube_all', 'og_manip_dist', 'og_endpoints', 'waves', 'wcube_all', 'wave_dic', 'manip_dist', 'endpoints']
    elif mode == 'wheel':
        cube_dic_name = 'cube_dic_lowp_2_minus1_alt'
        keys_keep = ['photom_df', cube_dic_name, 'outlier_trials', 'photom_trials_used', 'cube_dic_keys', 'dlc_file', 'rad_file']
        day_dic_keys_keep = ['waves','wcube_all','wave_dic','trial_defs','wheel_trans','stride_stance_dic','hand_peaks_troughs','foot_peaks_troughs','og_waves','og_wcube_all']
    for sessname in session_list:
        if sessname in blacklist:
            continue
        session = sess_cage.sessions[sessname]
        day_dic = session['day_dic']
        day_dic_keep = dict((k, day_dic[k]) for k in day_dic_keys_keep)
        session_keep = dict((k, session[k]) for k in keys_keep)
        session_keep.update({'day_dic':day_dic_keep})
        comp_sess_cage.add_sess(sessname, session_keep)
    return comp_sess_cage

## Compress and serialize wheel cage

In [None]:
compressed_wheel_cage = gen_compressed_cage(wheel_sess_cage, w_ordered_sessions, mode='wheel')
compressed_wheel_pkl_folder = datapath + pkl_folder + '/Compressed_Wheel'
serialize_sess_cage(compressed_wheel_pkl_folder, compressed_wheel_cage)

## Compress and serialize manip cage

In [None]:
compressed_manip_cage = gen_compressed_cage(manip_sess_cage, m_ordered_sessions, mode='manip')
compressed_manip_pkl_folder = datapath + pkl_folder + '/Compressed_Manip'
serialize_sess_cage(compressed_manip_pkl_folder, compressed_manip_cage)

# (1.0) Load in Compressed Sess cages (Start here if already have compressed cages)

In [None]:
print(pkl_folder)
manip_pkl_comp = pkl_folder + '/Compressed_Manip'
wheel_pkl_comp = pkl_folder + '/Compressed_Wheel'

In [None]:
#load in manip compressed cage
compressed_manip_cage = sessions_cage()
for file in os.listdir(datapath + manip_pkl_comp):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    obj = load_pickle_file('/' + file, datapath + manip_pkl_comp)
    compressed_manip_cage.add_sess(key, obj)
m_ordered_sessions = list(compressed_manip_cage.sessions.keys())
m_ordered_sessions.sort()

In [None]:
# Sessions list for trial type (wheel)
# We use these specific list for trial-specific parsing
# ie: doing linear regression only on rewarded trials

m_ordered_sessions_rew = []
m_ordered_sessions_unrew = []
for sessname in m_ordered_sessions:
    day_dic = compressed_manip_cage.sessions[sessname]['day_dic']
    if len(day_dic['wave_dic']['rewarded']) != 0:
        m_ordered_sessions_rew.append(sessname)
    if len(day_dic['wave_dic']['unrewarded']) != 0:
        m_ordered_sessions_unrew.append(sessname)

In [None]:
#load in wheel comp cage
compressed_wheel_cage = sessions_cage()
for file in os.listdir(datapath + wheel_pkl_comp):
    if file.startswith('.'):
        continue
    key = file.split('.')[0]
    obj = load_pickle_file('/' + file, datapath + wheel_pkl_comp)
    compressed_wheel_cage.add_sess(key, obj)
    
# All sessions for wheel

w_ordered_sessions = list(compressed_wheel_cage.sessions.keys())
w_ordered_sessions.sort()

In [None]:
# Sessions list for trial type (wheel) (parsing good vs bad trials)
w_ordered_sessions_good = []
w_ordered_sessions_bad = []
for sessname in w_ordered_sessions:
    day_dic = compressed_wheel_cage.sessions[sessname]['day_dic']
    if len(day_dic['wave_dic']['good']) != 0:
        w_ordered_sessions_good.append(sessname)
    if len(day_dic['wave_dic']['bad']) != 0:
        w_ordered_sessions_bad.append(sessname)

# (2.0) Groups Dics (maps time zone -> list of session names)

In [None]:
abcdf_ez_erly = ['2024_02_05', '2024_02_06'] #f exception
abcdf_ez_late = ['2024_02_08', '2024_02_09']
abcdf_hd_erly = ['2024_02_12', '2024_02_13']
abcdf_hd_late = ['2024_02_15', '2024_02_16']
abcdf_time_list = [abcdf_ez_erly,abcdf_ez_late,abcdf_hd_erly,abcdf_hd_late]

ghijk_ez_erly = ['2024_04_29', '2024_04_30']
ghijk_ez_late = ['2024_05_02', '2024_05_03']
ghijk_hd_erly = ['2024_05_06', '2024_05_07']
ghijk_hd_late = ['2024_05_09', '2024_05_10']
ghijk_time_list = [ghijk_ez_erly,ghijk_ez_late,ghijk_hd_erly,ghijk_hd_late]

exceptions_dic = {'F_ez_erly': ['2024_02_05', '2024_02_07'], 'K_ez_erly': ['2024_04_30', '2024_05_01'], 'K_hd_late': ['2024_05_09', '2024_05_11']}

wheel_groups_dic = {'ez_erly': [], 'ez_late': [], 'hd_erly':[], 'hd_late':[]}
for i, timezone in enumerate(list(wheel_groups_dic.keys())):
    for mouse_ID in ['A','B','C','D','F','G','H','I','J','K']:
        mouse_name = full_mouse_name(mouse_ID)
        if mouse_ID + '_' + timezone in exceptions_dic.keys():
            days = exceptions_dic[mouse_ID + '_' + timezone]
            sessnames = [mouse_name + '-' + date for date in days]
            wheel_groups_dic[timezone] += sessnames
        else:
            if mouse_ID in ['A','B','C','D','F']:
                days = abcdf_time_list[i]
            else:
                days = ghijk_time_list[i]
            sessnames = [mouse_name + '-' + date for date in days]
            wheel_groups_dic[timezone] += sessnames

In [None]:
ghikj_early_days = ['2024_04_15','2024_04_16','2024_04_17']
ghikj_mid_days = ['2024_04_19','2024_04_22']
ghik_late_days = ['2024_04_24','2024_04_25','2024_04_26']
j_late_days = ['2024_04_23','2024_04_24','2024_04_25']

abcd_early_days = ['2023_12_05','2023_12_07','2023_12_08']
f_early_days = ['2024_01_15','2024_01_16','2024_01_17']
abcd_late_days = ['2023_12_13','2023_12_14','2023_12_15']
f_late_days = ['2024_01_24','2024_01_25','2024_01_26']
f_mid_days = ['2024_01_19','2024_01_22']

early_sessions, late_sessions, mid_sessions = [],[],[]
for ses in m_ordered_sessions:
    date = ses.split('-')[1]
    mouse_ID = ses.split('-')[0][-1] 
    
    if date in abcd_early_days or date in f_early_days or date in ghikj_early_days:
        early_sessions.append(ses)
    elif date in abcd_late_days or date in f_late_days:
        late_sessions.append(ses)
    elif mouse_ID == 'J' and date in j_late_days:
        late_sessions.append(ses)
    elif mouse_ID != 'J' and date in ghik_late_days:
        late_sessions.append(ses)
    else:
        if mouse_ID in ['G','H','I','K','J'] and date in ghikj_mid_days:
            mid_sessions.append(ses)
        elif mouse_ID == 'F' and date in f_mid_days:
            mid_sessions.append(ses)
        elif mouse_ID not in ['G','H','I','K','J','F']:
            mid_sessions.append(ses)

manip_groups_dic = {'early':early_sessions, 'mid': mid_sessions, 'late': late_sessions}

# (3.0) Linear Regression Pipeline

In [None]:
startpoint = p_BACK_WINDOW - 30 #1 second back
endpoint = p_BACK_WINDOW + 60 #2 seconds forward

"""
EXPLAINER: So here we are calling a wrapper function that will help generate a dictionary. This dictionary
maps session name to a sub-dictionary containing metric values (ie: linear regression R^2 and single reg regression R^2)

In these dictionary-generating functions (which are all stored in metric_util.py), we supply:
1. the param_dic we used to generate 
the respective manip cage (ie: manip_param_dic for compressed_manip_cage) 
2. the list of sessions we care about (for generating dictionary - just supply the ordered_sessions not a subset cuz you can specify 
a subset for plotting/comparison later)
3. and the cage
4. There are also important parameters like: photom_cube_type ('default', 'rew', 'control'). Note that if the cage you supply in IS already
a cage full of only randomized cubes, just keep photom_cube_type as default. These key words are for accessing different cube_dics that 
are already stored inside of the cage (so they were generated during bigrun)
5. keep save_flag as False. If you do want to turn them on, look through the code and make sure you specify a proper path for saving to occur

"""
reg_dic_manip = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default')
reg_dic_wheel = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default')

In [None]:
# note that I use a diff ordered sessions list here
reg_dic_manip_rew = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions_rew, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='rew')

In [None]:
# tigher bounds around 0

reg_dic_manip_tight = mu.regression_wrapper(p_BACK_WINDOW-10, p_BACK_WINDOW+20, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default')
reg_dic_manip_rew_tight = mu.regression_wrapper(p_BACK_WINDOW-10, p_BACK_WINDOW+20, manip_param_dic, m_ordered_sessions_rew, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='rew')

### Declaring random cubes

In [None]:
#These are not for bootstrapping, but just for barcharts + t tests 

"""
NOTE IMPORTANT REFERENCE ON RANDOM CUES

# RC1: random initiation points
# RC2: trial scrambling
# RC3: trial shifting
# RC4 (1+2): random initiation points + trial scrambling
# RC5 (1+3): random initiation points + trial shifting

"""

# RC1: random initiation points

reg_dic_manip_rc1 = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control')
reg_dic_wheel_rc1 = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control')

# RC2: trial scrambling

reg_dic_manip_rc2 = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', shuffle_trials=True)

reg_dic_wheel_rc2 = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', shuffle_trials=True)

# RC3: trial shifting (random adjusts = True)

reg_dic_manip_rc3 = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', random_adjusts=True)

reg_dic_wheel_rc3 = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', random_adjusts=True)

# RC4 (1+2): random initiation points + trial scrambling

reg_dic_manip_rc4 = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control', shuffle_trials=True)
reg_dic_wheel_rc4 = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control', shuffle_trials=True)

# RC5 (1+3): random initiation points + trial shifting

reg_dic_manip_rc5 = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control', random_adjusts=True)
reg_dic_wheel_rc5 = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='control', random_adjusts=True)

## Grouping and plotting: R^2

### Cross behavior

In [None]:
"""
EXPLAINER: 
Once you have the metric dictionary that is desired (your paramters being trial type, behavior, type of cube (movement, reward, random), other
metric related hyperparameters ie: start and endpoint) now we parse through these dictionaries and group session averages by mice (via 
mh.generate_groups_vector -- the groups correspond to the mice). Then we get mice averages (via mh.means_per_group).

Finally, we use metric_helper's plot_barchart. Note that I have no automation on the title and labels, so that is something you need to 
mindfully change. Also note that you CAN select which mice you compare in plot_barchart (which also computes a two sample paired t test)



"""


mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip, key='lr_r2_scores')
group_manip_lr_r2 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, reg_dic_wheel, key='lr_r2_scores')
group_wheel_lr_r2 = mh.means_per_group(group2)
mh.plot_barchart(group_manip_lr_r2, group_wheel_lr_r2, ['manip','wheel'], title='Manip vs Wheel (-30->+60) LR r^2', 
                 label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

#group_manip_lr_r2, group_wheel_lr_r2

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(manip_groups_dic['early'], reg_dic_manip, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(wheel_groups_dic['ez_erly'], reg_dic_wheel, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['manip','wheel'], title='Manip vs Wheel (-30->+60) LR r^2, Early vs EZ early, ', 
                 label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(manip_groups_dic['late'], reg_dic_manip, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(wheel_groups_dic['ez_late'], reg_dic_wheel, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['manip','wheel'], title='Manip vs Wheel (-30->+60) LR r^2, Late vs EZ Late, '
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

### movement init vs rew

In [None]:

mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions_rew, reg_dic_manip_rew, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['movement init','rew init'], title='Manip Movement vs Rew Init (-30->+60) LR r^2 '
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

In [None]:

mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip_tight, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions_rew, reg_dic_manip_rew_tight, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['movement init','rew init'], title='Manip Movement vs Rew Init (-10->+20) LR r^2 '
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

### Cross time

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(manip_groups_dic['early'], reg_dic_manip, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(manip_groups_dic['late'], reg_dic_manip, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['manip early','manip late'], title='Manip - early vs late (-30->+60) LR r^2, EZ early vs EZ Late, '
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(wheel_groups_dic['ez_erly'], reg_dic_wheel, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(wheel_groups_dic['ez_late'], reg_dic_wheel, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['wheel ez early','wheel ez late'], title='Wheel - ez early vs late (-30->+60) LR r^2, EZ early vs EZ Late, '
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(wheel_groups_dic['hd_erly'], reg_dic_wheel, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(wheel_groups_dic['hd_late'], reg_dic_wheel, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['wheel Hard early','wheel Hard late'], title='Wheel - Hard early vs late (-30->+60) LR r^2'
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

### Cross Controls

In [None]:
control_names = ['rand inits', 'scrambled trials', 'shifted trials','rand inits + scrambled', 'rand inits + shifting']
for i,reg_dic_control in enumerate([reg_dic_manip_rc1,reg_dic_manip_rc2,reg_dic_manip_rc3,reg_dic_manip_rc4,reg_dic_manip_rc5]):
    mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip, key='lr_r2_scores')
    group1_means = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_control, key='lr_r2_scores')
    group2_means = mh.means_per_group(group2)
    mh.plot_barchart(group1_means, group2_means, ['manip','manip rc' + str(i+1)], title='Manip vs Control (-30->+60) LR r^2 -- ' + control_names[i], 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

    #group_manip_lr_r2, group_wheel_lr_r2

In [None]:
control_names = ['rand inits', 'scrambled trials', 'shifted trials','rand inits + scrambled', 'rand inits + shifting']
for i,reg_dic_control in enumerate([reg_dic_wheel_rc1,reg_dic_wheel_rc2,reg_dic_wheel_rc3,reg_dic_wheel_rc4,reg_dic_wheel_rc5]):
    mouse_IDs, group1 = mh.generate_groups_vector(w_ordered_sessions, reg_dic_wheel, key='lr_r2_scores')
    group1_means = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, reg_dic_control, key='lr_r2_scores')
    group2_means = mh.means_per_group(group2)
    mh.plot_barchart(group1_means, group2_means, ['Wheel','Wheel rc' + str(i+1)], title='Wheel vs Control (-30->+60) LR r^2 -- ' + control_names[i], 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

    #group_manip_lr_r2, group_wheel_lr_r2

### Cross Trial Types

In [None]:
#Manip rew vs unrew
startpoint = p_BACK_WINDOW - 30
endpoint = p_BACK_WINDOW + 60
reg_dic_manip_rewtrials = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions_rew, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', trial_type='rewarded')
reg_dic_manip_unrewtrials = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions_unrew, compressed_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', trial_type='unrewarded')

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions_rew, reg_dic_manip_rewtrials, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions_unrew, reg_dic_manip_unrewtrials, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['Manip Rew Trials','Manip Unrew Trials'], title='Manip - Rew vs Unrew (-30->+60) LR r^2'
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

In [None]:
reg_dic_wheel_goodtrials = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions_good, compressed_wheel_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', trial_type='good')
reg_dic_wheel_badtrials = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions_bad, compressed_wheel_cage, 
                                      stpt_label='', plot=False, save_flag=False, 
                                      photom_cube_type='default', trial_type='bad')

In [None]:
#Wheel rew vs unrew
mouse_IDs, group1 = mh.generate_groups_vector(w_ordered_sessions_good, reg_dic_wheel_goodtrials, key='lr_r2_scores')
group1 = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions_bad, reg_dic_wheel_badtrials, key='lr_r2_scores')
group2 = mh.means_per_group(group2)
mh.plot_barchart(group1, group2, ['Wheel Good Trials','Wheel Bad Trials'], title='Wheel - Good vs Bad (-30->+60) LR r^2'
                 , label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim = [0,1])

## grouping and plotting: single reg corr

In [None]:
print(reg_dic_manip['RR20231108_A-2023_12_05']['sr_ref'])
sr_ref = ['DCN-Thal', 'SNr-Thal', 'SNr-DCN']
for i in range(len(sr_ref)):
    reg_label = sr_ref[i]
    mouse_IDs, group1 = mh.generate_groups_matrix(m_ordered_sessions, reg_dic_manip, col_ind=i, key='sr_r2_scores')
    group1 = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_matrix(w_ordered_sessions, reg_dic_wheel, col_ind=i, key='sr_r2_scores')
    group2 = mh.means_per_group(group2)
    mh.plot_barchart(group1, group2, ['manip','wheel'], title='Manip vs Wheel (-10->+30) SR R^2: ' + reg_label + ' pval: ', 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])
    

In [None]:
for mtime_z, wtime_z in zip(['early','late'],['ez_erly','ez_late']):
    sr_ref = ['DCN-Thal', 'SNr-Thal', 'SNr-DCN']
    for i in range(len(sr_ref)):
        reg_label = sr_ref[i]
        mouse_IDs, group1 = mh.generate_groups_matrix(manip_groups_dic[mtime_z], reg_dic_manip, col_ind=i, key='sr_r2_scores')
        group1 = mh.means_per_group(group1)
        mouse_IDs, group2 = mh.generate_groups_matrix(wheel_groups_dic[wtime_z], reg_dic_wheel, col_ind=i, key='sr_r2_scores')
        group2 = mh.means_per_group(group2)
        mh.plot_barchart(group1, group2, ['manip','wheel'], title=mtime_z + ', Manip vs Wheel (-10->+30) SR R^2: ' + reg_label + ' pval: ', 
                        label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])
        

## grouping and plotting: LR coeffs

In [None]:
print(reg_dic_manip['RR20231108_A-2023_12_05']['lr_coeffs_labels'])
lr_coeffs_ref = ['DCN','SNr']
for i in range(len(lr_coeffs_ref)):
    reg_label = lr_coeffs_ref[i]
    mouse_IDs, group1 = mh.generate_groups_matrix(m_ordered_sessions, reg_dic_manip, col_ind=i, key='lr_coeffs')
    group1 = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_matrix(w_ordered_sessions, reg_dic_wheel, col_ind=i, key='lr_coeffs')
    group2 = mh.means_per_group(group2)
    mh.plot_barchart(group1, group2, ['manip','wheel'], title='Manip vs Wheel (-10->+30) LR coeff: ' + reg_label + ' pval: ', 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[-1,1])
    

# Random Cube Generation and metric compute

In [None]:
def gen_random_cube(sessname, cage, parameter_dic, oreg_list=[], shuffle_trials = False):
    session = cage.sessions[sessname]
    photom_df = session['photom_df']
    day_dic = session['day_dic']
    cube_dic, output_dic_keys = pph.random_cube_generate(photom_df, day_dic, oreg_list, parameter_dic, shuffle_trials=shuffle_trials,
                                                         phot_coldic_override=None, p_BACK_WINDOW = 120, p_FORWARD_WINDOW = 120)
    return cube_dic

def gen_random_cage(cage, session_list, parameter_dic, oreg_list = [], shuffle_trials=False):
    rand_cage = sessions_cage()
    for sessname in session_list:
        random_cube_dic = gen_random_cube(sessname, cage, parameter_dic, oreg_list=oreg_list, shuffle_trials=shuffle_trials)
        cube_dic_name = 'rand_cube_dic_lowp_' + parameter_dic['name']
        session = {cube_dic_name: random_cube_dic}
        rand_cage.add_sess(sessname, session)
    return rand_cage

def free_up_memory(cage):
    del cage
    gc.collect()
        

In [None]:
# generate random cage
rand_manip_cage = gen_random_cage(compressed_manip_cage, m_ordered_sessions, manip_param_dic)

# compute_metric wrt rand_cage
startpoint = p_BACK_WINDOW - 30
endpoint = p_BACK_WINDOW + 60
reg_dic_manip_rand = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, rand_manip_cage, 
                                      stpt_label='', plot=False, save_flag=False, photom_cube_type='control')

# collect group means wrt reg_dic
mouse_IDs, group_rm = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip_rand, key='lr_r2_scores')
group_rm = mh.means_per_group(group_rm)

# delete large objects and garbage collect
del rand_manip_cage
gc.collect()

# optional 
# del reg_dic_manip_rand
# gc.collect()

In [None]:
rand_wheel_cage = gen_random_cage(compressed_wheel_cage, w_ordered_sessions, wheel_param_dic)
startpoint = p_BACK_WINDOW - 30
endpoint = p_BACK_WINDOW + 60
reg_dic_wheel_rand = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, rand_wheel_cage, 
                                      stpt_label='', plot=False, save_flag=False, photom_cube_type='control')
mouse_IDs, group_rw = mh.generate_groups_vector(w_ordered_sessions, reg_dic_wheel_rand, key='lr_r2_scores')
group_rw = mh.means_per_group(group_rw)
del rand_wheel_cage
gc.collect()

In [None]:
mh.plot_barchart(group1, group2, ['manip','wheel'], title='Manip vs Wheel (-10->+30) LR r^2', 
                 label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

# Bootstrapping

## RC 1: Bootstrapping: rand trial inits

In [None]:
# this will save an ultra compressed cage (no behavioral info except wavedic)
def random_cage_generate(compressed_cage, session_list, parameter_dic, cube_dic_key='zscores'):
    """_summary_
        generate random cage using copied-over (with some slight mods) preprocess helper functions from bigrun in
        preprocess_helper.py (pph)
    Args:
        cube_dic_key (str, optional): _description_. Defaults to 'zscores'.
            if you don't change this to None, the random cage generated will just have all keys (just takes up more memory)

    Returns:
        _type_: _description_
    """
    random_cage = sessions_cage()
    for sessname in session_list:
        session = compressed_cage.sessions[sessname]
        
        pc_override=None
        if sessname == 'RR20240320_G-2024_05_07': # this is a wheel session
            templis = ['CH1-470', 'CH1-410', 'Thal-470', 'Thal-410', 'DCN-470', 'DCN-410','SNr-470', 'SNr-410']
            pc_override = {key: i for i,key in enumerate(templis)}
            print(pc_override)
        
        oreg_list = [] #NOTE to richard. change this later when outliers have been determined for wheel too pls
        cube_dic, output_dic_keys = pph.random_cube_generate(session['photom_df'], session['day_dic'], oreg_list, parameter_dic, phot_coldic_override=pc_override, p_BACK_WINDOW = 150, p_FORWARD_WINDOW = 150, shuffle_trials=False)
        cube_dic_keep = {}
        if cube_dic_key != None:
            cube_dic_keep[cube_dic_key] = cube_dic[cube_dic_key]
            new_session = {'cube_dic_lowp_'+parameter_dic['name']: cube_dic_keep}
        else:
            new_session = {'cube_dic_lowp_'+parameter_dic['name']: cube_dic}
        random_cage.add_sess(sessname, new_session)
    gc.collect()
    return random_cage
        

In [None]:
rc1_wheel_grouplis = []
rc1_manip_grouplis = []

for i in range(100):
    random_wheel_cage = random_cage_generate(compressed_wheel_cage, w_ordered_sessions, wheel_param_dic, cube_dic_key='zscores')
    reg_dic_wheel_rand_inits = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, random_wheel_cage,
                                        stpt_label='', plot=False, save_flag=False, photom_cube_type='default')
    mouse_IDs, group1 = mh.generate_groups_vector(w_ordered_sessions, reg_dic_wheel_rand_inits, key='lr_r2_scores')
    group1 = mh.means_per_group(group1)
    rc1_wheel_grouplis.append(group1)
    del reg_dic_wheel_rand_inits
    del random_wheel_cage
    
    random_manip_cage = random_cage_generate(compressed_manip_cage, m_ordered_sessions, manip_param_dic, cube_dic_key='zscores')
    reg_dic_manip_rand_inits = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, random_manip_cage,
                                        stpt_label='', plot=False, save_flag=False, photom_cube_type='default')
    mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip_rand_inits, key='lr_r2_scores')
    group2 = mh.means_per_group(group2)
    rc1_manip_grouplis.append(group2)
    del reg_dic_manip_rand_inits
    del random_manip_cage
    gc.collect()

In [None]:
# Save bootstrap data (this is important for the 10k run)

filename_wheel_rc1 = 'rc1_bootstrap_100_wheel_data.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_wheel_rc1
with open(filepath, 'wb') as f:
    pickle.dump(rc1_wheel_grouplis, f)
f.close()

filename_manip_rc1 = 'rc1_bootstrap_100_wheel_data.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_manip_rc1
with open(filepath, 'wb') as f:
    pickle.dump(rc1_manip_grouplis, f)
f.close()


In [None]:
# def bootstrap_hist(control_arr, exp_group, title=''):
mh.bootstrap_hist(rc1_wheel_grouplis, group_wheel_lr_r2, title='Wheel RC1 (random inits) : 100 iter')

In [None]:
mh.bootstrap_hist(rc1_manip_grouplis, group_manip_lr_r2, title='Manip RC1 (random inits) : 100 iter')

## RC 3: Bootstrapping: Rand shifts (Exp vs Nobehav_Shtrial)

In [None]:
wshifts_grouplis = []
mshifts_grouplis = []
for i in range(1000):
    reg_dic_manip_shifts = mu.regression_wrapper(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage,
                                                random_adjusts=True,
                                        stpt_label='', plot=False, save_flag=False, photom_cube_type='default')
    reg_dic_wheel_shifts = mu.regression_wrapper(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                                random_adjusts=True,
                                        stpt_label='', plot=False, save_flag=False, photom_cube_type='default')
    
    mouse_IDs, group1 = mh.generate_groups_vector(w_ordered_sessions, reg_dic_wheel_shifts, key='lr_r2_scores')
    group1 = mh.means_per_group(group1)
    wshifts_grouplis.append(group1)
    mouse_IDs, group2 = mh.generate_groups_vector(m_ordered_sessions, reg_dic_manip_shifts, key='lr_r2_scores')
    group2 = mh.means_per_group(group2)
    mshifts_grouplis.append(group2)
    del reg_dic_manip_shifts
    del reg_dic_wheel_shifts
    gc.collect()


In [None]:
# Save bootstrap data (this is important for the 10k run)

filename_wheelshifts = 'wheel_shifts_1000_list.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_wheelshifts
with open(filepath, 'wb') as f:
    pickle.dump(wshifts_grouplis, f)
f.close()

filename_manipshifts = 'manip_shifts_1000_list.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_manipshifts
with open(filepath, 'wb') as f:
    pickle.dump(mshifts_grouplis, f)
f.close()


In [None]:
# Reload bootstrap data

filename_manipshifts = 'manip_shifts_1000_list.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_manipshifts
with open(filepath, 'rb') as f:
    loaded_mps = pickle.load(f) # deserialize using load()
f.close()

filename_wheelshifts = 'wheel_shifts_1000_list.pkl'
filepath = datapath + '/misc_pickles' + '/' + filename_wheelshifts
with open(filepath, 'rb') as f:
    loaded_wps = pickle.load(f) # deserialize using load()
f.close()

In [None]:

    
mshifts_arr = np.array(mshifts_grouplis)
mh.bootstrap_hist(mshifts_arr, group_manip_lr_r2, title='manip exp vs shifts bootstrap (1000 iter)')

In [None]:
wshifts_arr = np.array(wshifts_grouplis)
bootstrap_hist(wshifts_arr, group_wheel_lr_r2, title='wheel exp vs shifts bootstrap (1000 iter)')

#group_manip_lr_r2, group_wheel_lr_r2

# Mutual Information Pipeline

In [None]:
startpoint = p_BACK_WINDOW - 30
endpoint = p_BACK_WINDOW + 60

mi_dic_manip = mu.gen_mi_dic(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                    photom_cube_type='default')
mi_dic_wheel = mu.gen_mi_dic(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                    photom_cube_type='default')

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, mi_dic_manip, key='dcn_mi')
group_manip_dcn_mi = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, mi_dic_wheel, key='dcn_mi')
group_wheel_dcn_mi = mh.means_per_group(group2)
mh.plot_barchart(group_manip_lr_r2, group_wheel_lr_r2, ['manip','wheel'], title='Manip vs Wheel (-30->+60) dcn_MI', 
                 label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

#group_manip_lr_r2, group_wheel_lr_r2

In [None]:
mouse_IDs, group1 = mh.generate_groups_vector(m_ordered_sessions, mi_dic_manip, key='snr_mi')
group_manip_dcn_mi = mh.means_per_group(group1)
mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, mi_dic_wheel, key='snr_mi')
group_wheel_dcn_mi = mh.means_per_group(group2)
mh.plot_barchart(group_manip_lr_r2, group_wheel_lr_r2, ['manip','wheel'], title='Manip vs Wheel (-30->+60) snr_MI', 
                 label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'], ylim=[0,1])

#group_manip_lr_r2, group_wheel_lr_r2

# Peaks characterization

In [None]:
startpoint = p_BACK_WINDOW - 30
endpoint = p_BACK_WINDOW + 60
peaks_dic_manip = mu.gen_peaks_dic(startpoint, endpoint, manip_param_dic, m_ordered_sessions, compressed_manip_cage, 
                                      photom_cube_type='default')
peaks_dic_wheel = mu.gen_peaks_dic(startpoint, endpoint, wheel_param_dic, w_ordered_sessions, compressed_wheel_cage,
                                    photom_cube_type='default')

In [None]:
reg = ['ch1','dcn','thal','snr']
for i in range(4):
    mouse_IDs, group1 = mh.generate_groups_matrix(m_ordered_sessions, peaks_dic_manip, col_ind=i, key='peaks')
    group_manip_peaks = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_matrix(w_ordered_sessions, peaks_dic_wheel, col_ind=i, key='peaks')
    # mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, mi_dic_wheel, key='dcn_mi')
    group_wheel_peaks = mh.means_per_group(group2)
    mh.plot_barchart(group_manip_peaks, group_wheel_peaks, ['manip','wheel'], title='PEAKS: ' + reg[i] + ' Manip vs Wheel (-30->+60)', 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'])

#group_manip_lr_r2, group_wheel_lr_r2

In [None]:
reg = ['ch1','dcn','thal','snr']
for i in range(4):
    mouse_IDs, group1 = mh.generate_groups_matrix(manip_groups_dic['early'], peaks_dic_manip, col_ind=i, key='peaks')
    group_manip_peaks_early = mh.means_per_group(group1)
    mouse_IDs, group2 = mh.generate_groups_matrix(manip_groups_dic['late'], peaks_dic_manip, col_ind=i, key='peaks')
    # mouse_IDs, group2 = mh.generate_groups_vector(w_ordered_sessions, mi_dic_wheel, key='dcn_mi')
    group_manip_peaks_late = mh.means_per_group(group2)
    mh.plot_barchart(group_manip_peaks_early, group_manip_peaks_late, ['early','late'], title='PEAKS: ' + reg[i] + ' Manip early vs late (-30->+60)', 
                    label=['A', 'B', 'C', 'D', 'F', 'G', 'H', 'I', 'J', 'K'])

#group_manip_lr_r2, group_wheel_lr_r2