In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py


from holodeck import plot, detstats
import holodeck.single_sources as ss
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

# Choose a Parameter Space

In [None]:
SHAPE = (30,25,35)
NREALS = 30
NFREQS = 40
NLOUDEST = 10

 # construct a param_space instance, note that `nsamples` and `seed` don't matter here for how we'll use this
pspace = holo.param_spaces.PS_Uniform_09A(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)

# get the parameter names from this library-space
param_names = pspace.param_names
num_pars = len(pspace.param_names)
print(f"{num_pars=} :: {param_names=}")

# choose each parameter to be half-way across the range provided by the library
pars = 0.5 * np.ones(num_pars)
params = pars * pspace.param_samples
print(f"{pars=}")
print(f"{params=}")

# construct `sam` and `hard` instances based on these parameters,
# using otherwise all default parameters for this library
sam, hard = pspace.model_for_normalized_params(pars)
# sam, hard = pspace.model_for_params(pspace.normalized_params(pars)) #this is way slower, but why??

# run this model, retrieving binary parameters and the GWB
data = holo.librarian.run_model(sam, hard, NREALS, NFREQS, NLOUDEST, 
                                gwb_flag=False, singles_flag=True, params_flag=True, details_flag=True)
print(f"retrieved data: {data.keys()=}")

Question: What are bin_params and gwb_params?

In [None]:
fobs_cents = data['fobs_cents']
hc_ss_mid = data['hc_ss']
hc_bg_mid = data['hc_bg']
sspar_mid = ss.all_sspars(fobs_cents, data['sspar'])
bgpar_mid = data['bgpar']
binpar_names = ss.par_names
# params = pars*pspace.params
print(f"{hc_ss_mid.shape=}")
print(f"{bgpar_mid.shape=}")
print(f"{sspar_mid.shape=}")
print(f"{param_names=}")
print(f"{params=}")
print(f"{binpar_names=}")

# Plot hc vs. bin pars

In [None]:
print(ss.par_labels)
print(ss.par_units)

In [None]:
def draw_hs_vs_par(ax, xx_ss=None, yy_ss=None, xx_bg=None, yy_bg=None, color_ss='r', color_bg='k', fast_ss=True, 
                   show_medians = False, show_ci=False, show_reals=True):
    if show_reals:
        if (xx_ss is not None) and (yy_ss is not None):
            if fast_ss:
                ax.scatter(xx_ss.flatten(), yy_ss.flatten(), marker='o', s=15, alpha=0.1, color=color_ss)
            else:
                colors = cm.rainbow(np.linspace(0,1,len(yy_ss[0])))
                for rr in range(len(yy_ss[0])):
                    ax.scatter(xx_ss[:,rr,:].flatten(), yy_ss[:,rr,:].flatten(), marker='o', s=10, alpha=0.1, color=colors[rr])
        if (xx_bg is not None) and (yy_bg is not None):
            ax.scatter(xx_bg.flatten(), yy_bg.flatten(), marker='x', s=15, alpha=0.1, color=color_bg)
    # if show_medians:
    #     if (xx_ss is not None) and (yy_ss is not None):
    #         ax.plot(np.median(xx_ss, axis=())
    #     if (xx_bg is not None) and (yy_bg is not None):




def plot_hs_vs_binpars(fobs_cents, hc_ss, hc_bg, sspar, bgpar, color_ss='r', color_bg='k', fast_ss=True):
    """ plot mtot (0), mrat (1), redz_init (2), dc_final (4), sepa_final(5), angs_final(6)"""
    colors = cm.rainbow(np.linspace(0,1,len(hc_ss[0])))
    idx = [0,1,2,4,5,6]

    labels = ss.par_labels[idx]
    units = ss.par_units[idx]
    xx_ss = sspar[idx]*units[:,np.newaxis,np.newaxis,np.newaxis]
    xx_bg = bgpar[idx]*units[:,np.newaxis,np.newaxis]
    print(f"{xx_ss.shape=}")
    print(f"{xx_bg.shape=}")

    yy_ss = hc_ss
    yy_bg = hc_bg
    print(f"{yy_ss.shape=}")
    print(f"{yy_bg.shape=}")


    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, sharey=True, figsize=(12,6))
    for ax in axs[:,0]:
        ax.set_ylabel(holo.plot.LABEL_CHARACTERISTIC_STRAIN)
    for ii, ax in enumerate(axs.flatten()):
        ax.set_xlabel(labels[ii])
        draw_hs_vs_par(ax, xx_ss[ii], yy_ss, xx_bg[ii], yy_bg, color_ss, color_bg, fast_ss, colors)
        
    fig.tight_layout()
    return fig

fig = plot_hs_vs_binpars(fobs_cents, hc_ss_mid, hc_bg_mid, sspar_mid, bgpar_mid, fast_ss=True)

# Calculate detstats

In [None]:
detstats.detect_bg_pta()

# Plot everything vs freqs

In [None]:
arr = np.random.uniform(0,1,100).reshape(5,20)
print(arr.shape)
app = np.random.uniform(0,1, 20)
tot = np.append(arr, app).reshape(6,20)
print(tot.shape)
print(np.all(arr[2] == tot[2]))

In [None]:
def plot_everything_vs_freqs(fobs_cents, hc_ss, hc_bg, sspar, bgpar, dp_ss, dp_bg, df_ss, df_bg,
                             color_ss='r', color_bg='k', fast_ss=True):
    """ plot mtot (0), mrat (1), redz_init (2), dc_final (4), sepa_final(5), angs_final(6)"""
    colors = cm.rainbow(np.linspace(0,1,len(hc_ss[0])))
    idx = [0,1,2,4,5,6]

    labels = ss.par_labels[idx]
    units = ss.par_units[idx]
    xx_ss = sspar[idx]*units[:,np.newaxis,np.newaxis,np.newaxis] # shape 9, F,R,L
    xx_ss = np.append(xx_ss, np.array([hc_ss, dp_ss, dp_bg])
    xx_ss = np.append(xx_ss, dp_ss)
    xx_bg = bgpar[idx]*units[:,np.newaxis,np.newaxis] # shape 9,F,R,L
    print(f"{xx_ss.shape=}")
    print(f"{xx_bg.shape=}")

    yy_ss = hc_ss
    yy_bg = hc_bg
    print(f"{yy_ss.shape=}")
    print(f"{yy_bg.shape=}")


    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, sharey=True, figsize=(12,6))
    for ax in axs[:,0]:
        ax.set_ylabel(holo.plot.LABEL_CHARACTERISTIC_STRAIN)
    for ii, ax in enumerate(axs.flatten()):
        ax.set_xlabel(labels[ii])
        draw_hs_vs_par(ax, xx_ss[ii], yy_ss, xx_bg[ii], yy_bg, color_ss, color_bg, fast_ss, colors)
        
    fig.tight_layout()
    return fig