In [None]:
# %load ../../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from datetime import datetime
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

In [None]:
import zcode.math as zmath
import zcode.plot as zplot

In [None]:
log.setLevel(log.DEBUG)

In [None]:
log.level, log.DEBUG, log.INFO

In [None]:
hard_time = 1.0e-1 * GYR
shape = 40
nreals = 100

sam = holo.sam.Semi_Analytic_Model(
    shape=shape
)
hard = holo.hardening.Fixed_Time.from_sam(
    sam, hard_time,
    progress=False
)
pta_dur = 16.03 * YR
nfreqs = 40
hifr = nfreqs/pta_dur
pta_cad = 1.0 / (2 * hifr)
fobs_cents = holo.utils.nyquist_freqs(pta_dur, pta_cad)
fobs_edges = holo.utils.nyquist_freqs_edges(pta_dur, pta_cad)
gwb = sam.gwb(fobs_edges, realize=nreals, hard=hard)

plot.plot_gwb(fobs_cents, gwb)
plt.show()


In [None]:
nbins = [5, 10, 123, 0]
_, fit_lamp, fit_plaw, fit_med_lamp, fit_med_plaw = holo.librarian.fit_spectra(fobs_cents, gwb, nbins=nbins)

In [None]:
num_snaps = len(nbins)
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
for med, fits, ax in zip([fit_med_lamp, fit_med_plaw], [fit_lamp, fit_plaw], axes):
    for ii, nn in enumerate(nbins):
        if np.all(fits[:, ii] == 0.0):
            continue
        color = ax._get_lines.get_next_color()
        kale.dist1d(fits[:, ii], ax=ax, label=str(nn), color=color)
        ax.axvline(med[ii], ls='--', color=color)
    
    ax.legend()
    
plt.show()


In [None]:
fig = plot.plot_gwb(fobs_cents, gwb)
ax = fig.axes[0]

xx = fobs_cents * YR
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'r-', alpha=0.5, lw=1.0, label="$10^{-15} \cdot f_\\mathrm{yr}^{-2/3}$")

fits = holo.librarian.get_gwb_fits_data(fobs_cents, gwb)

for ls, idx in zip([":", "--"], [1, -1]):
    med_lamp = fits['fit_med_lamp'][idx]
    med_plaw = fits['fit_med_plaw'][idx]
    yy = (10.0 ** med_lamp) * (xx ** med_plaw)
    label = fits['fit_nbins'][idx]
    label = 'all' if label in [0, None] else label
    ax.plot(xx, yy, color='k', ls=ls, alpha=0.5, lw=2.0, label=str(label) + " bins")

label = fits['fit_label'].replace(" | ", "\n")
fig.text(0.99, 0.99, label, fontsize=6, ha='right', va='top')

ax.legend()
plt.show()


In [None]:
fig = plot.plot_gwb(fobs_cents, gwb)
ax = fig.axes[0]

xx = fobs_cents * YR
yy = np.median(gwb, axis=-1)
ax.plot(xx, yy, 'k:')

for nn in [5, 10, None]:
    xx, amp, gamma = holo.librarian.fit_powerlaw(fobs_cents, np.median(gwb, axis=-1), nn)
    ax.plot(xx, amp * (xx ** gamma), ls='--')

plt.show()


In [None]:
fig = plot.plot_gwb(fobs_cents, gwb, nsamp=None)
ax = fig.axes[0]

xx = fobs_cents * YR
yy = np.median(gwb, axis=-1)
ax.plot(xx, yy, 'k-')

nreals = gwb.shape[1]

fits = np.zeros((nreals, 2))
for nn in range(nreals):
    yy = gwb[:, nn]
    xx, *fits[nn, :] = holo.librarian.fit_powerlaw(fobs_cents, yy, 5)
    cc, = ax.plot(xx, fits[nn, 0] * (xx ** fits[nn, 1]), ls='--', alpha=0.5)
    cc = cc.get_color()
    ax.plot(fobs_cents*YR, yy, color=cc, alpha=0.5)

plt.show()

draw_fits = fits.copy()
draw_fits[:, 0] = np.log10(draw_fits[:, 0])

kale.corner(draw_fits.T)
plt.show()


In [None]:
hard_time=-2.2957907176750907
hard_gamma_inner=-1.3335554512862717
gsmf_phi0=-2.802178096487384
gsmf_mchar0=11.704311872442908
gsmf_alpha0=-1.7179504809027346
gpf_zbeta=2.397456708546681
gpf_qgamma=0.4609649227136603
gmt_norm=0.5765308121579338
gmt_zbeta=-0.26777937808636665
mmb_amp=8.301258575486393
mmb_plaw=0.4785954601355894
mmb_scatter=0.12386778329303819

hard_time = (10.0 ** hard_time) * GYR
gmt_norm = gmt_norm * GYR
mmb_amp = (10.0 ** mmb_amp) * MSOL

gsmf = holo.sam.GSMF_Schechter(phi0=gsmf_phi0, mchar0_log10=gsmf_mchar0, alpha0=gsmf_alpha0)
gpf = holo.sam.GPF_Power_Law(qgamma=gpf_qgamma, zbeta=gpf_zbeta)
gmt = holo.sam.GMT_Power_Law(time_norm=gmt_norm, zbeta=gmt_zbeta)
mmbulge = holo.host_relations.MMBulge_KH2013(mamp=mmb_amp, mplaw=mmb_plaw, scatter_dex=mmb_scatter)

sam = holo.sam.Semi_Analytic_Model(
    gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge,
    shape=20
)
hard = holo.hardening.Fixed_Time.from_sam(
    sam, hard_time, gamma_sc=hard_gamma_inner,
    progress=False
)
pta_dur = 16.03 * YR
nfreqs = 40
hifr = nfreqs/pta_dur
pta_cad = 1.0 / (2 * hifr)
fobs_cents = holo.utils.nyquist_freqs(pta_dur, pta_cad)
fobs_edges = holo.utils.nyquist_freqs_edges(pta_dur, pta_cad)
gwb = sam.gwb(fobs_edges, realize=10, hard=hard)

plot.plot_gwb(fobs_cents, gwb)
plt.show()


In [None]:
SHAPE = None
TIME = 1.0 * GYR

sam = holo.sam.Semi_Analytic_Model(shape=SHAPE)
hard = holo.hardening.Fixed_Time.from_sam(sam, TIME, interpolate_norm=False)

In [None]:
STEPS = 500

# () start from the hardening model's initial separation
rmax = hard._sepa_init
# (M,) end at the ISCO
rmin = utils.rad_isco(sam.mtot)
# rmin = hard._TIME_TOTAL_RMIN * np.ones_like(sam.mtot)
# Choose steps for each binary, log-spaced between rmin and rmax
extr = np.log10([rmax * np.ones_like(rmin), rmin])
rads = np.linspace(0.0, 1.0, STEPS)[np.newaxis, :]
# (M, X)
rads = extr[0][:, np.newaxis] + (extr[1] - extr[0])[:, np.newaxis] * rads
rads = 10.0 ** rads
# (M, Q, Z, X)
mt, mr, rz, rads = np.broadcast_arrays(
    sam.mtot[:, np.newaxis, np.newaxis, np.newaxis],
    sam.mrat[np.newaxis, :, np.newaxis, np.newaxis],
    sam.redz[np.newaxis, np.newaxis, :, np.newaxis],
    rads[:, np.newaxis, np.newaxis, :]
)
# (X, M*Q*Z)
mt, mr, rz, rads = [mm.reshape(-1, STEPS).T for mm in [mt, mr, rz, rads]]
# (X, M*Q*Z) --- `Fixed_Time.dadt` will only accept this shape
dadt = hard.dadt(mt, mr, rads)
# Integrate (inverse) hardening rates to calculate total lifetime to each separation
times_evo = -utils.trapz_loglog(-1.0 / dadt, rads, axis=0, cumsum=True)


In [None]:
tt = times_evo[-1, :]/GYR
fig, ax = plot.figax(scale='lin')
print(utils.stats(tt))
kale.dist1d(tt, density=True)
plt.show()