In [None]:
# !jt -t onedork -T
!jt -r

In [None]:
%config Completer.use_jedi = False # To make auto-complete faster

#Reloads imported files automatically
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:88% !important; }</style>"))

In [None]:
import pandas as pd
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from plotting.matplotlib_param_funcs import set_matplotlib_params,reset_rcParams
set_matplotlib_params()

In [None]:
import src.compute_variables as CV
import src.compute_errors as CE
from src.errorconfig import MonteCarloConfig,BootstrapConfig

import plotting.map_functions as mapf
import plotting.mixed_plots as MP

import utils.miscellaneous_functions as MF
import utils.coordinates as coordinates
import utils.load_sim as load_sim
import utils.load_data as load_data

In [None]:
 plt.rcParams["font.size"] = 20

In [None]:
#CHOOSE

x_var = "l"
y_var = "b"

vel_x_var = 'r'
vel_y_var = 'l'

In [None]:
degree_symbol = "^\circ"

symbol_dict = mapf.get_kinematic_symbols_dict(x_variable=x_var,
                                             y_variable=y_var,
                                             vel_x_variable=vel_x_var,
                                             vel_y_variable=vel_y_var)

units_dict = mapf.get_kinematic_units_dict(degree_symbol=degree_symbol)

pos_symbols_dict,pos_units_dict = mapf.get_position_symbols_and_units_dict(degree_symbol=r"$%s$"%degree_symbol)

In [None]:
all_funcs_dict = {
    "correlation": CV.calculate_correlation,
    "anisotropy": CV.calculate_anisotropy,
    "tilt_abs": CV.calculate_tilt
}

In [None]:
general_path = '/Users/luismi/Desktop/MRes_UCLan/'

In [None]:
def get_save_path_spatial_cuts(save_path, spatial_cuts_dict):
    
    orders = {
        0: ["b","z"],
        1: ["d","R","x"],
        2: ["l","y"]
    }
    
    for o in orders:
        var = [v for v in spatial_cuts_dict if v in orders[o]]
        
        if len(var) == 0:
            continue
        elif len(var) > 1:
            raise ValueError(f"Did not expect more than one variable from `{orders[o]}`. If it was not a mistake, please specify the order.")
            
        variable = var[0]
            
        value_tuple = spatial_cuts_dict[variable]
        
        save_path += f"{MF.return_int_or_dec(value_tuple[0],2)}{variable}{MF.return_int_or_dec(value_tuple[1],2)}/"
        MF.create_dir(save_path)
    
    return save_path

# Load

In [None]:
zabs = True
# zabs = False

R0 = 8.1

GSR = True
# GSR = False

## Sim

In [None]:
sim_choice = "708main"
# sim_choice = "708mainDiff4"
# sim_choice = "708mainDiff5"

rot_angle = 27
axisymmetric = False
pos_scaling = 1.7

filename = load_sim.build_filename(choice=sim_choice,rot_angle=rot_angle,R0=R0,axisymmetric=axisymmetric,zabs=zabs,pos_factor=pos_scaling,GSR=GSR)

In [None]:
load_chunk = False

if not load_chunk:
    np_path = general_path+f"data/{sim_choice}/numpy_arrays/"
        
    df0 = load_sim.load_simulation(path=np_path,filename=filename)
else:
    if sim_choice == "708main" and rot_angle == 27 and not axisymmetric and zabs and sim_scaling == 1.7:
        pickle_name = "df_bulge_zabs.pkl"
        df0 = pd.read_pickle("708main_simulation/"+pickle_name)

## Data

In [None]:
obs_errors = True
# obs_errors = False

data_zabs = True
# data_zabs = False

In [None]:
data_path = general_path+"data/Observational_data/"

data = load_data.load_and_process_data(data_path=data_path, error_bool=obs_errors, zabs=zabs, R0=R0, GSR=GSR, drop_unused=False)

# Data uncertainty histograms

In [None]:
var = "d"

In [None]:
save_path = general_path + "graphs/Observations/Apogee/Uncertainties/"
save_path += f"{var}/"

MF.create_dir(save_path)

print(save_path)

In [None]:
# cuts_dict = {}
cuts_dict = {"FeH":[-1,0.61]}

data_df = MF.apply_cuts_to_df(data, cuts_dict=cuts_dict)

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 100
# log_bool = True
log_bool = False
# plot_range = [0,0.4]
plot_range = None

if True: # error hist
    fig,ax=plt.subplots()
    ax.hist(data_df[var+"_error"],bins=bins,log=log_bool,range=plot_range)
    ax.axvline(data_df[var+"_error"].median(),label="Median: %s %s"%(MF.return_int_or_dec(data_df[var+"_error"].median(),2),mapf.get_units(var)),color="red")
    ax.set_xlabel(mapf.get_symbol(var+"_error")+(f" [{mapf.get_units(var)}]" if mapf.get_units(var) != "" else ""))
    ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename, save
    filename = var
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict) if len(cuts_dict) > 0 else ""
    filename += f"_{plot_range[0]}range{plot_range[1]}" if plot_range is not None else ""
    filename += f"_{bins}bins"
    filename += "_log" if log_bool else ""
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 100
# log_bool = True
log_bool = False
# plot_range = [0,0.2]
plot_range = None

if True: # fractional error hist
    fig,ax=plt.subplots()

    ax.hist(data_df[var+"_error"]/np.abs(data_df[var]),bins=bins,range=plot_range,log=log_bool)
    ax.axvline((data_df[var+"_error"]/np.abs(data_df[var])).median(),\
               label="Median $%s$"%(100*MF.return_int_or_dec((data_df[var+"_error"]/np.abs(data_df[var])).median(),2)) +r"$~$%",color="red")
    ax.set_xlabel(mapf.get_symbol(var+"_fractionalerror"))
    ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename, save
    filename = f"{var}_frac"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{plot_range[0]}range{plot_range[1]}" if plot_range is not None else ""
    filename += f"_{bins}bins"
    filename += "_log" if log_bool else ""
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

# MC

In [None]:
def get_save_path_MC(perturbed_vars, data_bool):
    
    save_path = general_path + f"graphs/other_plots/MonteCarlo/" + str.join(",", perturbed_vars) + "/"
    MF.create_dir(save_path)

    save_path += "data/" if data_bool else "model/"
    MF.create_dir(save_path)
    
    return save_path

In [None]:
data_bool = True
# data_bool = False

# perturbed_vars = ["d","vr","pmra","pmdec"]
perturbed_vars = ["d"]

In [None]:
save_path = get_save_path_MC(perturbed_vars=perturbed_vars,data_bool=data_bool)

print(save_path)

In [None]:
if data_bool:
    cuts_dict = {"FeH":[-1,-0.21], "l":[-2,2], "R":[0,2]}
#     cuts_dict = {"FeH":[-0.21,0.61], "l":[-2,2], "R":[0,3.5]}
#     cuts_dict = {"FeH":[-0.21,0.61], "l":[-2,2], "R":[0,2]}
else:
#     cuts_dict={"age":[0,4],"R":[0,5],"l":[-2,2],"b":[0,0.01], "R":[0,3.5]} # nuclear disc
#     cuts_dict={"age":[4,7],"l":[-2,2],"b":[3.5,4], "R":[0,3.5]} # young pop
#     cuts_dict={"age":[4,7],"l":[-2,2],"b":[3,6], "R":[0,2]} # young pop
    cuts_dict={"age":[9.5,10],"l":[-2,2],"b":[3,6], "R":[0,2]} # old pop

In [None]:
# df = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=cuts_dict)
df = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=cuts_dict)
print(len(df),"stars")

In [None]:
func_name = "correlation"
# func_name = "anisotropy"

func = all_funcs_dict[func_name]

true_value = func(vx=df["v"+vel_x_var],vy=df["v"+vel_y_var])
print(f"{func_name}: {true_value}")

In [None]:
if "d" in perturbed_vars:
    affected_cuts_dict = {k:v for k,v in cuts_dict.items() if k in ["d","R"]}
else:
    affected_cuts_dict = None
    
repeats = 5000

## Multiple MC errors

In [None]:
std_boot,_,boot_vals = CE.get_std_bootstrap(function=func,vx=df["v"+vel_x_var].values,vy=df["v"+vel_y_var].values,\
                                            bootstrapconfig= BootstrapConfig(repeats=repeats))

print("Bootstrap error:",std_boot)

In [None]:
df_MC = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=MF.clean_cuts_from_dict(cuts_dict,cuts_to_remove=affected_cuts_dict))
print(len(df),len(df_MC))

In [None]:
frac_errors = np.arange(0.05,0.25,0.05)

In [None]:

montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,affected_cuts_dict=affected_cuts_dict,error_frac=frac_error,repeats=repeats,symmetric=False)

## Single MC error

In [None]:
# frac_error = 0.1
frac_error = 0.2
# frac_error = None

In [None]:
montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,affected_cuts_dict=affected_cuts_dict,error_frac=frac_error,repeats=repeats,symmetric=False)

bootstrapconfig = BootstrapConfig(repeats=repeats,symmetric=True)

In [None]:
df_MC = MF.apply_cuts_to_df(data if data_bool else df0, cuts_dict=MF.clean_cuts_from_dict(cuts_dict,affected_cuts_dict))
print(len(df),len(df_MC))

In [None]:
std_MC_low,std_MC_high, MC_values, within_cut = CE.get_std_MC(function=func,df=df_MC,montecarloconfig=montecarloconfig,true_value=true_value,\
                                                               vel_x_var=vel_x_var,vel_y_var=vel_y_var)

std_boot,_,boot_vals = CE.get_std_bootstrap(function=func,vx=df[f"v{vel_x_var}"].values,vy=df[f"v{vel_y_var}"].values,bootstrapconfig=bootstrapconfig)

print(f"Mean\t MC: {np.mean(MC_values)}. Boot: {np.mean(boot_vals)}")
print(f"Median\t MC: {np.median(MC_values)}. Boot: {np.median(boot_vals)}")
print(f"Std\t MC low,high: {std_MC_low},{std_MC_high}. Boot: {std_boot}")

In [None]:
save_bool = True
# save_bool = False

In [None]:
boot_hist_bool = True
# boot_hist_bool = False
MC_bar_width = None
# MC_bar_width = 0.001

if True: # boot & MC value hist
    fig,ax=plt.subplots()
    if MC_bar_width is not None:
        ax.hist(MC_values,bins=50,color="blue",alpha=0.5,label=fr"MC values ($R={repeats}$)",width=MC_bar_width)
    else: # It doesn't like it if I pass width=None and I couldn't find what the default is
        ax.hist(MC_values,bins=50,color="blue",alpha=0.5,label=fr"MC values ($R={repeats}$)")
    if boot_hist_bool:
        ax.hist(boot_vals,bins=50,color="green",alpha=0.5, label=fr"Bootstrap values ($R={repeats}$)")
    ax.axvline(true_value,color="red",label="True value")
    ax.set_xlabel(symbol_dict[func_name]); ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
    ax.legend()

if True: # filename and save
    filename = func_name
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += f"_noBoot" if not boot_hist_bool else ""
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{repeats}repeats"
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

## d-specific plots

In [None]:
save_bool = True
# save_bool = False

In [None]:
originally_within_Rmax = df_errors["R"] <= 3.5

In [None]:
if True: # plot scatter inside/outside limit
    
    fig,ax=plt.subplots()
    
    colors = ["grey","red","limegreen"]
    elements_dic = {
        "Stayed inside/outside": [(originally_within_Rmax&within_cut)|(~originally_within_Rmax&~within_cut), 1, colors[0],1,"grey"],
        "Moved outside": [originally_within_Rmax&~within_cut, 10, colors[1],0.5,"k"],
        "Moved inside": [~originally_within_Rmax&within_cut, 10, colors[2],0.5,"k"]
    }

    y_var = "vr"; ylabel = r"$v_r$" + units_dict["mean_vx"]
#     y_var = "d"; ylabel = r"$d~$[kpc]"

    for k in elements_dic:
        condition,size,color,lw,edgecolor = elements_dic[k]

        ax.scatter(x=df_errors[condition]["R"], y=df_errors[condition][y_var], color=color, s=size, label=f"{k} ({sum(condition)})",lw=lw,edgecolor=edgecolor)

    ax.axvline(x=3.5,color="grey",lw=1,linestyle="--")
    ax.text(x=3.5+0.1,y=0.85*ax.get_ylim()[1], s=r"$R=3.5~$kpc",color="grey",size=15)

    ax.set_xlabel(r"$R$ [kpc]"); ax.set_ylabel(ylabel)

    leg = ax.legend()

    for text,c in zip(leg.get_texts(),colors):
        text.set_color(c)
             
if True: # filename and save
    filename = f"{y_var}-R"
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)

    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    
    plt.show()

In [None]:
save_bool = True
# save_bool = False

In [None]:
bins = 50 if data_bool else 100

fig,ax=plt.subplots()

MC_d = np.random.normal(loc=df_errors["d"],scale=df_errors["d_error"] if "d_error" in df_errors else frac_error*df_errors["d"]) # this is an example - in the std calculation this is calculated {repeats} number of times

ax.hist(MC_d,bins=bins,label=r"MC $d$")
ax.hist(df_errors["d"],bins=bins,color="red",alpha=0.4,label=r"Original $d$")
ax.hist(MC_d-df_errors["d"],bins=bins,color="k",histtype="step",label="Difference")
ax.set_xlabel(r"$d$ [kpc]"); ax.set_ylabel(r"$N$",rotation=0,labelpad=20)
ax.legend()

if True:
    filename = "dHists"
    filename += f"_{frac_error}fracErr" if frac_error is not None else "_dataErr"
    filename += "_" + MF.extract_str_from_cuts_dict(cuts_dict)
    filename += f"_{bins}bins"
    
    print(filename)
    
    if save_bool:
        plt.savefig(save_path+filename+".png", bbox_inches="tight")
        print("Saved in",save_path)
    plt.show()

# Downsampling errors

In [None]:
def get_sampling_errors(df, function, sampling_sizes, symmetric=True, replacement=False, repeat=500, tilt=False, absolute=True, verbose=False):

    errors_low = np.empty_like(sampling_sizes,dtype=float)
    errors_high = np.empty_like(sampling_sizes,dtype=float)

    for s,size in enumerate(sampling_sizes):
        errors_low[s],errors_high[s],values = CE.get_std_bootstrap(function=function,vx=df.vr.values,vy=df.vl.values,tilt=tilt,absolute=absolute,\
                                     bootstrapconfig=BootstrapConfig(bootstrap_size=size,replacement=replacement,symmetric=symmetric,repeats=repeat))
        
        if verbose:
            print(size,end="; ")
    if verbose:
        print("\n")
    
    return errors_low if len(errors_low)>1 else errors_low[0],\
           errors_high if len(errors_high)>1 else errors_high[0],\
           values

## Value histograms

In [None]:
all_cuts_dict = {"R":[0,2],"b":[3,6],"l":[-2,2],"age": [4,7]}

df = MF.apply_cuts_to_df(df0, cuts_dict=all_cuts_dict)

function = CV.calculate_correlation

true_val = function(df.vr.values,df.vl.values)

print("True value:",true_val)

montecarloconfig = MonteCarloConfig(perturbed_vars=["d"],repeats=500,error_frac=0.2,symmetric=False, \
                                    affected_cuts_dict={k:v for k,v in all_cuts_dict.items() if k == "R"})
        
df_MC = MF.apply_cuts_to_df(df0, cuts_dict=montecarloconfig.clean_value_cuts_dict(all_cuts_dict))
std_MC_low,std_MC_high,MC_values,_ = CE.get_std_MC(df=df_MC,function=function,true_value=true_val,vel_x_var="r",vel_y_var="l",tilt=False, absolute=True,\
                  montecarloconfig=montecarloconfig)

print("Std MC:",std_MC_low,std_MC_high)

std,_,boot_vals = get_sampling_errors(df=df,function=function,repeat=500,replacement=True,symmetric=True,sampling_sizes=[len(df)],verbose=False)
print("Std boot:",std)

In [None]:
bins = 50
sizes = [50,100,300,1000,5000,10000,len(df)]

fig,ax=plt.subplots()

for s in sizes:
    replacement = False if s != len(df) else True
    label = s if s!=len(df) else f"Bootstrap ({s})"
    
    std,_,vals = get_sampling_errors(df=df,function=function,repeat=500,replacement=replacement,symmetric=True,sampling_sizes=[s],verbose=True)

    ax.hist(vals,bins=bins,alpha=0.7,label=label)
    
ax.hist(MC_values,color="cyan",label="MC 20% distance error",alpha=0.7)
    
ax.set_yscale("log")
    
ax.axvline(true_val,color="grey",label="True value",linestyle="--")
ax.legend()
ax.set_xlabel("Correlation")
ax.set_ylabel(r"$N$")
plt.show()

## Error vs N

In [None]:
base_path = general_path+"graphs/other_plots/downsampling_errors/"

In [None]:
function_dict = {
    "anisotropy": CV.calculate_anisotropy,
    "correlation": CV.calculate_correlation,
    "tilt_abs": CV.calculate_tilt
}

### Save

In [None]:
# MC_symmetric = True
MC_symmetric = False

repeats = 500

In [None]:
sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100))); logSampling = True
# sampling_sizes = np.arange(50,5000+50,50); logSampling = False

MC_perturbed_vars_list = [["d"]]
MC_error_fracs = [0.1,0.2]

all_dicts = [ # select for plotting
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},
#         "label": "Young",
#         "color": "blue"
#     },
#     {
#         "spatial_cuts": {"R":[0,2],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},
#         "label": "Old",
#         "color": "red"
#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age": [4,7]},
#         "label": "Young",
#         "color": "blue"
#     },
#     {
#         "spatial_cuts": {"R":[0,3.5],"b":[3.5,4.5],"l":[-2,2]},
#         "pop_cuts": {"age":[9.5,10]},
#         "label": "Old",
#         "color": "red"
#     },
    {
        "spatial_cuts": {"R":[0,2],"b":[1.5,2],"l":[-2,2]},
        "pop_cuts": {"age": [4,7]},
        "label": "Young",
        "color": "blue"
    },
    {
        "spatial_cuts": {"R":[0,2],"b":[1.5,2],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "label": "Old",
        "color": "red"
    },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[1.5,2],"l":[-2,2]},
        "pop_cuts": {"age": [4,7]},
        "label": "Young",
        "color": "blue"
    },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[1.5,2],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "label": "Old",
        "color": "red"
    },
    {
        "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age": [4,7]},
        "label": "Young",
        "color": "blue"
    },
    {
        "spatial_cuts": {"R":[0,2],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "label": "Old",
        "color": "red"
    },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age": [4,7]},
        "label": "Young",
        "color": "blue"
    },
    {
        "spatial_cuts": {"R":[0,3.5],"b":[3,6],"l":[-2,2]},
        "pop_cuts": {"age":[9.5,10]},
        "label": "Old",
        "color": "red"
    }
]

for dic in all_dicts: # check there are enough stars
    star_number = len(MF.apply_cuts_to_df(df0, cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]]))
    assert star_number > max(sampling_sizes),\
    f"There are not enough stars (namely {star_number}) for the dict with cuts %s and %s"%(dic["spatial_cuts"],dic["pop_cuts"])

for dic in all_dicts:
    df = MF.apply_cuts_to_df(df0, cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]])
    
    dic["total_N"] = len(df)
    
    print(dic["label"])
    for func in function_dict:
        print(func)
        
        standard_errors,*_ =\
            get_sampling_errors(df=df,function=function_dict[func],repeat=repeats,replacement=False,symmetric=True,sampling_sizes=sampling_sizes,\
                                tilt=func=="tilt_abs", absolute=True, verbose=True)
        dic[func+"_errors"] = list(standard_errors)
        
        dic[func] = function_dict[func](df.vr.values,df.vl.values)
        
        print("True value:",dic[func])
        
        dic[func+"_bootstrap_error"],*_ = get_sampling_errors(df=df,function=function_dict[func],sampling_sizes=[len(df)],repeat=repeats,\
                                                        tilt=func=="tilt_abs", absolute=True, replacement=True)
        
        print("Bootstrap error:",dic[func+"_bootstrap_error"])
        
        for perturbed_vars in MC_perturbed_vars_list: # MC error
            
            affected_cuts_dict = {k:v for k,v in dic["spatial_cuts"].items() if k in ["d","R"]} if "d" in perturbed_vars else None
            
            df_MC = df if affected_cuts_dict is None else MF.apply_cuts_to_df(df0, cuts_dict=\
                      MF.clean_cuts_from_dict(cuts_dict=[dic["spatial_cuts"],dic["pop_cuts"]], cuts_to_remove=affected_cuts_dict))
            
            for frac in MC_error_fracs:

                montecarloconfig = MonteCarloConfig(perturbed_vars=perturbed_vars,repeats=repeats,error_frac=frac,symmetric=MC_symmetric,\
                                                    affected_cuts_dict=affected_cuts_dict)

                perturbed_vars_str = str.join(",",perturbed_vars)
                
                dic[func+f"_MC_{perturbed_vars_str}_{frac}_error_low"],dic[func+f"_MC_{perturbed_vars_str}_{frac}_error_high"],MC_values,_ =\
                    CE.get_std_MC(df=df_MC,function=function_dict[func],true_value=dic[func],vel_x_var="r",vel_y_var="l",tilt=func=="tilt_abs", absolute=True,\
                                  montecarloconfig=montecarloconfig)

                print(f"MC {perturbed_vars_str} errors {frac} frac:",dic[func+f"_MC_{perturbed_vars_str}_{frac}_error_low"],dic[func+f"_MC_{perturbed_vars_str}_{frac}_error_high"])
        
        print("\n")
    
    if True: # save as json
        save_path = get_save_path_spatial_cuts(save_path=base_path,spatial_cuts_dict=dic["spatial_cuts"])

        if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
            filename = "anicorr"
        else:
            raise ValueError("Please specify map list name")

        filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])

        filename += f"_{repeats}repeats"
        filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
        filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")

        MF.save_dic_as_json(dic=dic, filename=save_path+filename)
        
        print("Saved dic as",filename+".json")
        print("In",save_path)
    
    print("\n")

### Plot

In [None]:
from scipy.optimize import curve_fit

class Func():
    def fit(self,x,y):
        self.fit_params,_ = curve_fit(f=self.func, p0=self.p0, xdata=x,ydata=y)
        
class ReciprocalFunc(Func):
    def __init__(self, p0=[39.65,100,0.02]):
        self.func = lambda x,a,b,c: a/(x+b)+c
        self.name = "reciprocal"
        self.label = r"$f(n)=\frac{a}{n+b}+c$"
        self.p0 = p0

class InverseSquareFunc(Func):
    def __init__(self, p0=[10]):
        self.func = lambda x,a: a/(np.sqrt(x))
        self.name = "inverse_square"
#         self.label = r"$f(n)=\frac{a}{\sqrt{n}}$"
        self.label = r"$f(n)\propto\frac{1}{\sqrt{n}}$"
        self.p0 = p0
        
def show_fit_params_text(ax, Func, x_eq=0.6,y_eq=0.42, color="grey", ndec_text=2, alpha=1, size=13, units=""):
    
    if not hasattr(Func,"fit_params"):
        raise AttributeError("Fit the function first!")
    
    abc_str = str.join(",",["abcdefg"[i] for i in range(len(Func.p0))])
    values_str = (len(Func.p0)*"%s,").removesuffix(",")
    final_str = (r"$(" if len(Func.p0)>1 else r"$") + abc_str + (")=(" if len(Func.p0)>1 else "=") + values_str + (")$" if len(Func.p0)>1 else "$")
    final_str += units

    ax.text(x=x_eq,y=y_eq,transform=ax.transAxes,color=color,alpha=alpha,size=size,
             s=final_str%tuple([MF.return_int_or_dec(param,ndec_text) for param in Func.fit_params]))

#### In bulk

In [None]:
def plot_error_vs_N(all_dicts, xlog_bool, ylog_bool, save_bool=True, show_bool=False, fit_bool=True, same_youngold_fits=True, logSampling=True):
    max_sampling_size = max(sampling_sizes)
    xtick_step = 500
    major_locator = 1000
    minor_locator = 250
    
    # fit_func = ReciprocalFunc(p0=[1,10,10])
    fit_func = InverseSquareFunc(p0=[1])

    if xlog_bool and ylog_bool:
        x_eq,y_eq = 0.75, 0.15
    elif xlog_bool:
        x_eq,y_eq = 0.65,0.4
    else:
        x_eq,y_eq = 0.4, 0.4
        
    hard_coded_ylims_dict = {
        "anisotropy": [0 if not ylog_bool else 0.001,0.3],
        "correlation": [0 if not ylog_bool else 0.001,0.149],
        "tilt_abs": [0 if not ylog_bool else 0.1,33]
    }

    if xlog_bool and ylog_bool:
        for k in hard_coded_ylims_dict:
            hard_coded_ylims_dict[k][1] *= 1.5
            
    MC_error_bool = xlog_bool and ylog_bool
    
    width_broken_axes = 0.13
    nrows = len(function_dict)

    fig,axs = plt.subplots(figsize=(8,10),ncols=3,nrows=nrows,gridspec_kw={"width_ratios":[1]+2*[width_broken_axes],"wspace":0.1,"hspace":0})

    for row,func in enumerate(function_dict):
        lax,cax,rax = axs[row]

        if True: # broken axes

            d = 0.02
            d_factor = 1/width_broken_axes

            for ax in [lax,cax]:
                ax.spines['right'].set_visible(False)
                ax.tick_params(which='both',right=False)

            lax.plot((1-d,1+d), (-d,d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
            lax.plot((1-d,1+d),(1-d,1+d), transform=lax.transAxes, color='k', clip_on=False,lw="1")

            cax.plot((1-d_factor*d,1+d_factor*d), (-d,d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
            cax.plot((1-d_factor*d,1+d_factor*d), (1-d,1+d), transform=cax.transAxes, color='k', clip_on=False,lw="1")

            for ax in [cax,rax]:
                ax.spines['left'].set_visible(False)
                ax.tick_params(which='both',left=False)

                ax.plot((-d_factor*d,+d_factor*d), (1-d,1+d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
                ax.plot((-d_factor*d,+d_factor*d), (-d,d), transform=ax.transAxes, color='k', clip_on=False,lw="1")

        if True: # plot
            for dic in all_dicts:
                lax.plot(sampling_sizes[sampling_sizes <= max_sampling_size],np.array(dic[func+"_errors"])[sampling_sizes <= max_sampling_size],\
                         label=dic["label"],color=dic["color"],alpha=0.7)

            for dic,ax in zip(all_dicts,[cax,rax]):
                ax.scatter(dic["total_N"],dic[func+"_bootstrap_error"],marker="*",color=dic["color"])

                if MC_error_bool:
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_low"],marker="v",color=dic["color"],s=23)
                    ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_high"],marker="^",color=dic["color"],s=23)

                ax.set_xticks([dic["total_N"]])

            lax.scatter(x=-100,y=0,marker="*",color="k",label="Bootstrap error") # just for the legend label

            if MC_error_bool:
                lax.scatter(x=-100,y=0,marker="v",color="k",label="MC 20% distance error",s=23) # just for the legend label

        if True: # ticks, lims, logscale

            if xlog_bool:
                lax.set_xscale("log")

            lax_leftlim = 40 if xlog_bool else 0
            lax.set_xlim(lax_leftlim,max_sampling_size+minor_locator*0.9)

            if not xlog_bool:
                lax.xaxis.set_major_locator(ticker.MultipleLocator(major_locator))
                lax.xaxis.set_minor_locator(ticker.MultipleLocator(minor_locator))
    #             lax.set_xticks([50]+list(np.arange(xtick_step,max_sampling_size+xtick_step,xtick_step)))

            for ax in [lax,cax,rax]:
                ax.tick_params(axis='x', which='major', pad=10)

                if ylog_bool:
                    ax.set_yscale("log")

                if ax in [cax,rax]:
                    ax.yaxis.set_ticklabels([])

                if func in hard_coded_ylims_dict:
                    ax.set_ylim(hard_coded_ylims_dict[func])
                else:
                    ax.set_ylim(bottom=0.005 if ylog_bool else 0)

                if row != nrows - 1:
                    ax.xaxis.set_ticklabels([])
                    ax.xaxis.set_ticklabels([])

        if fit_bool:

            x_plot = np.linspace(min(sampling_sizes),max_sampling_size,500)

            if func != "tilt_abs" and same_youngold_fits:
                fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[0][func+"_errors"])[sampling_sizes<=max_sampling_size])
                lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color="grey", linestyle="--")
                show_fit_params_text(ax=lax,Func=fit_func,color="grey",x_eq=x_eq,y_eq=y_eq)
            else:
                for (pop_idx,color,y) in zip([0,1],["blue","red"],[y_eq-0.04,y_eq+0.04]):
                    fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[pop_idx][func+"_errors"])[sampling_sizes<=max_sampling_size])

                    lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color=color, linestyle="--",alpha=0.5)
                    show_fit_params_text(ax=lax,Func=fit_func,alpha=0.7,color=color,x_eq=x_eq,y_eq=y,units=mapf.get_units(func))

            lax.plot([-100,-100],[0,1],label=fit_func.label,linestyle="--",color="grey") # just for the legend label

        if True: # labels, legend, text
            if row == nrows-1:
    #             lax.set_xlabel("Sample size"); lax.text(s=r"Total $N$",x=1.1,y=-0.25,transform=lax.transAxes,size="medium")
                lax.text(s="Sample size",x=0.5,y=-0.25,transform=lax.transAxes,size="medium")

            if row == 0:
                fig.legend(loc=(0.65,0.77) if not (xlog_bool and ylog_bool) else (0.16,0.7),framealpha=0 if xlog_bool and ylog_bool else 1)
            elif row == 1:
                lax.set_ylabel(r"Standard error")

            if xlog_bool and ylog_bool:
                x_func_text = 0.98
            elif xlog_bool:
                x_func_text = 0.3
            else:
                x_func_text = 0.2
            fig.text(s=mapf.get_kinematic_titles_dict("r","l")[func.removesuffix("_abs")],x=x_func_text,y=0.83,transform=lax.transAxes,size=15)

            fig.align_labels()

    if True: # filename, save and show
        save_path = get_save_path_spatial_cuts(save_path=base_path, spatial_cuts_dict=all_dicts[0]["spatial_cuts"])

        if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
            filename = "anicorr"
        else:
            raise ValueError("Please specify map list name")
            
        filename += "_"+MF.extract_str_from_cuts_dict(dic["spatial_cuts"])

        for dic in all_dicts:
            filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])

        if xlog_bool and ylog_bool:
            filename += "_xylog"
        elif xlog_bool:
            filename += "_xlog"
        elif ylog_bool:
            filename += "_ylog"

        if MC_error_bool:
            filename += "_MC"

        filename += f"_{repeats}repeats"
        filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
        filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")

        if not fit_bool:
            filename += "_noFit"
        elif not same_youngold_fits:
            filename += "_diffFits"

        print(filename)

        if save_bool:
            print("Saving in",save_path)
            plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")

        if show_bool:
            plt.show()
        else:
            plt.close()

In [None]:
logSampling = True
# logSampling = False

In [None]:
if not logSampling:
    all_dicts_list = [
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_5000maxsize_5000repeats"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_5000maxsize_5000repeats"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_5000maxsize_5000repeats"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_5000maxsize_5000repeats"),
        ],
    ]

    sampling_sizes = np.arange(50,5000+50,50)

    for all_dicts in all_dicts_list: # check pairwise consistency of spatial cuts
        spatial_cuts = all_dicts[0]["spatial_cuts"]
        for dic in all_dicts:
            assert dic["spatial_cuts"] == spatial_cuts, "The spatial cuts were not the same across the dicts!"

In [None]:
if logSampling:
    all_dicts_list = [
    #     [
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R3.5/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R3.5/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
    #     ],
    #     [
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R2/-2l2/anicorr_4age7_5000repeats_50size5000_100stepsLog.json"),
    #         MF.load_dic_from_json(filename=base_path+"3.5b4.5/0R2/-2l2/anicorr_9.5age10_5000repeats_50size5000_100stepsLog.json"),
    #     ],
        [
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R3.5/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R3.5/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"1.5b2/0R2/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
        [
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
            MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
        ],
    ]

    sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100)))

    for all_dicts in all_dicts_list: # check pairwise consistency of spatial cuts
        spatial_cuts = all_dicts[0]["spatial_cuts"]
        for dic in all_dicts:
            assert dic["spatial_cuts"] == spatial_cuts, "The spatial cuts were not the same across the dicts!"

In [None]:
for all_dicts in all_dicts_list:
    if logSampling:
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=True)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=True)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=True, same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=True, same_youngold_fits=False)
    else:
        plot_error_vs_N(all_dicts, False, False, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=False)
        plot_error_vs_N(all_dicts, False, False, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, False, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)
        plot_error_vs_N(all_dicts, True, True, save_bool=True, show_bool=False, logSampling=False,same_youngold_fits=False)

#### Individually

In [None]:
all_dicts = [
    MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_4age7_500repeats_50size5000_100stepsLog.json"),
    MF.load_dic_from_json(filename=base_path+"3b6/0R2/-2l2/anicorr_9.5age10_500repeats_50size5000_100stepsLog.json"),
#     MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_4age7_5000maxsize_5000repeats.json"),
#     MF.load_dic_from_json(filename=base_path+"3b6/0R3.5/-2l2/anicorr_9.5age10_5000maxsize_5000repeats.json"),
]

colors = ["blue","red","cyan","orange"]
labels = ["Young (R<3.5)","Old (R<3.5)", "Young (R<2)","Old (R<2)"]

for dic,color,label in zip(all_dicts,colors,labels):
    dic["color"] = color
    dic["label"] = None

In [None]:
logSampling = True
# logSampling = False

if logSampling:
    sampling_sizes = np.int64(np.round(10**np.linspace(np.log10(50),np.log10(5000),100)))
else:
    sampling_sizes = np.arange(50,5000+50,50)

In [None]:
max_sampling_size = max(sampling_sizes)
xtick_step = 500
major_locator = 1000
minor_locator = 250

In [None]:
ylog_bool = True
# ylog_bool = False

xlog_bool = True
# xlog_bool = False

if logSampling and not xlog_bool:
    raise ValueError("Cannot use logSampling with non-log x-axis")

In [None]:
fit_bool = True
# fit_bool = False

same_fits_all_dicts = {
    "anisotropy": False,
    "correlation": False,
    "tilt_abs": False
}

# fit_func = ReciprocalFunc(p0=[1,10,10])
fit_func = InverseSquareFunc(p0=[1])

# fit_params_text_bool = True
fit_params_text_bool = False

if fit_params_text_bool:
    if xlog_bool and ylog_bool:
        x_eq,y_eq = 0.75, 0.15
    elif xlog_bool:
        x_eq,y_eq = 0.65,0.4
    else:
        x_eq,y_eq = 0.4, 0.4

In [None]:
show_true_pearson_standard_error = True
# show_true_pearson_standard_error = False

def pearson_standard_error(n_plot, correlation):
    return (1-correlation**2)/np.sqrt(n_plot - 3); # https://doi.org/10.1525/collabra.87615
#     return (1-correlation**2)/np.sqrt(n_plot - 2) # https://en.youscribe.com/BookReader/Index/520541/?documentId=491664
#     return (1-correlation**2)/np.sqrt(n_plot) # https://www.tandfonline.com/doi/abs/10.1080/01621459.1928.10502991
pearson_se_label = r"$\frac{1-\rho^2}{\sqrt{n-3}}$"

In [None]:
hard_coded_ylims_dict = {
#     "anisotropy": [0 if not ylog_bool else 0.001,0.3],
#     "correlation": [0 if not ylog_bool else 0.001,0.149],
#     "tilt_abs": [0 if not ylog_bool else 0.1,33]
}

if xlog_bool and ylog_bool:
    for k in hard_coded_ylims_dict:
        hard_coded_ylims_dict[k][1] *= 1.5

In [None]:
# brokenaxes = True
brokenaxes = False

MC_error_bool = brokenaxes and xlog_bool and ylog_bool
bootstrap_error_bool = brokenaxes

figsize = (8,10) if brokenaxes else (10,10)
if brokenaxes: # reduce tick size
    plt.rcParams["xtick.major.size"] = 6
    plt.rcParams["xtick.minor.size"] = 3
    plt.rcParams["ytick.major.size"] = 6
    plt.rcParams["ytick.minor.size"] = 3
plt.rcParams["font.size"] = 19

In [None]:
# save_bool = True
save_bool = False

In [None]:
width_broken_axes = 0.13
nrows = len(function_dict)

fig,axs = plt.subplots(figsize=figsize,ncols=3,nrows=nrows,gridspec_kw={"width_ratios":[1]+2*[width_broken_axes],"wspace":0.1,"hspace":0})

for row,func in enumerate(function_dict):
    lax,cax,rax = axs[row]

    if brokenaxes: # broken axes

        d = 0.02
        d_factor = 1/width_broken_axes
        
        for ax in [lax,cax]:
            ax.spines['right'].set_visible(False)
            ax.tick_params(which='both',right=False)
        
        lax.plot((1-d,1+d), (-d,d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
        lax.plot((1-d,1+d),(1-d,1+d), transform=lax.transAxes, color='k', clip_on=False,lw="1")
        
        cax.plot((1-d_factor*d,1+d_factor*d), (-d,d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
        cax.plot((1-d_factor*d,1+d_factor*d), (1-d,1+d), transform=cax.transAxes, color='k', clip_on=False,lw="1")
        
        for ax in [cax,rax]:
            ax.spines['left'].set_visible(False)
            ax.tick_params(which='both',left=False)
            
            ax.plot((-d_factor*d,+d_factor*d), (1-d,1+d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
            ax.plot((-d_factor*d,+d_factor*d), (-d,d), transform=ax.transAxes, color='k', clip_on=False,lw="1")
    else:
        fig.delaxes(cax)
        fig.delaxes(rax)
    
    if True: # plot
        for dic in all_dicts:
            lax.plot(sampling_sizes[sampling_sizes <= max_sampling_size],np.array(dic[func+"_errors"])[sampling_sizes <= max_sampling_size],\
                     label=dic["label"] if row==0 else None,color=dic["color"],alpha=0.7)

        for dic,ax in zip(all_dicts,[cax,rax]):
            ax.scatter(dic["total_N"],dic[func+"_bootstrap_error"],marker="*",color=dic["color"])
            
            if MC_error_bool:
                ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_low"],marker="v",color=dic["color"],s=23)
                ax.scatter(dic["total_N"],dic[func+"_MC_d_0.2_error_high"],marker="^",color=dic["color"],s=23)
            
            ax.set_xticks([dic["total_N"]])
        
        if bootstrap_error_bool:
            lax.scatter(x=-100,y=0,marker="*",color="k",label="Bootstrap error") # just for the legend label
        
        if MC_error_bool:
            lax.scatter(x=-100,y=0,marker="v",color="k",label="MC 20% distance error",s=23) # just for the legend label
    
    if True: # ticks, lims, logscale
        
        if xlog_bool:
            lax.set_xscale("log")
        
        lax_leftlim = 40 if xlog_bool else 0
        lax_rightlim = max_sampling_size+minor_locator*0.9 if not logSampling else max_sampling_size+10**(MF.get_exponent(max_sampling_size))
        lax.set_xlim(lax_leftlim,lax_rightlim)
        
        if not xlog_bool:
            lax.xaxis.set_major_locator(ticker.MultipleLocator(major_locator))
            lax.xaxis.set_minor_locator(ticker.MultipleLocator(minor_locator))
#             lax.set_xticks([50]+list(np.arange(xtick_step,max_sampling_size+xtick_step,xtick_step)))
        
        for ax in [lax,cax,rax]:
            ax.tick_params(axis='x', which='major', pad=10)
            
            if ylog_bool:
                ax.set_yscale("log")
                
            if ax in [cax,rax]:
                ax.yaxis.set_ticklabels([])
            
            if func in hard_coded_ylims_dict:
                ax.set_ylim(hard_coded_ylims_dict[func])
#             else:
#                 ax.set_ylim(bottom=0.005 if ylog_bool else 0)

            if row != nrows - 1:
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_ticklabels([])
        
    if fit_bool:

        x_plot = np.linspace(min(sampling_sizes),max_sampling_size,500)
        
        if same_fits_all_dicts[func]:
            fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[0][func+"_errors"])[sampling_sizes<=max_sampling_size])
            lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color="grey", linestyle="--")
            
            if fit_params_text_bool:
                show_fit_params_text(ax=lax,Func=fit_func,color="grey",x_eq=x_eq,y_eq=y_eq)
        elif not (func == "correlation" and show_true_pearson_standard_error):
            for (pop_idx,color,y) in zip([0,1],["blue","red"],[y_eq-0.04,y_eq+0.04]):
                fit_func.fit(x = sampling_sizes[sampling_sizes<=max_sampling_size], y = np.array(all_dicts[pop_idx][func+"_errors"])[sampling_sizes<=max_sampling_size])

                lax.plot(x_plot, fit_func.func(x_plot, *fit_func.fit_params), color=color, linestyle="--",alpha=0.5)
                
                if fit_params_text_bool:
                    show_fit_params_text(ax=lax,Func=fit_func,alpha=0.7,color=color,x_eq=x_eq,y_eq=y,units=mapf.get_units(func))
        
        if row == 0:
            lax.plot([-100,-99],[np.mean(dic["anisotropy_errors"]),np.mean(dic["anisotropy_errors"])],\
                     linestyle="--",color="grey",label=fit_func.label) # just for the legend label
    
    if func == "correlation" and show_true_pearson_standard_error:
        for dic in all_dicts:
            lax.plot(x_plot,pearson_standard_error(x_plot,dic["correlation"]),color=dic["color"],alpha=0.5,linestyle="dotted")
        
        lax.plot([0],[0],color="k",linestyle="dotted",label=pearson_se_label) # for legend
    
    if True: # labels, legend, text
        if row == nrows-1:
            if brokenaxes:
#                 lax.set_xlabel("Sample size"); lax.text(s=r"Total $N$",x=1.1,y=-0.25,transform=lax.transAxes,size="medium")
                lax.text(s="Sample size",x=0.5,y=-0.25,transform=lax.transAxes,size="medium")
            else:
                lax.set_xlabel("Sample size");
        
        if xlog_bool and ylog_bool:
            x_func_text = 0.98
        elif xlog_bool:
            x_func_text = 0.6
        else:
            x_func_text = 0.2
        
        func_str = mapf.get_kinematic_titles_dict("r","l")[func.removesuffix("_abs")]
            
        if not brokenaxes:
            x_func_text -= 0.27 if xlog_bool else 0
        fig.text(s=func_str,x=x_func_text,y=0.83,transform=lax.transAxes,size="small")
        
        if row == 0:
            if brokenaxes:
                fig.legend(loc=(0.65,0.77) if not (xlog_bool and ylog_bool) else (0.16,0.7),framealpha=0 if xlog_bool and ylog_bool else 1)
            else:
                lax.legend(loc="lower left")
        elif row == 1:
            lax.set_ylabel(r"Standard error")
            
            if show_true_pearson_standard_error and func == "correlation":
                lax.legend(loc="lower left" if xlog_bool and ylog_bool else "upper right")
        
        fig.align_labels()
        
if True: # filename, save and show
    save_path = get_save_path_spatial_cuts(save_path=base_path, spatial_cuts_dict=all_dicts[0]["spatial_cuts"])

    if list(function_dict.keys()) == ["anisotropy","correlation","tilt_abs"]:
        filename = "anicorr"
    else:
        raise ValueError("Please specify map list name")

    filename += "_"+MF.extract_str_from_cuts_dict(dic["spatial_cuts"])

    for dic in all_dicts:
        filename += '_'+MF.extract_str_from_cuts_dict(dic["pop_cuts"])

    if xlog_bool and ylog_bool:
        filename += "_xylog"
    elif xlog_bool:
        filename += "_xlog"
    elif ylog_bool:
        filename += "_ylog"

    filename += f"_{repeats}repeats"
    filename += f"_{min(sampling_sizes)}size{max(sampling_sizes)}"
    filename += f"_{len(sampling_sizes)}steps" + ("Log" if logSampling else "")

    if fit_bool:
        if not same_youngold_fits:
            filename += "_diffFits"
        if not fit_params_text_bool:
            filename += "_noFitParams"
    else:
        filename += "_noFit"
        
    if MC_error_bool:
        filename += "_MC"
        
    if not brokenaxes:
        filename += "_noBroken"
        
    print(filename)
    
    if save_bool:
        print("Saving in",save_path)
        plt.savefig(save_path+filename+".png",dpi=200,bbox_inches="tight")
    plt.show()