In [None]:
# !jt -t onedork -T
# !jt -r

In [None]:
%config Completer.use_jedi = False # To make auto-complete faster

#Reloads imported files automatically
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:88% !important; }</style>"))

In [None]:
import pandas as pd
import numpy as np
import scipy.stats as stats

import os
from collections import namedtuple
import warnings

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import colormaps as mplcmaps

from plotting.matplotlib_param_funcs import set_matplotlib_params,reset_rcParams
set_matplotlib_params()

In [None]:
import src.compute_variables as CV
import src.bootstrap_errors as bootstrap
import src.montecarlo_errors as monte_carlo
from src.errorconfig import MonteCarloConfig,BootstrapConfig

import plotting.plotting_helpers as PH
import plotting.map_functions as mapf
import plotting.mixed_plots as MP

import utils.error_helpers as error_helpers
import utils.miscellaneous_functions as MF
import utils.coordinates as coordinates
import utils.load_sim as load_sim
import utils.load_data as load_data

In [None]:
plt.rcParams["font.size"] = 20

In [None]:
#CHOOSE

x_var = "l"
y_var = "b"

vel_x_var = 'r'
vel_y_var = 'l'

In [None]:
degree_symbol = "^\circ"

symbol_dict = mapf.get_kinematic_symbols_dict(x_variable=x_var,
                                             y_variable=y_var,
                                             vel_x_variable=vel_x_var,
                                             vel_y_variable=vel_y_var)

units_dict = mapf.get_kinematic_units_dict(degree_symbol=degree_symbol)

pos_symbols_dict,pos_units_dict = mapf.get_position_symbols_and_units_dict(degree_symbol=r"$%s$"%degree_symbol)

titles_dict = mapf.get_kinematic_titles_dict()

In [None]:
general_path = '/Users/luismi/Desktop/MRes_UCLan/'

In [None]:
def get_save_path_spatial_cuts(base_path, spatial_cuts):
    
    full_str = MF.combine_multiple_cut_dicts_into_str(spatial_cuts, cut_separator="_", order_separator="/")
    
    save_path = base_path + full_str + "/"
    
    os.makedirs(save_path, exist_ok=True)
    
    return save_path

In [None]:
all_funcs_dict = {
    "anisotropy": CV.calculate_anisotropy,
    "correlation": CV.calculate_correlation,
    "tilt_abs": CV.calculate_tilt
}

# Load

In [None]:
zabs = True
# zabs = False

R0 = 8.1

GSR = True
# GSR = False

## Sim

In [None]:
sim_choice = "708main"
# sim_choice = "708mainDiff4"
# sim_choice = "708mainDiff5"

rot_angle = 27
axisymmetric = False
pos_scaling = 1.7

filename = load_sim.build_filename(choice=sim_choice,rot_angle=rot_angle,R0=R0,axisymmetric=axisymmetric,zabs=zabs,pos_factor=pos_scaling,GSR=GSR)

In [None]:
np_path = general_path+f"data/{sim_choice}/numpy_arrays/"
        
df0 = load_sim.load_simulation(path=np_path,filename=filename)

## Data

In [None]:
obs_errors = True
# obs_errors = False

data_zabs = True
# data_zabs = False

In [None]:
data_path = general_path+"data/Observational_data/"

data = load_data.load_and_process_data(data_path=data_path, error_bool=obs_errors, zabs=zabs, R0=R0, GSR=GSR, drop_unused=False)

# Data uncertainty histograms

In [None]:
var = "d"

In [None]:
save_path = general_path + "graphs/Observations/Apogee/Uncertainties/"
save_path += f"{var}/"

MF.create_dir(save_path)

print(save_path)

In [None]:
# cuts_dict = {}
cuts_dict = {"FeH":[-1,0.61]}

data_df = MF.apply_cuts_to_df(data, cuts_dict=cuts_dict)

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 100
# log_bool = True
log_bool = False
# plot_range = [0,0.4]
plot_range = None

if True: # error hist
    fig,ax=plt.subplots()
    ax.hist(data_df[var+"_error"],bins=bins,log=log_bool,range=plot_range)
    ax.axvline(data_df[var+"_error"].median(),label="Median: %s %s"%(MF.return_int_or_dec(data_df[var+"_error"].median(),2),mapf.get_units(var)),color="red")
    ax.set_xlabel(mapf.get_symbol(var+"_error")+(f" [{mapf.get_units(var)}]" if mapf.get_units(var) != "" else ""))
    ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename, save
    filename = var
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict) if len(cuts_dict) > 0 else ""
    filename += f"_{plot_range[0]}range{plot_range[1]}" if plot_range is not None else ""
    filename += f"_{bins}bins"
    filename += "_log" if log_bool else ""
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 100
# log_bool = True
log_bool = False
# plot_range = [0,0.2]
plot_range = None

if True: # fractional error hist
    fig,ax=plt.subplots()

    ax.hist(data_df[var+"_error"]/np.abs(data_df[var]),bins=bins,range=plot_range,log=log_bool)
    ax.axvline((data_df[var+"_error"]/np.abs(data_df[var])).median(),\
               label="Median $%s$"%(100*MF.return_int_or_dec((data_df[var+"_error"]/np.abs(data_df[var])).median(),2)) +r"$~$%",color="red")
    ax.set_xlabel(mapf.get_symbol(var+"_fractionalerror"))
    ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename, save
    filename = f"{var}_frac"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{plot_range[0]}range{plot_range[1]}" if plot_range is not None else ""
    filename += f"_{bins}bins"
    filename += "_log" if log_bool else ""
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

# MC

In [None]:
def get_save_path_MC(perturbed_vars, data_bool):
    
    save_path = general_path + f"graphs/other_plots/MonteCarlo/" + str.join(",", perturbed_vars) + "/"
    MF.create_dir(save_path)

    save_path += "data/" if data_bool else "model/"
    MF.create_dir(save_path)
    
    return save_path

In [None]:
# data_bool = True
data_bool = False

# perturbed_vars = ["d","vr","pmra","pmdec"]
perturbed_vars = ["d"]

In [None]:
if data_bool:
    cuts_dict = {"FeH":[-1,-0.21], "l":[-2,2], "R":[0,2]}
#     cuts_dict = {"FeH":[-0.21,0.61], "l":[-2,2], "R":[0,3.5]}
#     cuts_dict = {"FeH":[-0.21,0.61], "l":[-2,2], "R":[0,2]}
else:
#     cuts_dict={"age":[0,4],"R":[0,5],"l":[-2,2],"b":[0,0.01], "R":[0,3.5]} # nuclear disc
#     cuts_dict={"age":[4,7],"l":[-2,2],"b":[3.51,6.6], "R":[0,3.5]} # young pop
#     cuts_dict={"age":[4,7],"l":[-2,2],"b":[3,6], "R":[0,2]} # young pop
    cuts_dict={"age":[9.5,10],"l":[-2,2],"b":[3.51,6.6], "R":[0,3.5]} # old pop

In [None]:
df = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=cuts_dict)
print(len(df),"stars")

In [None]:
# func_name = "correlation"
func_name = "anisotropy"
# func_name = "tilt_abs"

func = all_funcs_dict[func_name]

true_value = func(vx=df["v"+vel_x_var].values,vy=df["v"+vel_y_var].values)
print(f"{func_name}: {true_value:.3f}")

In [None]:
if "d" in perturbed_vars:
    affected_cuts_dict = {k:v for k,v in cuts_dict.items() if k in ["d","R"]}
else:
    affected_cuts_dict = None
    
repeats = 500

## Multiple frac errors

In [None]:
save_path = get_save_path_MC(perturbed_vars=perturbed_vars,data_bool=data_bool)

save_path += "multiple_frac_errors/"
MF.create_dir(save_path)

save_path += func_name + "/"
MF.create_dir(save_path)

print(save_path)

In [None]:
df_MC = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=MF.clean_cuts_from_dict(cuts_dict,cuts_to_remove=affected_cuts_dict))

print(f"Before affected cuts: {len(df_MC)}")
print(f"After affected cuts: {len(df)}")

In [None]:
boot_result_68 = bootstrap.scipy_bootstrap(vx=df["v"+vel_x_var],vy=df["v"+vel_y_var],function=all_funcs_dict[func_name],repeats=repeats,
                                           confidence_level=0.68)
# boot_result_95 = bootstrap.scipy_bootstrap(vx=df["v"+vel_x_var],vy=df["v"+vel_y_var],function=all_funcs_dict[func_name],repeats=repeats,
#                                            confidence_level=0.95)

In [None]:
min_frac_err = 0.05
max_frac_err = 0.35
frac_err_step = 0.025

frac_errors = np.arange(min_frac_err, max_frac_err+frac_err_step, frac_err_step)

In [None]:
all_MC_results = np.full(shape=(len(frac_errors),repeats), fill_value=None)

montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,affected_cuts_dict=affected_cuts_dict,repeats=repeats,symmetric=False)

for i,frac_err in enumerate(frac_errors):
    print(f"{frac_err:.3f}",end="; ")
    
    montecarloconfig.error_frac = frac_err
    
    MC_result = monte_carlo.get_std_MC(function=func,df=df_MC,config=montecarloconfig,true_value=true_value,vel_x_var=vel_x_var,vel_y_var=vel_y_var)
    
    all_MC_results[i] = MC_result.MC_distribution
    
if None in all_MC_results:
    raise ValueError("Not all MC results were filled correctly")

In [None]:
show_boot_CI = True
# show_boot_CI = False

boot_CI_level = 68
# boot_CI_level = 95

if boot_CI_level == 68:
    CI_low,CI_high = boot_result_68.confidence_interval
elif boot_CI_level == 95:
    CI_low,CI_high = boot_result_95.confidence_interval

In [None]:
# save_bool = True
save_bool = False

In [None]:
fig,ax=plt.subplots()

deviations = np.mean(all_MC_results, axis=1) - true_value

norm = plt.Normalize(vmin=deviations.min(), vmax=deviations.max())
cmap = PH.choose_cmap(vmin=deviations.min(), vmax=deviations.max(),all_from_divergent_cmap=True)

colors = [cmap(norm(dev)) for dev in deviations]

if True: # plot
    
    for f,frac_err in enumerate(frac_errors):
        ax.boxplot(x=all_MC_results[f],positions=[f],patch_artist=True,widths=0.75,medianprops={"color":"k"},
                   boxprops={"facecolor": colors[f]})
    
    ax.axhline(y=true_value,color="red",label="True value")
    
    if show_boot_CI:
        ax.axhline(y=CI_low,color="grey",linestyle="--")
        ax.axhline(y=CI_high,color="grey",linestyle="--",label=f"{boot_CI_level}% bootstrap CI")

if True: # axis, legend
    xticklabels = len(frac_errors)*[None]
    for f,frac_err in enumerate(frac_errors):
        if f%2!=0:
            continue
        xticklabels[f] = f"{frac_err:.3f}"
    
    ax.set_xticks(ticks=range(len(frac_errors)),labels=xticklabels)
    
    ax.set(xlabel="Fractional distance error", ylabel=mapf.get_kinematic_titles_dict()[func_name])
    
    if func_name == "tilt_abs":
        ax.set_ylim(bottom=max([-45,ax.get_ylim[0]]))
    
    ax.legend()

if True: # filename and save
    filename = func_name
    filename += f"_boxplots_{min_frac_err}frac{max_frac_err}step{frac_err_step}"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{repeats}repeats"
    filename += f"_boot{boot_CI_level}CI"
    
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        
        for fileformat in [".png",".pdf"]:
            plt.savefig(save_path+filename+fileformat, bbox_inches="tight", dpi=250)
            print(fileformat)
    plt.show()

In [None]:
show_boot_dist = True
# show_boot_dist = False

In [None]:
save_bool = True
# save_bool = False

In [None]:
fig,ax=plt.subplots()

min_i = 0

if True: # range, colors
    hist_range = [
        min([np.min(all_MC_results), np.min(boot_result_68.bootstrap_distribution)]),
        max([np.max(all_MC_results), np.max(boot_result_68.bootstrap_distribution)])
    ]

    deviations = all_MC_results.mean(axis=1) - true_value
    
    norm = plt.Normalize(vmin=deviations.min(), vmax=deviations.max())
    # cmap = PH.choose_cmap(vmin=deviations.min(), vmax=deviations.max(),all_from_divergent_cmap=True, divergent_cmap=mplcmaps["seismic"])
    cmap=mplcmaps["jet"]
    colors = [cmap(norm(dev)) for dev in deviations]
    
if True: # plot
    for i,frac_err in enumerate(frac_errors):
        if i < min_i: continue

        ax.hist(all_MC_results[i],alpha=0.5,bins=100,range=hist_range,histtype="barstacked",edgecolor="k",label=f"{frac_err:.3f}",\
                color=colors[i])

    if show_boot_dist:
        ax.hist(boot_result_68.bootstrap_distribution,alpha=0.5,bins=100,histtype="barstacked",range=hist_range,edgecolor="k",label="Bootstrap",color="grey")

    ax.axvline(x=true_value,color="red")
    ax.legend()
    ax.set(xlabel=mapf.get_kinematic_titles_dict()[func_name],ylabel=r"$N$")

if True: # filename and save
    filename = func_name
    filename += f"_hists_{MF.return_int_or_dec(frac_errors[min_i],2)}frac{MF.return_int_or_dec(max_frac_err,2)}step{frac_err_step}"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{repeats}repeats"
    filename += "_boot" if show_boot_dist else ""
    
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        plt.savefig(save_path+filename+".png", bbox_inches="tight", dpi=250)
    plt.show()

### Multiple panels

In [None]:
perturbed_vars = ["d"]

In [None]:
save_path = get_save_path_MC(perturbed_vars=perturbed_vars,data_bool=data_bool)

save_path += "multiple_frac_errors/"
MF.create_dir(save_path)

print(save_path)

In [None]:
spatial_cuts = {"l":[-2,2],"b":[3.51,6.6]}
affected_cut = {"R":[0,3.5]}

pop_cuts_list = [
    {"age":[4,7]},
    {"age":[9.5,10]}
]

In [None]:
func_list = ["anisotropy", "correlation", "tilt_abs"]; func_prefix = "anicorr"

In [None]:
df_MC_ages = [
    MF.apply_cuts_to_df(df0, cuts_dict=[spatial_cuts,pop_cut]) for pop_cut in pop_cuts_list
]

In [None]:
min_frac_err = 0.05
max_frac_err = 0.35
frac_err_step = 0.025

frac_errors = np.arange(min_frac_err, max_frac_err+frac_err_step, frac_err_step)

In [None]:
repeats = 500

boot_CI_level = 68

In [None]:
true_values = np.full(shape=(len(pop_cuts_list), len(func_list)), fill_value=None)
boot_CIs = np.full(shape=(len(pop_cuts_list), len(func_list), 2), fill_value=None)
all_MC_results = np.full(shape=(len(pop_cuts_list), len(func_list), len(frac_errors),repeats), fill_value=None)

montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,affected_cuts_dict=affected_cut,repeats=repeats)

for a,df_MC in enumerate(df_MC_ages):
    print("age:",pop_cuts_list[a]["age"])
    
    df = MF.apply_cuts_to_df(df_MC, cuts_dict=affected_cut)
    
    for f,func_name in enumerate(func_list):
        print(func_name)
        
        func = all_funcs_dict[func_name]
    
        true_values[a,f] = func(vx=df["v"+vel_x_var].values,vy=df["v"+vel_y_var].values)
        
        boot_res = bootstrap.scipy_bootstrap(vx=df["v"+vel_x_var],vy=df["v"+vel_y_var],function=func,repeats=repeats,
                                               confidence_level=boot_CI_level/100)
        
        boot_CIs[a,f] = boot_res.confidence_interval
        
        for e,frac_err in enumerate(frac_errors):
            print(f"{frac_err:.3f}",end="; " if frac_err != max(frac_errors) else "\n")

            montecarloconfig.error_frac = frac_err

            MC_result = monte_carlo.get_std_MC(function=func,df=df_MC,config=montecarloconfig,true_value=true_value,vel_x_var=vel_x_var,vel_y_var=vel_y_var)

            all_MC_results[a,f,e] = MC_result.MC_distribution
            
    print("\n")

if True: # check all arrays were filled correctly
    assert None not in true_values, "Not all true values were filled correctly"
    assert None not in boot_CIs, "Not all boot CIs were filled correctly"
    assert None not in all_MC_results, "Not all MC results were filled correctly"

In [None]:
show_boot_CI = True
# show_boot_CI = False

In [None]:
save_bool = True
# save_bool = False

In [None]:
fig,axs=plt.subplots(figsize=(16,15),ncols=len(pop_cuts_list),nrows=len(func_list),sharex=True,gridspec_kw={"hspace":0})

for col in range(len(pop_cuts_list)):
    for row,func_name in enumerate(func_list):
        
        ax = axs[row,col]
        true_value = true_values[col,row]

        if True: # plot
            
            deviations = np.mean(all_MC_results[col,row], axis=1) - true_value
            norm = plt.Normalize(vmin=deviations.min(), vmax=deviations.max())
            cmap = PH.choose_cmap(vmin=deviations.min(), vmax=deviations.max(),all_from_divergent_cmap=True)
            colors = [cmap(norm(dev)) for dev in deviations]

            for f,frac_err in enumerate(frac_errors):
                ax.boxplot(x=all_MC_results[col,row,f],positions=[f],patch_artist=True,widths=0.75,medianprops={"color":"k"},
                           boxprops={"facecolor": colors[f]})

            ax.axhline(y=true_value,color="blueviolet",label="True value", linestyle="--")
            
            if PH.shall_plot_zero_line(minima=all_MC_results[col,row,f].min(), maxima=all_MC_results[col,row,f].max()):
                ax.axhline(y=0,color="grey",linestyle="dotted")
                

            if show_boot_CI:
                
                CI_low,CI_high = boot_CIs[col,row]
                ax.fill_between(x=ax.get_xlim(),y1=CI_low,y2=CI_high,color="grey",alpha=0.15,label=f"{boot_CI_level}% bootstrap CI")

        if True: # axis, title, legend
            xticklabels = len(frac_errors)*[None]
            for f,frac_err in enumerate(frac_errors):
                if f%2!=0:
                    continue
                xticklabels[f] = f"{frac_err:.3f}"

            ax.set_xticks(ticks=range(len(frac_errors)),labels=xticklabels)
            
            if col == 0:
                ax.set_ylabel(mapf.get_kinematic_titles_dict()[func_name])
                
            if row == len(func_list) - 1:
                ax.set_xlabel("Fractional distance error")

            if func_name == "tilt_abs":
                ax.set_ylim(bottom=max([-45,ax.get_ylim()[0]]))
            
            if row == 0:
                ax.set_title(["Young","Old"][col])
            if col+row==0:
                ax.legend()

if True: # filename and save
    filename = func_prefix
    filename += f"_boxplots_{min_frac_err}frac{max_frac_err}step{frac_err_step}"
    filename += "_" + MF.combine_multiple_cut_dicts_into_str(pop_cuts_list)
    filename += "_" + MF.combine_multiple_cut_dicts_into_str([spatial_cuts,affected_cut],order_separator="_")
    filename += f"_{repeats}repeats"
    filename += f"_boot{boot_CI_level}CI"
    
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        
        for fileformat in [".png",".pdf"]:
            plt.savefig(save_path+filename+fileformat, bbox_inches="tight", dpi=300)
            print(fileformat)
    plt.show()

## Single MC error

In [None]:
# frac_error = 0.1
frac_error = 0.2
# frac_error = None

In [None]:
montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,affected_cuts_dict=affected_cuts_dict,error_frac=frac_error,repeats=repeats,symmetric=False)

bootstrapconfig = BootstrapConfig(repeats=repeats,symmetric=True)

In [None]:
df_MC = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=MF.clean_cuts_from_dict(cuts_dict,affected_cuts_dict))
print(len(df),len(df_MC))

In [None]:
MC_result = monte_carlo.get_std_MC(function=func,df=df_MC,config=montecarloconfig,true_value=true_value,\
                                                               vel_x_var=vel_x_var,vel_y_var=vel_y_var)

boot_result = bootstrap.get_std_bootstrap(function=func,vx=df[f"v{vel_x_var}"].values,vy=df[f"v{vel_y_var}"].values,config=bootstrapconfig)

print(f"Mean\t MC: {np.mean(MC_values):.4f}. Boot: {np.mean(boot_result.bootstrap_distribution):.4f}")
print(f"Median\t MC: {np.median(MC_values):.4f}. Boot: {np.median(boot_result.bootstrap_distribution):.4f}")
print(f"CI low\t MC: {MC_result.confidence_interval[0]:.4f}. Boot: {boot_result.confidence_interval[0]:.4f}")
print(f"CI high\t MC: {MC_result.confidence_interval[1]:.4f}. Boot: {boot_result.confidence_interval[1]:.4f}")

In [None]:
# save_bool = True
save_bool = False

In [None]:
boot_hist_bool = True
# boot_hist_bool = False
MC_bar_width = None
# MC_bar_width = 0.001

if True: # boot & MC value hist
    fig,ax=plt.subplots()
    if MC_bar_width is not None:
        ax.hist(MC_result.MC_distribution,bins=50,color="blue",alpha=0.5,label=fr"MC values ($R={repeats}$)",width=MC_bar_width)
    else: # It doesn't like it if I pass width=None and I couldn't find what the default is
        ax.hist(MC_result.MC_distribution,bins=50,color="blue",alpha=0.5,label=fr"MC values ($R={repeats}$)")
    if boot_hist_bool:
        ax.hist(boot_result.bootstrap_distribution,bins=50,color="green",alpha=0.5, label=fr"Bootstrap values ($R={repeats}$)")
    ax.axvline(true_value,color="red",label="True value")
    ax.set_xlabel(symbol_dict[func_name]); ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename and save
    filename = func_name
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += f"_noBoot" if not boot_hist_bool else ""
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{repeats}repeats"
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

## d-specific plots

In [None]:
save_bool = True
# save_bool = False

In [None]:
originally_within_Rmax = df_MC["R"] <= 3.5

In [None]:
if True: # plot scatter inside/outside limit
    
    fig,ax=plt.subplots()
    
    colors = ["grey","red","limegreen"]
    elements_dic = {
        "Stayed inside/outside": [(originally_within_Rmax&within_cut)|(~originally_within_Rmax&~within_cut), 1, colors[0],1,"grey"],
        "Moved outside": [originally_within_Rmax&~within_cut, 10, colors[1],0.5,"k"],
        "Moved inside": [~originally_within_Rmax&within_cut, 10, colors[2],0.5,"k"]
    }

    y_var = "vr"; ylabel = r"$v_r$" + units_dict["mean_vx"]
#     y_var = "d"; ylabel = r"$d~$[kpc]"

    for k in elements_dic:
        condition,size,color,lw,edgecolor = elements_dic[k]

        ax.scatter(x=df_MC[condition]["R"], y=df_MC[condition][y_var], color=color, s=size, label=f"{k} ({sum(condition)})",lw=lw,edgecolor=edgecolor)

    ax.axvline(x=3.5,color="grey",lw=1,linestyle="--")
    ax.text(x=3.5+0.1,y=0.85*ax.get_ylim()[1], s=r"$R=3.5~$kpc",color="grey",size=15)

    ax.set_xlabel(r"$R$ [kpc]"); ax.set_ylabel(ylabel)

    leg = ax.legend()

    for text,c in zip(leg.get_texts(),colors):
        text.set_color(c)
             
if True: # filename and save
    filename = f"{y_var}-R"
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)

    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 50 if data_bool else 100

fig,ax=plt.subplots()

MC_d = np.random.normal(loc=df_MC["d"],scale=df_MC["d_error"] if "d_error" in df_MC else frac_error*df_MC["d"]) # this is an example - in the std calculation this is calculated {repeats} number of times

ax.hist(MC_d,bins=bins,label=r"MC $d$")
ax.hist(df_MC["d"],bins=bins,color="red",alpha=0.4,label=r"Original $d$")
ax.hist(MC_d-df_MC["d"],bins=bins,color="k",histtype="step",label="Difference")
ax.set_xlabel(r"$d$ [kpc]"); ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
ax.legend()

if True:
    filename = "dHists"
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{bins}bins"
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

# Bootstrap

In [None]:
def bootstrap_multiple_sampling_sizes(df, function, sampling_sizes, config, vectorised=False, batch_size=None, tilt=False, absolute=True, verbose=True):
    
    confidence_intervals = np.full(shape=(len(sampling_sizes),2), fill_value=np.nan)
    bootstrap_distributions = np.full(shape=(len(sampling_sizes),config.repeats), fill_value=np.nan)
    standard_errors = np.full(shape=len(sampling_sizes), fill_value=np.nan)
    biases = np.full(shape=len(sampling_sizes), fill_value=np.nan)
    
    for s,size in enumerate(sampling_sizes):
        config.sample_size = size
        res = bootstrap.get_std_bootstrap(function=function,vx=df.vr.values,vy=df.vl.values,tilt=tilt,absolute=absolute,\
                                          config=config,vectorised=vectorised,batch_size=batch_size)
        
        confidence_intervals[s] = res.confidence_interval
        bootstrap_distributions[s] = res.bootstrap_distribution
        standard_errors[s] = res.standard_error
        biases[s] = res.bias
        
        if verbose:
            print(size,end="; ")
    if verbose:
        print("\n")
        
    for v, var in enumerate([confidence_intervals, bootstrap_distributions, standard_errors, biases]):
        assert not np.any(np.isnan(var)), f"Array {v} was not filled correctly"
        
    Result = namedtuple("Result", ["confidence_intervals", "bootstrap_distributions", "standard_errors", "biases"])
    
    return Result(confidence_intervals=confidence_intervals, bootstrap_distributions=bootstrap_distributions,\
                  standard_errors=standard_errors, biases=biases)

In [None]:
def bootstrap_multiple_sampling_sizes_recursive(df, function, sampling_sizes, config, nested_config=None, vectorised=False, batch_size=None, tilt=False, absolute=True, verbose=True):
    
    confidence_intervals = np.full(shape=(len(sampling_sizes),2), fill_value=np.nan)
    standard_errors = np.full(shape=len(sampling_sizes), fill_value=np.nan)
    biases = np.full(shape=len(sampling_sizes), fill_value=np.nan)
    
    nested_confidence_intervals = np.full(shape=(len(sampling_sizes), config.repeats, 2), fill_value=np.nan)
    nested_standard_errors = np.full(shape=(len(sampling_sizes),config.repeats), fill_value=np.nan)
    nested_biases = np.full(shape=(len(sampling_sizes), config.repeats), fill_value=np.nan)
    
    for s,size in enumerate(sampling_sizes):
        config.sample_size = size
        
        res = bootstrap.get_std_bootstrap_recursive(function=function,vx=df.vr.values,vy=df.vl.values,tilt=tilt,absolute=absolute,\
                                          config=config, nested_config=nested_config, boot_vectorised=vectorised, boot_batch_size=batch_size)
        
        confidence_intervals[s] = res.confidence_interval
        standard_errors[s] = res.standard_error
        biases[s] = res.bias
        
        nested_confidence_intervals[s] = res.bootstrap_confidence_intervals
        nested_standard_errors[s] = res.bootstrap_standard_errors
        nested_biases[s] = res.bootstrap_biases
        
        if verbose:
            print(size,end="; ")
    if verbose:
        print("\n")
        
    for v, var in enumerate([confidence_intervals, standard_errors, biases, nested_confidence_intervals, nested_standard_errors, nested_biases]):
        assert not np.any(np.isnan(var)), f"Array {v} was not filled correctly"
        
    Result = namedtuple("Result", ["confidence_intervals", "standard_errors", "biases",\
                                   "nested_confidence_intervals","nested_standard_errors","nested_biases"])
    
    return Result(confidence_intervals=confidence_intervals, standard_errors=standard_errors,biases=biases,\
                  nested_confidence_intervals=nested_confidence_intervals,nested_standard_errors=nested_standard_errors,nested_biases=nested_biases)

## Method comparison

In [None]:
save_path = general_path + "/graphs/other_plots/Bootstrapping/comparison_scipy/"

In [None]:
# cuts_dict = {"R":[0,2],"b":[3,3.01],"l":[-2,2],"age":[4,7]}
cuts_dict = {"R":[0,2],"b":[3,3.01],"l":[-2,2],"age":[9.5,10]}

df = MF.apply_cuts_to_df(df0, cuts_dict=cuts_dict)

vx = df.vr.values
vy = df.vl.values

print(len(df))

function_dict = all_funcs_dict
repeats = 10000

In [None]:
def compare_error_methods(df, func_name, function, repeats, ax, hist_range=None, common_labels_on=False, tilt=False,absolute=True,R_hat=None):
    n = len(df)
    vx = df.vr.values
    vy = df.vl.values
    
    true_value = error_helpers.apply_function(vx=vx,vy=vy,function=function,tilt=tilt,absolute=absolute,R_hat=R_hat)
    
    our_res = bootstrap.get_std_bootstrap(vx=vx,vy=vy,function=function,config=BootstrapConfig(repeats=repeats,symmetric=False),tilt=tilt,absolute=absolute,R_hat=R_hat)
    scipy_res = bootstrap.scipy_bootstrap(vx=vx,vy=vy,function=function,config=BootstrapConfig(repeats=repeats),tilt=tilt,absolute=absolute,R_hat=R_hat)
    
    # plot
    our_distribution, our_CI, our_se = our_res.bootstrap_distribution, our_res.confidence_interval, our_res.standard_error
    scipy_distribution, scipy_CI, scipy_se = scipy_res.bootstrap_distribution, scipy_res.confidence_interval, scipy_res.standard_error
    
    bootstrap_mean = np.mean(our_distribution)
        
    ax.hist(our_distribution,bins=75,label="Ours" if common_labels_on else None, alpha=0.5,color="blue",range=hist_range)
    ax.hist(scipy_distribution,bins=75,color="orange",label="SciPy" if common_labels_on else None, alpha=0.5,range=hist_range)

    ax.axvline(x=true_value,alpha=0.5,color="green",label="True value" if common_labels_on else None)
    ax.axvline(x=np.mean(our_distribution),alpha=0.5,color="cyan",label="Bootstrap mean" if common_labels_on else None)
    ax.axvline(x=np.median(our_distribution),alpha=0.5,color="red",label="Bootstrap median" if common_labels_on else None)
    ax.axvline(x=np.mean(scipy_CI),alpha=0.5,color="orange",linestyle="--",label="Scipy CI centre" if common_labels_on else None)
    
    ax.axvline(x=scipy_CI[0],color="orange",label="SciPy CI" if common_labels_on else None)
    ax.axvline(x=scipy_CI[1],color="orange")

    ax.axvline(x=true_value-our_CI[0],color="blue",linestyle="-", label="Our CI" if common_labels_on else None)
    ax.axvline(x=true_value+our_CI[1],color="blue",linestyle="-")
    
    ax.axvline(x=np.percentile(our_distribution,16),color="k",linestyle="dotted", label="Percentile CI" if common_labels_on else None)
    ax.axvline(x=np.percentile(our_distribution,84),color="k",linestyle="dotted")
    
    ax.axvline(x=bootstrap_mean-our_se,color="magenta",linestyle="--", label="Standard error" if common_labels_on else None)
    ax.axvline(x=bootstrap_mean+our_se,color="magenta",linestyle="--")
    
    our_symmetric_CI = error_helpers.build_confidence_interval(values=our_distribution,central_value=true_value,symmetric=True)
    
    ax.axvline(x=true_value-our_symmetric_CI[0],color="pink",linestyle="--", label="Our symmetric CI" if common_labels_on else None)
    ax.axvline(x=true_value+our_symmetric_CI[1],color="pink",linestyle="--")
    
    if func_name == "tilt_abs": # Roca-Fabrega
        
        roca_fabrega_error = MF.get_error_vertex_deviation_roca_fabrega(n=len(df),vx=vx,vy=vy)
        
        ax.axvline(x=true_value-roca_fabrega_error,color="grey",linestyle="dotted", label="Roca-Fabrega 2014")
        ax.axvline(x=true_value+roca_fabrega_error,color="grey",linestyle="dotted")

In [None]:
save_bool = True
# save_bool = False

In [None]:
fig,axs = plt.subplots(figsize=(10,20),nrows=3)

legend_rows = [0,1,2]

hist_ranges = [None, None, [-45,10]]
for i,(func_name,func) in enumerate(function_dict.items()):
    compare_error_methods(df=df, func_name=func_name, function=func, repeats=repeats, ax=axs[i], hist_range=hist_ranges[i], common_labels_on=i in legend_rows)
    axs[i].set_xlabel(titles_dict[func_name])
    axs[i].set_ylabel(r"N")
    
    if i in legend_rows or i == len(function_dict):
        axs[i].legend()
    
if True: # filename, save, show
    filename = "hist"
    
    if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
        filename += "_anicorr"
    else:
        raise ValueError("Please specify map list name")
    
    filename += "_"+MF.combine_multiple_cut_dicts_into_str(all_cuts=cuts_dict,cut_separator="_",order_separator="_")
    filename += "_%srepeats"%str(repeats)
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path + filename + ".png", dpi=250, bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

### Normal vs vectorised (ours)

In [None]:
import time
import timeit

Single sample size

In [None]:
vx = np.random.normal(size=(50000))
# vy = np.random.normal(size=(50000),scale=100) + np.random.normal(size=(50000),scale=1)
vy = np.random.normal(size=(50000),scale=3)
# sample_size = len(vx)
sample_size = 5000
repeats = 5000

batch_size = None

func = CV.calculate_tilt
tilt = True; absolute = True

# func = CV.calculate_correlation
# tilt = False

#####

begin_time = time.time()

res_vectorised = bootstrap.get_std_bootstrap(vx=vx,vy=vy,function=func,tilt=tilt,absolute=absolute,vectorised=True,batch_size=batch_size,\
                               config=BootstrapConfig(sample_size=sample_size,repeats=repeats))

vector_time = time.time()

res = bootstrap.get_std_bootstrap(vx=vx,vy=vy,function=func,tilt=tilt,absolute=absolute,\
                                       config=BootstrapConfig(sample_size=sample_size,repeats=repeats))

normal_time = time.time()

print("Vectorised time:",vector_time - begin_time)
print("Normal time:",normal_time - vector_time)

Multiple sample sizes

In [None]:
max_sample_size = 5000
repeats = 5000
batch_size = None
timeit_repeats = 15

vectorised_times = []
normal_times = []

all_sample_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(max_sample_size),10))); logSampling=True

for sample_size in all_sample_sizes:
    print(sample_size, end="; ")

    def bootstrap_vectorized():
        return bootstrap.get_std_bootstrap(vx=vx,vy=vy,function=func,tilt=tilt,absolute=absolute,vectorised=True,batch_size=batch_size,\
                                       config=BootstrapConfig(sample_size=sample_size,repeats=repeats))

    def bootstrap_normal():
        return bootstrap.get_std_bootstrap(vx=vx,vy=vy,function=func,tilt=tilt,absolute=absolute,\
                                       config=BootstrapConfig(sample_size=sample_size,repeats=repeats))

    vectorised_times.append(timeit.timeit(bootstrap_vectorized, number=timeit_repeats))
    normal_times.append(timeit.timeit(bootstrap_normal, number=timeit_repeats))

In [None]:
save_path = "/Users/luismi/Desktop/MRes_UCLan/graphs/other_plots/bootstrapping/vectorisation_timeit/"

In [None]:
save_bool = True
# save_bool = False

In [None]:
fig,ax=plt.subplots()

ax.plot(all_sample_sizes, np.array(vectorised_times)/timeit_repeats, color="green",label="Vectorised")
ax.plot(all_sample_sizes, np.array(normal_times)/timeit_repeats, color="blue",label="Normal")

ax.set_ylabel("Time per run [s]")
ax.set_xlabel("Sample size")
ax.legend()

if True: # filename and save
    filename = "timevsN"
    
    filename += f"_{min(all_sample_sizes)}size{max(all_sample_sizes)}"
    filename += f"_{len(all_sample_sizes)}steps" + ("Log" if logSampling else "")
    filename += f"_{repeats}repeats"
    filename += f"_{timeit_repeats}times"
    filename += f"_batchsize{batch_size}"
    
    print(filename)

    if save_bool:
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        
        print("Saved in save_path")
plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
# xlog = True
xlog = False

ylog = True
# ylog = False

fig,ax=plt.subplots()

ax.plot(all_sample_sizes, np.array(normal_times)/np.array(vectorised_times), color="green",label="Vectorised")

ax.set_ylabel(r"$t_\mathrm{normal} / t_\mathrm{vectorised}$")
ax.set_xlabel("Sample size")

if xlog:
    ax.set_xscale("log")
if ylog:
    ax.set_yscale("log")


ax.set_ylim(bottom=1)
ax.grid(which="both")

if xlog:
    ax.set_xlim(45,5600)

if True: # filename and save
    filename = "ratio"
    
    filename += f"_{min(all_sample_sizes)}size{max(all_sample_sizes)}"
    filename += f"_{len(all_sample_sizes)}steps" + ("Log" if logSampling else "")
    filename += f"_{repeats}repeats"
    filename += f"_{timeit_repeats}times"
    filename += f"_batchsize{batch_size}"
    
    if xlog and ylog:
        filename += f"_logscale"
    elif xlog:
        filename += f"_xlogscale"
    elif ylog:
        filename += f"_ylogscale"
    
    print(filename)

    if save_bool:
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        
        print("Saved in save_path")
plt.show()

## Sampling distributions

In [None]:
save_path = general_path + "/graphs/other_plots/Bootstrapping/sampling_distributions/"

In [None]:
repeats = 5000

all_cuts_list = [
    {"R":[0,2],"b":[3,6],"l":[-2,2],"age":[4,7]},
    {"R":[0,2],"b":[3,6],"l":[-2,2],"age":[9.5,10]}
]

all_sizes = [
    [50,65,80,100,250,500,1000,5000,len(MF.apply_cuts_to_df(df0, cuts_dict=all_cuts_list[0]))],
    [50,65,80,100,250,500,1000,5000,len(MF.apply_cuts_to_df(df0, cuts_dict=all_cuts_list[1]))]
]

n_sizes = len(all_sizes[0])

function_dict = all_funcs_dict

In [None]:
bootstrap_estimate_symbols_dict = {
    "tilt_abs": r"$\overline{l^*_\mathrm{v}}$",
    "anisotropy": r"$\overline{\beta_{rl}^*}$",
    "correlation": r"$\overline{\rho_{rl}^*}$",
}

In [None]:
total_N = np.zeros(shape=len(all_cuts_list))
true_values = np.zeros(shape=(len(all_cuts_list), len(function_dict)))
all_values = np.zeros(shape=(len(all_cuts_list), len(function_dict), n_sizes, repeats))

for c,(cuts,sizes) in enumerate(zip(all_cuts_list,all_sizes)):
    print(cuts,end="\n")

    df = MF.apply_cuts_to_df(df0, cuts_dict=cuts)
    
    total_N[c] = len(df)
    
    for f,func_name in enumerate(function_dict):
        print(func_name,end="\n")
        
        true_values[c,f] = all_funcs_dict[func_name](df.vr.values,df.vl.values)
        
        _,all_values[c,f],_,_ = bootstrap_multiple_sampling_sizes(df=df,function=all_funcs_dict[func_name],repeats=repeats,sampling_sizes=sizes)

In [None]:
save_bool = True
# save_bool = False

In [None]:
alpha = 0.5
bins = 50

cmap = mplcmaps['jet']

fig,axs=plt.subplots(figsize=(15,15),nrows=3,ncols=2,gridspec_kw={"hspace":0.25})

for row,func_name in enumerate(function_dict):
    
    for col,sizes in enumerate(all_sizes):
        
        max_size = total_N[col]
        true_val = true_values[col, row]
        ax = axs[row,col]
        
        if row == 0:
            ax.set_title(["Young","Old"][col],fontsize="large")
        
        for s,size in enumerate(sizes):
        
            color = cmap(int(256*s/(len(sizes)-1)))

            vals = all_values[col,row,s]

            if size != max_size:
                val_label = r"$N=%i,$%s$=%.3f$%s$)$"%(size, bootstrap_estimate_symbols_dict[func_name],np.mean(vals), mapf.get_units(func_name)) if true_val%10 == 0 else\
                            r"$N=%i,$%s$=%.2f$%s$)$"%(size, bootstrap_estimate_symbols_dict[func_name],np.mean(vals), mapf.get_units(func_name))
            else:
                val_label = r"Total $(N=%i)$"%max_size

            ax.hist(vals,bins=bins,alpha=alpha,label=val_label,color=color)

        if True: # axvline, labels, legend
            true_val_label = r"True value $($%s$=%.3f$%s$)$"%(symbol_dict[func_name],true_val, mapf.get_units(func_name)) if true_val%10 == 0 else\
                             r"True value $($%s$=%.2f$%s$)$"%(symbol_dict[func_name],true_val, mapf.get_units(func_name))

            ax.axvline(true_val,color="grey",label=true_val_label,linestyle="--")
            ax.legend(fontsize=10.5)
            ax.set_xlabel(mapf.get_kinematic_titles_dict("r","l")[func_name.removesuffix("_abs")] + units_dict[func_name])
            ax.set_ylabel(r"$N$")

if True: # filename, save, show
    filename = "hists"
    
    if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
        filename += "_anicorr"
    else:
        raise ValueError("Please specify map list name")
    
    filename += "_"+MF.combine_multiple_cut_dicts_into_str(all_cuts=all_cuts_list,cut_separator="_",order_separator="_")
    filename += "_%ssizes"%len(sizes)
    filename += "_%srepeats"%str(repeats)
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

## Bias vs N

In [None]:
save_path = general_path + "/graphs/other_plots/Bootstrapping/bias/"

In [None]:
plt.rcParams["xtick.major.size"] = 7
plt.rcParams["xtick.minor.size"] = 3

In [None]:
sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100))); logSampling = True

all_cuts_list = [
    {"R":[0,2],"b":[3,6],"l":[-2,2],"age":[4,7]},
    {"R":[0,2],"b":[3,6],"l":[-2,2],"age":[9.5,10]}
]

colors = ["blue","red"]
zorders = [0,1]

function_dict = all_funcs_dict

In [None]:
bootstrapconfig = BootstrapConfig(repeats=5000, symmetric=False, from_mean=True)

In [None]:
true_values = np.zeros(shape=(len(all_cuts_list), len(function_dict)))
bias_values = np.zeros(shape=(len(all_cuts_list), len(function_dict), len(sampling_sizes)))
std_lowhigh_values = np.zeros(shape=(len(all_cuts_list), len(function_dict), len(sampling_sizes), 2))
bootstrap_distributions = np.zeros(shape=(len(all_cuts_list), len(function_dict), len(sampling_sizes), bootstrapconfig.repeats))

for c,cuts in enumerate(all_cuts_list):
    print(cuts,end="\n")

    df = MF.apply_cuts_to_df(df0, cuts_dict=cuts)
    
    for f,func_name in enumerate(function_dict):
        print(func_name,end="\n")
        
        true_values[c,f] = all_funcs_dict[func_name](df.vr.values,df.vl.values)
        
        res = bootstrap_multiple_sampling_sizes(df=df,function=all_funcs_dict[func_name],sampling_sizes=sampling_sizes,config=bootstrapconfig)
        
        std_lowhigh_values[c,f] = res.confidence_intervals
        bootstrap_distributions[c,f] = res.bootstrap_distributions
        bias_values[c,f] = res.biases

In [None]:
# save_bool = True
save_bool = False

In [None]:
# bias vs N

fig,axs = plt.subplots(figsize=(8,10),nrows=len(function_dict),gridspec_kw={"hspace":0})

for f,(ax,func) in enumerate(zip(axs,function_dict)):
    
    for c in range(len(all_cuts_list)):
        
        ax.plot(sampling_sizes, np.mean(bootstrap_distributions[c,f],axis=-1) - true_values[c,f],\
                color=colors[c], zorder=zorders[c], label=["Young","Old"][c])
    
    ax.set_xlabel("Sample size")
    ax.set_ylabel("Bias %s"%((r"$[$" + mapf.get_units(func) + r"$]$") if mapf.get_units(func) != "" else ""))
    
    ax.set_xscale("log")
    ax.set_xlim(42,6000)
    
    if f==1:
        ax.set_ylim(-0.01,0.01)

    func_str = mapf.get_kinematic_titles_dict("r","l")[func.removesuffix("_abs")]
    ax.text(s=func_str,x=0.75,y=0.75,transform=ax.transAxes,size=15)
    
    ax.axhline(y=0,color="grey",linestyle="--")
    
    if f == 0:
        ax.legend(loc="lower right")
    
if True: # filename, save, show
    
    filename = "biasVsN"
    filename += "_" + MF.combine_multiple_cut_dicts_into_str(all_cuts_list,cut_separator="_",order_separator="_")
    
    filename += f"_{bootstrapconfig.repeats}repeats"
    filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
    filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
# true value and bootstrap mean as function of N

# boot_std = True; surface_label = "Bootstrap standard deviation"
boot_std = False

boot_ci = True; ci_percentile = 68; surface_label = f"Bootstrap {ci_percentile}% interval"
# boot_ci = False

surface_alpha = 0.3

fig,axs = plt.subplots(figsize=(8,10),nrows=len(function_dict),gridspec_kw={"hspace":0})

for f,(ax,func) in enumerate(zip(axs,function_dict)):

    for c in range(len(all_cuts_list)):
        
        ax.plot(sampling_sizes, len(sampling_sizes)*[true_values[c,f]], color=colors[c], zorder=zorders[c],linestyle="--")
        ax.plot(sampling_sizes, np.mean(bootstrap_distributions[c,f], axis=-1), color=colors[c], zorder=zorders[c])
        
        if boot_std:
            ax.fill_between(x=sampling_sizes,\
                            y1=mean_values[c,f]-std_lowhigh_values[c,f,:,0],\
                            y2=mean_values[c,f]+std_lowhigh_values[c,f,:,1],color=colors[c],alpha=surface_alpha)
        elif boot_ci:
            q_low = (100-ci_percentile)/2
            q_high = (100+ci_percentile)/2 # this is the same as 100-q_low

            ax.fill_between(x=sampling_sizes, \
                             y1=np.percentile(bootstrap_distributions[c,f],axis=-1, q=q_low),\
                             y2=np.percentile(bootstrap_distributions[c,f],axis=-1, q=q_high),\
                             color=colors[c], alpha=surface_alpha)

if True: # labels, lims, legend
    
    for ax,func in zip(axs,function_dict):
        ax.set_ylabel(symbol_dict[func]+units_dict[func])
        ax.set_xlim(42,6000)
        ax.set_xscale("log")
    
    axs[-1].set_xlabel("Sample size")
    
    axs[-1].plot([-1,-1],[0,0],color="k",linestyle="--",label="True value")
    axs[-1].plot([-1,-1],[0,0],color="k",linestyle="-",label="Bootstrap mean")
    
    if boot_std or boot_ci:
        axs[-1].fill_between([-1,-1],[0,0],[0,0],color="k",alpha=surface_alpha,label=surface_label)
    
    axs[-1].scatter([-1],[0],color="blue",label="Young")
    axs[-1].scatter([-1],[0],color="red",label="Old")

#     axs[0].legend()
    axs[-1].legend(loc="upper right",fontsize=14)
    
    fig.align_labels()

if True: # filename, save, show
    
    filename = "bias"
    filename += "_" + MF.combine_multiple_cut_dicts_into_str(all_cuts_list,cut_separator="_",order_separator="_")
    
    filename += "_symmetricStd" if bootstrapconfig.symmetric else ""
    
    filename += f"_{bootstrapconfig.repeats}repeats"
    filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
    filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")
    
    filename += "_boot"
        
    nested_bootstrap_strings = [
        "Std" if boot_std else "",
        f"{ci_percentile}q" if boot_ci else ""
    ]

    joint_boot_strings = str.join(",", nested_bootstrap_strings)
    while ",," in joint_boot_strings:
        joint_boot_strings = joint_boot_strings.replace(",,",",")

    filename += joint_boot_strings.strip(",")
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

## SE vs N

In [None]:
base_path = general_path+"graphs/other_plots/bootstrapping/assumption_test/"

In [None]:
function_dict = {
    "anisotropy": CV.calculate_anisotropy,
    "correlation": CV.calculate_correlation,
    "tilt_abs": CV.calculate_tilt
}

### Save

In [None]:
def save_sampling_arrays(save_path, base_filename, func_name, res, recursive=False, verbose=True):
    np.save(save_path+base_filename+f"_{func_name}_CIs", res.confidence_intervals)        
    np.save(save_path+base_filename+f"_{func_name}_SEs", res.standard_errors)        
    np.save(save_path+base_filename+f"_{func_name}_biases", res.biases)
    
    if recursive:
        np.save(save_path+base_filename+f"_{func_name}_nestedCIs", res.nested_confidence_intervals)            
        np.save(save_path+base_filename+f"_{func_name}_nestedSEs", res.nested_standard_errors)            
        np.save(save_path+base_filename+f"_{func_name}_nestedbiases", res.nested_biases)    
    else:
        np.save(save_path+base_filename+f"_{func_name}_distributions", res.bootstrap_distributions)

In [None]:
sampling_repeats = 10000
bootstrap_repeats = 500

In [None]:
save_bootstrap_errors = True
# save_bootstrap_errors = False # only sampling errors will be saved, not nested bootstrap errors

In [None]:
sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100))); logSampling = True

# sampling_sizes = np.arange(50,5000+50,50); logSampling = False

In [None]:
sampling_config = BootstrapConfig(repeats=sampling_repeats,symmetric=False,replace=True)
bootstrap_config = BootstrapConfig(repeats=bootstrap_repeats,symmetric=False,replace=True)

bootstrap_vectorised = True
# bootstrap_vectorised = False

batch_size = None

In [None]:
all_dicts = [ # select for plotting
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},
#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[6,9],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[6,9],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[6,13],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[6,13],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[1.5,2],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[1.5,2],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[3,6]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[3,6]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age": [4,7]}
    },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},

    },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},

#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[3,6]},
#         "pop_cuts": {"age": [4,7]}
#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[3,6]},
#         "pop_cuts": {"age":[9.5,10]},

#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[1.5,2],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]}
#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[1.5,2],"l":[-2,2]},
#         "pop_cuts": {"age" :[9.5,10]},

#     }
];


for dic in all_dicts: # check there are enough stars if sampling without replacement 
    star_number = len(MF.apply_cuts_to_df(df0, cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]]))
    if not (sampling_config.replace or max(sampling_sizes) > 0.05*star_number):
        warnings.warn(f"Sampling without replacement but the max sampling size is larger than 5%% of the\
        population for the cuts %s and %s"%(dic["spatial_cuts"],dic["pop_cuts"]))

for dic in all_dicts:
    df = MF.apply_cuts_to_df(df0, cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]])
    
    if True: # save_path and general filename
        save_path = get_save_path_spatial_cuts(base_path=base_path,spatial_cuts=dic["spatial_cuts"])
        
        save_path += "arrays/"
        MF.create_dir(save_path)
        
        save_path += f"sampling{sampling_repeats}repeats/"
        MF.create_dir(save_path)
        
        save_path += f"bootstrap{bootstrap_repeats}repeats/"
        MF.create_dir(save_path)
        
        save_path += MF.extract_str_from_cuts_dict(dic["pop_cuts"]) + "/"
        MF.create_dir(save_path)
        
        base_filename = f"{min(sampling_sizes)}size{max(sampling_sizes)}"
        base_filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")

        print("General filename:",base_filename)
        print("Saving in",save_path,"\n")
    
    for func in function_dict:
        print(func)
        
        if save_bootstrap_errors:
            res = bootstrap_multiple_sampling_sizes_recursive(df=df,function=function_dict[func],sampling_sizes=sampling_sizes,\
                                                              vectorised=bootstrap_vectorised, batch_size=batch_size,\
                                                              config=sampling_config, nested_config=bootstrap_config)
        else:
            res = bootstrap_multiple_sampling_sizes(df=df,function=function_dict[func],sampling_sizes=sampling_sizes,\
                                                    config=sampling_config, vectorised=bootstrap_vectorised, batch_size=batch_size)
        
        save_sampling_arrays(save_path=save_path, func_name=func,base_filename=base_filename, res=res, recursive=save_bootstrap_errors)
        
        print("Arrays saved successfully.\n")
    
    print("\n") 

### Plot

In [None]:
plt.rcParams["font.size"] = 19
plt.rcParams["legend.fontsize"] = 15

plt.rcParams["xtick.major.size"] = 7
plt.rcParams["xtick.minor.size"] = 3
plt.rcParams["ytick.major.size"] = 7
plt.rcParams["ytick.minor.size"] = 3

In [None]:
from scipy.optimize import curve_fit

class Func():
    def fit(self,x,y):
        self.fit_params,_ = curve_fit(f=self.func, p0=self.p0, xdata=x,ydata=y)
        
class ReciprocalFunc(Func):
    def __init__(self, p0=[39.65,100,0.02]):
        self.func = lambda x,a,b,c: a/(x+b)+c
        self.name = "reciprocal"
        self.label = r"$f(n)=\frac{a}{n+b}+c$"
        self.p0 = p0

class InverseSquareFunc(Func):
    def __init__(self, p0=[10]):
        self.func = lambda x,a: a/(np.sqrt(x))
        self.name = "inverse_square"
#         self.label = r"$f(n)=\frac{a}{\sqrt{n}}$"
        self.label = r"$f(n)\propto\frac{1}{\sqrt{n}}$"
        self.p0 = p0
        
def show_fit_params_text(ax, Func, x_eq=0.6,y_eq=0.42, color="grey", ndec_text=2, alpha=1, size=13, units=""):
    
    if not hasattr(Func,"fit_params"):
        raise AttributeError("Fit the function first!")
    
    abc_str = str.join(",",["abcdefg"[i] for i in range(len(Func.p0))])
    values_str = (len(Func.p0)*"%s,").removesuffix(",")
    final_str = (r"$(" if len(Func.p0)>1 else r"$") + abc_str + (")=(" if len(Func.p0)>1 else "=") + values_str + (")$" if len(Func.p0)>1 else "$")
    final_str += units

    ax.text(x=x_eq,y=y_eq,transform=ax.transAxes,color=color,alpha=alpha,size=size,
             s=final_str%tuple([MF.return_int_or_dec(param,ndec_text) for param in Func.fit_params]))

#### In bulk

In [None]:
def plot_error_vs_N(all_dicts, xlog_bool, ylog_bool, save_bool=True, show_bool=False, fit_bool=True, same_youngold_fits=True, logSampling=True):
    max_sampling_size = max(sampling_sizes)
    xtick_step = 500
    major_locator = 1000
    minor_locator = 250
    
    # fit_func = ReciprocalFunc(p0=[1,10,10])
    fit_func = InverseSquareFunc(p0=[1])

    if xlog_bool and ylog_bool:
        x_eq,y_eq = 0.75, 0.15
    elif xlog_bool:
        x_eq,y_eq = 0.65,0.4
    else:
        x_eq,y_eq = 0.4, 0.4
        
    hard_coded_ylims_dict = {
        "anisotropy": [0 if not ylog_bool else 0.001,0.3],
        "correlation": [0 if not ylog_bool else 0.001,0.149],
        "tilt_abs": [0 if not ylog_bool else 0.1,33]
    }

    if xlog_bool and ylog_bool:
        for k in hard_coded_ylims_dict:
            hard_coded_ylims_dict[k][1] *= 1.5
            
    MC_error_bool = xlog_bool and ylog_bool
    
    width_broken_axes = 0.13
    nrows = len(function_dict)

    fig,axs = plt.subplots(figsize=(8,10),ncols=3,nrows=nrows,gridspec_kw={"width_ratios":[1]+2*[width_broken_axes],"wspace":0.1,"hspace":0})

    for row,func in enumerate(function_dict):
        lax,cax,rax = axs[row]

        if True: # broken axes

            d = 0.02
            d_factor = 1/width_broken_axes

            for ax in [lax,cax]:
                ax.spines['right'].set_visible(False)
                ax.tick_params(which='both',right=False)

            lax.plot((1-d,1+d), (-d,d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
            lax.plot((1-d,1+d),(1-d,1+d), transform=lax.transAxes, color='k', clip_on=False,lw="1")

            cax.plot((1-d_factor*d,1+d_factor*d), (-d,d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
            cax.plot((1-d_factor*d,1+d_factor*d), (1-d,1+d), transform=cax.transAxes, color='k', clip_on=False,lw="1")

            for ax in [cax,rax]:
                ax.spines['left'].set_visible(False)
                ax.tick_params(which='both',left=False)

                ax.plot((-d_factor*d,+d_factor*d), (1-d,1+d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
                ax.plot((-d_factor*d,+d_factor*d), (-d,d), transform=ax.transAxes, color='k', clip_on=False,lw="1")

        if True: # plot
            for dic in all_dicts:
                lax.plot(sampling_sizes[sampling_sizes <= max_sampling_size],np.array(dic[func+"_errors"])[sampling_sizes <= max_sampling_size],\
                         label=dic["label"],color=dic["color"],alpha=0.7)

            for dic,ax in zip(all_dicts,[cax,rax]):
                ax.scatter(dic["total_N"],dic[func+"_bootstrap_error"],marker="*",color=dic["color"])

                if MC_error_bool:
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_low"],marker="v",color=dic["color"],s=23)
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_high"],marker="^",color=dic["color"],s=23)

                ax.set_xticks([dic["total_N"]])

            lax.scatter(x=-100,y=0,marker="*",color="k",label="Bootstrap error") # just for the legend label

            if MC_error_bool:
                lax.scatter(x=-100,y=0,marker="v",color="k",label="MC 20% distance error",s=23) # just for the legend label

        if True: # ticks, lims, logscale

            if xlog_bool:
                lax.set_xscale("log")

            lax_leftlim = 40 if xlog_bool else 0
            lax.set_xlim(lax_leftlim,max_sampling_size+minor_locator*0.9)

            if not xlog_bool:
                lax.xaxis.set_major_locator(ticker.MultipleLocator(major_locator))
                lax.xaxis.set_minor_locator(ticker.MultipleLocator(minor_locator))
    #             lax.set_xticks([50]+list(np.arange(xtick_step,max_sampling_size+xtick_step,xtick_step)))

            for ax in [lax,cax,rax]:
                ax.tick_params(axis='x', which='major', pad=10)

                if ylog_bool:
                    ax.set_yscale("log")

                if ax in [cax,rax]:
                    ax.yaxis.set_ticklabels([])

                if func in hard_coded_ylims_dict:
                    ax.set_ylim(hard_coded_ylims_dict[func])
                else:
                    ax.set_ylim(bottom=0.005 if ylog_bool else 0)

                if row != nrows - 1:
                    ax.xaxis.set_ticklabels([])
                    ax.xaxis.set_ticklabels([])

        if fit_bool:

            x_plot = np.linspace(min(sampling_sizes),max_sampling_size,500)

            if func != "tilt_abs" and same_youngold_fits:
                fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[0][func+"_errors"])[sampling_sizes<=max_sampling_size])
                lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color="grey", linestyle="--")
                show_fit_params_text(ax=lax,Func=fit_func,color="grey",x_eq=x_eq,y_eq=y_eq)
            else:
                for (pop_idx,color,y) in zip([0,1],["blue","red"],[y_eq-0.04,y_eq+0.04]):
                    fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[pop_idx][func+"_errors"])[sampling_sizes<=max_sampling_size])

                    lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color=color, linestyle="--",alpha=0.5)
                    show_fit_params_text(ax=lax,Func=fit_func,alpha=0.7,color=color,x_eq=x_eq,y_eq=y,units=mapf.get_units(func))

            lax.plot([-100,-100],[0,1],label=fit_func.label,linestyle="--",color="grey") # just for the legend label

        if True: # labels, legend, text
            if row == nrows-1:
    #             lax.set_xlabel("Sample size"); lax.text(s=r"Total $N$",x=1.1,y=-0.25,transform=lax.transAxes,size="medium")
                lax.text(s="Sample size",x=0.5,y=-0.25,transform=lax.transAxes,size="medium")

            if row == 0:
                fig.legend(loc=(0.65,0.77) if not (xlog_bool and ylog_bool) else (0.16,0.7),framealpha=0 if xlog_bool and ylog_bool else 1)
            elif row == 1:
                lax.set_ylabel(r"Standard error")

            if xlog_bool and ylog_bool:
                x_func_text = 0.98
            elif xlog_bool:
                x_func_text = 0.3
            else:
                x_func_text = 0.2
            fig.text(s=mapf.get_kinematic_titles_dict("r","l")[func.removesuffix("_abs")],x=x_func_text,y=0.83,transform=lax.transAxes,size=15)

            fig.align_labels()

    if True: # filename, save and show
        save_path = get_save_path_spatial_cuts(save_path=base_path, spatial_cuts_dict=all_dicts[0]["spatial_cuts"])

        if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
            filename = "anicorr"
        else:
            raise ValueError("Please specify map list name")
            
        filename += "_"+MF.extract_str_from_cuts_dict(dic["spatial_cuts"])

        for dic in all_dicts:
            filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])

        if xlog_bool and ylog_bool:
            filename += "_xylog"
        elif xlog_bool:
            filename += "_xlog"
        elif ylog_bool:
            filename += "_ylog"

        if MC_error_bool:
            filename += "_MC"

        filename += f"_{repeats}repeats"
        filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
        filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")

        if not fit_bool:
            filename += "_noFit"
        elif not same_youngold_fits:
            filename += "_diffFits"

        print(filename)

        if save_bool:
            print("Saving in",save_path)
            plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")

        if show_bool:
            plt.show()
        else:
            plt.close()

In [None]:
logSampling = True
# logSampling = False

In [None]:
if not logSampling:
    all_dicts_list = [
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_5000maxsize_5000repeats"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_5000maxsize_5000repeats"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_5000maxsize_5000repeats"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_5000maxsize_5000repeats"),
        ],
    ]

    sampling_sizes = np.arange(50,5000+50,50)

    for all_dicts in all_dicts_list: # check pairwise consistency of spatial cuts
        spatial_cuts = all_dicts[0]["spatial_cuts"]
        for dic in all_dicts:
            assert dic["spatial_cuts"] == spatial_cuts, "The spatial cuts were not the same across the dicts!"

In [None]:
if logSampling:
    all_dicts_list = [
    #     [
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R3.5/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R3.5/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
    #     ],
    #     [
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R2/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R2/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
    #     ],
        [
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R3.5/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R3.5/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
    ]

    sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100)))

    for all_dicts in all_dicts_list: # check pairwise consistency of spatial cuts
        spatial_cuts = all_dicts[0]["spatial_cuts"]
        for dic in all_dicts:
            assert dic["spatial_cuts"] == spatial_cuts, "The spatial cuts were not the same across the dicts!"

In [None]:
for all_dicts in all_dicts_list:
    if logSampling:
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=True)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=True)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=True, same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=True, same_youngold_fits=False)
    else:
        plot_error_vs_N(all_dicts, False, False, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, False, False, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)

#### Individually

Load from numpy arrays

In [None]:
def fill_all_dicts_with_arrays(all_dicts, base_path, function_dict, sample_sizes_str, array_name_dict, calculate_correlation=False, verbose=True):
    if verbose and calculate_correlation:
        print("Calculating correlation")
    
    for dic in all_dicts:
        spatial_path = get_save_path_spatial_cuts(base_path=base_path,spatial_cuts=dic["spatial_cuts"])

        for func in function_dict:

            sampling_repeats_str = "sampling" + str(dic["sampling_repeats"]) +"repeats"
            bootstrap_repeats_str = "bootstrap" + str(dic["bootstrap_repeats"]) +"repeats"
            pop_str = MF.extract_str_from_cuts_dict(dic["pop_cuts"])

            load_path = f"{spatial_path}arrays/{sampling_repeats_str}/{bootstrap_repeats_str}/{pop_str}/{sample_sizes_str}"
            
            for array_name in array_name_dict:
                dic[func + "_" + array_name_dict[array_name]] = np.load(load_path + f"_{func}_" + array_name + ".npy")
                dic[func + "_bootstrap_" + array_name_dict[array_name]] = np.load(load_path + f"_{func}_nested" + array_name + ".npy")

            if func == "correlation" and calculate_correlation:
                df_cut = MF.apply_cuts_to_df(df=df0,cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]])
                dic["correlation"] = CV.calculate_correlation(vx=df_cut.vr.values,vy=df_cut.vl.values)
    
    if verbose:
        print("Arrays filled successfully")

In [None]:
# assumption test on standard error

all_dicts = [
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[4,7]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    
]

sample_sizes_str = "50size5000_100stepsLog"

array_name_dict = {
    "SEs": "errors"
}

fill_all_dicts_with_arrays(all_dicts=all_dicts,base_path=base_path,function_dict=function_dict,\
                           sample_sizes_str=sample_sizes_str,array_name_dict=array_name_dict,calculate_correlation=True)

# for filename
repeats_str = f"_s{5000}b{500}rep"
print("Repeats string is",repeats_str)

ylabel = "Standard error"
print("ylabel is",ylabel)

leg_label = "SE"
quantity = "errors"
            
colors = ["blue", "red"]
labels = ["Young SE", "Old SE"]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    
    dic["linestyle"] = "-"
    dic["alpha"] = 0.7
    
    dic["mean_boot_label"] = None
    dic["median_boot_label"] = None

In [None]:
# assumption test on bias

all_dicts = [
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[4,7]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    
]

sample_sizes_str = "50size5000_100stepsLog"

array_name_dict = {
    "biases": "biases"
}

fill_all_dicts_with_arrays(all_dicts=all_dicts,base_path=base_path,function_dict=function_dict,sample_sizes_str=sample_sizes_str,array_name_dict=array_name_dict)

# for filename
repeats_str = f"_s{5000}b{500}rep"
print("Repeats string is",repeats_str)
            
colors = ["blue", "red"]
labels = ["Young", "Old"]

ylabel = "Bias"
print("ylabel is",ylabel)

leg_label = "biases"
print("leg_label is",leg_label)

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    
    dic["linestyle"] = "-"
    dic["alpha"] = 0.7
    
    dic["mean_boot_label"] = None
    dic["median_boot_label"] = None

In [None]:
# assess robustness of errors for different B

all_dicts = [
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[4,7]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/arrays/sampling1000repeats/bootstrap1000repeats/"+\
                          "anicorr_4age7_1000repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    {
        "spatial_cuts": {"b":[3,6],"R":[0,3.5],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "sampling_repeats": 5000,
        "bootstrap_repeats": 500
    },
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/arrays/sampling1000repeats/bootstrap1000repeats/"+\
                          "anicorr_9.5age10_1000repeats_50size5000_100stepsLog_sampleBootErrors.json"),
]

sample_sizes_str = "50size5000_100stepsLog"
fill_all_dicts_with_arrays(all_dicts=[all_dicts[0],all_dicts[2]],base_path=base_path,function_dict=function_dict,\
                           sample_sizes_str=sample_sizes_str,calculate_correlation=True)

# for filename
repeats_str = f"_s5000b500rep,s1000b1000rep"
print("Repeats string is",repeats_str)
            
colors = ["blue", "cyan", "red", "orange"]
labels = [r"Young $S,B=5000,500$", r"Young $S,B=1000,1000$", r"Old $S,B=5000,500$", r"Old $S,B=1000,1000$"]

# mean_boot_labels = labels
mean_boot_labels = 4*[None]

median_boot_labels = labels
# median_boot_labels = 4*[None]

# ylabel = "Standard error"
# ylabel = "Mean bootstrap SE"
ylabel = "Median bootstrap SE"
print("ylabel is",ylabel)

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    dic["mean_boot_label"] = mean_boot_labels[i]
    dic["median_boot_label"] = median_boot_labels[i]
    
    dic["linestyle"] = "-"
    dic["alpha"] = 0.7

Load from json

In [None]:
repeats = 1000

In [None]:
# with vs without resampling

base_path_no_resampling = base_path.split("with_replacement")[0]+"without_replacement/"

all_dicts = [
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    
    MF.load_dic_from_json(filename=base_path_no_resampling+f"3b6/0R2/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path_no_resampling+f"3b6/0R2/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
]

colors = 2*["blue", "red"]
labels = ["Young", "Old"] + 2*[None]
linestyles = 2*["-"] + 2*["--"]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    
    dic["linestyle"] = linestyles[i]
    dic["alpha"] = 0.7

In [None]:
# sampleBootErrors

all_dicts = [
#     MF.load_dic_from_json(filename=base_path+f"1.5b2/0R2/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    
#     MF.load_dic_from_json(filename=base_path+f"1.5b2/0R2/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog_sampleBootErrors.json"),
]

colors = ["blue", "red"]
labels = ["Young SE", "Old SE"]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    
    dic["linestyle"] = "-"
    dic["alpha"] = 0.7

In [None]:
# latitude, radius

all_dicts = [
    MF.load_dic_from_json(filename=base_path+f"1.5b2/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    
    MF.load_dic_from_json(filename=base_path+f"1.5b2/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
]

colors = ["blue","blue","cyan","cyan"] + ["red","red","orange","orange"]
labels = [r"Young $1.5^\circ<|b|<2^\circ$", None, r"Young $3^\circ<|b|<6^\circ$", None] +\
         [r"Old $1.5^\circ<|b|<2^\circ$", None, r"Old $3^\circ<|b|<6^\circ$", None]
linestyles = 4*["-","-."]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    dic["linestyle"] = linestyles[i]
    
    dic["alpha"] = 0.7

In [None]:
# latitude

all_dicts = [
    MF.load_dic_from_json(filename=base_path+f"1.5b2/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"6b13/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    
    MF.load_dic_from_json(filename=base_path+f"1.5b2/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"6b13/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
]

colors = ["blue","dodgerblue","cyan"] + ["red", "darkviolet", "orange"]
labels = [r"Young $1.5^\circ<|b|<2^\circ$", r"Young $3^\circ<|b|<6^\circ$", r"Young $6^\circ<|b|<13^\circ$"] + \
         [r"Old $1.5^\circ<|b|<2^\circ$", r"Old $3^\circ<|b|<6^\circ$", r"Old $6^\circ<|b|<13^\circ$"]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    
    dic["linestyle"] = "-"
    dic["alpha"] = 0.7

In [None]:
# longitude, radius

all_dicts = [
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/-2l2/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/3l6/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/3l6/anicorr_4age7_{repeats}repeats_50size5000_100stepsLog.json"),
    
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/-2l2/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R3.5/3l6/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+f"3b6/0R2/3l6/anicorr_9.5age10_{repeats}repeats_50size5000_100stepsLog.json"),
]

colors = ["blue", "blue", "cyan", "cyan"] + ["red","red","orange","orange"]
labels = [r"Young $|l|<2^\circ$", None, r"Young $3^\circ<l<6^\circ$", None] + [r"Old $|l|<2^\circ$", None, r"Old $3^\circ<l<6^\circ$", None]
linestyles = 4*["-","-."]

for i,dic in enumerate(all_dicts):
    dic["color"] = colors[i]
    dic["label"] = labels[i]
    dic["linestyle"] = linestyles[i]
    
    dic["alpha"] = 0.7

Plot

In [None]:
logSampling = True
# logSampling = False

if logSampling:
    sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100)))
else:
    sampling_sizes = np.arange(50,5000+50,50)

In [None]:
max_sampling_size = max(sampling_sizes)
# xtick_step = 500
# major_locator = 1000
# minor_locator = 250

In [None]:
# ylog_bool = True
ylog_bool = False

xlog_bool = True
# xlog_bool = False

if logSampling and not xlog_bool:
    print("WARNING: Using logSampling with non-log x-axis")
elif not logSampling and xlog_bool:
    print("WARNING: Using non-log sampling with log x-axis")

In [None]:
same_fits_all_dicts = {
    "anisotropy": False,
    "correlation": False,
    "tilt_abs": False
}

# fit_func = ReciprocalFunc(p0=[1,10,10])
fit_func = InverseSquareFunc(p0=[1])

# fit_params_text_bool = True
fit_params_text_bool = False

if fit_params_text_bool:
    if xlog_bool and ylog_bool:
        x_eq,y_eq = 0.75, 0.15
    elif xlog_bool:
        x_eq,y_eq = 0.65,0.4
    else:
        x_eq,y_eq = 0.4, 0.4

In [None]:
show_true_pearson_standard_error = True
# show_true_pearson_standard_error = False

def pearson_standard_error(n_plot, correlation):
    return (1-correlation**2)/np.sqrt(n_plot - 3); # https://doi.org/10.1525/collabra.87615
#     return (1-correlation**2)/np.sqrt(n_plot - 2) # https://en.youscribe.com/BookReader/Index/520541/?documentId=491664
#     return (1-correlation**2)/np.sqrt(n_plot) # https://www.tandfonline.com/doi/abs/10.1080/01621459.1928.10502991
pearson_se_label = r"$\frac{1-\rho^2}{\sqrt{n-3}}$"

In [None]:
if quantity == "errors":
    hard_coded_ylims_dict = {
    #     "anisotropy": [0 if not ylog_bool else 0.001, 0.3],
#         "correlation": [0 if not ylog_bool else 0.001, 0.149],
    #     "tilt_abs": [0 if not ylog_bool else 0.1, 33]
    }
elif quantity == "biases":
    hard_coded_ylims_dict = {
        "correlation": [-0.019,0.019],
    }
    
if xlog_bool and ylog_bool:
    for k in hard_coded_ylims_dict:
        hard_coded_ylims_dict[k][1] *= 1.5

In [None]:
# brokenaxes = True
brokenaxes = False

# MC_error_bool = brokenaxes and xlog_bool and ylog_bool
MC_error_bool = False

bootstrap_error_bool = brokenaxes

figsize = (8,10) if brokenaxes else (10,10)
if brokenaxes: # reduce tick size
    plt.rcParams["xtick.major.size"] = 6
    plt.rcParams["xtick.minor.size"] = 3
    plt.rcParams["ytick.major.size"] = 6
    plt.rcParams["ytick.minor.size"] = 3

In [None]:
fit_bool = True
# fit_bool = False

if quantity != "errors":
    fit_bool = False

In [None]:
show_actual_standard_errors = True
# show_actual_standard_errors = False

show_sample_bootstrap_errors = True
# show_sample_bootstrap_errors = False

if show_sample_bootstrap_errors:
    
    ### Mean
    show_bootstrap_mean = True; bootstrap_mean_label = f"Mean of bootstrap {leg_label}"
#     show_bootstrap_mean = False
    
    ### Median
#     show_bootstrap_median = True; bootstrap_median_label = f"Median of bootstrap {leg_label}"
    show_bootstrap_median = False
    
    ### STD
#     show_nested_bootstrap_std = True; bootstrap_std_label = f"STD of bootstrap {leg_label}"
    show_nested_bootstrap_std = False

#     nested_bootstrap_std_symmetric = True
    nested_bootstrap_std_symmetric = False

    ### PERCENTILE
    
    show_percentile_sample_bootstrap_errors_surface = True
#     show_percentile_sample_bootstrap_errors_surface = False
    
    percentile = 68; percentile_sample_bootstrap_error_surface_label = f"68% interval of bootstrap {leg_label}"; filename_suffix_percentile_bootstrap = "68q"
#     percentile = 95; percentile_sample_bootstrap_error_surface_label = f"95% interval of bootstrap {leg_label}"; filename_suffix_percentile_bootstrap = "95q"
#     percentile = 99.7; percentile_sample_bootstrap_error_surface_label = f"99.7% interval of bootstrap {leg_label}"; filename_suffix_percentile_bootstrap = "99.7q"
#     percentile = 100; percentile_sample_bootstrap_error_surface_label = f"Full range of bootstrap {leg_label}"; filename_suffix_percentile_bootstrap = "100q"
    pass

In [None]:
min_y = [0.001,0.001,0] if quantity == "errors" else 3*[None]

In [None]:
save_bool = True
# save_bool = False

In [None]:
# quantity vs N

width_broken_axes = 0.13
nrows = len(function_dict)

fig,axs = plt.subplots(figsize=figsize,ncols=3,nrows=nrows,gridspec_kw={"width_ratios":[1]+2*[width_broken_axes],"wspace":0.1,"hspace":0})

for row,func in enumerate(function_dict):
    lax,cax,rax = axs[row]

    if brokenaxes: # broken axes

        d = 0.02
        d_factor = 1/width_broken_axes
        
        for ax in [lax,cax]:
            ax.spines['right'].set_visible(False)
            ax.tick_params(which='both',right=False)
        
        lax.plot((1-d,1+d), (-d,d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
        lax.plot((1-d,1+d),(1-d,1+d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
        
        cax.plot((1-d_factor*d,1+d_factor*d), (-d,d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
        cax.plot((1-d_factor*d,1+d_factor*d), (1-d,1+d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
        
        for ax in [cax,rax]:
            ax.spines['left'].set_visible(False)
            ax.tick_params(which='both',left=False)
            
            ax.plot((-d_factor*d,+d_factor*d), (1-d,1+d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
            ax.plot((-d_factor*d,+d_factor*d), (-d,d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
    else:
        fig.delaxes(cax)
        fig.delaxes(rax)
    
    if True: # plot
        for d,dic in enumerate(all_dicts):
            
            quantity = list(array_name_dict.values())[0]
            
            if quantity == "biases":
                lax.axhline(y=0,color="grey",linestyle="dotted")
            
            if show_actual_standard_errors:
                lax.plot(sampling_sizes[sampling_sizes <= max_sampling_size],np.array(dic[func+"_"+quantity])[sampling_sizes <= max_sampling_size],\
                         label=dic["label"] if row==0 else None,color=dic["color"],alpha=dic["alpha"],linestyle=dic["linestyle"])
            
            if show_sample_bootstrap_errors:

                mean_bootstrap_errors = np.mean(dic[func+"_bootstrap_"+quantity],axis=1)
                
                if show_nested_bootstrap_std:
                    std_low,std_high = np.zeros(shape=(len(sampling_sizes))),np.zeros(shape=(len(sampling_sizes)))

                    for i,mean in enumerate(mean_bootstrap_errors):
                        std_low[i],std_high[i] = error_helpers.build_confidence_interval(central_value=mean, values=dic[func+"_bootstrap_"+quantity][i],\
                                                                                        symmetric=nested_bootstrap_std_symmetric)
                    
                    lax.fill_between(x=sampling_sizes, y1=mean_bootstrap_errors-std_low, y2=mean_bootstrap_errors+std_high, color=dic["color"], \
                                 alpha=0.25, linewidth=0 if show_percentile_sample_bootstrap_errors_surface else None)
                    
                if show_percentile_sample_bootstrap_errors_surface:
                    q_low = (100-percentile)/2
                    q_high = (100+percentile)/2 # this is the same as 100-q_low
                    
                    lax.fill_between(x=sampling_sizes, \
                                     y1=np.percentile(dic[func+"_bootstrap_"+quantity],axis=1, q=q_low),\
                                     y2=np.percentile(dic[func+"_bootstrap_"+quantity],axis=1, q=q_high),\
                                     color=dic["color"], alpha=0.1 if show_nested_bootstrap_std else 0.25)
                
                if show_bootstrap_mean:
                    lax.plot(sampling_sizes, mean_bootstrap_errors, color=dic["color"], linestyle="--", label=dic["mean_boot_label"])
                
                if show_bootstrap_median:
                    lax.plot(sampling_sizes, np.median(dic[func+"_bootstrap_"+quantity],axis=1),color=dic["color"],linestyle="-.",label=dic["median_boot_label"])
        
        if show_sample_bootstrap_errors and row==0: # just for legend labels
            
            labely = np.mean(dic[func+"_"+quantity])
            
            if show_bootstrap_mean and dic["mean_boot_label"] is None:
                lax.plot([-2,-1],[labely,labely],color="k",linestyle="--",label=bootstrap_mean_label)
                
            if show_bootstrap_median and dic["median_boot_label"] is None:
                lax.plot([-2,-1],[labely,labely],color="k",linestyle="-.",label=bootstrap_median_label)
            
            if show_nested_bootstrap_std:
                lax.fill_between(x=[-2,-1],y1=[labely,labely],y2=[labely,labely],color="k",alpha=0.25,label=bootstrap_std_label)
            
            if show_percentile_sample_bootstrap_errors_surface:
                lax.fill_between(x=[-2,-1],y1=[labely,labely],y2=[labely,labely],color="k",alpha=0.1 if show_nested_bootstrap_std else 0.25,\
                                 label=percentile_sample_bootstrap_error_surface_label)
                    
        if brokenaxes:
            for dic,ax in zip(all_dicts,[cax,rax]):
                ax.scatter(dic["total_N"],dic[func+"_bootstrap_error"],marker="*",color=dic["color"])

                if MC_error_bool:
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_low"],marker="v",color=dic["color"],s=23)
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_high"],marker="^",color=dic["color"],s=23)

                ax.set_xticks([dic["total_N"]])

            if bootstrap_error_bool:
                lax.scatter(x=-100,y=0,marker="*",color="k",label="Bootstrap error") # just for the legend label

            if MC_error_bool:
                lax.scatter(x=-100,y=0,marker="v",color="k",label="MC 20% distance error",s=23) # just for the legend label
    
    if True: # ticks, lims, logscale
        
        if xlog_bool:
            lax.set_xscale("log")
        
        lax_leftlim = 45 if xlog_bool else 0
        lax_rightlim = max_sampling_size+minor_locator*0.7 if not xlog_bool else max_sampling_size+10**(MF.get_exponent(max_sampling_size))//2
        lax.set_xlim(lax_leftlim,lax_rightlim)
        
        if not xlog_bool:
            lax.xaxis.set_major_locator(ticker.MultipleLocator(major_locator))
            lax.xaxis.set_minor_locator(ticker.MultipleLocator(minor_locator))
#             lax.set_xticks([50]+list(np.arange(xtick_step,max_sampling_size+xtick_step,xtick_step)))
        
        for ax in [lax,cax,rax]:
            if brokenaxes:
                ax.tick_params(axis='x', which='major', pad=10)
            
            if ylog_bool:
                ax.set_yscale("log")
                
            if ax in [cax,rax]:
                ax.yaxis.set_ticklabels([])
            
            if func in hard_coded_ylims_dict:
                ax.set_ylim(hard_coded_ylims_dict[func])
#             else:
#                 ax.set_ylim(bottom=0.005 if ylog_bool else 0)

            if not ylog_bool and quantity=="errors":
                if func == "anisotropy" and lax.get_ylim()[1] < 0.6:
                    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
                if func == "correlation":
                    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
                if func == "tilt_abs":
                    ax.yaxis.set_major_locator(ticker.MultipleLocator(10))

            if row != nrows - 1:
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_ticklabels([])
        
    if show_actual_standard_errors and fit_bool:

        x_plot = np.linspace(min(sampling_sizes),max_sampling_size,500)
        
        if same_fits_all_dicts[func]:
            fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[0][func+"_errors"])[sampling_sizes<=max_sampling_size])
            lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color="grey", linestyle="dotted")
            
            if fit_params_text_bool:
                show_fit_params_text(ax=lax,Func=fit_func,color="grey",x_eq=x_eq,y_eq=y_eq)
        elif not (func == "correlation" and show_true_pearson_standard_error):
            for i, dic in enumerate(all_dicts):
                fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(dic[func+"_errors"])[sampling_sizes<=max_sampling_size])

                lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color=dic["color"], linestyle="dotted",alpha=0.5)
                
                if fit_params_text_bool:
                    show_fit_params_text(ax=lax,Func=fit_func,alpha=0.7,color=dic["color"],x_eq=x_eq,y_eq=y_eq - i*0.08,units=mapf.get_units(func))
        
        if row == 0:
            labely = np.mean(dic["anisotropy_errors"])
            lax.plot([-100,-99],[labely,labely],linestyle="dotted",color="grey",label=fit_func.label) # just for the legend label
    
    if show_actual_standard_errors and fit_bool and func == "correlation" and show_true_pearson_standard_error:
        for dic in all_dicts:
            lax.plot(x_plot,pearson_standard_error(x_plot,dic["correlation"]),color=dic["color"],alpha=0.5,linestyle="dotted")
        
        lax.plot([0],[0],color="grey",linestyle="dotted",label=pearson_se_label) # for legend
    
    if True: # labels
        if row == nrows-1:
            if brokenaxes:
#                 lax.set_xlabel("Sample size"); lax.text(s=r"Total $N$",x=1.1,y=-0.25,transform=lax.transAxes,size="medium")
                lax.text(s="Sample size",x=0.5,y=-0.25,transform=lax.transAxes,size="medium")
            else:
                lax.set_xlabel("Sample size");
                
        lax.set_ylabel(ylabel + r" %s"%((r"$[$" + mapf.get_units(func) + r"$]$") if mapf.get_units(func) != "" else ""))
        
        fig.align_labels()
    
    if True: # text
        if xlog_bool and ylog_bool:
            x_func_text = 0.98
        elif xlog_bool:
            x_func_text = 0.6 if quantity == "errors" else 1.15
        else:
            x_func_text = 0.2
        
        func_str = mapf.get_kinematic_titles_dict("r","l")[func.removesuffix("_abs")]
            
        if not brokenaxes:
            x_func_text -= 0.45 if xlog_bool else 0
        fig.text(s=func_str,x=x_func_text,y=0.85 if quantity=="errors" else 0.8,transform=lax.transAxes,size="small")
    
    if True: # legend
        if row == 0:
            if brokenaxes:
                fig.legend(loc=(0.65,0.77) if not (xlog_bool and ylog_bool) else (0.16,0.7),framealpha=0 if xlog_bool and ylog_bool else 1)
            else:
                lax.legend(loc="best" if not ylog_bool else "lower left", framealpha=0.5)
        elif row == 1:
            if show_actual_standard_errors and fit_bool and show_true_pearson_standard_error and func == "correlation":
                lax.legend(loc="lower left" if xlog_bool and ylog_bool else "upper right", framealpha=0.5)
        
    if True: # fix ylims if needed
        if func not in hard_coded_ylims_dict:
#             if lax.get_ylim()[0] < 0:
            for ax in [lax,cax,rax]:
                ax.set_ylim(bottom=min_y[row])
        
if True: # filename, save and show
    
    all_spatial_cuts = [dic["spatial_cuts"] for dic in all_dicts]
    save_path = get_save_path_spatial_cuts(base_path=base_path,spatial_cuts=all_spatial_cuts)
    
    filename = quantity
    
    if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
        filename += "_anicorr"
    else:
        raise ValueError("Please specify map list name")

    pop_str = MF.extract_str_from_cuts_dict(all_dicts[0]["pop_cuts"])
    filename += "_"+pop_str
    for dic in all_dicts[1:]:
        if MF.extract_str_from_cuts_dict(dic["pop_cuts"]) not in filename:
            filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])
    
    filename += repeats_str
    filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}step{len(sampling_sizes)}"
    filename += "log" if logSampling else ""
    
    if xlog_bool and ylog_bool:
        filename += "_xylog"
    elif xlog_bool:
        filename += "_xlog"
    elif ylog_bool:
        filename += "_ylog"
    
    if quantity == "errors":
        if show_actual_standard_errors:
            if fit_bool:
                if all(same_fits_all_dicts.values()):
                    filename += "_sharedFits"
                if fit_params_text_bool:
                    filename += "_fitParams"
            else:
                filename += "_noFit"
        else:
            filename += "_noSE"
        
    if not show_sample_bootstrap_errors:
        filename += "_noBoot"
    else:
        
        filename += "_boot"
        
        nested_bootstrap_strings = [
            "Avg" if show_bootstrap_mean else "",
            "Med" if show_bootstrap_median else "",
            "Std" if show_nested_bootstrap_std else "",
            filename_suffix_percentile_bootstrap if show_percentile_sample_bootstrap_errors_surface else ""
        ]
        
        joint_boot_strings = str.join(",", nested_bootstrap_strings)
        while ",," in joint_boot_strings:
            joint_boot_strings = joint_boot_strings.replace(",,",",")
        
        filename += joint_boot_strings.strip(",")
        
    if brokenaxes:
        filename += "_broken"
        
        if MC_error_bool:
            filename += "_MC"
        
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        
        for fileformat in [".png",".pdf"]:
            plt.savefig(save_path+filename+fileformat,dpi=200,bbox_inches="tight")
            print(fileformat)
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
# distributions of standard errors

sample_size_indices = np.array([0,len(sampling_sizes)//2,-1])

nrows = len(function_dict)

fig,axs = plt.subplots(figsize=(20,15),ncols=len(sample_size_indices),nrows=nrows,gridspec_kw={"hspace":0.3})#,"wspace":0},)

for row,func in enumerate(function_dict):
    row_axs = axs[row]
    
    for col,(ax,sample_idx) in enumerate(zip(row_axs,sample_size_indices)):
        
        for d,dic in enumerate(all_dicts):
            
            ax.hist(dic[func+"_bootstrap_errors"][sample_idx], color=dic["color"],bins=50, label=["Young","Old"][d],alpha=0.5)#alpha=dic["alpha"])
            
            ax.axvline(np.mean(dic[func+"_bootstrap_errors"][sample_idx]), color=dic["color"])
            ax.axvline(dic[func+"_errors"][sample_idx], color=["cyan","orange"][d],linestyle="dotted",lw=2)
        
        ax.axvline(x=np.mean(dic[func+"_bootstrap_errors"][sample_idx]),ymin=0,ymax=0,color="k",label="Mean")
        ax.axvline(x=np.mean(dic[func+"_bootstrap_errors"][sample_idx]),ymin=0,ymax=0,color="k",linestyle="dotted",lw=2,label="Actual SE")
        
        if True: # title, labels
            if row == 0:
                ax.set_title(f"Sample size {sampling_sizes[sample_idx]}")
            
            if col == 0:
                ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
            ax.set_xlabel(["Anisotropy bootstrap SE","Correlation bootstrap SE",r"Vertex deviation bootstrap SE $[^\circ]$"][row])
                
            if row == len(function_dict)//2 and col == len(sample_size_indices)//2: # center
                ax.legend()
#             if row == 0 and col == 0:
#                 ax.legend()
            pass
            
if True: # filename and save
    all_spatial_cuts = [dic["spatial_cuts"] for dic in all_dicts]
    save_path = get_save_path_spatial_cuts(base_path=base_path,spatial_cuts=all_spatial_cuts)
    
    filename = "bootstrapErrorHists"
    
    if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
        filename += "_anicorr"
    else:
        raise ValueError("Please specify map list name")

    pop_str = MF.extract_str_from_cuts_dict(all_dicts[0]["pop_cuts"])
    filename += "_"+pop_str
    for dic in all_dicts[1:]:
        if MF.extract_str_from_cuts_dict(dic["pop_cuts"]) not in filename:
            filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])

    filename += f"_{repeats}repeats"
    filename += "_" + str.join(",",sampling_sizes[sample_size_indices].astype(str))
        
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
        
    plt.show()