### Event Distribution of Mice with Time

In [None]:
import sys
sys.path.insert(0, '../scripts')

import numpy as np
import direction_transition as dit
from meals import find_meals_paper
from accuracy import graph_single_stats
from preprocessing import read_excel_by_sheet
from path import *
from intervals import mean_pellet_collect_time, plot_retrieval_time_by_block, perform_T_test


spl_root = '../export/Supplementary/WT_Transitions'
export_root = '../export/Figure 4'
import os
os.makedirs(export_root, exist_ok=True)
os.makedirs(spl_root, exist_ok=True)
time_threshold = 60
pellet_count_threshold = 2
len(rev_male_sheets), len(rev_female_sheets)

In [None]:
female_block_fir_meal = []
female_block_fir_meal_prop = []
female_meal_avg_acc = []
female_n_blocks = []
temp = []

for sheet in rev_female_sheets:
  temp_prop = []
  temp_acc = []
  meals = [] 
  data = read_excel_by_sheet(sheet, rev_female_path, cumulative_accuracy=False)  
  blocks = dit.split_data_to_blocks(data)
  
  for block in blocks:
    meal, meal_acc = find_meals_paper(block,
                                      time_threshold=time_threshold, 
                                      pellet_threshold=pellet_count_threshold)
    meals.extend(meal)
    temp_acc.extend(meal_acc)

  data_stats = dit.get_transition_info(blocks, [time_threshold, pellet_count_threshold], reverse=False)
  first_meal_blocks = data_stats['First_Good_Meal_Time'].tolist()
  block_time = data_stats['Block_Time'].tolist()
  each = 0
  for i in range(len(first_meal_blocks)):
    if first_meal_blocks[i] < block_time[i]:
      each += 1
  temp.append(each/len(blocks))
  female_n_blocks.append(len(blocks))
  avg_ratio, avg_time, avg_good_time = dit.first_meal_stats(data_stats, ignore_inactive=True)
  female_block_fir_meal.append(avg_good_time)
  female_meal_avg_acc.append(np.mean(temp_acc))
  female_block_fir_meal_prop.append(avg_ratio)
  
  # dit.graph_tranition_stats(data_stats, blocks, sheet, export_path=os.path.join(spl_root, f'female_{sheet}_transitions.svg'))

In [None]:
graph_single_stats(female_n_blocks, stats_name='Number of Blocks', unit='', group_name='Female',
                  export_path=os.path.join(export_root, 'WT_number_of_blocks.svg'), violin_width=0.3,)

In [None]:
graph_single_stats(female_block_fir_meal, stats_name='Mean 1st Meal Time', unit='minutes', group_name='Female',
                  export_path=os.path.join(export_root, 'WT_1st_meal_absolute_time.svg'), violin_width=0.3,)

In [None]:
graph_single_stats(female_block_fir_meal_prop, 
                  unit='%', stats_name='1st Meal and Block Time Ratio', 
                  group_name='Female', violin_width=0.3,
                  export_path=os.path.join(export_root, 'WT_1st_meal_normalized.svg'))

In [None]:
graph_single_stats(female_meal_avg_acc, violin_width=0.3,
                  unit='%', stats_name='Meal Accuracy', group_name='Female',
                  export_path=os.path.join(export_root, 'WT_meal_accuracy.svg'))

### Learning Score

In [None]:
export_root = '../export/Figure 3'
os.makedirs(export_root, exist_ok=True)
action_prop = 0.75
block_prop = 0.6
meal_config=[60, 2]
day = 3
female_scores = []
female_learning_result = []
female_blocks_list = []

In [None]:
for sheet in rev_female_sheets:
    data = read_excel_by_sheet(sheet, rev_female_path, cumulative_accuracy=False)  
    blocks = dit.split_data_to_blocks(data, day=day)
    data_stats = dit.get_transition_info(blocks, meal_config=meal_config, reverse=False)
    female_blocks_list.append(blocks)
    female_learning_result.append(dit.learning_result(blocks, action_prop=action_prop))
    female_scores.append(dit.learning_score(blocks, block_prop=1, action_prop=action_prop))

In [None]:
dit.plot_learning_score_trend([female_blocks_list], 
                              ['Wild Type'], 
                              export_path="../export/Figure 3/WT_learning_score_overall.svg")

In [None]:
dit.plot_pellet_ratio_trend([female_blocks_list], 
                            ['Wild Type'], 
                            export_path="../export/Figure 3/WT_pellet_in_meal_overall.svg")

In [None]:
dit.graph_learning_score_single(female_scores, group_name='Female', proportion=action_prop, 
                                export_path=os.path.join(export_root, f'WT_{int(action_prop*100)}_learning_score.svg'))

In [None]:
dit.graph_learning_results_single(female_learning_result, proportion=0.25, group_name='Female',
                                export_path=os.path.join(export_root, 'WT_learning_result.svg'))

### Pellet Retrieval Analysis

In [None]:
export_root = '../export/Supplementary'
os.makedirs(os.path.join(export_root, 'WT_retrieval_time'), exist_ok=True)
female_all_times = []
female_mean = []
female_pred = []
female_slope= []
rev_female_time_dict = {}

In [None]:
for sheet in rev_female_sheets[:]:
    times, mean, std = mean_pellet_collect_time(rev_female_path, sheet, remove_outlier=True, n_stds=3, day=day)
    _, pred, slope = plot_retrieval_time_by_block(rev_female_path, sheet, day=day, n_stds=3, 
                                        export_path=os.path.join(export_root, 'WT_retrieval_time',
                                                                 f'female_{sheet}.svg'))
    female_all_times.append(times)
    female_mean.append(mean)
    female_pred.append(pred)
    female_slope.append(slope)
    rev_female_time_dict[sheet] = mean

In [None]:
graph_single_stats(female_mean, 'Pellet Retrieval Time',  group_name='Female', unit='minutes', violin_width=0.3,
                  export_path=os.path.join(export_root, 'wt_retrieval_mean.svg'))

In [None]:
female_pred.remove(max(female_pred))

In [None]:
graph_single_stats(female_pred, 'Predicted Pellet Retrieval Time', group_name='Female', unit='minutes',violin_width=0.3,
                  export_path=os.path.join(export_root, 'wt_retrieval_pred.svg'))

In [None]:
graph_single_stats(female_slope, 'Retrieval Time Best-fit Line Slope', group_name='Female', unit='', violin_width=0.3,
                  export_path=os.path.join(export_root, 'wt_retrieval_slope.svg'))

In [None]:
female_all_times = []
female_means = []
female_stds = []
fr1_female_dict = {}

for sheet in fr1_female_sheets:
    times, mean, std = mean_pellet_collect_time(fr1_female_path, sheet, remove_outlier=True, n_stds=2, day=3)
    female_all_times.append(times)
    female_means.append(mean)
    female_stds.append(std)
    fr1_female_dict[sheet] = mean

In [None]:
female_means.remove(max(female_means))
female_means.remove(max(female_means))

In [None]:
graph_single_stats(female_means, 'Pellet Retrieval Time', group_name='Female', unit='minutes',violin_width=0.3,
                  export_path=os.path.join('../export/Figure 2/', 'wt_fr1_retrieval_mean.svg'))