In [None]:
# %load ../../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from datetime import datetime
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath
import zcode.plot as zplot

def draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None, plot_kwargs={}):
    if color is None:
        color = ax._get_lines.get_next_color()

    mm, *conf = np.percentile(gwb, [50, 25, 75], axis=1)
    hh, = ax.plot(xx, mm, alpha=0.5, color=color, label=label, **plot_kwargs)
    ax.fill_between(xx, *conf, color=color, alpha=0.15)

    if (nsamp is not None) and (nsamp > 0):
        nsamp_max = gwb.shape[1]
        idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False)
        for ii in idx:
            ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-')
            
    return hh


In [None]:
SHAPE = 20
TIME = 3 * GYR

gsmf = holo.sam.GSMF_Schechter()               # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()                 # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()                 # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()    # M-MBulge Relation            (MMB)

# hard = holo.hardening.Hard_GW()

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=SHAPE)
hard = holo.hardening.Fixed_Time.from_sam(sam, TIME)

In [None]:
fobs_edges = utils.nyquist_freqs_edges(20*YR, 0.2*YR)
fobs = utils.midpoints(fobs_edges, log=True)
gwb = sam.gwb(fobs_edges, hard=hard, realize=30, )    # calculate many different realizations

In [None]:
xx = fobs * YR
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR,
    ylabel=plot.LABEL_CHARACTERISTIC_STRAIN
)
draw_gwb(ax, xx, gwb)

plot._twin_hz(ax)
plt.show()
fname = Path("~/test.png").expanduser().resolve()
fig.savefig(fname, dpi=300)
print(fname, utils.get_file_size(fname))

# Calculate age during evolution from hardening model

In [None]:
sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=11)
# hard = holo.hardening.Fixed_Time.from_sam(sam, GYR, exact=True)
hard = holo.hardening.Fixed_Time.from_sam(sam, GYR, exact=False)

STEPS = 22
# ()
rmax = hard._sepa_init
# (M,)
rmin = utils.rad_isco(sam.mtot)

extr = np.log10([rmax * np.ones_like(rmin), rmin])
rads = np.linspace(0.0, 1.0, STEPS)[np.newaxis, :]
# (M, X)
rads = extr[0][:, np.newaxis] + (extr[1] - extr[0])[:, np.newaxis] * rads
rads = 10.0 ** rads
# (M, Q, Z, X)
mt, mr, rz, rads = np.broadcast_arrays(sam.mtot[:, np.newaxis, np.newaxis, np.newaxis], sam.mrat[np.newaxis, :, np.newaxis, np.newaxis], sam.redz[np.newaxis, np.newaxis, :, np.newaxis], rads[:, np.newaxis, np.newaxis, :])

# (X, M*Q*Z)
mt, mr, rz, rads = [mm.reshape(-1, STEPS).T for mm in [mt, mr, rz, rads]]
# (X, M*Q*Z) --- `Fixed_Time.dadt` will only accept this shape
dadt = hard.dadt(mt, mr, rads)
# Integrate (inverse) hardening rates to calculate total lifetime
times = -utils.trapz_loglog(-1.0 / dadt, rads, axis=0, cumsum=True)
print(utils.stats(times[-1, :]/GYR))

## Interpolate to target frequencies

In [None]:
# (X, M*Q*Z)
frst_orb_evo = utils.kepler_freq_from_sepa(mt, rads)

# `rz` is shaped (X, M*Q*Z) and is constant for all X
frst_gw = fobs[:, np.newaxis] * (1.0 + rz[0, np.newaxis, :])

xx = frst_orb_evo[1:, :]*2.0
yy = times
xnew = frst_gw
times_new = utils.ndinterp(frst_gw.T, xx.T, yy.T, xlog=True, ylog=True).T

fig, ax = plot.figax()

print(f"{frst_gw.shape=}, {times_new.shape=}")

nbins = times.shape[-1]
for ii in np.random.choice(nbins, 3, replace=False):
    cc, = ax.plot(2*frst_orb_evo[1:, ii]*YR, times[:, ii]/GYR, alpha=0.5, marker='.')
    cc = cc.get_color()
    ax.scatter(frst_gw[:, ii]*YR, times_new[:, ii]/GYR, color=cc, marker='x', alpha=0.5)
    
plt.show()


In [None]:
# np.random.seed(12345)
xx = np.random.uniform(0.0, 1.0, size=(3, 4))
xx = np.sort(xx, axis=-1)
yy = np.random.uniform(0.0, 1.0, size=xx.shape)

fig, ax = plt.subplots()

# xnew = sorted(np.random.uniform(0.0, 1.0, 2))
# ynew = utils.ndinterp(xnew, xx, yy)
# for ii in range(xx.shape[0]):
#     cc, = ax.plot(xx[ii, :], yy[ii, :])
#     cc = cc.get_color()
#     ax.scatter(xnew, ynew[ii, :], color=cc, alpha=0.5, marker='x')
    
xnew = np.sort(np.random.uniform(0.0, 1.0, (xx.shape[0], 2)), axis=1)
print(xnew)
ynew = utils.ndinterp(xnew, xx, yy)
for ii in range(xx.shape[0]):
    cc, = ax.plot(xx[ii, :], yy[ii, :])
    cc = cc.get_color()
    ax.scatter(xnew[ii], ynew[ii, :], color=cc, alpha=0.5, marker='x')
    
plt.show()

# Compare GWBs with different stalling/coalescing cuts

In [None]:
SHAPE = 30
TIME = 5 * GYR
REALS = 100

fobs_edges = utils.nyquist_freqs_edges(10*YR, 0.02*YR)
fobs = utils.midpoints(fobs_edges, log=True)

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=SHAPE)
# hard = holo.hardening.Fixed_Time.from_sam(sam, GYR, exact=True)
hard = holo.hardening.Fixed_Time.from_sam(sam, TIME, exact=False)

In [None]:
gwbs = []
flags = []
for vv in np.ndindex(2, 2):
    print()
    _flags = [bool(ff) for ff in vv]
    flags.append(_flags)
    kw = dict(zero_coalesced=_flags[0], zero_stalled=_flags[1])
    _gwb = sam.gwb(fobs_edges, hard, realize=REALS, **kw)
    gwbs.append(_gwb)
    print()

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR,
    ylabel=plot.LABEL_CHARACTERISTIC_STRAIN,
)
fig.text(0.99, 0.99, f"lifetime = {TIME/GYR:.1f} [Gyr]", ha='right', va='top', fontsize=10)

xx = fobs * YR
for gwb, flag in zip(gwbs, flags):
    lab = f"{str(flag[0])}, {str(flag[1])}"
    plot_kwargs = dict(ls='--') if flag[1] else {}
    draw_gwb(ax, xx, gwb, nsamp=None, label=lab, plot_kwargs=plot_kwargs)

ax.legend(title='Coalesced, Stalled')
plot._twin_hz(ax)
plt.show()

fname = Path("~/coal-stall.png").expanduser()
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")

In [None]:
SHAPE = 30
REALS = 100

fobs_edges = utils.nyquist_freqs_edges()
fobs = utils.midpoints(fobs_edges, log=True)

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=SHAPE)

In [None]:
times_list = [1e-1, 1.0, 5.0, 10.0]
gwb_times = []
flag_times = []
for time in times_list:
    hard = holo.hardening.Fixed_Time.from_sam(sam, time * GYR, exact=False)
    
    gwbs = []
    flags = []
    for flag in [True, False]:
        flags.append(flag)
        _gwb = sam.gwb(fobs_edges, hard, realize=REALS, zero_stalled=flag)
        gwbs.append(_gwb)
        
    gwb_times.append(gwbs)
    flag_times.append(flags)
       

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR,
    ylabel=plot.LABEL_CHARACTERISTIC_STRAIN,
)

xx = fobs * YR
time_lines = []
time_labels = []
for jj, (time, gwbs, flags) in enumerate(zip(times_list, gwb_times, flag_times)):
    # if jj == 1:
    #     continue
    color = None
    flag_lines = []
    flag_labels = []
    for ii, (gwb, flag) in enumerate(zip(gwbs, flags)):
        plot_kwargs = dict(ls='--') if flag else {}
        hh = draw_gwb(ax, xx, gwb, nsamp=None, plot_kwargs=plot_kwargs, color=color)
        flag_lines.append(hh)
        flag_labels.append(str(flag))
        if color is None:
            color = hh.get_color()
        if ii == 1:
            time_lines.append(hh)
            time_labels.append(f"{time:5.2f}")

leg = zplot.legend(ax, time_lines, time_labels, loc='ur', title='lifetime [Gyr]')
zplot.legend(ax, flag_lines, flag_labels, prev=leg, loc='ll', title='stalled')
plot._twin_hz(ax)
plt.show()

fname = Path("~/stall.png").expanduser()
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")