In [None]:
# %load ../../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath

In [None]:
import holodeck.simple_sam

In [None]:
fobs_yr = 1.0 / YR

In [None]:
sam_simple = holo.simple_sam.Simple_SAM()
gwb_simple = sam_simple.gwb_ideal(fobs_yr)
print(gwb_simple)

In [None]:
# gsmf = holo.sam.GSMF_Schechter()
# gpf = holo.sam.GPF_Power_Law()
# gmt = holo.sam.GMT_Power_Law()
mmbulge = holo.host_relations.MMBulge_Standard(
    mamp=sam_simple._mbh_star, mplaw=sam_simple._alpha_mbh_star, mref=1e11*MSOL
)
# sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=100)
sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
gwb = sam.gwb_ideal(fobs_yr)
print(gwb)

In [None]:
gwb, gwb_simple, gwb/gwb_simple

In [None]:
def frac_diff(v1, v2):
    ee = (v2 - v1)/np.min([v1, v2], axis=0)
    return ee

def frac_truth(yy, truth):
    yy = yy if (truth is None) else np.fabs(yy - truth) / truth
    return yy

## Grid Bounds

Total Mass

In [None]:
def_size = 80
mtot_extr = [5.0, 11.0]
widths = np.linspace(2.0, 8.0, 7)
print(widths)
num = len(widths)
gwb_lo = np.zeros(num)
gwb_hi = np.zeros(num)

sam_fid = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii in tqdm.tqdm_notebook(range(num)):
    wid = widths[ii]

    lo = mtot_extr[1]-wid
    mtot_vals = [lo, mtot_extr[1], def_size]
    mtot_vals[0] = MSOL * (10.0 ** mtot_vals[0])
    mtot_vals[1] = MSOL * (10.0 ** mtot_vals[1])
    # sam_simp = holo.simple_sam.Simple_SAM()
    _sam = holo.sam.Semi_Analytic_Model(mtot=tuple(mtot_vals), mmbulge=mmbulge)
    gwb_lo[ii] = _sam.gwb_ideal(fobs_yr)

    hi = mtot_extr[0]+wid
    mtot_vals = [mtot_extr[0], hi, def_size]
    mtot_vals[0] = MSOL * (10.0 ** mtot_vals[0])
    mtot_vals[1] = MSOL * (10.0 ** mtot_vals[1])
    # sam_simp = holo.simple_sam.Simple_SAM()
    _sam = holo.sam.Semi_Analytic_Model(mtot=tuple(mtot_vals), mmbulge=mmbulge)
    gwb_hi[ii] = _sam.gwb_ideal(fobs_yr)
    
    print(ii, wid, [lo, mtot_extr[1]], [mtot_extr[0], hi])


In [None]:
fig, ax = plot.figax(xscale='linear')
xvals = widths

truth = None
truth = gwb_fid

test = gwb_lo
y1 = np.fabs(test - truth) / truth if truth is not None else test

test = gwb_hi
y2 = np.fabs(test - truth) / truth if truth is not None else test

test = gwb_fid
fid = np.fabs(test - truth) / truth if truth is not None else test

ax.plot(xvals, y1, 'r-', alpha=0.5, label='lo')
ax.plot(xvals, y2, 'b--', alpha=0.5, label='hi')
ax.axhline(fid, color='k', ls='--', alpha=0.25)
ax.legend()
plt.show()

Mass Ratio

In [None]:
def_size = 80
widths = np.linspace(1.0, 5.0, 9)
print(widths)
num = len(widths)
gwb_lo = np.zeros(num)

sam_fid = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge)
gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii in tqdm.tqdm_notebook(range(num)):
    wid = widths[ii]

    vals = [-wid, 0.0, def_size]
    vals[0] = 10.0 ** vals[0]
    vals[1] = 10.0 ** vals[1]
    # sam_simp = holo.simple_sam.Simple_SAM()
    _sam = holo.sam.Semi_Analytic_Model(mrat=tuple(vals), mmbulge=mmbulge)
    gwb_lo[ii] = _sam.gwb_ideal(fobs_yr)

    print(ii, wid)


In [None]:
fig, ax = plot.figax(xscale='linear')
xvals = widths

truth = None
# truth = gwb_fid

test = gwb_lo
y1 = np.fabs(test - truth) / truth if truth is not None else test

test = gwb_fid
fid = np.fabs(test - truth) / truth if truth is not None else test

ax.plot(xvals, y1, 'r-', alpha=0.5, label='lo')
ax.axhline(fid, color='k', ls='--', alpha=0.25)
ax.legend()
plt.show()

redshift

In [None]:
def_size = 100
lo_vals = [1.0e0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]
hi_vals = [  2.0,  3.0,  4.0,  5.0, 6.0, 7.0, 8.0]
num = len(lo_vals)
assert num == len(hi_vals)
gwb_lo = np.zeros(num)
gwb_hi = np.zeros(num)

sam_fid = holo.sam.Semi_Analytic_Model(redz=(lo_vals[-1], hi_vals[-1], def_size), mmbulge=mmbulge)
gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii in tqdm.tqdm_notebook(range(num)):
    vals = [lo_vals[ii], hi_vals[-1], def_size]
    _sam = holo.sam.Semi_Analytic_Model(redz=tuple(vals), mmbulge=mmbulge)
    gwb_lo[ii] = _sam.gwb_ideal(fobs_yr)

    vals = [lo_vals[-1], hi_vals[ii], def_size]
    _sam = holo.sam.Semi_Analytic_Model(redz=tuple(vals), mmbulge=mmbulge)
    gwb_hi[ii] = _sam.gwb_ideal(fobs_yr)

    print(ii, [lo_vals[ii], hi_vals[-1]], [lo_vals[-1], hi_vals[ii]])


In [None]:
fig, ax = plot.figax(xscale='linear')

truth = None
truth = gwb_fid

test = gwb_lo
y1 = np.fabs(test - truth) / truth if truth is not None else test

test = gwb_hi
y2 = np.fabs(test - truth) / truth if truth is not None else test

test = gwb_fid
fid = np.fabs(test - truth) / truth if truth is not None else test

ax.plot(y1, 'r-', alpha=0.5, label='lo')
ax.plot(y2, 'b--', alpha=0.5, label='hi')
ax.axhline(fid, color='k', ls='--', alpha=0.25)
ax.legend()
plt.show()

In [None]:
def_size = 61
lo_vals = [1e-2, 1e-4, 1e-6]
hi_vals = [2.0, 3.0, 4.0, 6.0, 8.0, 9.0, 9.9]
num_lo = len(lo_vals)
num_hi = len(hi_vals)

gwb_log = np.zeros((num_lo, num_hi))
gwb_lin = np.zeros((num_lo, num_hi))

# sam_fid = holo.sam.Semi_Analytic_Model(redz=(1e-4, 10.0, 300), mmbulge=mmbulge)
# redz_fid = np.linspace(0.0, 10.0, 1000)
# sam_fid = holo.sam.Semi_Analytic_Model(redz=redz_fid, mmbulge=mmbulge)
# gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii, scale in enumerate(['log', 'lin']):
    for jj, lo in enumerate(lo_vals):
        for kk, hi in enumerate(hi_vals):
            vals = (lo, hi)
            redz = zmath.spacing(vals, scale=scale, num=def_size)
            _sam = holo.sam.Semi_Analytic_Model(redz=redz, mmbulge=mmbulge)
            if ii == 0:
                gwb_log[jj, kk] = _sam.gwb_ideal(fobs_yr)
            else:
                gwb_lin[jj, kk] = _sam.gwb_ideal(fobs_yr)


In [None]:
fig, ax = plot.figax()
xvals = hi_vals

truth = gwb_fid
# truth = None

for jj, lo in enumerate(lo_vals):
    yy = frac_truth(gwb_log[jj, :], truth)
    cc, = ax.plot(xvals, yy, label=f'log {lo:.1e}', alpha=0.5)
    cc = cc.get_color()

    yy = frac_truth(gwb_lin[jj, :], truth)
    ax.plot(xvals, yy, label=f'lin {lo:.1e}', ls='--', color=cc, alpha=0.5)

ax.legend()
plt.show()

## Number of Points

In [None]:
# fid_shape = [81, 50, 60]
fid_shape = [102, 101, 100]
# num_points = [10, 20, 40, 100, 200]
num_points = [10, 20, 40, 100]
num = len(num_points)

gwb_vals = np.zeros((4, num))
val_names = ['mtot', 'mrat', 'redz', 'all']

sam_fid = holo.sam.Semi_Analytic_Model(shape=fid_shape, mmbulge=mmbulge)
gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii in tqdm.tqdm_notebook(range(num)):
    for jj in range(4):
        shape = np.copy(fid_shape)
        if jj == 3:
            shape = num_points[ii]
        else:
            shape[jj] = num_points[ii]

        _sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, shape=shape)
        print(ii, jj, shape, _sam.shape)
        gwb_vals[jj, ii] = _sam.gwb_ideal(fobs_yr)

# gwb_fid = gwb_vals[3, -1]

In [None]:
ref = None
ref = gwb_fid

fig, ax = plot.figax(xscale='lin')
for ii in range(4):
    yy = gwb_vals[ii, :]
    yy = np.fabs(yy - ref) / ref if ref is not None else yy
    ax.plot(num_points, yy, label=val_names[ii])

if ref is None:
    ax.axhline(gwb_fid, color='k', alpha=0.5, ls='--')
    
ax.legend()
plt.show()

In [None]:
redz_sizes = [60, 80, 100]

fid_shape = [102, 101, 100]
num_points = [10, 20, 40, 80, 100]
# num_points = [10, 20, 40]
num = len(num_points)

nz = len(redz_sizes)

gwb_mtot = np.zeros((nz, num))
gwb_mrat = np.zeros((nz, num))

sam_fid = holo.sam.Semi_Analytic_Model(shape=fid_shape, mmbulge=mmbulge)
gwb_fid = sam_fid.gwb_ideal(fobs_yr)

for ii, redz_num in enumerate(redz_sizes):
    shape = np.copy(fid_shape)
    shape[2] = redz_num
    print(ii, shape)
    for jj in range(num):
        shape_mtot = np.copy(shape)
        shape_mtot[0] = num_points[jj]
        print("\t", jj, shape_mtot)
        _sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, shape=shape_mtot)
        gwb_mtot[ii, jj] = _sam.gwb_ideal(fobs_yr)

        shape_mrat = np.copy(shape)
        shape_mrat[1] = num_points[jj]
        print("\t", jj, shape_mrat)
        _sam = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, shape=shape_mrat)
        gwb_mrat[ii, jj] = _sam.gwb_ideal(fobs_yr)

# gwb_fid = gwb_vals[3, -1]

In [None]:
ref = None
ref = gwb_fid

fig, ax = plot.figax(xscale='lin')
for ii in range(3):
    # ii = 2
    yy = gwb_mtot[ii, :]
    yy = np.fabs(yy - ref) / ref if ref is not None else yy
    cc, = ax.plot(num_points, yy, label=f'mtot z={redz_sizes[ii]}', alpha=0.65)
    cc = cc.get_color()
    print(ii, gwb_mtot[ii, -1], gwb_mrat[ii, -1])

    yy = gwb_mrat[ii, :]
    yy = np.fabs(yy - ref) / ref if ref is not None else yy
    ax.plot(num_points, yy, label=f'mrat z={redz_sizes[ii]}', color=cc, ls='--', alpha=0.65)
    # break

if ref is None:
    ax.axhline(gwb_fid, ls='--', alpha=0.75, lw=0.75, color='k')

ax.legend()
plt.show()

In [None]:
ref = None
ref = gwb_fid

fig, ax = plot.figax(xscale='lin')
for ii in range(3):
    # ii = 2
    yy = gwb_mtot[ii, :]
    yy = np.fabs(yy - ref) / ref if ref is not None else yy
    cc, = ax.plot(num_points, yy, label=f'mtot z={redz_sizes[ii]}', alpha=0.65)
    cc = cc.get_color()
    print(ii, gwb_mtot[ii, -1], gwb_mrat[ii, -1])

    yy = gwb_mrat[ii, :]
    yy = np.fabs(yy - ref) / ref if ref is not None else yy
    ax.plot(num_points, yy, label=f'mrat z={redz_sizes[ii]}', color=cc, ls='--', alpha=0.65)
    # break

if ref is None:
    ax.axhline(gwb_fid, ls='--', alpha=0.75, lw=0.75, color='k')

ax.legend()
plt.show()

## dlog10(M) vs. dM integration --- varying mass-array sizes

In [None]:
size = [10, 20, 40, 100, 200, 400]
gwb_reg = np.zeros(len(size))
gwb_log = np.zeros(len(size))
for ii, ss in enumerate(tqdm.tqdm_notebook(size)):
    ss = int(ss)
    sam_simp = holo.simple_sam.Simple_SAM(size=ss)
    gwb_reg[ii] = sam_simp.gwb_ideal(fobs_yr, dlog10=False)
    gwb_log[ii] = sam_simp.gwb_ideal(fobs_yr, dlog10=True)

In [None]:
fig, ax = plot.figax()
truth = 0.5 * (gwb_reg[-1] + gwb_log[-1])
# truth = gwb_log[-1]
ax.plot(size, np.fabs(gwb_reg - truth) / truth, 'r-', alpha=0.5, label='reg')
ax.plot(size, np.fabs(gwb_log - truth) / truth, 'b--', alpha=0.5, label='log')
ax.legend()
plt.show()

# Cumulative Distributions

## total mass

In [None]:
NUM = 400
kw_name = 'mtot'
extr = [1e4, 1.0e14]
extr = np.array(extr) * MSOL

log = zmath.spacing(extr, scale='log', num=NUM)
print(log[:5])
print(zmath.minmax(log))

vals = {kw_name: log}
sam_log = holo.sam.Semi_Analytic_Model(mmbulge=mmbulge, **vals)
num_log = sam_log._integrated_binary_density(sum=False)
gwb_log = sam_log.gwb_ideal(fobs_yr, sum=False) ** 2

In [None]:
fig, ax = plot.figax(xlabel='mtot')
xx = log[1:]

ls = '-'
sc_name = 'log'

vals = gwb_log
name = 'gwb'

temp = np.sum(vals, axis=(1, 2))
temp_forw = np.cumsum(temp)
temp_back = np.cumsum(temp[::-1])[::-1]

temp_forw = np.sqrt(temp_forw)
temp_back = np.sqrt(temp_back)

xx = log[1:]/MSOL
yy = temp_forw
yy = frac_truth(yy, yy[-1])
ax.plot(xx, yy, label=f'{sc_name} {name} forward', alpha=0.5, ls=ls)

xx = log[:-1]/MSOL
yy = temp_back
yy = frac_truth(yy, yy[0])
ax.plot(xx, yy, label=f'{sc_name} {name} backward', alpha=0.5, ls=ls)

ax.legend(ncols=2)
plt.show()

# redshift

In [None]:
NUM = 400
extr = (1e-4, 10.0)
redz_lin = zmath.spacing(extr, scale='lin', num=NUM)
redz_log = zmath.spacing(extr, scale='log', num=NUM)
print(redz_lin[:5])
print(redz_log[:5])
print(zmath.minmax(redz_lin), zmath.minmax(redz_log))

sam_log = holo.sam.Semi_Analytic_Model(redz=redz_log, mmbulge=mmbulge)
num_log = sam_log._integrated_binary_density(sum=False)
gwb_log = sam_log.gwb_ideal(fobs_yr, sum=False) ** 2

sam_lin = holo.sam.Semi_Analytic_Model(redz=redz_lin, mmbulge=mmbulge)
num_lin = sam_lin._integrated_binary_density(sum=False)
gwb_lin = sam_lin.gwb_ideal(fobs_yr, sum=False) ** 2

In [None]:
fig, ax = plot.figax(xlabel='redz')
xx = redz_log[1:]

for sc_vals, sc_name, ls in zip([[gwb_log, num_log], [gwb_lin, num_lin]], ['log', 'lin'], ['-', '--']):
    ax.set_prop_cycle(None)

    # for vals, name in zip([gwb_log, num_log], ['gwb', 'num']):
    for vals, name in zip(sc_vals, ['gwb', 'num']):

        temp = np.sum(vals, axis=(0, 1))
        temp_forw = np.cumsum(temp)
        temp_back = np.cumsum(temp[::-1])[::-1]
        if name == 'gwb':
            temp_forw = np.sqrt(temp_forw)
            temp_back = np.sqrt(temp_back)

        yy = frac_truth(temp_forw, temp_forw[-1])
        c1, = ax.plot(xx, yy, label=f'{sc_name} {name} forward', alpha=0.5, ls=ls)

        col = c1.get_color()

        yy = frac_truth(temp_back, temp_back[0])
        ax.plot(xx, yy, label=f'{sc_name} {name} backward', alpha=0.5, ls=ls)

ax.legend(ncols=2)
plt.show()

# Different redz scalings

In [None]:
nums = np.logspace(1, 3, 11)
nums = zmath.spacing([10, 300], 'log', 5)
extr = (1e-3, 10.0)
print(nums)

gwb_lin_z = np.zeros_like(nums)
gwb_lin = np.zeros_like(nums)
gwb_log = np.zeros_like(nums)
gwb_mix_1 = np.zeros_like(nums)
gwb_mix_2 = np.zeros_like(nums)
gwb_mix_3 = np.zeros_like(nums)
gwb_mix_4 = np.zeros_like(nums)
gwb_mix_5 = np.zeros_like(nums)

for ii, nn in enumerate(nums):
    nn = int(nn)
    redz_lin_z = zmath.spacing([0.0, extr[1]], 'lin', nn)
    redz_lin = zmath.spacing(extr, 'lin', nn)
    redz_log = zmath.spacing(extr, 'log', nn)

    nl = nn // 2
    nh = nn - nl
    mix = 1.0
    lo = zmath.spacing([extr[0], mix], 'log', nl-1)
    hi = zmath.spacing([mix, extr[1]], 'lin', nh+1)[1:]
    redz_mix_1 = np.concatenate([[0.0], lo, hi])

    nl = nn // 5
    nh = nn - nl
    mix = 0.1
    lo = zmath.spacing([extr[0], mix], 'log', nl)
    hi = zmath.spacing([mix, extr[1]], 'lin', nh+1)[1:]
    redz_mix_2 = np.concatenate([lo, hi])

    nl = nn // 4
    nh = nn - nl
    mix = 0.3
    lo = zmath.spacing([extr[0], mix], 'log', nl)
    hi = zmath.spacing([mix, extr[1]], 'lin', nh+1)[1:]
    redz_mix_3 = np.concatenate([lo, hi])

    nl = nn // 2
    nh = nn - nl
    mix = 1.0
    lo = zmath.spacing([extr[0], mix], 'log', nl)
    hi = zmath.spacing([mix, extr[1]], 'lin', nh+1)[1:]
    redz_mix_4 = np.concatenate([lo, hi])

    nl = 3
    nh = nn - nl
    mix = 0.1
    lo = zmath.spacing([extr[0], mix], 'log', nl)
    hi = zmath.spacing([mix, extr[1]], 'lin', nh+1)[1:]
    redz_mix_5 = np.concatenate([lo, hi])

    redz_list = [redz_lin_z, redz_lin, redz_log, redz_mix_1, redz_mix_2, redz_mix_3, redz_mix_4, redz_mix_5]
    gwb_list = [gwb_lin_z, gwb_lin, gwb_log, gwb_mix_1, gwb_mix_2, gwb_mix_3, gwb_mix_4, gwb_mix_5]
    # for redz, save in zip([redz_mix_3, redz_mix_4, redz_mix_5], [gwb_mix_3, gwb_mix_4, gwb_mix_5]):
    # for redz, save in zip([redz_lin_z], [gwb_lin_z]):
    for redz, save in zip(redz_list, gwb_list):
        sam = holo.sam.Semi_Analytic_Model(redz=redz, mmbulge=mmbulge)
        save[ii] = sam.gwb_ideal(fobs_yr, sum=True)
        

In [None]:
# fig, ax = plot.figax(ylim=[5.06e-16, 5.5e-16])
fig, ax = plot.figax()
xvals = nums.astype(int)
print(nums, xvals)
gwb_list = [gwb_lin_z, gwb_lin, gwb_log, gwb_mix_1, gwb_mix_2, gwb_mix_3, gwb_mix_4, gwb_mix_5]
names = ['gwb_lin_z', 'gwb_lin', 'gwb_log', 'gwb_mix_1', 'gwb_mix_2', 'gwb_mix_3', 'gwb_mix_4', 'gwb_mix_5']
lines = [':', '--', '-.', '-', '-', '-', '-', '-']

truth = gwb_lin[-1]
# truth = gwb_lin_z[-1]
# truth = gwb_mix_1[-1]
# truth = None

for gwb, nam, ls in zip(gwb_list, names, lines):
    yy = frac_truth(gwb, truth)
    ax.plot(xvals, yy, label=nam, ls=ls)

ax.legend()
plt.show()