In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode
import zcode.math as zmath
import zcode.plot as zplot

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
# shape = (150, 151, 152)
shape = (30, 31, 32)
# shape = None

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape)

In [None]:
# fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)
fobs = utils.nyquist_freqs(10.0*YR, 1.0*YR)
# fobs = kale.utils.spacing(fobs, scale='log', num=fobs.size)

In [None]:
NUM_REALS = 10

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=NUM_REALS)

In [None]:
gwb = np.zeros((fobs.size-1, NUM_REALS))
gwf = np.zeros((fobs.size-1, NUM_REALS))
for ii in utils.tqdm(range(NUM_REALS)):
    gwf_freqs, gwf[:, ii], gwb[:, ii] = holo.sam.sampled_gws_from_sam(
        sam, fobs=fobs, hard=holo.evolution.Hard_GW, cut_below_mass=1e6*MSOL, limit_merger_time=None,
        sample_threshold=100, poisson_inside=False,
    )
    # gwb[:, ii] = np.sqrt(_gwb[:]**2 + gwf**2)
    # break

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

# amp = 10e-16
# yy = amp * np.power(xx, -2/3)
# ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
if np.ndim(gwb) == 1:
    ax.plot(xx, gwb, 'k-', alpha=0.25)
else:
    plot_gwb(xx, gwb, label='sampled')
    # plot_gwb(xx, np.sqrt(gwb**2 + gwf**2), label='sampled')

plt.legend()
plt.show()
# fig.savefig('temp.png')

# Manual

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.evolution.Hard_GW
# shape = (150, 151, 152)
shape = (30, 31, 32)
shape = (30, 31, 300)
# shape = None

sam = holo.sam.Semi_Analytic_Model(
        gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape,
        redz=[0.0, 6.0, 32]
)
print(f"{sam.shape=}")

fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)

vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
        sam, hard, fobs=fobs, cut_below_mass=None, limit_merger_time=None,
        sample_threshold=10.0, poisson_inside=True, poisson_outside=True,
)
print(f"{weights.size=:.4e}, {weights.sum()=:.8e}")

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=NUM_REALS)

In [None]:
use_weights = np.copy(weights)
print(kale.utils.stats_str(use_weights))

# sel = (use_weights > 1.0)
# sel = (use_weights < 1.0)
# sel = (use_weights != 1.0)
# sel = ~sel
sel = slice(None)

use_weights = use_weights[sel]
# use_weights = np.random.poisson(use_weights)
print(kale.utils.stats_str(use_weights))

# print(kale.utils.stats_str(use_weights[sel]))
# print(kale.utils.stats_str(use_weights[~sel]))
# print(f"weights: {zmath.frac_str(sel)}")

gff, gwf, gwb = holo.sam._gws_from_samples(vals[:, sel], use_weights, fobs)
fig, ax = plot.figax()
ff = kale.utils.midpoints(fobs)
ax.plot(ff, gwb)
ax.plot(gff, gwf, 'rx', alpha=0.5)

ax.plot(ff, gwb_smooth, 'b:')
plot_gwb(ff, gwb_rough, label='roughed')

plt.show()

In [None]:
mc = utils.chirp_mass(*utils.m1m2_from_mtmr(vals[0], vals[1]))
rz = vals[2].copy()
rz[rz < 1e-1] = 1e-1
fo = vals[3].copy()
frst = utils.frst_from_fobs(fo, rz)
dc = cosmo.comoving_distance(rz).cgs.value
hs = utils.gw_strain_source(mc, dc, frst)
print(hs.shape, fo.shape, weights.shape)

In [None]:
# idx = (weights > 1e6)

temp = np.sqrt(weights * hs**2)
idx = (temp > 1e-13)

# idx = ()

print(f"{np.count_nonzero(idx):.4e}, {len(idx):.4e}")
print(f"{kale.utils.stats_str(mc[idx]/MSOL)}")
print(f"{kale.utils.stats_str(rz[idx])}")
print(f"{kale.utils.stats_str(dc[idx]/MPC)}")
print(f"{kale.utils.stats_str(hs[idx])}")
print(f"{kale.utils.stats_str(weights[idx])}")
print(f"{kale.utils.stats_str(np.sqrt(weights[idx]*hs[idx]**2))}")



In [None]:
idx = np.argsort(fo)
fo = fo[idx]
hs = hs[idx]
ww = weights[idx]

idx = np.digitize(fo, fobs) - 1
gwb = np.zeros(fobs.size-1)

for ii in utils.tqdm(range(gwb.size)):
    sel = (idx == ii)
    temp = weights[sel] * (hs[sel] ** 2)
    gwb[ii] = np.sum(temp)
    
gwb = np.sqrt(gwb) 
gwb *= np.sqrt(kale.utils.midpoints(fobs)/np.diff(fobs))

In [None]:
fig, ax = plot.figax()

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)
ax.plot(xx, gwb, 'k-', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='roughed')
# plot_gwb(xx, gwb, label='sampled')

plt.legend()
plt.show()

# Compare Properties of Sampled Population to pure-SAM

In [None]:
NUM_REALS = 10
vals = []
weights = []
for ii in holo.utils.tqdm(range(NUM_REALS)):
    _vals, _weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
            sam, holo.evolution.Hard_GW, fobs=fobs,
            sample_threshold=5.0, cut_below_mass=None, limit_merger_time=None,
    )
    vals.append(_vals)
    weights.append(_weights)

## Number of Sources vs. Frequency

In [None]:
num_fobs = fobs.size
MASS_MIN = 1.0e8 * MSOL
# MASS_MIN = 0.0/
num = np.zeros((num_fobs, NUM_REALS))

for ii in holo.utils.tqdm(range(NUM_REALS)):
    fo = vals[ii][-1]
    idx = np.digitize(fo, fobs) - 1
    for jj in range(num_fobs):
        sel = (idx == jj) & (vals[ii][0] > MASS_MIN)
        num[jj, ii] = np.sum(weights[ii][sel])
        
sel = (kale.utils.midpoints(edges[0]) > MASS_MIN)
idx = list(np.arange(mass.ndim))
idx.pop(mass.ndim - 1)
sam_num = np.sum(mass[sel], axis=tuple(idx))

In [None]:
fig, ax = plot.figax()
xx = fobs*YR
# ax.plot(xx, num)
ax.plot(kale.utils.midpoints(xx), num.mean(axis=-1)[:-1], 'b-')
ax.plot(kale.utils.midpoints(xx), sam_num, 'r--')

aa, bb = plot._get_hist_steps(xx, sam_num)
ax.plot(aa, bb, 'r:')
# ax.set(xlim=[0.95e-1, 2.5e-1], ylim=[2e9, 4e10])
plt.show()

# Chirp-mass Distribution

In [None]:
mchirp = []
for ii in range(NUM_REALS):
    vv = vals[ii]
    mc = utils.m1m2_from_mtmr(vv[0], vv[1])
    mc = utils.chirp_mass(*mc)
    print(mc.shape, utils.stats(mc/MSOL))
    break

## Sampling in Log vs. Linear Space

In [None]:
def func(xx):
    zz = np.power(xx, +1.5) * np.exp(-xx)
    return zz

NUM = 1e4
# xx = np.logspace(-2, 1, 100)
xx = kale.utils.spacing([1e-2, 1e1], scale='log', num=100)
yy = func(xx)
Y = np.cumsum(yy)
norm = NUM / Y[-1]
yy *= norm
Y *= norm

dydx = np.diff(Y) / np.diff(xx)
dydlnx = np.diff(Y) / np.diff(np.log(xx))
xc = kale.utils.midpoints(xx)

aa = kale.sample_grid(xc, dydx, nsamp=NUM)
a1 = kale.sample_grid(np.log(xc), dydx, nsamp=NUM)
a1 = np.e ** a1
bb = kale.sample_grid(xc, dydlnx, nsamp=NUM)
b1 = kale.sample_grid(np.log(xc), dydlnx, nsamp=NUM)
b1 = np.e ** b1
# ha, _ = np.histogram(aa, bins=xx)

fig, axes = plot.figax(ncols=2)
axes[0].set(xscale='linear', yscale='linear')
for ax in axes:
    ax.plot(xx, yy)
    ax.plot(xc, dydx)
    ax.plot(xc, dydlnx)
    *_, p = ax.hist(aa, bins=xx, histtype='step', label='dydx')
    c = p[0].get_edgecolor()
    ax.hist(a1, bins=xx, histtype='step', ls='--', color=c)
    *_, p = ax.hist(bb, bins=xx, histtype='step', label='dydlnx')
    c = p[0].get_edgecolor()
    ax.hist(b1, bins=xx, histtype='step', ls='--', color=c)

plt.legend()
plt.show()

# kalepy sampling fractional binaries

In [None]:
def func(xx):
    zz = np.power(xx, +1.5) * np.exp(-xx)
    return zz

NUM = 1e3
xx = kale.utils.spacing([1e-2, 1e1], scale='log', num=100)
yy = func(xx)
Y = np.cumsum(yy)
norm = NUM / Y[-1]
yy *= norm
Y *= norm
dydx = np.diff(Y) / np.diff(xx)
xc = kale.utils.midpoints(xx)

NREALS = 100
# NSAMP = 1e2
NSAMP = NUM
nbins = xx.size - 1
dist = np.zeros((nbins, NREALS))
wdist = np.zeros((nbins, NREALS))
for rr in range(NREALS):
    ss = kale.sample_grid(xc, dydx, nsamp=NSAMP)
    dist[:, rr], _ = np.histogram(ss, bins=xx)
    ss, ww = kale.sample_outliers(xc, dydx, 10.0, nsamp=int(NSAMP))
    ss = ss.squeeze()
    # print(np.shape(ss), np.shape(ww))
    wdist[:, rr], _ = np.histogram(ss, bins=xx, weights=ww)
    # wdist[:, rr], _ = np.histogram(ss, bins=xx)

fig, ax = plot.figax()
ax.plot(xx, yy)
# ax.plot(xc, dydx)

ave = np.mean(dist, axis=-1)
ax.plot(xc, ave, 'r--')

ave = np.mean(wdist, axis=-1)
ax.plot(xc, ave, 'b--')

plt.show()