This notebook is for testing/developing the sampling of `_Population_Discrete` models, typically the `Discrete_Illustris` population model.  (Re)Sampling of the population is performed using `kalepy`.

In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR
import holodeck.gravwaves

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import holodeck.extensions

fobs_gw = holo.utils.nyquist_freqs(dur=10.0*YR, cad=0.1*YR)

fobs_orb = fobs_gw / 2.0
realizer = holo.extensions.Realizer(fobs_orb)

In [None]:
samples = []
for ii in tqdm.trange(4):
    nn, samps = realizer()
    samples.append(samps)

## Calculate GWB Without fully sampling the Universe

In [None]:
resamp = holo.population.PM_Resample(1.0)
pop = holo.population.Pop_Illustris(mods=resamp)
print(f"{pop.size=}")

fixed = holo.hardening.Fixed_Time.from_pop(pop, 2.0 * GYR)
evo = holo.evolution.Evolution(pop, fixed)
evo.evolve()

In [None]:
# construct sampling frequencies
freqs_volumetric = holo.utils.nyquist_freqs(dur=10.0*YR, cad=0.1*YR)
# calculate discretized GW signals
gwb_volumetric = holo.gravwaves.GW_Discrete(evo, freqs_volumetric, nreals=100)
gwb_volumetric.emit()

# Compare Different GWB Calculation methods

In [None]:
def _calc(fobs_gw_edges, evo):
    fobs_gw_cents = kale.utils.midpoints(fobs_gw_edges, log=False)
    dlnf = np.diff(np.log(fobs_gw_edges))

    PARAMS = ['mass', 'sepa', 'dadt', 'scafa']
    # convert from GW to orbital frequencies
    fobs_orb_cents = fobs_gw_cents / 2.0
    fobs_orb_edges = fobs_gw_edges / 2.0
    data_fobs = evo.at('fobs', fobs_orb_cents, params=PARAMS)
    
    redz = cosmo.a_to_z(data_fobs['scafa'])
    valid = (redz > 0.0)
    # rest-frame GW-frequencies
    frst_gw_cents = utils.frst_from_fobs(fobs_gw_cents[np.newaxis, :], redz)
    frst_orb_cents = frst_gw_cents / 2.0
    dcom = cosmo.z_to_dcom(redz)
    m1, m2 = np.moveaxis(data_fobs['mass'], -1, 0)
    # convert from GW- to orbital-fequencies
    dfdt, _ = utils.dfdt_from_dadt(data_fobs['dadt'], data_fobs['sepa'], frst_orb=frst_orb_cents)

    _lambda_factor = utils.lambda_factor_dlnf(frst_orb_cents, dfdt, redz, dcom=dcom) / evo._sample_volume
    num_binaries = _lambda_factor * dlnf[np.newaxis, :]

    # ! ---- Direct ----
    mchirp = utils.chirp_mass(m1, m2)
    hs2 = utils.gw_strain_source(mchirp, dcom, frst_orb_cents)**2
    gwb_direct = np.zeros_like(hs2)
    gwb_direct[valid] = hs2[valid] * np.random.poisson(num_binaries[valid])
    gwb_direct = np.sum(gwb_direct, axis=0) / dlnf
    gwb_direct = np.sqrt(gwb_direct)


    # ! ---- Sample Values ----
    # select only valid entries
    mt, mr = utils.mtmr_from_m1m2(m1[valid], m2[valid])
    # broadcast `fobs` to match the shape of binaries, then select valid entries
    fo = (fobs_orb_cents[np.newaxis, :] * np.ones_like(redz))[valid]
    redz = redz[valid]
    weights = num_binaries[valid]

    vals = np.asarray([mt, mr, redz, fo])
    weights = np.random.poisson(weights)
    *_, gwb_vals = holo.gravwaves._gws_from_samples(vals, weights, fobs_gw_edges)
    

    # ! ---- Sample Full Universe ---- ! #    
    names, samples, other_direct, _vals, _vals_weights = evo._sample_universe(fobs_orb_edges, down_sample=None)
    # fo = samples[-1]
    # print(f"{freqs[0]*YR=}, {freqs[1]*YR=}")
    # print(f"{utils.stats(fo*YR)=}")
    # print(f"{utils.frac_str(fo < freqs[1])=}")
    # raise

    weights = np.ones_like(samples[0])
    _, _gwf, _gwb = holo.gravwaves._gws_from_samples(samples, weights, freqs)
    gwb_samples = np.linalg.norm([_gwf, _gwb], axis=0)    
    
    return gwb_direct, gwb_vals, other_direct, gwb_samples


def calc(fobs_gw_edges, evo, nreals):
    shape = (fobs_gw_edges.size - 1, nreals)
    gwb_direct = np.zeros(shape)
    other_direct = np.zeros(shape)
    gwb_samples = np.zeros(shape)
    gwb_vals = np.zeros(shape)
    for rr in tqdm.trange(nreals):
        gwb_direct[:, rr], gwb_vals[:, rr], other_direct[:, rr], gwb_samples[:, rr] = _calc(fobs_gw_edges, evo)

    return gwb_direct, gwb_vals, other_direct, gwb_samples

freqs = holo.utils.nyquist_freqs(dur=1*YR, cad=0.01*YR)
gwb_direct, gwb_vals, other_direct, gwb_samples = calc(freqs, evo, nreals=30)
gwb_volumetric = holo.gravwaves.GW_Discrete(evo, freqs, nreals=100)
gwb_volumetric.emit()

In [None]:
fig, ax = plot.figax(figsize=[10, 8], xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)


# ---- Discrete
gwb_list = [gwb_direct, gwb_vals, gwb_samples, gwb_volumetric.both]
gwb_names = ['gwb_direct', 'gwb_vals', 'samples', 'gwb_volumetric']
# gwb_list = [gwb_samples, gwb_volumetric.both]
# gwb_names = ['samples', 'gwb_volumetric']
xvals_list = [freqs*YR] * len(gwb_list)
xx = utils.minmax(np.concatenate(xvals_list))
plot._draw_plaw(ax, xx, 1e-15, 1)
kw = dict(yfilter=True, percs=[10, 90])
for xx, gwb, lab in zip(xvals_list, gwb_list, gwb_names):
    hh, _ = plot.draw_med_conf(ax, xx, gwb, label=lab, **kw)
    col = hh.get_color()
    size = gwb.shape[1]
    sel = np.min([size, 5])
    sel = np.random.choice(size, sel, replace=False)

    xx = kale.utils.midpoints(xx) if xx.size == gwb.shape[0]+1 else xx
    ax.plot(xx, gwb[:, sel], color=col, alpha=0.35, lw=0.5)

# ax.set(xlim=[1, 10], ylim=[1e-17, 2e-15])
ax.legend()
plt.show()

## Fully sample Universe

In [None]:
NUM = 10
freqs_samples = np.logspace(0, 1, NUM) / YR

names, samples_10 = evo._sample_universe(freqs_samples)
num_samp_10 = samples_10[0].size
print(names, samples_10[0].shape, f"{num_samp_10:.4e}")

In [None]:
print(names, samples_10[0].shape, f"{num_samp_10:.4e}")

In [None]:
fobs = holo.utils.nyquist_freqs(dur=1.0*YR, cad=0.1*YR)

In [None]:
weights_10 = np.ones_like(samples_10[0])
gff_10, gwf_10, gwb_10 = holo.sam._gws_from_samples(samples_10, weights_10, fobs)

In [None]:
NUM = 100
freqs_samples = np.logspace(0, 1, NUM) / YR

names, samples_100 = evo._sample_universe(freqs_samples)
num_samp_100 = samples_100[0].size
print(names, samples_100[0].shape, f"{num_samp_100:.4e}")

In [None]:
print(names, samples_100[0].shape, f"{num_samp_100:.4e}")

In [None]:
weights_100 = np.ones_like(samples_100[0])
gff_100, gwf_100, gwb_100 = holo.sam._gws_from_samples(samples_100, weights_100, fobs)

In [None]:
fig, ax = plot.figax(xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

xx = freqs_volumetric * YR
med, *conf = np.percentile(gwb_volumetric.back, [50, 25, 75], axis=-1)
hh, = ax.plot(xx, med, 'k--')
ax.fill_between(xx, *conf, alpha=0.1, color=hh.get_color())

xx = gff_10 * YR
hh = ax.scatter(xx, gwf_10)
plot.draw_hist_steps(ax, fobs*YR, gwb_10, color=hh.get_facecolor())

xx = gff_100 * YR
hh = ax.scatter(xx, gwf_100)
plot.draw_hist_steps(ax, fobs*YR, gwb_100, color=hh.get_facecolor())


plt.show()

## Test

In [None]:
# construct sampling frequencies
freqs_volumetric = holo.utils.nyquist_freqs(dur=10.0*YR, cad=0.1*YR)
# calculate discretized GW signals
gwb_volumetric = holo.gravwaves.GW_Discrete(evo, freqs_volumetric, nreals=100)
gwb_volumetric.emit()

In [None]:
# NUM = 10
# DOWN = None
DOWN = 1e2
# freqs = np.logspace(0, 1, NUM) / YR
# fobs = holo.utils.nyquist_freqs(dur=10.0*YR, cad=0.1*YR)
fobs = holo.utils.nyquist_freqs(dur=10.0*YR, cad=2*YR)
# fobs = holo.utils.nyquist_freqs(dur=1.0*YR, cad=0.1*YR)
freqs = fobs

REALS = 10
gff = np.zeros((fobs.size - 1, REALS))
gwf = np.zeros_like(gff)
gwb = np.zeros_like(gff)
check_direct = np.zeros_like(gff)
check_vals = np.zeros_like(gff)
for rr in tqdm.trange(REALS):
    names, samples, check_direct[:, rr], vals, vals_weights = evo._sample_universe(freqs, down_sample=DOWN)
    weights = np.ones_like(samples[0])
    if DOWN is not None:
        weights *= DOWN
        vals_weights *= DOWN
    gff[:, rr], gwf[:, rr], gwb[:, rr] = holo.gravwaves._gws_from_samples(samples, weights, fobs)
    _gff, _gwf, check_vals[:, rr] = holo.gravwaves._gws_from_samples(vals, np.random.poisson(vals_weights), fobs)
    check_vals[:, rr] = np.sqrt(check_vals[:, rr]**2 + _gwf**2)

In [None]:
fig, ax = plot.figax(xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)
plot._draw_plaw(ax, [0.1, 10], 1e-15, 1)

xx = freqs_volumetric * YR
plot.draw_med_conf(ax, xx, gwb_volumetric.back, color='k', label='volumetric')

xx = fobs * YR
plot.draw_med_conf(ax, xx, check_direct, label='direct')
plot.draw_med_conf(ax, xx, check_vals, label='vals')

gwb_both = np.sqrt(gwb**2 + gwf**2)
plot.draw_med_conf(ax, xx, gwb_both, label='samples')

ax.legend()
plt.show()

In [None]:
NUM = 10
DOWN = None
freqs = np.logspace(0, 1, NUM) / YR
fobs = holo.utils.nyquist_freqs(dur=1.0*YR, cad=0.1*YR)

names, samples = evo._sample_universe(freqs, down_sample=DOWN)
num_samp = samples[0].size
print(names, samples[0].shape, f"{num_samp:.4e}")

weights = np.ones_like(samples[0])
if DOWN is not None:
    weights *= DOWN
gff, gwf, gwb = holo.sam._gws_from_samples(samples, weights, fobs)


fig, ax = plot.figax(xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

xx = freqs_volumetric * YR
med, *conf = np.percentile(gwb_volumetric.back, [50, 25, 75], axis=-1)
hh, = ax.plot(xx, med, 'k--')
ax.fill_between(xx, *conf, alpha=0.1, color=hh.get_color())

xx = gff * YR
hh = ax.scatter(xx, gwf)
plot.draw_hist_steps(ax, fobs*YR, gwb, color=hh.get_facecolor())

plt.show()

In [None]:
NUM = 100
DOWN = None
freqs = np.logspace(0, 1, NUM) / YR
fobs = holo.utils.nyquist_freqs(dur=1.0*YR, cad=0.1*YR)

names, samples = evo._sample_universe(freqs, down_sample=DOWN)
num_samp = samples[0].size
print(names, samples[0].shape, f"{num_samp:.4e}")

weights = np.ones_like(samples[0])
if DOWN is not None:
    weights *= DOWN
gff, gwf, gwb = holo.sam._gws_from_samples(samples, weights, fobs)


fig, ax = plot.figax(xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

xx = freqs_volumetric * YR
med, *conf = np.percentile(gwb_volumetric.back, [50, 25, 75], axis=-1)
hh, = ax.plot(xx, med, 'k--')
ax.fill_between(xx, *conf, alpha=0.1, color=hh.get_color())

xx = gff * YR
hh = ax.scatter(xx, gwf)
plot.draw_hist_steps(ax, fobs*YR, gwb, color=hh.get_facecolor())

plt.show()

In [None]:
NUM = 100
DOWN = None
freqs = np.logspace(0, 1, NUM) / YR
fobs = holo.utils.nyquist_freqs(dur=1.0*YR, cad=0.1*YR)

names, samples = evo._sample_universe(freqs, down_sample=DOWN)
num_samp = samples[0].size
print(names, samples[0].shape, f"{num_samp:.4e}")

weights = np.ones_like(samples[0])
if DOWN is not None:
    weights *= DOWN
gff, gwf, gwb = holo.sam._gws_from_samples(samples, weights, fobs)


fig, ax = plot.figax(xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

xx = freqs_volumetric * YR
med, *conf = np.percentile(gwb_volumetric.back, [50, 25, 75], axis=-1)
hh, = ax.plot(xx, med, 'k--')
ax.fill_between(xx, *conf, alpha=0.1, color=hh.get_color())

xx = gff * YR
hh = ax.scatter(xx, gwf)
plot.draw_hist_steps(ax, freqs*YR, gwb, color=hh.get_facecolor())

plt.show()

## Other

There should be a better way to sample in frequencies because we know the power-law index of the hardening rate (and thus the expected number) of sources vs. frequencies.
* Do normal `kalepy` resampling, and then override the frequencies manually?
* Do grid-resampling in frequency, and KDE resampling for other parameters?