In [None]:
# %load ../../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath

In [None]:
import holodeck.sam_simple

In [None]:
fobs_yr = 1.0 / YR

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
gwb_simple = sam_simple.gwb(fobs_yr)
print(gwb_simple)

In [None]:
# gsmf = holo.sam.GSMF_Schechter()
# gpf = holo.sam.GPF_Power_Law()
# gmt = holo.sam.GMT_Power_Law()
mmbulge = holo.relations.MMBulge_Standard(
    mamp=sam_simple._mbh_star, mplaw=sam_simple._alpha_mbh_star, mref=1e11*MSOL
)
# sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=100)
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
gwb = sam.gwb_ideal(fobs_yr)
print(gwb)

In [None]:
gwb, gwb_simple, gwb/gwb_simple

In [None]:
NUM = 100
mgal = MSOL * (10.0 ** np.random.uniform(8, 12, NUM))
qgal = 10.0 ** np.random.uniform(-3.0, 0.0, NUM)
redz = zmath.random_power([0.01, 1.5], 2.0, NUM)

In [None]:
def check(mm, qq, zz, check, simple):
    err = (check - simple) / np.min([check, simple], axis=0)
    isclose = np.isclose(check, simple, rtol=1e-2, atol=0.0)

    print(zmath.str_array(check))
    print(zmath.str_array(simple))
    print(zmath.stats_str(check))
    print(zmath.stats_str(simple))
    print(zmath.str_array(err))
    print(zmath.stats_str(err))
    print(isclose)

    fig, axes = plot.figax(ncols=3)
    ax = axes[0]
    xx = mm
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    ax = axes[1]
    xx = qq
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    ax = axes[2]
    xx = zz
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    return isclose
    
def frac_diff(v1, v2):
    ee = (v2 - v1)/np.min([v1, v2], axis=0)
    return ee

## GSMF

In [None]:
gsmf_check = sam._gsmf(mgal, redz)
gsmf_simple = sam_simple.gsmf(mgal, redz)
err = (gsmf_check - gsmf_simple) / np.min([gsmf_check, gsmf_simple], axis=0)
print(np.isclose(gsmf_check, gsmf_simple, rtol=1e-2))
# print(err)

fig, axes = plot.figax(ncols=2)
ax = axes[0]
ax.scatter(mgal, gsmf_check, color='r', alpha=0.5, marker='+')
ax.scatter(mgal, gsmf_simple, color='b', alpha=0.5, marker='x')

ax = axes[1]
ax.scatter(qgal, gsmf_check, color='r', alpha=0.5, marker='+')
ax.scatter(qgal, gsmf_simple, color='b', alpha=0.5, marker='x')
plt.show()

## GPF

In [None]:
gpf_check = sam._gpf(mgal, qgal, redz)
gpf_simple = sam_simple.gpf(mgal, qgal, redz)

check(mgal, qgal, redz, gpf_check, gpf_simple)
plt.show()

## GMT

In [None]:
gmt_check = sam._gmt(mgal, qgal, redz)
gmt_simple = sam_simple.gmt(mgal, qgal, redz)
check(mgal, qgal, redz, gmt_check, gmt_simple)
plt.show()

## Galaxy NDens

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
n1 = sam_simple.ndens_galaxy(mgal, qgal, redz)
n2 = sam_simple._ndens_galaxy_check(mgal, qgal, redz)

check(mgal, qgal, redz, n1, n2)
plt.show()

In [None]:
sam = holo.sam.Semi_Analytic_Model()
ndg_check = sam._ndens_gal(mgal, qgal, redz)
ndg_simple = sam_simple.ndens_galaxy(mgal, qgal, redz)

uu_check = ndg_check; vv_simple = ndg_simple
check(mgal, qgal, redz, uu_check, vv_simple)
plt.show()

# MBH NDens

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

uu = sam._ndens_mbh(mgal, qgal, redz)
vv = sam_simple.ndens_mbh(mgal, qgal, redz)

close = check(mgal, qgal, redz, uu, vv)
if not all(close):
    err_msg = "MBH Number-Density mismatch between SAM and Simple_SAM!"
    raise ValueError(err_msg)

plt.show()

In [None]:
sam._density = None
ndens_sam = sam.static_binary_density
mbh_tot, mbh_rat, redz = np.copy(sam.grid)
mbh_pri, mbh_sec = utils.m1m2_from_mtmr(mbh_tot, mbh_rat)
mst_pri, mst_sec = [mmbulge.mstar_from_mbh(_mbh, scatter=False) for _mbh in [mbh_pri, mbh_sec]]
mst_rat = mst_sec/mst_pri

ndens_ref = sam._ndens_mbh(mst_pri, mst_rat, redz)

err = frac_diff(ndens_sam, ndens_ref)
print(zmath.stats_str(ndens_sam))
print(zmath.stats_str(ndens_ref))
print("err=", zmath.stats_str(err))

if not np.allclose(ndens_sam, ndens_ref, rtol=1e-6, atol=0.0):
    err_msg = f"sam mbh ndens does not match consistency check || error too large!"
    raise ValueError(err_msg)


## GWB Ideal

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

hc1 = sam_simple.gwb_sam(fobs_yr, sam, dlog10=False)
hc2 = sam_simple.gwb_sam(fobs_yr, sam, dlog10=True)
print(hc1, hc2, hc1/hc2)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

plt.show()

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()

mg = sam_simple.mass_gal[:, np.newaxis, np.newaxis]
qg = sam_simple.mrat_gal[np.newaxis, :, np.newaxis]
rz = sam_simple.redz[np.newaxis, np.newaxis, :]
mtot = sam_simple.mbh[:, :, np.newaxis]

ndens_a1 = sam_simple.ndens_mbh(mg, qg, rz, dlog10=False)
ndens_b1 = ndens_a1 * mtot * np.log(10.0)
ndens_b2 = sam_simple.ndens_mbh(mg, qg, rz, dlog10=True)
ndens_a2 = ndens_b2 / (mtot * np.log(10.0))

print("dlog10(M)")
print(utils.stats(ndens_b1, prec=4))
print(utils.stats(ndens_b2, prec=4))
err = frac_diff(ndens_b1, ndens_b2)
print(utils.stats(err, prec=2))

print("dM")
print(utils.stats(ndens_a1, prec=4))
print(utils.stats(ndens_a2, prec=4))
err = frac_diff(ndens_a1, ndens_a2)
print(utils.stats(err[err > 0], prec=2))

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()

hc1 = sam_simple.gwb(fobs_yr, dlog10=True)
hc2 = sam_simple.gwb(fobs_yr, dlog10=False)
print(hc1, hc2, hc1/hc2)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

# check(mgal, qgal, redz, uu, vv)
plt.show()

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

hc1 = sam_simple.gwb(fobs_yr)
hc2 = sam_simple.gwb_sam(fobs_yr, sam)
print(hc1, hc2, hc1/hc2)
if not np.isclose(hc1, hc2, rtol=1e-5, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

# check(mgal, qgal, redz, uu, vv)
plt.show()

In [None]:
sam_simple = holo.sam_simple.Simple_SAM()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

hc1 = sam.gwb_ideal(fobs_yr)
hc2 = sam_simple.gwb(fobs_yr)

print(hc1, hc2, hc1/hc2)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err_msg = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err_msg)


# Convergence Tests

In [None]:
size = [10, 20, 40, 100, 200, 400]
print(size)
gwb_reg = np.zeros(len(size))
gwb_log = np.zeros(len(size))
for ii, ss in enumerate(tqdm.tqdm_notebook(size)):
    ss = int(ss)
    print(ii, ss)
    sam_simp = holo.sam_simple.Simple_SAM(size=ss)
    gwb_reg[ii] = sam_simp.gwb(fobs_yr, dlog10=False)
    gwb_log[ii] = sam_simp.gwb(fobs_yr, dlog10=True)

In [None]:
fig, ax = plot.figax()
truth = 0.5 * (gwb_reg[-1] + gwb_log[-1])
# truth = gwb_log[-1]
ax.plot(size, np.fabs(gwb_reg - truth) / truth, 'r-', alpha=0.5, label='reg')
ax.plot(size, np.fabs(gwb_log - truth) / truth, 'b--', alpha=0.5, label='log')
ax.legend()
plt.show()