In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
def gwb_sa_gw_only(sam, fobs_gw):
    ndens = sam.static_binary_density   # This is ``d^3 n / [dlog10(M) dq dz]``
    mt, mr, rz = [vv[..., np.newaxis] for vv in sam.grid]
    mc = utils.chirp_mass_mtmr(mt, mr)
    fogw = np.asarray(fobs_gw)[np.newaxis, np.newaxis, np.newaxis, :]

    pref = 4 * np.pi * np.power(np.pi * fogw, -4.0/3.0) / SPLC**2
    integ = (ndens[..., np.newaxis] / MPC**3) * np.power(NWTG*mc, 5.0/3.0) / np.power(1+rz, 1.0/3.0)
    hc = pref * utils._integrate_grid_differential_number(sam.edges, integ, freq=False)
    hc = np.sum(hc, axis=(0, 1, 2))
    hc = np.sqrt(hc)
    return hc

In [None]:
NREALS = 30
SHAPE = 100
hard = holo.evolution.Hard_GW()
sam = holo.sam.Semi_Analytic_Model(shape=SHAPE)

In [None]:
ff = utils.nyquist_freqs(20*YR, 0.2*YR)
# ff = fobs_gw_edges
gwb_sa = gwb_sa_gw_only(sam, ff)
fig, ax = plot.figax()
xx = ff * YR
plot._draw_plaw(ax, xx, f0=1)
ax.plot(xx, gwb_sa, 'b+-')

In [None]:
DUR = 100.0 * YR
FBIN = 0
FMIN = 1 / DUR
DF = FMIN / 2.0
fobs_gw = FMIN
fobs_gw_edges = FMIN * (FBIN + 1) + np.array([-DF, +DF])
fobs_gw_all = [fobs_gw_edges[0], fobs_gw, fobs_gw_edges[1]]
print(fobs_gw_edges*YR, "1/yr")

In [None]:
gwb_test = sam.gwb(fobs_gw_edges, realize=NREALS, hard=hard)

gff = np.zeros(NREALS)
gwf = np.zeros(NREALS)
gwb = np.zeros(NREALS)

for ii in range(NREALS):
    fobs_orb_edges = fobs_gw_edges / 2.0
    # `fobs_orb` is returned in `edges[3]`, and vals[3] is also observer-frame orbital frequencies
    vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(sam, hard, fobs_orb=fobs_orb_edges, sample_threshold=1e2, poisson_inside=True, poisson_outside=True)
    gff[ii], gwf[ii], gwb[ii] = holo.gravwaves._gws_from_samples(vals, weights, fobs_gw_edges)

In [None]:
fobs_gw_all, gwb_sa_ref

In [None]:
gwb_sa_ref = gwb_sa_gw_only(sam, fobs_gw_all)

kw = dict(density=True, hist=True, carpet=True, confidence=True, quantiles=[0.5])
kale.dist1d(gwb_test, color='r', **kw)
gwb_mc = np.sqrt(gwb**2 + gwf**2)
kale.dist1d(gwb_mc, color='k', **kw)
ax = plt.gca()
# ax.set(xlim=[4.0e-14, 4.1e-14])
for sa in gwb_sa_ref:
    ax.axvline(sa, color='b', ls=':', alpha=0.5)

plt.show()

In [None]:
ratio = gwb_test[:, np.newaxis] / gwb_mc[np.newaxis, :]
np.mean(ratio), np.median(ratio), np.std(ratio)

# Try different SAM grid sizes

In [None]:
NREALS = 30
SHAPE = 10
hard = holo.evolution.Hard_GW()
sam = holo.sam.Semi_Analytic_Model(shape=SHAPE)

gwb_test = sam.gwb(fobs_gw_edges, realize=NREALS, hard=hard)

gff = np.zeros(NREALS)
gwf = np.zeros(NREALS)
gwb = np.zeros(NREALS)

for ii in range(NREALS):
    fobs_orb_edges = fobs_gw_edges / 2.0
    # `fobs_orb` is returned in `edges[3]`
    vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(sam, hard, fobs_orb=fobs_orb_edges, sample_threshold=1e2, poisson_inside=True, poisson_outside=True)
    gff[ii], gwf[ii], gwb[ii] = holo.gravwaves._gws_from_samples(vals, weights, fobs_gw_edges)
    
sa = gwb_test
mc = np.sqrt(gwf**2 + gwb**2)
data_100 = dict(gwb_sa=sa, gwb_mc=mc)

kw = dict(density=True, hist=True, carpet=True, confidence=True, quantiles=[0.5])
kale.dist1d(sa, color='r', **kw)
kale.dist1d(mc, color='k', **kw)

In [None]:
data = [data_20, data_30, data_40, data_50, data_60, data_80, data_100]
nums = [20, 30, 40, 50, 60, 80, 100]
nd = len(data)
mc = np.zeros((nd, NREALS))
sa = np.zeros((nd, NREALS))

for ii, dd in enumerate(data):
    mc[ii, :] = dd['gwb_mc'][:]
    sa[ii, :] = dd['gwb_sa'][:]


In [None]:
fig, ax = plot.figax()

xx = nums
ax.plot(xx, np.median(mc, axis=1), color='k')
ax.fill_between(xx, *np.percentile(mc, [25, 75], axis=1), color='k', alpha=0.1)

ax.plot(xx, np.median(sa, axis=1), color='r')
ax.fill_between(xx, *np.percentile(sa, [25, 75], axis=1), color='r', alpha=0.1)

ratio = sa/mc
tw = ax.twinx()
tw.plot(xx, np.median(ratio, axis=1), 'b-')
tw.fill_between(xx, *np.percentile(ratio, [25, 75], axis=1), color='b', alpha=0.1)

plt.show()

# Try different tweaks to calculation (e.g. bin edges vs. cents...)

In [None]:
fobs_orb_edges = fobs_edges / 2.0
edges, dnum = sam.dynamic_binary_number(hard, fobs=fobs_orb_edges)
number_0 = holo.utils._integrate_grid_differential_number(edges, dnum, freq=True)
number_1 = holo.utils._integrate_grid_differential_number(edges, dnum, freq=False)
number_1 = holo.utils.trapz(number_1/edges[3], edges[3], axis=3, cumsum=False)

ratio = np.nan_to_num(number_1/number_0)
print(kale.utils.stats_str(ratio))

In [None]:
cents = [kale.utils.midpoints(ee, log=False) for ee in edges]
cgrid = np.meshgrid(*cents, indexing='ij')
egrid = np.meshgrid(*edges, indexing='ij')
df = kale.utils.midpoints(np.diff(egrid[-1], axis=-1), axis=(0, 1, 2))
dlnf = kale.utils.midpoints(np.diff(np.log(egrid[-1]), axis=-1), axis=(0, 1, 2))
print(np.shape(cgrid), np.shape(egrid), df.shape, dlnf.shape)

In [None]:
def hs_from_grid(grid):
    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(grid[0], grid[1]))
    dc = cosmo.comoving_distance(grid[2]).cgs.value
    # These should be *orbital*-frequencies
    fr = utils.frst_from_fobs(grid[3], grid[2])
    hs = utils.gw_strain_source(mc, dc, fr)
    return hs

hs_cents = hs_from_grid(cgrid)
hs_edges = hs_from_grid(egrid)
hs_edge_cents = kale.utils.midpoints(hs_edges, axis=None)

ratio = hs_cents/hs_edge_cents
print(kale.utils.stats_str(ratio))

In [None]:
hs_weight = dnum * (hs_edges ** 2)
print(utils.stats(hs_weight))
hs_weight = np.sqrt(utils._integrate_grid_differential_number(edges, hs_weight, freq=True) / number_0)
hs_weight = np.nan_to_num(hs_weight)
ratio = hs_weight / hs_edge_cents
print(kale.utils.stats_str(ratio))

In [None]:
hc_0 = np.sqrt(np.sum(number_0 * (hs_cents**2) * cgrid[-1] / df, axis=(0, 1, 2)))
hc_1 = np.sqrt(np.sum(number_0 * (hs_cents**2) / dlnf, axis=(0, 1, 2)))
print(hc_0, hc_1)
ratio = hc_1/hc_0
print(ratio)

In [None]:
# gwb_ref = sam.gwb(fobs_edges, realize=NREALS, hard=hard)
gwb_test = sam.gwb(fobs_edges, realize=NREALS, hard=hard)

In [None]:
# gwb_ref = sam.gwb(fobs_edges, realize=NREALS, hard=hard)
gwb_test = sam.gwb(fobs_edges, realize=NREALS, hard=hard)

gff = np.zeros(NREALS)
gwf = np.zeros(NREALS)
gwb = np.zeros(NREALS)

for ii in range(NREALS):
    vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(sam, hard, fobs=fobs_edges, sample_threshold=1e2, poisson_inside=True, poisson_outside=True)
    gff[ii], gwf[ii], gwb[ii] = holo.gravwaves._gws_from_samples(vals, weights, 2*fobs_edges)

In [None]:
tot = np.sqrt(gwb**2 + gwf**2)
aves = []
meds = []
for val in [gwb_ref, tot]:
    sel = np.percentile(val, [25, 75])
    sel = (sel[0] < val) & (val < sel[1])
    ave = np.mean(val[sel])
    med = np.median(val[sel])
    std = np.std(val[sel])
    aves.append(ave)
    meds.append(med)

for vals in [aves, meds]:
    diff = (vals[1] - vals[0]) / vals[0]
    print(f"{diff:.4f}")

In [None]:
kw = dict(density=True, hist=True, carpet=True, confidence=True, quantiles=[0.5])
kale.dist1d(gwb_ref, **kw)
kale.dist1d(gwb_test, color='r', **kw)
kale.dist1d(np.sqrt(gwf**2 + gwb**2), color='k', **kw)

# Where to calculate strains relative to bins?  Centers?  Edges?  Weighted?

In [None]:
import zcode.math as zmath

In [None]:
NUM = 1e3
masses = zmath.random_power([1e6, 1e10], -2, int(NUM))
bin_edges = zmath.spacing(masses, 'log', 20)
bin_cents = zmath.midpoints(bin_edges)

fig, ax = plot.figax()
ax.hist(masses, bins=bin_edges, weights=0.6*np.ones_like(masses), alpha=0.2)
plt.show()

In [None]:
mrat = 0.2
redz = 0.1
dcom = cosmo.comoving_distance(redz).cgs.value
fobs = 1.0 / YR
frst_orb = (fobs / 2.0) * (1.0 + redz)
mchirp = utils.chirp_mass(*utils.m1m2_from_mtmr(masses * MSOL, mrat))
print(utils.stats(mchirp))
hs = utils.gw_strain_source(mchirp, dcom, frst_orb)
print(utils.stats(hs))

hs_bins, *_ = sp.stats.binned_statistic(masses, hs**2, statistic='mean', bins=bin_edges)
hs_bins = np.sqrt(hs_bins)

mchirp_edges = utils.chirp_mass(*utils.m1m2_from_mtmr(bin_edges * MSOL, mrat))
hs_edges = utils.gw_strain_source(mchirp_edges, dcom, frst_orb)
hs_edges_cent_lin = zmath.midpoints(hs_edges, log=False)
hs_edges_cent_log = zmath.midpoints(hs_edges, log=True)

mchirp_cents = utils.chirp_mass(*utils.m1m2_from_mtmr(bin_cents * MSOL, mrat))
hs_cents = utils.gw_strain_source(mchirp_cents, dcom, frst_orb)

fig, axes = plot.figax(ncols=2)

ax = axes[0]
plot.draw_hist_steps(ax, bin_edges, hs_bins)
ax.scatter(bin_edges, hs_edges, color='r', marker='+', alpha=0.2)
ax.scatter(bin_cents, hs_cents, color='b', marker='x', alpha=0.2)
ax.scatter(bin_cents, hs_edges_cent_lin, color='g', marker='|', alpha=0.2)
ax.scatter(bin_cents, hs_edges_cent_log, color='yellow', marker='.', alpha=0.5)
# ax.set(xlim=[3e8, 3e9], ylim=[3e-17, 1e-15])

plt.show()