In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

In [None]:
# ---- Create initial population

pop = holo.population.Pop_Illustris()

# ---- Apply population modifiers

# resample to increase the number of binaries
mod_resamp = holo.population.PM_Resample(resample=5.0)
# modify population (in-place)
pop.modify(mod_resamp)

# Magic Power-Law Evolution

## Demonstrate functional form

In [None]:
Fixed_Time = holo.evolution.Fixed_Time

In [None]:
rads = np.logspace(-4, 4, 100)
mtot = 1.0e9 * MSOL
mrat = 0.2
g1 = -1.0
g2 = +2.5

fig, ax = plot.figax()

rchar = 300.0 * PC

for norm in [1e7, 1e8, 1e9]:
    yy, _ = Fixed_Time._dadt_dedt(mtot, mrat, rads*PC, norm, rchar, g1, g2)
    yy = np.fabs(yy)
    yy = rads / yy
    ax.plot(rads, yy, label=f"$10^{{{np.log10(norm):.1f}}}$")

ax.axvline(rchar/PC, ls='--')

plt.legend()
plt.show()
    

## Uniform merger-time

In [None]:
fix_time = 2.0 * GYR
fixed = holo.evolution.Fixed_Time.from_pop(pop, fix_time)
evo = holo.evolution.Evolution(pop, fixed)
evo.evolve()

In [None]:
time = evo.tlook
dt = time[:, 0] - time[:, -1]

fig, ax = plot.figax(scale='lin', xlabel='Time: actual/specified', ylabel='density')
kale.dist1d(dt/fix_time, density=True)
plt.show()

In [None]:
sepa = np.logspace(-4, 4, 100) * PC
plot.plot_evo(evo, sepa=sepa)
plt.show()

In [None]:
freqs = holo.utils.nyquist_freqs(20.0, 0.3) / YR
gwb = holo.gravwaves.GW_Discrete(evo, freqs, nreals=10)
gwb.emit()

In [None]:
plot.plot_gwb(gwb)
plt.show()

## Callable Merger Time

In [None]:
fix_time = holo.sam.GMT_Power_Law()
fixed = holo.evolution.Fixed_Time.from_pop(pop, fix_time)
evo = holo.evolution.Evolution(pop, fixed)
evo.evolve()

In [None]:
time = evo.tlook
dt = time[:, 0] - time[:, -1]
dt = dt / GYR
print(utils.stats(dt))

fig, ax = plot.figax(scale='lin', xlabel='Time: actual/specified', ylabel='density')
kale.dist1d(dt, density=True)
plt.show()

# Diagnostics

Calculate normalization to get particular integrated time

In [None]:
time = 2.5 * GYR

args = [mtot, mrat, rchar, g1, g2, 1e4*PC]

norm = Fixed_Time._get_norm(time, *args)[0]
print(f"{norm=:.2e}")
tot = Fixed_Time._time_total(norm, *args)[0]
print(f"{tot/GYR=:.2e} {tot/time=:.2e}")


In [None]:
NUM = int(2e3)
# NUM = 3
mtot = MSOL * 10 ** np.random.uniform(6, 10, NUM)
mrat = 10 ** np.random.uniform(-4, 0, NUM)
time = np.random.uniform(0.0, 10.0, NUM) * GYR
rchar = PC * 10.0 ** np.random.uniform(-1, 2)
# print(f"{mtot=}")
# print(f"{mrat=}")
# print(f"{time=}")

args = [mtot, mrat, rchar, g1, g2, 1e4*PC]

print(f"{time/GYR=:}")
# norm = timed._get_norm(time, *args)
norm = Fixed_Time._get_norm_chunk(time, *args)

print(f"{norm=:}")
tot = Fixed_Time._time_total(norm, *args)
print(f"{tot/GYR=:} {tot/time=:}")

In [None]:
NUM = int(1e4)
mt = 10.0 ** np.random.uniform(6, 11, NUM) * MSOL
mr = 10.0 ** np.random.uniform(-5, 0, NUM)
td = np.random.uniform(0.0, 20.0, NUM+1)[1:] * GYR
rm = 10.0 ** np.random.uniform(3, 5, NUM) * PC
# rm = 1e4 * PC

norm = Fixed_Time._get_norm_chunk(td, mt, mr, 10*PC, -1.0, +2.5, rm)

print(td/GYR)

valid = np.isfinite(norm) & (norm > 0.0)
print("valid = ", utils.frac_str(valid, 4), np.all(valid))

In [None]:
points = [mt, mr, td, rm]
units = [MSOL, 1.0, GYR, PC]
points = [pp/uu for pp, uu in zip(points, units)]
points = np.log10(points).T
interp = sp.interpolate.LinearNDInterpolator(points, np.log10(norm))
backup = sp.interpolate.NearestNDInterpolator(points, np.log10(norm))

In [None]:
def test_and_check(interp, backup, rchar, gamma_one, gamma_two, num=1e2, debug=True):
    NUM = int(1e2)
    _mt = 10.0 ** np.random.uniform(6, 11, NUM) * MSOL
    _mr = 10.0 ** np.random.uniform(-4, 0, NUM)
    _td = np.random.uniform(0.0, 20.0, NUM+1)[1:] * GYR
    _rm = 10.0 ** np.random.uniform(3, 5, NUM) * PC

    test_points = [_mt, _mr, _td, _rm]
    test_points = [pp/uu for pp, uu in zip(test_points, units)]
    test_points = np.log10(test_points).T
    tests = 10.0 ** interp(test_points)
    
    bads = ~np.isfinite(tests)
    num_bad = np.count_nonzero(bads)
    if (num_bad > 0) and debug:
        print(f"WARNING: found non-finite test values {utils.frac_str(bads)}")
        for tt in test_points.T:
            print(f"\t{tt[bads]:}")

    backup_points = [tt[bads] for tt in test_points.T]
    tests[bads] = 10.0 ** backup(np.array(backup_points).T)
    bads = ~np.isfinite(tests)
    if np.any(bads):
        print(f"WARNING: non-finite test values after backup {utils.frac_str(bads)}")
        raise
            
    checks = Fixed_Time._get_norm_chunk(_td, _mt, _mr, rchar, gamma_one, gamma_two, _rm)
    bads = ~np.isfinite(checks)
    if np.any(bads):
        print(f"WARNING: found non-finite check values {utils.frac_str(bads)}")
        for tt in test_points.T:
            print(f"\t{tt[bads]:}")
            
    return tests, checks, test_points, num_bad
    

In [None]:
tests, checks, test_points, num_bad = test_and_check(interp, backup, 10.0*PC, -1.0, +2.5, debug=False)
frac = tests/checks
print(f"{num_bad=} = {num_bad/tests.size:.2e} ::: {utils.stats(frac, prec=4)}")

In [None]:
nums_list = [1e3, 3e3, 1e4, 3e4, 1e5]
nums_bad = np.zeros_like(nums_list)
errors = np.zeros((nums_bad.size, 3))

for ii, num in enumerate(utils.tqdm(nums_list)):
    interp, backup = Fixed_Time._calculate_interpolant(10.0*PC, -1.0, +2.5, num_points=num)
    tests, checks, test_points, nbad = \
        test_and_check(interp, backup, 10.0*PC, -1.0, +2.5, debug=False)
    fracs = tests / checks
    nums_bad[ii] = nbad
    errors[ii, :] = utils.quantiles(fracs, sigmas=[-1, 0, 1])
    

In [None]:
fig, ax = plot.figax(yscale='lin')
ax.plot(nums_list, nums_bad)
plt.show()


fig, ax = utils.figax(yscale='lin')
ax.plot(nums_list, errors[:, 1])
ax.fill_between(nums_list, errors[:, 0], errors[:, -1], alpha=0.2)
plt.show()