In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
from tqdm import tqdm


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


Set up

In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = True
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

Get param names

In [None]:
pspace = holo.param_spaces.PS_Uniform_07_GW(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)
param_names = pspace.param_names

Get data and dets file functions

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_7GW'     
):
    # if path == '/Users/emigardiner/GWs/holodeck/output/anatomy_redz':
    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 
    load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/detstats_s{nskies}' 
    # else:
    #     load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz' 
    #     load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}_ds' 
    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    file = np.load(load_dets_from_file, allow_pickle=True)
    dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

def get_ratio_arrays(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_7GW/figdata', red=False,     
    ):
    filename = path+f'/ratio_arrays_{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz'
    file = np.load(filename)
    xx = file['xx_params']
    y0p0 = file['yy_ratio']
    if red:
        y1p5 = file['y1p5_ratio']
        y3p0 = file['y3p0_ratio']
    if red:
        return xx, y0p0, y1p5, y3p0
    else:
        return xx, y0p0

In [None]:
if BUILD_ARRAYS:
    targets = [
        'gsmf_phi0',  'mmb_mamp_log10', 
        'gsmf_mchar0_log10', 'mmb_scatter_dex', ]
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_7GW/figdata'  
    for target in tqdm(targets):
        print(target)

        # white noise only
        data, params, dsdat = get_data(target)
        xx=[]
        yy=[]
        for pp, par in enumerate(params):
            xx.append(params[pp][target])
            dp_bg = np.repeat(dsdat[pp]['dp_bg'], NSKIES).reshape(NREALS, NSKIES)
            dp_ss = dsdat[pp]['ev_ss']
            yy.append(dp_ss/dp_bg)

        # # red_gamma = -1.5
        # data, params, dsdat = get_data(target, red_gamma=-1.5, red2white=1)
        # y1p5=[]
        # for pp, par in enumerate(params):
        #     dp_bg = np.repeat(dsdat[pp]['dp_bg'], NSKIES).reshape(NREALS, NSKIES)
        #     dp_ss = dsdat[pp]['ev_ss']
        #     y1p5.append(dp_ss/dp_bg)

        # # red_gamma = -3.0
        # data, params, dsdat = get_data(target, red_gamma=-3.0, red2white=1)
        # y3p0=[]
        # for pp, par in enumerate(params):
        #     dp_bg = np.repeat(dsdat[pp]['dp_bg'], NSKIES).reshape(NREALS, NSKIES)
        #     dp_ss = dsdat[pp]['ev_ss']
        #     y3p0.append(dp_ss/dp_bg)

            
        np.savez(path+f'/ratio_arrays_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz', 
                xx_params = xx, yy_ratio = yy, 
         ) #y1p5_ratio = y1p5, y3p0_ratio=y3p0)

In [None]:
# file = np.load('/Users/emigardiner/GWs/holodeck/output/anatomy_7GW/mmb_scatter_dex_v21_r500_shapeNone/detstats_s100_r2w1.0_rg-3.0.npz')
# print(file.files)

In [None]:
ylabel = '$EV_\mathrm{SS} / DP_\mathrm{BG}$'
targets = [
    'gsmf_phi0',  'gsmf_mchar0_log10',
    'mmb_mamp_log10', 'mmb_scatter_dex', 
    ]
col0p0 = [
    '#336948',  '#336948',
    '#9e5c41',  '#9e5c41',
]
# col1p5 = [
#     '#4da169', '#e67739', 
#     '#4da169', '#e67739', 
# ]
# col3p0 = [
#     '#8fcf91', '#fda363', 
#     '#8fcf91', '#fda363', 
# ]

fig, axs = plot.figax_single(nrows=2, ncols=2, sharey=True, sharex=False, xscale='linear',
                             height=5)
for ii, ax in enumerate(axs.flatten()):
    ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
    if ii == 0 or ii == 2:
        ax.set_ylabel(ylabel)
    xx, yy0p0, = get_ratio_arrays(targets[ii], red=False)

    # xx, yy0p0, yy1p5, yy3p0 = get_ratio_arrays(targets[ii])
    col=col0p0[ii]
    for pp in [50, 95]:
    # for pp in [50]: 
        med, *conf = np.percentile(yy0p0, [50, 50-pp/2, 50+pp/2], axis=(1,2))
        ax.plot(xx, med, alpha=0.9, color=col0p0[ii])
        ax.fill_between(xx, *conf, color=col0p0[ii], alpha=0.25)

        # med, *conf = np.percentile(yy1p5, [50, 50-pp/2, 50+pp/2], axis=(1,2))
        # ax.plot(xx, med, alpha=0.9, color=col1p5[ii])
        # ax.fill_between(xx, *conf, color=col1p5[ii], alpha=0.25)

        # med, *conf = np.percentile(yy3p0, [50, 50-pp/2, 50+pp/2], axis=(1,2))
        # ax.plot(xx, med, alpha=0.9, color=col3p0[ii])
        # ax.fill_between(xx, *conf, color=col3p0[ii], alpha=0.25)

# handles = [mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=0.0$', color = col0p0[0]),
#            mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=-1.5$', color = col1p5[0]),
#            mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=-3.0$', color = col3p0[0])]
# axs[0,0].legend(handles=handles, loc='upper left', ncol=2)
plt.subplots_adjust(wspace=0.05, hspace=0.32)
# fig.tight_layout()

if SAVEFIG:
    savename = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
    # savename = savename+'/ratio_vs_var6_r2w_pp50.png'
    savename = savename+f'/ratio_vs_GWvar{NVARS}_wn_pp5095.png'
    fig.savefig(savename, dpi=100)

In [None]:
# ylabel = '$EV_\mathrm{SS} / DP_\mathrm{BG}$'
# targets = [
#     'gsmf_phi0',  'mmb_mamp_log10', 
#     'gsmf_mchar0_log10', 'mmb_scatter_dex', 
#     ]
# col0p0 = [
#     '#336948', '#9e5c41', 
#     '#336948', '#9e5c41', 
# ]
# col1p5 = [
#     '#4da169', '#e67739', 
#     '#4da169', '#e67739', 
# ]
# col3p0 = [
#     '#8fcf91', '#fda363', 
#     '#8fcf91', '#fda363', 
# ]

# fig, axs = plot.figax_double(nrows=2, ncols=3, sharey=True, sharex=False, xscale='linear')
# for ii, ax in enumerate(axs.flatten()):
#     ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
#     if ii == 0 or ii == 3:
#         ax.set_ylabel(ylabel)
#     xx, yy0p0, yy1p5, yy3p0 = get_ratio_arrays(targets[ii])
#     col=col0p0[ii]
#     # for pp in [50, 95]:
#     for pp in [68]: 
#         med, *conf = np.percentile(yy0p0, [50, 50-pp/2, 50+pp/2], axis=(1,2))
#         ax.plot(xx, med, alpha=0.9, color=col0p0[ii])
#         ax.fill_between(xx, *conf, color=col0p0[ii], alpha=0.25)

#         med, *conf = np.percentile(yy1p5, [50, 50-pp/2, 50+pp/2], axis=(1,2))
#         ax.plot(xx, med, alpha=0.9, color=col1p5[ii])
#         ax.fill_between(xx, *conf, color=col1p5[ii], alpha=0.25)

#         med, *conf = np.percentile(yy3p0, [50, 50-pp/2, 50+pp/2], axis=(1,2))
#         ax.plot(xx, med, alpha=0.9, color=col3p0[ii])
#         ax.fill_between(xx, *conf, color=col3p0[ii], alpha=0.25)

# handles = [mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=0.0$', color = col0p0[0]),
#            mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=-1.5$', color = col1p5[0]),
#            mpl.lines.Line2D([0], [0], label='$\gamma_\mathrm{red}=-3.0$', color = col3p0[0])]
# axs[0,0].legend(handles=handles, loc='upper left', ncol=2)

# plt.subplots_adjust(wspace=0)

# # fig.tight_layout()

# if SAVEFIG:
#     savename = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
#     savename = savename+f'/ratio_vs_GWvar{NVARS}_r2w_pp50_redzfixed.png'
#     # savename = savename+'/ratio_vs_var6_wn_pp5095.png'
#     fig.savefig(savename, dpi=100)