In [None]:
# %load ../../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath

In [None]:
import holodeck.simple_sam

In [None]:
fobs_yr = 1.0 / YR

In [None]:
sam_simple = holo.simple_sam.Simple_SAM()
gwb_simple = sam_simple.gwb_ideal(fobs_yr)
print(gwb_simple)

In [None]:
# gsmf = holo.sam.GSMF_Schechter()
# gpf = holo.sam.GPF_Power_Law()
# gmt = holo.sam.GMT_Power_Law()
mmbulge = holo.host_relations.MMBulge_Standard(
    mamp_log10=sam_simple._mbh_star_log10, mplaw=sam_simple._alpha_mbh_star, mref=1e11*MSOL
)
# sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=100)
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
gwb = sam.gwb_ideal(fobs_yr)
print(gwb)

In [None]:
gwb, gwb_simple, gwb/gwb_simple

In [None]:
NUM = 100
mgal = MSOL * (10.0 ** np.random.uniform(8, 12, NUM))
qgal = 10.0 ** np.random.uniform(-3.0, 0.0, NUM)
# redz = zmath.random_power([0.01, 1.5], 2.0, NUM)
# zgal = redz.copy()
zgal = zmath.random_power([0.01, 1.5], 2.0, NUM)
# zgal = redz.copy()

In [None]:
def check(mm, qq, zz, check, simple):
    err = (check - simple) / np.min([check, simple], axis=0)
    isclose = np.isclose(check, simple, rtol=1e-2, atol=0.0)

    print(zmath.str_array(check))
    print(zmath.str_array(simple))
    print(zmath.stats_str(check))
    print(zmath.stats_str(simple))
    print(zmath.str_array(err))
    print(zmath.stats_str(err))
    print(isclose)

    fig, axes = plot.figax(ncols=3)
    ax = axes[0]
    xx = mm
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    ax = axes[1]
    xx = qq
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    ax = axes[2]
    xx = zz
    ax.scatter(xx, check, color='r', alpha=0.5, marker='+')
    ax.scatter(xx, simple, color='b', alpha=0.5, marker='x')
    tw = ax.twinx()
    ii = np.argsort(xx)
    tw.plot(xx[ii], err[ii], alpha=0.25)

    return isclose
    
def frac_diff(v1, v2):
    denom = np.min([v1, v2], axis=0)
    bads = (v1 == 0.0) | (v2 == 0.0)
    denom[bads] = 1.0
    ee = (v2 - v1) / denom
    ee[(v1 == 0.0) & (v2 == 0.0)] = 0.0
    return ee

def frac_truth(yy, truth):
    yy = yy if (truth is None) else np.fabs(yy - truth) / truth
    return yy

# Check Components

## GSMF

In [None]:
gsmf_check = sam._gsmf(mgal, zgal)
gsmf_simple = sam_simple.gsmf(mgal, zgal)
# err = (gsmf_check - gsmf_simple) / np.min([gsmf_check, gsmf_simple], axis=0)
err = frac_diff(gsmf_check, gsmf_simple)
print(utils.stats(err))
print(np.isclose(gsmf_check, gsmf_simple, rtol=1e-2))
# print(err)

fig, axes = plot.figax(ncols=2)
ax = axes[0]
ax.scatter(mgal, gsmf_check, color='r', alpha=0.5, marker='+')
ax.scatter(mgal, gsmf_simple, color='b', alpha=0.5, marker='x')

ax = axes[1]
ax.scatter(qgal, gsmf_check, color='r', alpha=0.5, marker='+')
ax.scatter(qgal, gsmf_simple, color='b', alpha=0.5, marker='x')
plt.show()

### Vary Parameters

In [None]:

# phi0=-2.77, phiz=-0.27, mref0=1.737801e11*MSOL, mrefz=0.0, alpha0=-1.24, alphaz=-0.03
params = [
    # ['phi0', 'gsmf_phi0_const', [-4.067, -2.13]],
    # ['phiz', 'gsmf_phiz', [-1.123, 0.0, +0.592]],
    # ['alpha0', 'gsmf_alpha0_const', [-2.123, 0.0, +0.1592]],
    ['alphaz', 'gsmf_alphaz', [-1.0023, 0.0, +0.4521592, -0.23158381946956436]],
    # ['mref0_log10', 'gsmf_log10m0', [12.42333, 11.00258, 10.5257]],
]

full_class = holo.sam.GSMF_Schechter
full_class_name = 'gsmf'

for pars in params:
    full_name, simp_name, vals = pars
    print(full_name, simp_name)
    for vv in vals:
        kw_full = {full_name: vv}
        kw_simp = {simp_name: vv}
        print(kw_full, kw_simp)
        instance = full_class(**kw_full)
        kw_sam_classes = {full_class_name: instance}
        sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, **kw_sam_classes)
        sam_simp = holo.simple_sam.Simple_SAM(**kw_simp)

        full = sam_full._gsmf(mgal, zgal)
        simp = sam_simp.gsmf(mgal, zgal)
        err = frac_diff(full, simp)
        print(full[:3])
        print(simp[:3])
        print(err[:3])
        print(utils.stats(err))
        assert np.allclose(full, simp, rtol=1e-6, atol=0.0)

        


## GPF

In [None]:
gpf = holo.sam.GPF_Power_Law()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, gpf=gpf)
sam_simple = holo.simple_sam.Simple_SAM()

gpf_check = sam._gpf(mgal, qgal, zgal)
gpf_simple = sam_simple.gpf(mgal, qgal, zgal)

check(mgal, qgal, zgal, gpf_check, gpf_simple)
plt.show()

### vary parameters

In [None]:

# phi0=-2.77, phiz=-0.27, mref0=1.737801e11*MSOL, mrefz=0.0, alpha0=-1.24, alphaz=-0.03
params = [
    ['phi0', 'gsmf_phi0_const', [-4.067, -2.13]],
    ['phiz', 'gsmf_phiz', [-1.123, 0.0, +0.592]],
    ['alpha0', 'gsmf_alpha0_const', [-2.123, 0.0, +0.1592]],
    ['alphaz', 'gsmf_alphaz', [-1.0023, 0.0, +0.4521592]],
    ['mref0_log10', 'gsmf_log10m0', [12.42333, 11.00258, 10.5257]],
]

full_class = holo.sam.GSMF_Schechter
full_class_name = 'gsmf'

for pars in params:
    full_name, simp_name, vals = pars
    print(full_name, simp_name)
    for vv in vals:
        kw_full = {full_name: vv}
        kw_simp = {simp_name: vv}
        instance = full_class(**kw_full)
        kw_sam_classes = {full_class_name: instance}
        sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, **kw_sam_classes)
        sam_simp = holo.simple_sam.Simple_SAM(**kw_simp)

        full = sam_full._gsmf(mgal, zgal)
        simp = sam_simp.gsmf(mgal, zgal)
        err = frac_diff(full, simp)
        print(full[:3])
        print(simp[:3])
        print(err[:3])
        print(utils.stats(err))
        assert np.allclose(full, simp, rtol=1e-6, atol=0.0)


## GMT

In [None]:
gmt_check = sam._gmt(mgal, qgal, zgal)
gmt_simple = sam_simple.gmt(mgal, qgal, zgal)
check(mgal, qgal, zgal, gmt_check, gmt_simple)
plt.show()

### Vary Parameters

In [None]:

# phi0=-2.77, phiz=-0.27, mref0=1.737801e11*MSOL, mrefz=0.0, alpha0=-1.24, alphaz=-0.03
params = [
    ['time_norm', 'gmt_norm', [0.1*GYR, 3.0*GYR]],
    ['malpha', 'gmt_alpha', [-1.0, +1.0]],
    ['zbeta', 'gmt_beta', [-1.0, +1.0]],
    ['qgamma', 'gmt_gamma', [-1.0, +1.0]],
]

full_class = holo.sam.GMT_Power_Law
full_class_name = 'gmt'

for pars in params:
    full_name, simp_name, vals = pars
    print(full_name, simp_name)
    for vv in vals:
        kw_full = {full_name: vv}
        kw_simp = {simp_name: vv}
        instance = full_class(**kw_full)
        kw_sam_classes = {full_class_name: instance}
        # sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, **kw_sam_classes)
        sam_full = holo.sam.Semi_Analytic_Model(**kw_sam_classes)
        sam_simp = holo.simple_sam.Simple_SAM(**kw_simp)

        full = sam_full._gmt(mgal, qgal, zgal)
        simp = sam_simp.gmt(mgal, qgal, zgal)
        err = frac_diff(full, simp)
        print(full[:3])
        print(simp[:3])
        print(err[:3])
        print(utils.stats(err))
        assert np.allclose(full, simp, rtol=1e-6, atol=0.0)


## Galaxy NDens

In [None]:
sam_simple = holo.simple_sam.Simple_SAM()
n1 = sam_simple.ndens_galaxy(mgal, qgal, zgal)
n2 = sam_simple._ndens_galaxy_check(mgal, qgal, zgal)

check(mgal, qgal, zgal, n1, n2)
plt.show()

In [None]:
sam = holo.sam.Semi_Analytic_Model()
ndg_check = sam._ndens_gal(mgal, qgal, zgal)
ndg_simple = sam_simple.ndens_galaxy(mgal, qgal, zgal)

uu_check = ndg_check; vv_simple = ndg_simple
check(mgal, qgal, zgal, uu_check, vv_simple)
plt.show()

## MBH NDens

In [None]:
sam_simple = holo.simple_sam.Simple_SAM()
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

uu = sam._ndens_mbh(mgal, qgal, zgal)
vv = sam_simple.ndens_mbh(mgal, qgal, zgal)

close = check(mgal, qgal, zgal, uu, vv)
if not all(close):
    err_msg = "MBH Number-Density mismatch between SAM and Simple_SAM!"
    raise ValueError(err_msg)

plt.show()

In [None]:
sam._density = None
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
ndens_sam = sam.static_binary_density
mbh_tot, mbh_rat, redz = np.copy(sam.grid)
mbh_pri, mbh_sec = utils.m1m2_from_mtmr(mbh_tot, mbh_rat)
mst_pri, mst_sec = [mmbulge.mstar_from_mbh(_mbh, scatter=False) for _mbh in [mbh_pri, mbh_sec]]
mst_rat = mst_sec/mst_pri

ndens_ref = sam._ndens_mbh(mst_pri, mst_rat, redz)

err = frac_diff(ndens_sam, ndens_ref)
print(zmath.stats_str(ndens_sam))
print(zmath.stats_str(ndens_ref))
print("err=", zmath.stats_str(err))

bads = ~np.isclose(ndens_sam, ndens_ref, rtol=1e-6, atol=1e-14)
if np.any(bads):
    print(ndens_sam[bads])
    print(ndens_ref[bads])
    print(err[bads])
    err_msg = f"sam mbh ndens does not match consistency check || error too large!"
    raise ValueError(err_msg)


## GWB Ideal

In [None]:
sam_simple = holo.simple_sam.Simple_SAM(size=100)
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

hc1 = sam_simple.gwb_sam(fobs_yr, sam, dlog10=False)
hc2 = sam_simple.gwb_sam(fobs_yr, sam, dlog10=True)

err = frac_diff(hc1, hc2)
print(hc1, hc2, err)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

plt.show()

In [None]:
sam_simple = holo.simple_sam.Simple_SAM()

mg = sam_simple.mass_gal[:, np.newaxis, np.newaxis]
qg = sam_simple.mrat_gal[np.newaxis, :, np.newaxis]
rz = sam_simple.redz[np.newaxis, np.newaxis, :]
mtot = sam_simple.mbh[:, :, np.newaxis]

ndens_a1 = sam_simple.ndens_mbh(mg, qg, rz, dlog10=False)
ndens_b1 = ndens_a1 * mtot * np.log(10.0)
ndens_b2 = sam_simple.ndens_mbh(mg, qg, rz, dlog10=True)
ndens_a2 = ndens_b2 / (mtot * np.log(10.0))

print("dlog10(M)")
print(utils.stats(ndens_b1, prec=4))
print(utils.stats(ndens_b2, prec=4))
err = frac_diff(ndens_b1, ndens_b2)
print(utils.stats(err, prec=2))
assert np.allclose(ndens_b1, ndens_b2)

print("dM")
print(utils.stats(ndens_a1, prec=4))
print(utils.stats(ndens_a2, prec=4))
err = frac_diff(ndens_a1, ndens_a2)
print(utils.stats(err[err > 0], prec=2))
assert np.allclose(ndens_a1, ndens_a2)

In [None]:
sam_simp = holo.simple_sam.Simple_SAM()

hc1 = sam_simp.gwb_ideal(fobs_yr, dlog10=True)
hc2 = sam_simp.gwb_ideal(fobs_yr, dlog10=False)
err = frac_diff(hc1, hc2)
print(hc1, hc2, err)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

# check(mgal, qgal, redz, uu, vv)
plt.show()

In [None]:
sam_simp = holo.simple_sam.Simple_SAM()
sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)

hc1 = sam_simp.gwb_ideal(fobs_yr)
hc2 = sam_simp.gwb_sam(fobs_yr, sam_full)
err = frac_diff(hc1, hc2)
print(hc1, hc2, err)
if not np.isclose(hc1, hc2, rtol=1e-5, atol=0.0):
    err = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err)

# check(mgal, qgal, redz, uu, vv)
plt.show()

In [None]:
val = 0.0
val = -0.8
freq_mult = 0.1

gpf = holo.sam.GPF_Power_Law(malpha=val)
sam_simp = holo.simple_sam.Simple_SAM(gpf_alpha=val)
sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, gpf=gpf)

hc1 = sam_full.gwb_ideal(fobs_yr*freq_mult)
hc2 = sam_simp.gwb_ideal(fobs_yr*freq_mult)

err = frac_diff(hc1, hc2)

print(hc1, hc2, err)
if not np.isclose(hc1, hc2, rtol=1e-2, atol=0.0):
    err_msg = f"{hc1=:.8e} (gwb) vs. {hc2=:.8e} || error too large!"
    raise ValueError(err_msg)


# Vary Parameters

In [None]:
NUM = 100
mgal = MSOL * (10.0 ** np.random.uniform(8, 12, NUM))
qgal = 10.0 ** np.random.uniform(-3.0, 0.0, NUM)
zgal = zmath.random_power([0.01, 1.5], 2.0, NUM)

In [None]:
sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
sam_simp = holo.simple_sam.Simple_SAM()
gwb_full = sam_full.gwb_ideal(fobs_yr)
gwb_simp = sam_simp.gwb_ideal(fobs_yr)
print("\t", gwb_full)
print("\t", gwb_simp)
err = frac_diff(gwb_full, gwb_simp)
print("\t", err)

assert np.allclose(gwb_full, gwb_simp, rtol=1e-2, atol=0.0)


Define all of the parameters that are going to be varied, and specify information about their range of values to randomly draw from, the classes they belong to in the SAM model, etc.

In [None]:
params = {
    'gsmf_phi0_const': [
        [-2.77, -0.29, +0.27],
        ['gsmf', holo.sam.GSMF_Schechter],
        ['phi0', None]
    ],
    'gsmf_phiz': [
        [-0.27, -0.21, +0.23],
        ['gsmf', holo.sam.GSMF_Schechter],
        ['phiz', None]
    ],
    'gsmf_log10m0': [
        [+11.24, -0.17, +0.20],
        ['gsmf', holo.sam.GSMF_Schechter],
        ['mref0_log10', None],
    ],
    'gsmf_alpha0_const': [
        [-1.24, -0.16, +0.16],
        ['gsmf', holo.sam.GSMF_Schechter],
        ['alpha0', None]
    ],
    'gsmf_alphaz': [
        [-0.03, -0.14, +0.16],
        ['gsmf', holo.sam.GSMF_Schechter],
        ['alphaz', None]
    ],
    
    'gpf_norm': [
        [0.02, 0.03], 
        ['gpf', holo.sam.GPF_Power_Law],
        ['frac_norm_allq', None]
    ],
    'gpf_alpha': [
        [-0.2, +0.2], 
        ['gpf', holo.sam.GPF_Power_Law],
        ['malpha', None]
    ],
    'gpf_beta': [
        [0.6, 1.0], 
        ['gpf', holo.sam.GPF_Power_Law],
        ['zbeta', None]
    ],
    'gpf_gamma': [
        [-0.2, +0.2], 
        ['gpf', holo.sam.GPF_Power_Law],
        ['qgamma', None]
    ],

    'gmt_norm': [
        [0.1, 2.0], 
        ['gmt', holo.sam.GMT_Power_Law],
        ['time_norm', None]
    ],
    'gmt_alpha': [
        [-0.2, +0.2], 
        ['gmt', holo.sam.GMT_Power_Law],
        ['malpha', None]
    ],
    'gmt_beta': [
        [-2, +1], 
        ['gmt', holo.sam.GMT_Power_Law],
        ['zbeta', None]
    ],
    'gmt_gamma': [
        [-0.2, +0.2], 
        ['gmt', holo.sam.GMT_Power_Law],
        ['qgamma', None]
    ],

    'mbh_star_log10': [
        [8.17, -0.32, +0.35], 
        ['mmbulge', holo.host_relations.MMBulge_Standard],
        ['mamp_log10', None]
    ],
    'alpha_mbh_star': [
        [1.01, -0.10, +0.08], 
        ['mmbulge', holo.host_relations.MMBulge_Standard],
        ['mplaw', None]
    ],
}

In [None]:
# Choose some frequencies at which to compare
# freqs = utils.nyquist_freqs(5*YR, 0.2*YR)
freqs = np.asarray([fobs_yr])

def get_param_from_bounds(bounds):
    # If there are only two bounds, it means to sample uniformly between them
    if len(bounds) == 2:
        xx = np.random.uniform(*bounds)
    # If there are 3 bounds, it means to sample from a normal distribution
    #    bounds[0] is the mean, bounds[1] is the negative stdev, bounds[2] is the positive stdev
    elif len(bounds) == 3:
        # if stdev are the same, draw from symmetric normal distribution
        if bounds[1] == bounds[2]:
            xx = np.random.normal(bounds[0], bounds[1])
        # if stdev are NOT the same, draw from lower and upper normal distributions, and choose
        # randomly between them
        else:
            # choose [-, +] deviations from zero
            yy = [np.sign(bb) * np.fabs(np.random.normal(0.0, np.fabs(bb)))
                  for bb in bounds[1:]]

            xx = bounds[0]
            # choose which of the [-, +] to use, and add that to the mean value
            if np.random.rand() < 0.5:
                xx -= np.fabs(yy[0])
            else:
                xx += np.fabs(yy[1])

    else:
        raise ValueError()

    return xx


def compare_sam_with_parameter(freqs, par, params, xx=None):

    #     'gsmf_log10m0': [
    #         [+11.24, -0.17, +0.20],
    #         ['gsmf', holo.sam.GSMF_Schechter],
    #         ['mref0', lambda xx: np.power(10.0, xx) * MSOL],
    #     ],

    vals = params[par]
    print("\n", par, vals)
    bounds, _par, _arg = vals
    par_class_name, par_class = _par
    arg_name, arg_func = _arg

    if xx is None:
        xx = get_param_from_bounds(bounds)

    # ---- Initialize ----
    
    # Creater the simple SAM.  `par` gives the name of the argument, and `xx` is its value
    simp_kwargs = {par: xx}
    print("\t", simp_kwargs)
    sam_simp = holo.simple_sam.Simple_SAM(**simp_kwargs)

    # For the full SAM, need to instantiate the class that this parameter belongs to
    # transform the value as needed (e.g. log10(x/Msol) ==> y gram)
    xx = arg_func(xx) if arg_func is not None else xx
    # We always need to use an `mmbulge` to match between the SAM models
    sam_kwargs = {'mmbulge': mmbulge}
    # If we're modifying an `mmbulge` parameter, start with default values and update them
    if par_class_name == 'mmbulge':
        class_kwargs = dict(mamp_log10=sam_simp._mbh_star_log10, mplaw=sam_simp._alpha_mbh_star)
    else:
        class_kwargs = {}
    class_kwargs[arg_name] = xx

    # construct the class containing the parameter
    print("\t", class_kwargs)
    par_inst = par_class(**class_kwargs)
    # add the class, or update `mmbulge` if we're modifying that
    sam_kwargs[par_class_name] = par_inst
    sam_full = holo.sam.Semi_Analytic_Model(**sam_kwargs)

    # ---- Test / Compare Models ----

    # - Galaxy Number Density
    print("galaxy number density")
    full = sam_full._ndens_gal(mgal, qgal, zgal)
    simp = sam_simp.ndens_galaxy(mgal, qgal, zgal)
    err = frac_diff(full, simp)
    print("\t", full[:3])
    print("\t", simp[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(full, simp, rtol=1e-3, atol=0.0)

    simp_check = sam_simp._ndens_galaxy_check(mgal, qgal, zgal)
    err = frac_diff(simp, simp_check)
    print("\t", simp[:3])
    print("\t", simp_check[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(full, simp, rtol=1e-3, atol=0.0)

    # - MBH Number Density
    print("MBH number density")
    print("\tscatter")
    full = sam_full._ndens_mbh(mgal, qgal, zgal)
    simp = sam_simp.ndens_mbh(mgal, qgal, zgal)
    # full = sam._ndens_mbh(*sam.grid)
    # simp = sam_simp.ndens_mbh(*sam.grid)
    err = frac_diff(full, simp)
    print("\t", full[:3])
    print("\t", simp[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(full, simp, rtol=1e-3, atol=0.0)

    print("\tgrid")
    full = sam_full.static_binary_density
    mbh_tot, mbh_rat, rz = np.copy(sam_full.grid)
    mbh_pri, mbh_sec = utils.m1m2_from_mtmr(mbh_tot, mbh_rat)
    mst_pri, mst_sec = [mmbulge.mstar_from_mbh(_mbh, scatter=False) for _mbh in [mbh_pri, mbh_sec]]
    mst_rat = mst_sec/mst_pri

    simp = sam_full._ndens_mbh(mst_pri, mst_rat, rz)
    err = frac_diff(full, simp)
    print("\t", full.flatten()[:3])
    print("\t", simp.flatten()[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(full, simp, rtol=1e-3, atol=1e-16)
    
    simp = sam_simp.ndens_mbh(mst_pri, mst_rat, rz)
    err = frac_diff(full, simp)
    print("\t", full.flatten()[:3])
    print("\t", simp.flatten()[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(full, simp, rtol=1e-3, atol=1e-16)
    # ii = np.argmax(err)
    # print(f"{err.max()=}, {err.flatten()[ii]=}, {full.flatten()[ii]=}, {simp.flatten()[ii]=}")
    ii = np.where(np.isclose(err, err.max()))
    print(f"{err.max()=}, {ii=}\n\t{err[ii]=}\n\t{full[ii]=}\n\t{simp[ii]=}")
    print(np.where(err > 1.0))

    # - GWB
    print("GWB")
    gwb_full = sam_full.gwb_ideal(freqs)
    gwb_simp = sam_simp.gwb_ideal(freqs)
    err = frac_diff(gwb_full, gwb_simp)
    print("\t", gwb_full[:3])
    print("\t", gwb_simp[:3])
    print("\t", utils.stats(err), err.shape)
    assert np.allclose(gwb_full, gwb_simp, rtol=1e-2, atol=0.0)
    return

params_list = params.keys()
# params_list = ['gpf_alpha']
params_list = ['gsmf_alphaz']
pval = None
# pval = -1.0
# pval = -0.23158381946956436
pval = -2.23158381946956436
# pval = -0.23158381946956436 * 5.01
    
for par in params_list:
    compare_sam_with_parameter(freqs, par, params, xx=pval)

## gsmf_alphaz

In [None]:
v1 = sam_simp.ndens_galaxy(mgal, qgal, zgal)
v2 = sam_simp_ref.ndens_galaxy(mgal, qgal, zgal)

fig, axes = plot.figax(ncols=3)
err = frac_diff(v1, v2)
print(zmath.minmax(err))
for xx, ax in zip([mgal/MSOL, qgal, zgal], axes):
    print(zmath.minmax(xx))
    ax.scatter(xx, np.fabs(err))
    
plt.show()

In [None]:
# val = -1.5
ff = fobs_yr

gsmf_alphaz = -1.23158381946956436
# gsmf_alphaz = -0.23158381946956436
gsmf_alphaz = +1.0
# gsmf_alphaz = 0.0

gsmf = holo.sam.GSMF_Schechter(alphaz=gsmf_alphaz)
sam_full = holo.sam.Semi_Analytic_Model(gsmf=gsmf)
sam_simp = holo.simple_sam.Simple_SAM(gsmf_alphaz=gsmf_alphaz)
# sam_full = holo.sam.Semi_Analytic_Model()
# sam_simp_ref = holo.simple_sam.Simple_SAM()
gwb_full = sam_full.gwb_ideal(ff, sum=False) ** 2
gwb_simp = sam_simp.gwb_ideal(ff, sum=False) ** 2
temp_full = np.sqrt(gwb_full.sum())
temp_simp = np.sqrt(gwb_simp.sum())
err = frac_diff(temp_full, temp_simp)
print(temp_full, temp_simp, err)


In [None]:
aa = sam_simp.mbh[:, -1]/MSOL
bb = sam_full.mtot/MSOL
aa.shape, bb.shape

In [None]:
plt.loglog(bb, aa, marker='.')
xx = np.logspace(3, 12, 100)
plt.loglog(xx, xx, 'k--')
ax = plt.gca()
ax.set(
    xlim=[9e3, 2e4],
    ylim=[5e3, 2e4],
)
plt.show()

In [None]:
ax = (1, 2)
full = np.sum(gwb_full, axis=ax)
simp = np.sum(gwb_simp, axis=ax)

full = np.sqrt(np.cumsum(full))
simp = np.sqrt(np.cumsum(simp))

truth = None
truth = full[-1]
# truth = simp[-1]
# truth = np.mean([full, simp], axis=0)

fig, ax = plot.figax(
    # xlim=[1e7, 1e9],
    # ylim=[1e-15, 2e-15],
)
tw = ax.twinx()
tw.set(
    # yscale='log', ylim=[1e-4, 1e-1]
)

for xx, yy, lab in zip([sam_full.mtot, sam_simp.mbh[:, -1]], [full, simp], ['full', 'simp']):
    xx = xx[1:] / MSOL
    ax.plot(xx, yy, label=lab, alpha=0.5)
    yy = frac_truth(yy, truth)
    tw.plot(xx, yy, label=lab, alpha=0.5, ls='--')

    
ax.legend()
plt.show()


In [None]:
val = 4.0
ff = fobs_yr
gpf = holo.sam.GPF_Power_Law(malpha=val)
sam_full = holo.sam.Semi_Analytic_Model(gpf=gpf, mmbulge=mmbulge)
sam_simp = holo.simple_sam.Simple_SAM(gpf_alpha=val)
gwb_full = sam_full.gwb_ideal(ff, sum=False) ** 2
gwb_simp = sam_simp.gwb_ideal(ff, sum=False) ** 2
temp_full = np.sqrt(gwb_full.sum())
temp_simp = np.sqrt(gwb_simp.sum())
err = frac_diff(temp_full, temp_simp)
print(temp_full, temp_simp, err)


In [None]:
ax = (1, 2)
full = np.sum(gwb_full, axis=ax)
simp = np.sum(gwb_simp, axis=ax)

full = np.sqrt(np.cumsum(full))
simp = np.sqrt(np.cumsum(simp))

truth = None
truth = full[-1]
# truth = simp[-1]
# truth = np.mean([full, simp], axis=0)

fig, ax = plot.figax(
    # xlim=[1e7, 1e9],
    # ylim=[1e-15, 2e-15],
)
tw = ax.twinx(); tw.set(yscale='log', ylim=[1e-4, 1e-1])
for xx, yy, lab in zip([sam_full.mtot, sam_simp.mbh[:, -1]], [full, simp], ['full', 'simp']):
    xx = xx[1:] / MSOL
    ax.plot(xx, yy, label=lab, alpha=0.5)
    yy = frac_truth(yy, truth)
    tw.plot(xx, yy, label=lab, alpha=0.5, ls='--')

    
ax.legend()
plt.show()


# GWB Redshift

In [None]:
ff = fobs_yr
redz_prime = True
sam_full = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
sam_simp = holo.simple_sam.Simple_SAM()
gwb_full = sam_full.gwb_ideal(ff, redz_prime=redz_prime)
gwb_simp = sam_simp.gwb_ideal(ff, redz_prime=redz_prime)
# gwb_full = sam_full.gwb_ideal(ff, sum=False) ** 2
# gwb_simp = sam_simp.gwb_ideal(ff, sum=False) ** 2
err = frac_diff(gwb_full, gwb_simp)
print(gwb_full, gwb_simp, err)