In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import os
from tqdm import tqdm

from holodeck import plot, detstats, utils
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

Set up

In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = True
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

MEDIAN=False
AVG=True

Get param names

In [None]:
pspace = holo.param_spaces.PS_Uniform_09B(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)
param_names = pspace.param_names

Get data and dets file functions

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 
    load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/detstats_s{nskies}' 
    
    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    file = np.load(load_dets_from_file, allow_pickle=True)
    dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

In [None]:
def _twiny_hz(ax, nano=True, fs=10, label=True, **kw):
    tw = ax.twinx()
    tw.grid(False)
    ylim = np.array(ax.get_ylim()) / YR
    if nano:
        ylim *= 1e9

    tw.set(ylim=ylim, yscale=ax.get_yscale())
    if label:
        label = plot.LABEL_GW_FREQUENCY_NHZ if nano else plot.LABEL_GW_FREQUENCY_YR
        tw.set_ylabel(label, fontsize=fs, **kw)
    return tw

### Build arrays

In [None]:
if BUILD_ARRAYS:
    targets = [
        'gsmf_phi0',  'mmb_mamp_log10', 'hard_time',  
        'gsmf_mchar0_log10', 'mmb_scatter_dex', 'hard_gamma_inner']
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
    for target in tqdm(targets):
        print(target)
        xx = [] # param
        yy = [] # frequency means

        # white noise only
        data, params, dsdat = get_data(target)
        freqs = data[0]['fobs_cents']
        freqs = np.repeat(freqs, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST)

        for pp, par in enumerate(params):
            xx.append(params[pp][target])
            dpssi = dsdat[pp]['gamma_ssi']
            if MEDIAN:
                freq_mean = np.average(freqs, weights=dpssi, axis=(0,-1))
            elif AVG:
                freq_mean = np.average(freqs, weights=dpssi)
            yy.append(freq_mean)
        if MEDIAN:
            saveloc = path+f'/freq_means_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz'
        elif AVG:
            saveloc = path+f'/freq_avg_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz'
        np.savez(saveloc, xx_params = xx, yy_fmeans = yy)
# else:
#     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
#     file = np.load(path+f'/freq_means_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
#     xx = file['xx_params']
#     yy = file['yy_fmeans']
#     file.close()


In [None]:
colors = [
    '#336948', '#9e5c41', '#2d839f',
    '#336948', '#9e5c41', '#2d839f',
]
path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
   

targets = [
    'gsmf_phi0', 'gsmf_mchar0_log10', 
    'mmb_mamp_log10', 'mmb_scatter_dex',
    'hard_time', 'hard_gamma_inner'
    ]

ylabel = 'Frequency [yr$^{-1}$]'
ylabel_nHz = 'Frequency [nHz]'


fig, axs = plot.figax_double(nrows=2, ncols=3, sharey=True, sharex=False, xscale='linear')

fig.text(0.06, 0.5, ylabel, ha='right', va='center', rotation='vertical')
fig.text(1.02, 0.5, ylabel_nHz, ha='right', va='center', rotation='vertical')
plt.subplots_adjust(wspace=0)
for ii, ax in enumerate(axs.flatten()):
    ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
    # if ii == 0 or ii == 3:
    #     ax.set_ylabel(ylabel)
    file = np.load(path+f'/freq_avg_{targets[ii]}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
    xx = file['xx_params']
    yy = file['yy_fmeans']*YR
    # print(xx.shape, yy.shape)
    # print(f"{yy.shape=}")
    ax.plot(xx, yy, color=colors[ii])
    if ii==2 or ii==5:
        _twiny_hz(ax, label=False, nano=True)
        

saveloc = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
savename = saveloc+'/favg_double.png'
fig.savefig(savename, dpi=100, bbox_inches='tight')

In [None]:

targets = [
    'gsmf_phi0', 'gsmf_mchar0_log10', 
    'mmb_mamp_log10', 'mmb_scatter_dex',
    'hard_time', 'hard_gamma_inner']
colors = [
    '#336948', '#336948', 
    '#9e5c41', '#9e5c41', 
    '#2d839f', '#2d839f',
]
path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
   

ylabel = 'Frequency [yr$^{-1}$]'
ylabel_nHz = 'Frequency [nHz]'
fig, axs = plot.figax_single(nrows=3, ncols=2, 
                             sharey=True, sharex=False, xscale='linear',
                             height=7)


for ii, ax in enumerate(axs.flatten()):
    ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
    # if ii == 0 or ii == 3:
    #     ax.set_ylabel(ylabel)
    file = np.load(path+f'/freq_avg_{targets[ii]}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
    xx = file['xx_params']
    yy = file['yy_fmeans']*YR
    print(f"{yy.shape=}")
    ax.plot(xx, yy, color=colors[ii])
    if ii%2 == 1:
        _twiny_hz(ax, label=False)
# fig.tight_layout()
    # ax.tick_params(axis='x', labelrotation=45)

fig.text(0.02, 0.5, ylabel, ha='right', va='center', rotation='vertical')
fig.text(1.12, 0.5, ylabel_nHz, ha='right', va='center', rotation='vertical')
plt.subplots_adjust(wspace=0.05, hspace=0.35)
saveloc = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
savename = saveloc+'/favg_single.png'
fig.savefig(savename, dpi=100, bbox_inches='tight')

# Add GW only

In [None]:
if BUILD_ARRAYS:
    targets = [
        'gsmf_phi0',  'mmb_mamp_log10', 
        'gsmf_mchar0_log10', 'mmb_scatter_dex', ]
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
    for target in tqdm(targets):
        print(target)
        xx = [] # param
        yy = [] # frequency means

        # white noise only
        data, params, dsdat = get_data(target, path = '/Users/emigardiner/GWs/holodeck/output/anatomy_7GW')
        freqs = data[var]['fobs_cents']
        freqs = np.repeat(freqs, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST)

        for pp, par in enumerate(params):
            xx.append(params[pp][target])
            dpssi = dsdat[pp]['gamma_ssi']
            if MEDIAN:
                freq_mean = np.average(freqs, weights=dpssi, axis=(0,-1))
            elif AVG:
                freq_mean = np.average(freqs, weights=dpssi)
            yy.append(freq_mean)
        if MEDIAN:
            saveloc = path+f'/freq_means_gw_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz'
        elif AVG:
            saveloc = path+f'/freq_avg_gw_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz'

        np.savez(saveloc,  xx_params = xx, yy_fmeans = yy)
# else:
#     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_7GW/figdata'  
#     file = np.load(path+f'/freq_means_gw_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
#     xx_gw = file['xx_params']
#     yy_gw = file['yy_fmeans']
#     file.close()


In [None]:

targets = [
    'gsmf_phi0', 'gsmf_mchar0_log10', 
    'mmb_mamp_log10', 'mmb_scatter_dex',
    'hard_time', 'hard_gamma_inner']
colors = [
    '#336948', '#336948', 
    '#9e5c41', '#9e5c41', 
    '#2d839f', '#2d839f',
]
path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'  
   

ylabel = 'Frequency [yr$^{-1}$]'
ylabel_nHz = 'Frequency [nHz]'
fig, axs = plot.figax_single(nrows=3, ncols=2, 
                             sharey=True, sharex=False, xscale='linear',
                             height=7)


for ii, ax in enumerate(axs.flatten()):
    ax.set_xlabel(plot.PARAM_KEYS[targets[ii]])
    # if ii == 0 or ii == 3:
    #     ax.set_ylabel(ylabel)
    file = np.load(path+f'/freq_avg_{targets[ii]}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
    xx = file['xx_params']
    yy = file['yy_fmeans']*YR
    print(f"{yy.shape=}")

    ax.plot(xx, yy, color=colors[ii], alpha=0.9)
    # for pp in [50,95]:
    #     percs = [50-pp/2, 50+pp/2]
    #     ax.fill_between(xx, *np.percentile(yy, percs, axis=-1), color=colors[ii], alpha=0.25)
    # plot.draw_med_conf_color(ax, xx, yy.reshape(NVARS, NREALS*NSKIES), color=colors[ii])

    print(f"{yy.shape=}")

    if ii<4:
        file = np.load(path+f'/freq_avg_gw_{targets[ii]}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz')
        xx_gw = file['xx_params']
        yy_gw = file['yy_fmeans']*YR
    elif ii==4: # hard_time, get mean, hard_gamma_inner uses same
        yy_gw = np.repeat(yy_gw[10], NVARS)

    ax.plot(xx, yy_gw, color='k', alpha=0.75, linestyle='--')
    # for pp in [50,]:
    #     percs = [50-pp/2, 50+pp/2]
    #     ax.fill_between(xx, *np.percentile(yy_gw, percs, axis=-1), color='k', alpha=0.15, linestyle='--')
    
    if ii%2 == 1:
        _twiny_hz(ax, label=False)

    # print min and max freq
    print(f"{np.min(yy)=}, {np.max(yy)=}, {np.min(yy_gw)=}, {np.max(yy_gw)=}")
# fig.tight_layout()
    # ax.tick_params(axis='x', labelrotation=45)

fig.text(0.04, 0.5, ylabel, ha='right', va='center', rotation='vertical')
fig.text(1.09, 0.5, ylabel_nHz, ha='right', va='center', rotation='vertical')
plt.subplots_adjust(wspace=0.05, hspace=0.35)
saveloc = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
savename = saveloc+'/favg_gw_single.png'
fig.savefig(savename, dpi=100, bbox_inches='tight')