In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
from tqdm import tqdm


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = True
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

SHOW_GW=True

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz', ssn='_ssn',     
):
    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 
    load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/detstats_s{nskies}{ssn}' 

    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()
    print(target, "got data")
    file = np.load(load_dets_from_file, allow_pickle=True)
    print(target, "loaded dets")
    print(file.files)
    dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

def get_ratio_arrays(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, debug=False,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata', ssn='_ssn', 
    red=False,     
    ):
    filename = path+f'/ratio_arrays_{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}{ssn}.npz'
    file = np.load(filename)
    if debug: print(f"{filename}\n{file.files}")
    xx = file['xx_params']
    y0p0 = file['yy_ratio']
    if red:
        y1p5 = file['y1p5_ratio']
        y3p0 = file['y3p0_ratio']
    file.close()
    if red:
        return xx, y0p0, y1p5, y3p0, #y1p5, y3p0
    else:
        return xx, y0p0, 

In [None]:
path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata' 
targets = [
    # 'gsmf_phi0',
    # 'gsmf_mchar0_log10',
    'mmb_mamp_log10',
    # 'mmb_scatter_dex',
    # 'hard_time', 
    # 'hard_gamma_inner',
    ]
ssn = '_ssn'
for target in tqdm(targets):
    data, params, dsdat = get_data(target,)
    xx = []
    yy = []
    for pp, par in enumerate(tqdm(params)):
        xx.append(params[pp][target])
        dp_bg = np.repeat(dsdat[pp]['dp_bg'], NSKIES).reshape(NREALS, NSKIES)
        dp_ss = dsdat[pp]['ev_ss']
        yy.append(dp_ss/dp_bg)
    np.savez(path+f'/ratio_arrays_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}{ssn}.npz',
                xx_params = xx, yy_ratio=yy)


In [None]:
ylabel = plot.LABEL_EVDP_RATIO
targets = [
    'gsmf_phi0', 'gsmf_mchar0_log10',  
    'mmb_mamp_log10',  'mmb_scatter_dex', 
    'hard_time',  'hard_gamma_inner'
    ]

col0p0 = [
    '#336948', '#336948', 
    '#9e5c41', '#9e5c41', 
    '#2d839f', '#2d839f',
    ]

col_gw = [
    'k', 'k', 
    'k', 'k', 
    'k', 'k'
    ]

# set which arrays are using ssn
ssn_arr = [
    '_ssn', '_ssn',
    '_ssn', '',
    '_ssn', '_ssn'
]

ylims = np.array(
    [[1.5e-3, 1.5e1], [1.5e-3, 9e0], [1.5e-3, 7e1],])
print(ylims, ylims.shape) # 3,2
ylims = np.repeat(ylims, 2).reshape(3,2,2) # 3,2,2
print(ylims, ylims.shape)
ylims = np.swapaxes(ylims, 1, 2).reshape(6,2)
print(ylims, ylims.shape)
# ylims = np.swapaxes(ylims, 0, 1).reshape(6,2)
print(ylims, ylims.shape)

fig, axs = plot.figax_single(
    nrows=3, ncols=2, sharey=False, sharex=False, xscale='linear', height=7)

fig.text(0.04, 0.5, ylabel, ha='left', va='center', rotation=90)
plt.subplots_adjust(wspace=0.05, hspace=0.35)


for ii, ax in enumerate(axs.flatten()):
    xx, yy = get_ratio_arrays(targets[ii], red=False, ssn=ssn_arr[ii])
    ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
    ax.tick_params(axis='y', which='both', right=True, left=True, direction='in')

    for pp in [50,95]:
        med, *conf = np.percentile(yy, [50, 50-pp/2, 50+pp/2], axis=(1,2))
        ax.plot(xx, med, alpha=0.9, color=col0p0[ii])
        ax.fill_between(xx, *conf, color=col0p0[ii], alpha=0.25)
    ax.set_ylim(ylims[ii])
    if ii%2==1:
        ax.set_yticklabels([])
