In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
# import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import utils, plot, cosmo
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.sam.MMBulge_Simple()     # M-MBulge Relation            (MMB)

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)

In [None]:
# Choose observed GW-Frequency bins based on nyquist sampling
fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)

In [None]:
vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
        sam, holo.evolution.Hard_GW, fobs=fobs,
        sample_threshold=5.0, cut_below_mass=3e7, limit_merger_time=4*GYR,
)

In [None]:
gwf_freqs, gwf, gwb = holo.sam._gws_from_samples(vals, weights, fobs)

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=100)

In [None]:
fig, ax = plot.figax()

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

# ff = np.sqrt(xx / np.diff(fobs*YR))
ff = 1.0

ax.plot(xx, gwb*ff, 'k-')

idx = (gwf > gwb)
xx = gwf_freqs * YR   # [1/sec] ==> [1/yr]
ax.scatter(xx[idx], (gwf*ff)[idx], color='r', s=20, alpha=0.5)
ax.scatter(xx[~idx], (gwf*ff)[~idx], edgecolor='r', facecolor='none', s=20, alpha=0.5)

# ax.plot(fobs*YR, gwb_smooth, 'b--')
# ax.plot(fobs*YR, np.median(gwb_rough, axis=-1), 'b:')
ax.plot(xx, gwb_smooth, 'b--')
# ax.plot(xx, np.median(gwb_rough, axis=-1), 'b:')
# ax.plot(fobs*YR, np.mean(gwb_rough, axis=-1), 'b-.')
ax.fill_between(xx, *np.percentile(gwb_rough, [25, 75], axis=-1), color='b', alpha=0.25)

plt.show()

In [None]:
SAMP_NREALS = 100
gwb = np.zeros((fobs.size-1, SAMP_NREALS))
for ii in utils.tqdm(range(SAMP_NREALS)):
    gwf_freqs, gwf, _gwb = holo.sam.sampled_gws_from_sam(
        sam, fobs=fobs, hard=holo.evolution.Hard_GW, 
        sample_threshold=5.0, cut_below_mass=None, limit_merger_time=None,
    )
    gwb[:, ii] = np.sqrt(_gwb[:]**2 + gwf**2)

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
if np.ndim(gwb) == 1:
    ax.plot(xx, gwb, 'k-', alpha=0.25)
else:
    plot_gwb(xx, gwb, label='sampled')

plt.legend()
plt.show()

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
plot_gwb(xx, gwb, label='sampled')

plt.legend()
plt.show()

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
plot_gwb(xx, gwb, label='sampled')

plt.legend()
plt.show()

# Manual

In [None]:
vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
        sam, holo.evolution.Hard_GW, fobs=fobs,
        sample_threshold=5.0, cut_below_mass=None, limit_merger_time=None,
)

In [None]:
kale.utils.jshape(edges)

In [None]:
counts, _ = np.histogramdd(vals.T, bins=edges, weights=weights)
print(counts.shape)

In [None]:
import zcode
import zcode.math as zmath

In [None]:
LOG = True
idx = (None, 0, 0, None)
print(counts[idx].squeeze().shape)
print(mass[idx].squeeze().shape)

fig, axes = plt.subplots(figsize=[10, 4], ncols=2)
extr = zmath.minmax(mass, prev=zmath.minmax(counts), limit=[0.5, None])
if LOG:
    extr = np.log10(extr)
smap = plot.smap(extr)
print(f"{extr=}")
for ii, (ax, qq) in enumerate(zip(axes, [counts, mass])):
    qq = qq[idx].squeeze()
    if LOG:
        qq = np.log10(qq)
    ax.pcolormesh(qq, cmap=smap.cmap, norm=smap.norm)
    
plt.colorbar(smap)
plt.show()

In [None]:
mc = utils.chirp_mass(*utils.m1m2_from_mtmr(vals[0], vals[1]))
rz = vals[2]
fo = vals[3]
frst = utils.frst_from_fobs(fo, rz)
dc = cosmo.comoving_distance(rz).cgs.value
hs = utils.gw_strain_source(mc, dc, frst)
print(hs.shape, fo.shape, weights.shape)

In [None]:
idx = np.argsort(fo)
fo = fo[idx]
hs = hs[idx]
weights = weights[idx]

idx = np.digitize(fo, fobs) - 1
gwb = np.zeros(fobs.size-1)

for ii in utils.tqdm(range(gwb.size)):
    sel = (idx == ii)
    temp = weights[sel] * (hs[sel] ** 2)
    gwb[ii] = np.sum(temp)
    
gwb = np.sqrt(gwb)
print(gwb)

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)
ax.plot(xx, gwb, 'k-', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
# plot_gwb(xx, gwb, label='sampled')

plt.legend()
plt.show()