In [None]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os.path as path
import itertools

In [None]:
# load data tables (dataframes, .csv files)
analysis_path = "D:\\Users\\mickey\\Data\\analysis\\prediction_p10_1s_20210106163829"
metric_stat_df = pd.read_csv(path.join(analysis_path,'prediction_metric_stats.csv'))
metric_stat_all_df = pd.read_csv(path.join(analysis_path,'prediction_metric_all_stats.csv'))
metric_stat_bin_df = pd.read_csv(path.join(analysis_path,'prediction_metric_bin_stats.csv'))

In [None]:
def bootstrap_est(data,n_boot,f):
    n_sample = data.shape[0]
    est = []
    for n in range(n_boot):
        _idx = np.random.choice(np.arange(n_sample),size=n_sample,replace=True)
        est.append(f(data[_idx,]))
    est = np.stack(est,axis=0)
    return est

In [None]:
corr_lists = [[float(x) for x in a.replace('[','').replace(']','').replace('\n','').split(' ') if not x=='']  for a in metric_stat_df.corr_mean.values]
corr_means = np.array(list(itertools.chain.from_iterable(corr_lists)))
# corr_mean_bsd = bootstrap_est(corr_means,100,lambda x: np.nanmean(x))
corr_mean = np.nanmean(corr_means)
corr_ci = np.nanpercentile(corr_means,[2.5,97.5])
print(corr_mean,corr_ci)

In [None]:
rpe_lists = [[float(x) for x in a.replace('[','').replace(']','').replace('\n','').split(' ') if not x=='']  for a in metric_stat_df.rpe_mean.values]
rpe_means = np.array(list(itertools.chain.from_iterable(rpe_lists)))
rpe_means = rpe_means[rpe_means < 10.0]
# corr_mean_bsd = bootstrap_est(corr_means,100,lambda x: np.nanmean(x))
rpe_mean = np.nanmean(rpe_means)
rpe_ci = np.nanpercentile(rpe_means,[2.5,97.5])
print(rpe_mean,rpe_ci)

In [None]:
# make xlabel string
rec_label = [f'{path.basename(path.dirname(path.dirname(f)))[-2:]}.{path.basename(path.dirname(f))}' for f in metric_stat_all_df.file_path]
print(rec_label)

In [None]:
rec_day = [path.basename(path.dirname(path.dirname(f)))[-2:] for f in metric_stat_all_df.file_path]
plt.figure(figsize=(8,2),dpi=80)
plt.fill_between(np.arange(len(metric_stat_all_df)),metric_stat_all_df['rpe_ci_2.5'],metric_stat_all_df['rpe_ci_97.5'],label='CI',alpha=0.2)
plt.plot(metric_stat_all_df.rpe_mean,label='mean')
plt.xticks(ticks=np.arange(len(metric_stat_all_df)),labels=rec_label,rotation=90);
plt.legend(loc=0)
plt.title('RPE, AR model (p = 10)')
plt.xlabel('Recording')
plt.ylabel('RPE')
plt.savefig(path.join(analysis_path,'rpe_v_files.png'))

In [None]:
# model_use_idx = np.array([1,2,3,4,9,10,11,12,13,14,19,20,21,22,23,24,26,27,28,29,30,31,32,33,34])
# display(metric_stat_all_df.iloc[model_use_idx])
model_use_idx = np.arange(len(metric_stat_all_df.rpe_mean))[metric_stat_all_df.rpe_mean < 2.0]

In [None]:
rec_label[metric_stat_all_df.rpe_mean.values.argmin()]

In [None]:
files_used = list(metric_stat_all_df.file_path.iloc[model_use_idx])
stat_bin_use = [f in files_used for f in list(metric_stat_bin_df.file_path)]
metric_stat_bin_df['file_label'] = np.array(rec_label).repeat(10)

In [None]:
g = sns.lineplot(x='bin_t',y='rpe_bin_mean',hue="file_label"
    ,data=metric_stat_bin_df[stat_bin_use],ci=None,palette='coolwarm')
g.set_title("Mean RPE v. Time")
g.get_figure().savefig(path.join(analysis_path,"mean_rpe_v_time.png"))
# g.set_legend(rec_label)