In [None]:
import os
import sys
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

import numpy as np
import time as time
import glob
import random
import re
from collections import defaultdict
import pickle

import constants

# Set up matplotlib and use a nicer set of plot parameters
%config InlineBackend.rc = {}
import matplotlib as mpl
mpl.rc('mathtext',fontset='stixsans')
mpl.rc('figure', facecolor="white")
#matplotlib.rc_file("../../templates/matplotlibrc")
import matplotlib.pyplot as plt
#import matplotlib.colors as colors
# %matplotlib notebook

from plotting import plot_preamble
plot_preamble()

#VEGA_OUTPUT_DIR = os.path.join(constants.BIAS_DIR_BASE, 'xcorr', 'vega-realistic-dispersion', 'bias-dispersion', 'output')
VEGA_OUTPUT_DIR = os.path.join(constants.BIAS_DIR_BASE, 'xcorr', 'vega', 'output')

In [None]:
regex = re.compile(r'.*bestfit_survey_([^_]*)_mass_([^_]*)_([^_]*).*\.npz')

param_bootstrap_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

In [None]:
min_mass = np.inf

for p in glob.glob(f'{VEGA_OUTPUT_DIR}/*.npz'):
    match = regex.match(p)
    survey, logmass_lb, logmass_ub = match.group(1, 2, 3)
    logmass_lb, logmass_ub = float(logmass_lb), float(logmass_ub)
    params = np.load(p)
    for key in params.files:
        val = params[key]
        if val.ndim == 1:
            val = val[0]
        param_bootstrap_dict[survey][key][(logmass_lb, logmass_ub)].append(val)
        min_mass = min(logmass_lb, min_mass)

In [None]:
figure_folder = os.path.join(constants.FIG_DIR_BASE, 'bias-err')

with open(os.path.join(constants.BIAS_DIR_BASE, 'theoretical', 'sim_bias_binwidth_0.1_binstep_0.1.pickle'), 'rb') as f:
    theoretical_mass, theoretical_bias = pickle.load(f)
    theoretical_mass, theoretical_bias = np.array(theoretical_mass), np.array(theoretical_bias)
    mask = theoretical_mass > min_mass
    theoretical_mass = theoretical_mass[mask]
    theoretical_bias = theoretical_bias[mask]

for survey, param_dict in param_bootstrap_dict.items():
    fig, axes = plt.subplots(len(param_dict), 1, sharex=True, figsize=(5, 2 * len(param_dict)), gridspec_kw={'height_ratios': [1 for _ in range(len(param_dict) - 1)] + [0.3]})
    # if not isinstance(axes, list):
    #     axes = [axes]
    #axes[0].set_title(f'Bin width = 0.1 dex')
    for ax, (param_name, mass_dict) in zip(axes, param_dict.items()):
        bin_boundaries = set(sum(mass_dict.keys(), ()))
        bin_boundaries = sorted(list(bin_boundaries))
        x = []
        lower_bound = []
        median = []
        upper_bound = []
        for i in range(len(bin_boundaries) - 1):
            key = (bin_boundaries[i], bin_boundaries[i + 1])
            mass_list = mass_dict[key]
            try:
                assert mass_list
            except AssertionError:
                continue
            x.append(np.mean([bin_boundaries[i], bin_boundaries[i + 1]]))
            lower_bound.append(np.percentile(mass_list, 16))
            median.append(np.median(mass_list))
            upper_bound.append(np.percentile(mass_list, 84))
        lower_bound, upper_bound = np.array(lower_bound), np.array(upper_bound)
        if param_name == 'bias_QSO':
            ax.set_ylabel(r'$b_{g}$')
        elif param_name == 'reduced_chisq':
            # ax.axhline(1, color='black', ls='--')
            ax.set_ylabel(r'$\chi_r^2 - 1$')
            lower_bound -= 1
            upper_bound -= 1
            # ax.set_ylim(0, np.max(upper_bound))
        else:
            ax.set_ylabel(param_name)
        ax.fill_between(x, lower_bound, upper_bound, alpha=0.7, color='darkred', label=f'Vega best fits')
        if param_name == 'bias_QSO':
            #ax.axhline(1)
            with open(os.path.join(constants.BIAS_DIR_BASE, 'vega', f'hmass_bias_{survey}.pickle'), 'wb') as f:
                pickle.dump((x, median), f)
            #if survey == list(param_bootstrap_dict.keys())[0]:
            ax.plot(theoretical_mass, theoretical_bias, color='grey', label=r'Theoretical $b_{g}-M_h$', zorder=-99, alpha=0.5)
            ax.set_ylim(None, 15)
        if param_name == 'sigma_velo_disp_gauss_QSO' and survey == 'CLAMATO':
            ax.axhline(3, color='black', ls='--', label=f'{survey} True dispersion')
        if not param_name == 'reduced_chisq':
            ax.legend()
    plt.xlabel(r'$log_{10}(M_h\; / \;M_\odot)$')
    plt.tight_layout()
    plt.savefig(os.path.join(figure_folder, f'{survey}_err.pdf'))
    plt.savefig(os.path.join(figure_folder, f'{survey}_err.png'), transparent=False, dpi=400)
    plt.show()