In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import matplotlib.pyplot as plt

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import tqdm

import holodeck as holo
from holodeck import utils, plot, cosmo
from holodeck.constants import YR, MSOL
from holodeck.sams import cyutils as sam_cyutils

mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'


In [None]:
NUM_FBINS = 14
PTA_DUR = 16.03     # [yrs]
NREALS = 1000
NLOUDEST = 5
SHAPE = None

space_class = holo.param_spaces.PS_Uniform_07A
space = space_class(holo.log, 0, SHAPE, None)

dur = PTA_DUR * YR
pta_cad = dur / (2 * NUM_FBINS)
fobs_gw_edges = holo.utils.nyquist_freqs_edges(dur, pta_cad)

draw_params = [
    [ 4.57784231, -1.51291368, 10.90450461,  8.85735088,  0.52998213],
    [ 6.96691128, -2.38054765, 11.08484247,  9.29421616,  0.45471499],
    [ 0.37454247, -2.10501603, 11.62087475,  8.84882612,  0.09778727]
]

In [None]:
def run_single_params_model_cython(space, params, fobs_gw_edges, nreals=NREALS, nloudest=NLOUDEST):
    param_names = space.param_names
    pars = {name: params[pp] for pp, name in enumerate(param_names)}
    print('pars:', pars)

    sam, hard = space.model_for_params(pars, sam_shape=space.sam_shape)
    fobs_gw_cents = utils.midpoints(fobs_gw_edges)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    redz_final, diff_num = sam_cyutils.dynamic_binary_number_at_fobs(
            fobs_orb_cents, sam, hard, cosmo
        )
    
    print('Calculating cython number')
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = sam_cyutils.integrate_differential_number_3dx1d(edges, diff_num)

    print('Calculating cython ss')
    hc_ss, hc_bg = holo.single_sources.ss_gws_redz(edges, redz_final, number,
                                              realize=nreals, loudest=nloudest, params=False)
    
    print('Calculting cython gwb')
    gwb = holo.gravwaves._gws_from_number_grid_integrated_redz(edges, redz_final, number, nreals)

    data = dict(fobs=fobs_gw_cents, fobs_edges = fobs_gw_edges, gwb=gwb,
                hc_ss=hc_ss, hc_bg=hc_bg)
    
    return data

data = run_single_params_model_cython(space, draw_params[0], fobs_gw_edges)

In [None]:
# np.savez('/Users/emigardiner/GWs/holodeck/output/random/test_ss_gwb_data.npz', 
#          fobs = data['fobs'], fobs_edges=data['fobs_edges'], gwb=data['gwb'],
#          hc_ss=data['hc_ss'], hc_bg=data['hc_bg'])

Compare results from sam.gwb(), labeled SS, and sam.new_gwb(), labeled GWB. Both use numerator 

In [None]:
fig, axs = plot.figax(
    figsize=(8,8), nrows=2, ncols=1, 
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN
    )

xx = data['fobs'] * YR

y1 = np.sqrt(np.sum(data['hc_ss']**2, axis=-1) + data['hc_bg']**2)
y2 = data['gwb']

labels = ['SS', 'GW']
for jj, yy in enumerate([y1, y2]):
    med, *conf = np.percentile(yy, [50, 25, 75], axis=-1)
    cc, = axs[0].plot(xx, med, alpha=0.5, label=labels[jj])
    cc = cc.get_color()
    axs[0].fill_between(xx, *conf, color=cc, alpha=0.1)

axs[1].plot(xx, np.median(y1, axis=-1)/np.median(y2, axis=-1), 'k--', label='SS/GW')
        

axs[1].set_ylabel('ratio')
for ax in axs:
    ax.legend()
    plot._twin_hz(ax)
    
plt.show()

In [None]:
fig, axs = plot.figax(
    figsize=(8,8), nrows=2, ncols=1, 
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN
    )

xx = data['fobs'] * YR

y1 = np.sqrt(np.sum(data['hc_ss']**2, axis=-1) + data['hc_bg']**2)
y2 = data['gwb']


labels = ['SS', 'GW']
for jj, yy in enumerate([y1, y2]):
    yy_p10, yy_p25, yy_med, yy_p75, yy_p90 = np.percentile(yy, [10, 25, 50, 75, 90], axis=-1)

    cc, = axs[0].plot(xx, med, alpha=0.5, label=labels[jj])
    cc = cc.get_color()

    axs[0].fill_between(xx, yy_p25, yy_p75, color=cc, alpha=0.2, linestyle='--')
    axs[0].fill_between(xx, yy_p10, yy_p90, color=cc, alpha=0.1, linestyle=':')
    
    # axs[0].fill_between(xx, *conf, color=cc, alpha=0.1)

y1_p10, y1_p25, y1_p50, y1_p75, y1_p90 = np.percentile(y1, [10, 25, 50, 75, 90], axis=-1)
y2_p10, y2_p25, y2_p50, y2_p75, y2_p90 = np.percentile(y2, [10, 25, 50, 75, 90], axis=-1)
axs[1].plot(xx, y1_p50/y2_p50, 'k-', label='median SS/GW')
axs[1].plot(xx, y1_p10/y2_p10, 'c:', label='10th percentile')
axs[1].plot(xx, y1_p25/y2_p25, 'b--', label='25th percentile')
axs[1].plot(xx, y1_p75/y2_p75, 'r--', label='75th percentile')
axs[1].plot(xx, y1_p90/y2_p90, 'm:', label='90th percentile')    

axs[1].set_ylabel('ratio')
for ax in axs:
    ax.legend()
    plot._twin_hz(ax)
    
plt.show()