# Loading data

In [4]:
import os
import copy
import numpy
import modules.io as io
import modules.plot_func as pltt
import modules.traj_analysis as tran
import modules.helper_func_class as misc
import modules.process_mouse_trials_lib as plib
import matplotlib.pyplot as plt
import scipy.stats
import warnings

warnings.filterwarnings('ignore')
#import matplotlib.image as mpimg


"""

THE SUFFIXES IN THE VARIABLE NAMES

*** rt suffix -> relative target experiments
                 this means STATIC entrance,
                 since the target is always
                 positioned in the same spot
                 relative to the entrance
                 the mouse takes

*** ft suffix -> fixed target experiments
                 this means RANDOM entrance,
                 since the target is in a
                 different spot in every trial,
                 relative to the entrance the mouse
                 takes

*** 2t suffix -> two target experiments
                 mice are trained in two locations, consecutively

"""

mouse_traj_dir_rt      = r'./experiments/relative_target/mouse_*'
mouse_traj_dir_rt90    = r'./experiments/relative_target_90deg/mouse_*'
mouse_traj_dir_ft      = r'./experiments/fixed_target/mouse_*'
mouse_traj_dir_2t      = r'./experiments/two_target_no_cues/mouse_*'
mouse_traj_dir_p2_R180 = [r'./experiments/two_targets_rot/mouse_*'       ,
                          r'./experiments/two_targets_rot_fem/mouse_*'   ,
                          r'./experiments/two_targets_rot_mixsex/mouse_*']

selfint_file_variable  = 'selfint_st_rt'

output_dir = 'figs/paper/active_sensing'
try:
    os.makedirs(output_dir)
except FileExistsError:
    pass

colors = pltt.get_gradient(5,'blue2')

color_lightblue   = numpy.array((65, 102, 216, 255))/255 #pltt.get_gradient(5,'blue2')[2] #numpy.array((65, 102, 216, 255))/255
color_darkblue    = pltt.get_gradient(2,'blue')[1] #numpy.array((224, 53, 53, 255))/255
color_lightred    = pltt.get_gradient(2,'red')[0] #numpy.array((65, 102, 216, 255))/255 #pltt.get_gradient(5,'blue2')[2] #numpy.array((65, 102, 216, 255))/255
color_darkred     = pltt.get_gradient(2,'red')[1] #numpy.array((224, 53, 53, 255))/255
color_lightyellow = numpy.append(pltt.get_gradient(10,'yellow')[2],1.0)
color_darkyellow  = numpy.append(pltt.get_gradient(2,'yellow')[::-1][0],1.0)

color_h1 = color_lightblue
color_h2 = color_darkblue

SMALL_FONTSIZE  = 12
MEDIUM_FONTSIZE = 14
LARGE_FONTSIZE  = 16
plt.rc('font',        size=SMALL_FONTSIZE )  # controls default text sizes
plt.rc('axes',   titlesize=SMALL_FONTSIZE )  # fontsize of the axes title
plt.rc('axes',   labelsize=MEDIUM_FONTSIZE) # fontsize of the x and y labels
plt.rc('xtick',  labelsize=SMALL_FONTSIZE ) # fontsize of the tick labels
plt.rc('ytick',  labelsize=SMALL_FONTSIZE ) # fontsize of the tick labels
plt.rc('legend',  fontsize=SMALL_FONTSIZE ) # legend fontsize
plt.rc('figure', titlesize=LARGE_FONTSIZE ) # fontsize of the figure title


FIGSIZE_1PANEL = numpy.array((4,3),dtype=float)
AX_POS_1PANEL  = numpy.array( [ 0.125, 0.125, 0.9, 0.88 ] ) # [xmin,ymin,width,heigh]
AX_POS_2PANELS = numpy.array([   AX_POS_1PANEL/2.0,
                               [ 1.5*AX_POS_1PANEL[0]+AX_POS_1PANEL[2]/2,AX_POS_1PANEL[1]/2,AX_POS_1PANEL[2]/2-0.5*AX_POS_1PANEL[0],AX_POS_1PANEL[3]/2 ] ])


# linear regression aux function
linreg_lin_func = lambda x,lr: lr.intercept + lr.slope * x
linreg_pl_func  = lambda x,lr: numpy.exp(lr.intercept + numpy.log(x)*lr.slope)
linreg_exp_func = lambda x,lr: numpy.exp(lr.intercept + x*lr.slope)

# t-test aux function
p_significant = 0.05
q_FDR         = 0.05 # FDR-level for thresholding p_values correcting for false-discovery rate in multiple comparisons
check_significance = lambda ttest_res: ttest_res.pvalue < p_significant


# loads experiment MAT files from file path according to the parameters

time_delay_after_food  = 3.0 # sec

# all_trials[k][m] -> mouse m in trial k
all_trials_rt,trial_labels_rt = io.load_trial_file(mouse_traj_dir_rt,load_only_training_sessions_relative_target=True ,skip_15_relative_target=True ,use_extra_trials_relative_target=False,sort_by_trial=True,fix_nan=True,remove_after_food=False,align_to_top=True,group_by='trial',return_group_by_keys=True)
all_trials_ft,trial_labels_ft = io.load_trial_file(mouse_traj_dir_ft,load_only_training_sessions_relative_target=True ,skip_15_relative_target=True ,use_extra_trials_relative_target=False,sort_by_trial=True,fix_nan=True,remove_after_food=False,align_to_top=True,group_by='trial',return_group_by_keys=True,max_trial_number=len(all_trials_rt))
all_trials_ft_full,_          = io.load_trial_file(mouse_traj_dir_ft,load_only_training_sessions_relative_target=True ,skip_15_relative_target=False,use_extra_trials_relative_target=True ,sort_by_trial=True,fix_nan=True,remove_after_food=False,align_to_top=True,group_by='trial',return_group_by_keys=True,max_trial_number=18)
all_trials_rt_p1_complete     = io.load_trial_file(mouse_traj_dir_rt,file_name_expr='mpos_*Probe_*',load_only_training_sessions_relative_target=False,skip_15_relative_target=False,use_extra_trials_relative_target=True,remove_after_food=False,sort_by_trial=True,fix_nan=True,align_to_top=True)
all_trials_ft_p1_complete     = io.load_trial_file(mouse_traj_dir_ft,file_name_expr='mpos_*Probe_*',load_only_training_sessions_relative_target=False,skip_15_relative_target=False,use_extra_trials_relative_target=True,remove_after_food=False,sort_by_trial=True,fix_nan=True,align_to_top=True)
all_trials_rt                 = tran.remove_path_after_food(all_trials_rt,r_target=None,return_t_to_food=False,force_main_target=True,hole_horizon=None,time_delay_after_food=time_delay_after_food)
all_trials_ft                 = tran.remove_path_after_food(all_trials_ft,r_target=None,return_t_to_food=False,force_main_target=True,hole_horizon=None,time_delay_after_food=time_delay_after_food)
all_trials_ft_full            = tran.remove_path_after_food(all_trials_ft_full,r_target=None,return_t_to_food=False,force_main_target=True,hole_horizon=None,time_delay_after_food=time_delay_after_food)
all_trials_rt_p1              = tran.remove_path_after_food(all_trials_rt_p1_complete,r_target=None,return_t_to_food=False,force_main_target=True,hole_horizon=None,copy_tracks=True,time_delay_after_food=time_delay_after_food)
all_trials_ft_p1              = tran.remove_path_after_food(all_trials_ft_p1_complete,r_target=None,return_t_to_food=False,force_main_target=True,hole_horizon=None,copy_tracks=True,time_delay_after_food=time_delay_after_food)

hole_horizon_rt_R180          = 5.0 # cm
get_hole_horiz_rt_R180        = lambda tr: hole_horizon_rt_R180 if int(tr.mouse_number) in [14,16] else None
all_trials_rt_R180            = io.load_trial_file(mouse_traj_dir_rt,file_name_expr='mpos_*_R180_1*',load_only_training_sessions_relative_target=False,skip_15_relative_target=True,use_extra_trials_relative_target=True,sort_by_trial=True,fix_nan=True,remove_after_food=False,group_by='none')#,max_trial_number=n_trials_to_use)
all_trials_rt_R180            = io.group_track_list([ tran.remove_path_after_food(tr,r_target=tr.r_target_reverse,return_t_to_food=False,force_main_target=False,hole_horizon=get_hole_horiz_rt_R180(tr),time_delay_after_food=0.0) for tr in all_trials_rt_R180 ],group_by='trial',return_group_keys=False)[0]
all_trials_rt_R180            = plib.rotate_trial_file(all_trials_rt_R180,(0,-1),return_only_track=True)

save_output_figures = False

In [26]:
all_trials_rt_R90            = io.load_trial_file(mouse_traj_dir_rt90,file_name_expr='mpos_*',
                                                  load_only_training_sessions_relative_target=True,
                                                  skip_15_relative_target=False,use_extra_trials_relative_target=False,
                                                  sort_by_trial=True,fix_nan=True,remove_after_food=True,group_by='trial',max_trial_number=21)
#all_trials_rt_R90            = io.group_track_list([ tran.remove_path_after_food(tr,r_target=tr.r_target_reverse,return_t_to_food=False,force_main_target=False,hole_horizon=get_hole_horiz_rt_R180(tr),time_delay_after_food=0.0) for tr in all_trials_rt_R90 ],group_by='trial',return_group_keys=False)
all_trials_rt_R90            = plib.rotate_trial_file(all_trials_rt_R90,(0,-1),return_only_track=True)


In [31]:
all_trials_rt_R90_probe      = io.load_trial_file(mouse_traj_dir_rt90,file_name_expr='mpos_*R90*',
                                                  load_only_training_sessions_relative_target=False,
                                                  skip_15_relative_target=False,use_extra_trials_relative_target=False,
                                                  sort_by_trial=True,fix_nan=True,remove_after_food=False,group_by='none')
all_trials_rt_R90_probe      = io.group_track_list([ tran.remove_path_after_food(tr,r_target=tr.r_target_reverse,return_t_to_food=False,force_main_target=False,hole_horizon=get_hole_horiz_rt_R180(tr),time_delay_after_food=0.0) for tr in all_trials_rt_R90_probe ],group_by='trial',return_group_keys=False)[0]
all_trials_rt_R90_probe      = plib.rotate_trial_file(all_trials_rt_R90_probe,(0,-1),return_only_track=True)


In [32]:
all_trials_rt_R90_probe

[trackfile(file_name='Raw data-Hidden Food Maze-04Nov2022-Trial   103.xlsx', day='13', exper_date='04Nov2022', is_reverse=1, keep_between_targets=False, mouse_gender='M', mouse_number='2', remove_after_food=False, start_location='SW', start_quadrant=2, trial='R90', arena_picture='BKGDimage-20221104_cropped.png', arena_picture_extent=[-105.53609805924413, 92.53703779366703, -72.46542805100178, 76.50943533697632]),
 trackfile(file_name='Raw data-Hidden Food Maze-04Nov2022-Trial   105.xlsx', day='13', exper_date='04Nov2022', is_reverse=1, keep_between_targets=False, mouse_gender='M', mouse_number='4', remove_after_food=False, start_location='NE', start_quadrant=4, trial='R90', arena_picture='BKGDimage-20221104_cropped.png', arena_picture_extent=[-105.53609805924413, 92.53703779366703, -72.46542805100178, 76.50943533697632])]

In [25]:
all_trials_rt_R90[-1]

[trackfile(file_name='Raw data-Hidden Food Maze-04Nov2022-Trial    98.xlsx', day='12', exper_date='04Nov2022', is_reverse=0, keep_between_targets=False, mouse_gender='M', mouse_number='1', remove_after_food=False, start_location='SW', start_quadrant=2, trial='21', arena_picture='BKGDimage-20221104_cropped.png', arena_picture_extent=[-105.53609805924413, 92.53703779366703, -72.46542805100178, 76.50943533697632]),
 trackfile(file_name='Raw data-Hidden Food Maze-04Nov2022-Trial    99.xlsx', day='12', exper_date='04Nov2022', is_reverse=0, keep_between_targets=False, mouse_gender='M', mouse_number='2', remove_after_food=False, start_location='SE', start_quadrant=3, trial='21', arena_picture='BKGDimage-20221104_cropped.png', arena_picture_extent=[-105.53609805924413, 92.53703779366703, -72.46542805100178, 76.50943533697632]),
 trackfile(file_name='Raw data-Hidden Food Maze-04Nov2022-Trial   100.xlsx', day='12', exper_date='04Nov2022', is_reverse=0, keep_between_targets=False, mouse_gender='M

# Rotated trajectories examples

In [None]:
#import sys
#if 'modules.traj_analysis' in sys.modules.keys():
#    del sys.modules['modules.traj_analysis']
#if tran:
#    del tran
#import modules.traj_analysis as tran

#import sys
#if 'modules.plot_func' in sys.modules.keys():
#    del sys.modules['modules.plot_func']
#if pltt:
#    del pltt
#import modules.plot_func as pltt

use_reverse_targets  = True
keep_between_targets = False
stop_at_food         = True
filename_expr        = 'mpos_*Probe2_*'
mouse_traj_dir       = r'./experiments/two_targets_rot/mouse_*'

hole_horizon          = 10.0 # cm
time_delay_after_food = 1.0 # sec

input_tracks           = io.load_trial_file(mouse_traj_dir,file_name_expr=filename_expr,align_to_top=True,fix_nan=True,sort_by_trial=True,return_group_by_keys=False,remove_after_food=False)
if keep_between_targets:
    input_tracks       = tran.keep_path_between_targets(input_tracks,return_t_in_targets=False,hole_horizon=hole_horizon,time_delay_after_food=time_delay_after_food,copy_tracks=True,use_reverse_targets=use_reverse_targets)
elif stop_at_food:
    input_tracks       = tran.remove_path_after_food(input_tracks,force_main_target=False,return_t_to_food=False,hole_horizon=hole_horizon,time_delay_after_food=time_delay_after_food,copy_tracks=True,use_reverse_targets=use_reverse_targets)

trim_trajectories  = True
traj_alpha         = 0.5

start_vec_align       = (0,-1)

ignore_entrance_positions   = False
normalize_by                = 'max'
hole_horizon_hole_check     = 3.0 #cm
threshold_method            = 'ampv'
velocity_amplitude_fraction = 0.2
use_velocity_minima         = True
velocity_min_prominence     = 5.0 # cm/s

#if trim_trajectories:
#    tracks_to_plot = []
#    for tr,tr_trim in zip(all_trials_p2_complete,all_trials_p2):
#        t0 = 0.0 # seconds
#        if (int(tr.mouse_number) == 35):
#            t0 = 35.0 # the mouse doesn't move in the first 35 seconds
#        if (int(tr.mouse_number) == 34):
#            t0 = 20.0 # the mouse doesn't move in the first 35 seconds
#        tr_new = tran.slice_track_by_time(tr,t0=t0,t1=tr_trim.time[-1]+1,copy_track=True)
#        tr_new.time -= tr_new.time[0]
#        tracks_to_plot.append(tr_new)
#else:
#    tracks_to_plot = copy.deepcopy(all_trials_p2_complete)

tracks_to_plot    = list(misc.flatten_list(copy.deepcopy(input_tracks),only_lists=True))
nrows_ncols_tuple = (1,4)
first_trial=0


all_trials = plib.rotate_trial_file(tracks_to_plot,start_vec_align,True)
k_slow,t_slow,r_slow,v_th = misc.unpack_list_of_tuples([ tran.find_slowing_down_close_to_hole(tr,hole_horizon_hole_check,threshold_method=threshold_method,gamma=velocity_amplitude_fraction,
                                                                                                 return_pos_from='hole',ignore_entrance_positions=ignore_entrance_positions,
                                                                                                 use_velocity_minima=use_velocity_minima,velocity_min_prominence=velocity_min_prominence) for tr in all_trials ])


fig_size           = tuple(numpy.array((30,22))*0.55)
color_red          = numpy.array((255, 66, 66,255))/255
color_blue         = numpy.array(( 10, 30,211,255))/255
color_orange       = pltt.get_gradient(10,'orange')[0]
ax,lines,panel_ind = pltt.plot_all_tracks_2targets(all_trials,hole_horizon,time_delay_after_food,
                                       traj1Args=dict(color                  = plt.get_cmap('cool'),
                                                      line_gradient_variable = 'time',
                                                      alpha                  = traj_alpha,
                                                      show_colorbar    = False,
                                                      show_target=True,show_reverse_target=True,show_alt_target=True,show_reverse_alt_target=True,
                                                      startArgs        = dict(marker='s',markeredgewidth=3  ,markersize=10,color='k'       ,fillstyle='full',markerfacecolor='w',label='Start',             labelArgs=dict(fontsize=12,va='top'   ,ha='left' ,color='k',pad=(4,-1)  ) ),
                                                      targetArgs       = dict(marker='^',markeredgewidth=2  ,markersize=10,color=color_blue,fillstyle='none',label='REL B',             labelArgs=dict(fontsize=16,va='bottom',ha='left' , fontweight='bold',color=color_blue,pad=( 2, 2))),
                                                      targetAltArgs    = dict(marker='^',markeredgewidth=2  ,markersize=10,color=color_red ,fillstyle='none',label='REL A',             labelArgs=dict(fontsize=16,va='bottom',ha='left' , fontweight='bold',color=color_red ,pad=( 2, 2))),
                                                      targetRevArgs    = dict(marker='o',markeredgewidth=2  ,markersize= 9,color=color_blue,fillstyle='none',label='B'    ,             labelArgs=dict(fontsize=16,va='bottom',ha='left' , fontweight='bold',color=color_blue,pad=( 2, 2))),
                                                      targetAltRevArgs = dict(marker='o',markeredgewidth=2  ,markersize= 9,color=color_red ,fillstyle='none',label='A'    ,zorder=10000,labelArgs=dict(fontsize=16,va='bottom',ha='left' , fontweight='bold',color=color_red ,pad=( 2, 2)))),
                                       traj2Args=dict(color=(0,0,0),alpha=traj_alpha),start_align_vector=start_vec_align,
                                       trim_trajectories=False,return_panel_ind=True,show_arena_holes=True,
                                       probe_title='',fig_size=fig_size,nrows_ncols_tuple=nrows_ncols_tuple,show_probetitle=False,fix_mouse_order_in_probe=False)

for a,tr in zip(ax.flatten(),all_trials):
    r    = tr.r_start.copy()
    r[1] = -r[1]
    pltt.plot_point(r,'Start on training',ax=a,fmt='d',color=0.6*numpy.ones(3),pointArgs=dict(markeredgewidth=3,markersize=10,markerfacecolor='w',fillstyle='full'),pad=(5,4),ha='left',fontsize=12,labelcolor=0.5*numpy.ones(3))
text_h = [tt for tt in misc.flatten_list(lines) if (('text' in str(type(tt))) and (tt.get_text() != 'Start'))]
#text_h[11].set(va='top')   # A,i=0,j=3
#text_h[3].set(ha='center') # A,i=1,j=1
#text_h[-3].set(va='top')   # A,i=1,j=2
#text_h[-1].set(ha='right') # A,i=1,j=3
#text_h[6].set(va='top')    # B,i=0,j=1
#text_h[4].set(va='top',ha='center',position=numpy.array(text_h[4].get_position())-numpy.array((2,6)))    # B,i=0,j=0


ax = pltt.tight_arena_panels(ax,adjust_title_position=False,dy_amid_panels=0.02)

cax = []
for k,(ind,r,tr) in enumerate(zip(k_slow,r_slow,tracks_to_plot)):
    t_seq = tr.time[ind]/tr.time[-1]
    i,j = numpy.unravel_index(k,nrows_ncols_tuple)
    pts=pltt.plot_trajectory_points(r,ax=ax[i,j],use_scatter=True,s=1e2*t_seq,marker='o',linewidth=1,edgecolor=plt.get_cmap('cool')(t_seq),zorder=9000,alpha=0.8)
    pts.set_facecolor('none')
    if j == 0:
        ax[i,j].text(-75,-50,'Probe B-A 180$^\\circ$',rotation=90,va='bottom',ha='center',fontsize=20,fontweight='bold')
        x_lim = ax[i,j].get_xlim()
        y_lim = ax[i,j].get_xlim()
        lh = pltt.plot_horizontal_lines(-70,ax[i,j],xMin=-55,xMax=120*8,color='k',linewidth=1.5,linestyle='-')
        lh.set_clip_on(False)
        ax[i,j].set_xlim(x_lim)
        ax[i,j].set_ylim(y_lim)
    if j == (nrows_ncols_tuple[1]-1):
        cax.append(pltt._add_colorbar(ax[i,j],'cool',w_fraction_of_ax=0.03,h_fraction_of_ax=0.2,p0=(1.01,1.1),minmax_tick_labels=('start','end'),title='Time\ncourse',titleArgs=dict(color='k',fontsize=12),ticklabelsArgs=dict(color='k')))


if save_output_figures:
    detection_label = 'vmin' if use_velocity_minima else 'slowdown'
    fileName        = f'{output_dir}/twotargets_ROTPr2_trajectories_ALL_TRIALS_PROBE_A-B_{detection_label}.png'
    print(' ... saving ', fileName)
    plt.savefig(  fileName,format='png',dpi=300,facecolor=(1,1,1,1),bbox_inches='tight')



plt.show()