In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.hardening.Hard_GW
# shape = (150, 151, 152)
# shape = (30, 31, 32)
# shape = (60, 61, 62)
shape = 60
# shape = None

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape)

In [None]:
# fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)
fobs = utils.nyquist_freqs(10.0*YR, 1.0*YR)
# fobs = kale.utils.spacing(fobs, scale='log', num=fobs.size)

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)

In [None]:
gwf_freqs, gwf, gwb = holo.sam._gws_from_samples(vals, weights, fobs)

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=100)

In [None]:
fig, ax = plot.figax()

ax = plot_gwb(fobs, gwb, color='r')
plot_gwb(fobs, gwb_smooth, ax=ax)

err = gwb_smooth[:, None]
err = (gwb - err) / err
err = np.sqrt(np.mean(err**2))
print(f"overall error = {err:+.8e}")
err_lo = (gwb[0] - gwb_smooth[0]) / gwb_smooth[0]
err_lo = np.mean(err_lo)
print(f"lowest freq error = {err_lo:+.8e}")

title = str(args) + f" shape:{sam.shape}"
label = f"err[0]={err_lo:+.4e} RMS={err:+.4e}"

fig = ax.get_figure()
ax.set_title(title, fontsize=10)
zplot.text(ax, label, loc='ur', fontsize=10)
fname = 'fig.png'
fname = zio.modify_exists(fname)
fig.savefig(fname)
plt.show()

In [None]:
names = ['M', 'q', 'z', 'f']
axis = 2
xx = data['edges'][axis]
thresh = args['sample_threshold']
print(f"{thresh=}")

mass = np.copy(data['mass'])
mass = np.moveaxis(mass, axis, 0)
xbins = mass.shape[0]
mass = np.reshape(mass, (xbins, -1))
extr = zmath.minmax(mass)
yvals = zmath.spacing(mass, 'log', 30, limit=[1e-2, None])

hist = np.zeros((xbins, yvals.size-1))
for ff in range(xbins):
    hist[ff, :], _ = np.histogram(mass[ff], bins=yvals)

fig, ax = plot.figax()
ax.set(title=names[axis])
smap = zplot.smap(hist, scale='log')
pcm = ax.pcolormesh(xx, yvals, hist.T, cmap=smap.cmap, norm=smap.norm)

plt.colorbar(pcm)
plt.show()


In [None]:
NUM_REALS = 30

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=NUM_REALS)

In [None]:
gwb = np.zeros((fobs.size-1, NUM_REALS))
gwf = np.zeros((fobs.size-1, NUM_REALS))
for ii in utils.tqdm(range(NUM_REALS)):
    gwf_freqs, gwf[:, ii], gwb[:, ii] = holo.sam.sampled_gws_from_sam(
        sam, fobs=fobs, hard=holo.hardening.Hard_GW, cut_below_mass=1e6*MSOL, limit_merger_time=None,
        sample_threshold=10, poisson_inside=True, poisson_outside=True,
    )
    # gwb[:, ii] = np.sqrt(_gwb[:]**2 + gwf**2)
    # break

In [None]:
fig, ax = plot.figax(figsize=[6, 3])

def plot_gwb(xx, yy, **kwargs):
    cc, = ax.plot(xx, np.median(yy, axis=-1), ls='--', **kwargs)
    cc = cc.get_color()
    temp = np.percentile(yy, [25, 75], axis=-1)
    ax.fill_between(xx, *temp, color=cc, alpha=0.25)    
    return

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

# amp = 10e-16
# yy = amp * np.power(xx, -2/3)
# ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb_smooth, 'b:')
plot_gwb(xx, gwb_rough, label='grid')
if np.ndim(gwb) == 1:
    ax.plot(xx, gwb, 'k-', alpha=0.25)
else:
    # plot_gwb(xx, gwb, label='sampled')
    plot_gwb(xx, np.sqrt(gwb**2 + gwf**2), label='sampled')

plt.legend()
plt.show()
fig.savefig('temp.png')

In [None]:
breaker()

# CDF of strain distributions

In [None]:
msg = f" [{fobs[0]*YR:.4e},{fobs[-1]*YR:.4e}]/yr "
pad = " "*len(msg)
print(f"GWB SAM between{msg}= {gwb_smooth[0]:.4e}")
print(f"GWB sampled    {pad}= {gwb[0].mean():.4e}")


In [None]:
sam.shape

In [None]:
FBIN = 0
sample_threshold = 10
edges, dnum = sam.dynamic_binary_number(hard, fobs=fobs)
number = holo.utils._integrate_grid_differential_number(edges, dnum, freq=True)

# breaker()

def sam_func(edges, dnum, number):
    print("==== sam_func ====")
    fobs = edges[-1][FBIN:FBIN+2]
    dnum = dnum[..., FBIN:FBIN+2]
    number = number[..., FBIN:FBIN+1]

    grid = np.meshgrid(*edges[:-1], indexing='ij')

    coms = [cc[..., np.newaxis] for cc in grid]
    # ===> (4, M', Q', Z', F)
    coms = np.broadcast_arrays(*coms, fobs[np.newaxis, np.newaxis, np.newaxis, :])

    # ---- find weighted bin centers
    # get unweighted centers
    cent = kale.utils.midpoints(dnum, log=False, axis=(0, 1, 2, 3))
    # get weighted centers for each dimension
    for ii, cc in enumerate(coms):
        coms[ii] = kale.utils.midpoints(dnum * cc, log=False, axis=(0, 1, 2, 3)) / cent

    # ---- calculate GW strain at bin centroids
    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(coms[0], coms[1]))
    dc = cosmo.comoving_distance(coms[2]).cgs.value
    fr = utils.frst_from_fobs(coms[3], coms[2])
    hs = utils.gw_strain_source(mc, dc, fr/2.0)

    dlogf = np.diff(np.log(fobs))
    print(f"{dlogf=} {1/dlogf=}")
    dlogf = dlogf[np.newaxis, np.newaxis, np.newaxis, :]

    number = number / dlogf
    hs = np.nan_to_num(hs)
    # (M',Q',Z',F) ==> (F,)
    hc = np.sqrt(np.sum(number*np.square(hs), axis=(0, 1, 2)))[0]
    return hc, hs, number


def sample_func(edges, dnum, number, **sample_kwargs):
    print("==== sample_func ====")
    edges_sample = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]

    vals, weights = kale.sample_outliers(
        edges_sample, dnum, sample_threshold, mass=number, **sample_kwargs
    )

    vals[0] = 10.0 ** vals[0]
    vals[3] = np.e ** vals[3]
    
    fextr = [edges[3][FBIN], edges[3][FBIN+1]]
    idx = (fextr[0] <= vals[3]) & (vals[3] < fextr[1])
    nval = np.count_nonzero(idx)
    frac = nval / idx.size
    # print(f"Selecting between [{fextr[0]*YR:.2e},{fextr[1]*YR:.2e})/yr :: {nval}/{idx.size}={frac:.4e}")
    weights = weights[idx]
    vals = [vv[idx] for vv in vals]
    # print(f"{vals[0].size=}")

    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(vals[0], vals[1]))
    rz = vals[2]
    fo = vals[3]
    frst = utils.frst_from_fobs(fo, rz)
    dc = cosmo.comoving_distance(rz).cgs.value
    hs = utils.gw_strain_source(mc, dc, frst/2.0)

    # cycles = 0.5 * np.sum(fextr) / np.diff(fextr)[0]
    cycles = 1.0 / np.diff(np.log(fextr))[0]
    weights = weights * cycles
    print(f"{cycles=}")
    gwb = np.sqrt(np.sum(weights * (hs ** 2)))
    return gwb, hs, weights
    

hc_sam, _hs1, _nn1 = sam_func(edges, dnum, number)
hc_sample, _hs2, _nn2 = sample_func(edges, dnum, number)
err = (hc_sample - hc_sam) / hc_sam
# sam=3.9640e-15 sample=5.0756e-15 :: err=2.8044e-01
print(f"sam={hc_sam:.4e} sample={hc_sample:.4e} :: {err=:.4e}")

In [None]:
hs1 = np.copy(_hs1.flatten())
nn1 = np.copy(_nn1.flatten())
hs2 = np.copy(_hs2)
nn2 = np.copy(_nn2)

idx1 = np.argsort(hs1)
idx2 = np.argsort(hs2)
hs1 = hs1[idx1]
nn1 = nn1[idx1]
hs2 = hs2[idx2]
nn2 = nn2[idx2]

fig, ax = plot.figax()

hs_edges = zmath.minmax(hs1, prev=zmath.minmax(hs2, filter='>'), filter='>')
hs_edges = zmath.spacing(hs_edges, 'log', 100)

col = 'k'
kw = dict(histtype='step', lw=1.0, alpha=0.75)
hist1, *_ = ax.hist(hs1, bins=hs_edges, weights=nn1, color=col, ls='-', **kw)
hist2, *_ = ax.hist(hs2, bins=hs_edges, weights=nn2, color=col, ls='--', **kw)

y1 = np.sqrt(np.cumsum(nn1*(hs1**2)))
y2 = np.sqrt(np.cumsum(nn2*(hs2**2)))

col = 'b'
tw = zplot.twin_axis(ax, pos=1.0, color=col, scale='lin')
tw.plot(hs1, y1, ls='-', color=col)
tw.plot(hs2, y2, ls='--', color=col)

# col = 'r'
# y1 = np.cumsum(nn1)
# y2 = np.cumsum(nn2)

# tw = zplot.twin_axis(ax, pos=1.1, color=col)
# tw.plot(hs1, y1, ls='-', color=col)
# tw.plot(hs2, y2, ls='--', color=col)

# ax.set(xlim=[1e-18, 1e-14])
fig.savefig('errdist.png')
plt.show()


In [None]:
indices = np.where((1e-17 < _hs1.squeeze()) & (_hs1.squeeze() < 1e-16))
fig, axes = plot.figax(ncols=3)
ww = _nn1.squeeze()
kw = dict(histtype='step', lw=1.0, alpha=0.75)
for ii, (ax, idx) in enumerate(zip(axes, indices)):
    ee = kale.utils.midpoints(edges[ii])
    
    print(ii, utils.stats(ee))
    *_, cc = ax.hist(ee[idx], bins=edges[ii], **kw)
    cc = cc[0].get_facecolor()
    ax.hist(ee[idx], bins=edges[ii], weights=ww[indices], color=cc, ls='--', **kw)
    marg = [0, 1, 2]
    marg.pop(ii)
    *_, cc = ax.hist(ee, bins=edges[ii], weights=ww.sum(axis=tuple(marg)), **kw)
    
plt.show()


In [None]:
i1 = (hs1 > 3e-18)
i2 = (hs2 > 3e-18)
w1 = nn1[i1]
w2 = nn2[i2]
print(zmath.frac_str(i1), w1.size, np.sum(w1))
print(zmath.frac_str(i2), w2.size, np.sum(w2))
print(f"{zmath.stats_str(w1)=}")
print(f"{zmath.stats_str(w2)=}")

In [None]:
hs1 = np.copy(_hs1.flatten())
nn1 = np.copy(_nn1.flatten())
hs2 = np.copy(_hs2)
nn2 = np.copy(_nn2)

yscale = 'lin'
fig, ax = plot.figax(yscale=yscale)
color = 'k'

idx1 = np.argsort(hs1)
idx2 = np.argsort(hs2)
hs1 = hs1[idx1]
nn1 = nn1[idx1]
hs2 = hs2[idx2]
nn2 = nn2[idx2]

y1 = np.cumsum(nn1)
y2 = np.cumsum(nn2)
ax.plot(hs1, y1, ls='-', color=color)
ax.plot(hs2, y2, ls='--', color=color)
y1 = y1[-1]
y2 = y2[-1]
err = (y2 - y1) / y1
print(f"weight: {y1:.4e}, {y2:.4e} :: {err:.4e}")

color = 'b'
tw1 = zplot.twin_axis(ax, color=color, pos=1.0, scale=yscale)
y1 = np.cumsum(nn1*(hs1**2))
y2 = np.cumsum(nn2*(hs2**2))
tw1.plot(hs1, y1, ls='-', color=color)
tw1.plot(hs2, y2, ls='--', color=color)
_y1 = y1[-1]
_y2 = y2[-1]
err = (_y2 - _y1) / _y1
print(f"weighted strain: {_y1:.4e}, {_y2:.4e} :: {err:.4e}")

## Try to duplicate problem using slice of 4D parameter space

In [None]:
def slice_func(edges, dnum, number, zbin, fbin, sample_threshold):
    edges = np.array(edges).copy()
    edges[-2] = edges[-2][zbin:zbin+2]
    edges[-1] = edges[-1][fbin:fbin+2]
    fobs = edges[-1]
    dnum = dnum[..., zbin:zbin+2, fbin:fbin+2]
    number = number[..., zbin:zbin+1, fbin:fbin+1]

    # ---- Grid Calculation

    # find weighted bin centers
    coms = np.meshgrid(*edges, indexing='ij')
    # get unweighted centers
    cent = kale.utils.midpoints(dnum, log=False, axis=(0, 1, 2, 3))
    # get weighted centers for each dimension
    for ii, cc in enumerate(coms):
        coms[ii] = kale.utils.midpoints(dnum * cc, log=False, axis=(0, 1, 2, 3)) / cent

    # calculate GW strain at bin centroids
    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(coms[0], coms[1]))
    dc = cosmo.comoving_distance(coms[2]).cgs.value
    fr = utils.frst_from_fobs(coms[3], coms[2])
    hs = utils.gw_strain_source(mc, dc, fr/2.0)

    dlogf = np.diff(np.log(fobs))
    dlogf = dlogf[np.newaxis, np.newaxis, np.newaxis, :]

    number = number / dlogf
    hs = np.nan_to_num(hs)
    # (M',Q',Z',F) ==> (F,)
    gwb_grid = np.sqrt(np.sum(number*np.square(hs), axis=(0, 1, 2)))[0]
    
    # ---- Sampled Calculation
    
    edges_sample = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]

    vals, weights = kale.sample_outliers(
        edges_sample, dnum, sample_threshold, mass=number,
    )

    vals[0] = 10.0 ** vals[0]
    vals[3] = np.e ** vals[3]
    
    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(vals[0], vals[1]))
    rz = vals[2]
    fo = vals[3]
    frst = utils.frst_from_fobs(fo, rz)
    dc = cosmo.comoving_distance(rz).cgs.value
    hs = utils.gw_strain_source(mc, dc, frst/2.0)

    # cycles = 0.5 * np.sum(fextr) / np.diff(fextr)[0]
    cycles = 1.0 / np.diff(np.log(fextr))[0]
    weights = weights * cycles
    gwb_sample = np.sqrt(np.sum(weights * (hs ** 2)))
        
    return gwb_grid, gwb_sample


fbin = 0
zbin = zmath.argnearest(edges[2], 0.1)
# zbin = 30
print(f"{fbin=}, {zbin=} :: {YR*edges[-1][[fbin, fbin+1]]}, {edges[-2][[zbin, zbin+1]]}")
gwb_grid, gwb_sample = slice_func(edges, dnum, number, zbin, fbin, sample_threshold=1e4)
err = (gwb_sample - gwb_grid) / gwb_grid
print(f"{gwb_grid=:.4e}, {gwb_sample=:.4e}, {err=:.4e}")

In [None]:
edges[2]

# Manual

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.hardening.Hard_GW
# shape = (150, 151, 152)
shape = (30, 31, 32)
# shape = (30, 31, 300)
# shape = None

sam = holo.sam.Semi_Analytic_Model(
        gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape,
        redz=[1e-2, 6.0, 32]
)
print(f"{sam.shape=}")

fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)

vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
        sam, hard, fobs=fobs, cut_below_mass=None, limit_merger_time=None,
        sample_threshold=10.0, poisson_inside=True, poisson_outside=True,
)
print(f"{weights.size=:.4e}, {weights.sum()=:.8e}")

In [None]:
kale.utils.jshape(edges)

In [None]:
counts, _ = np.histogramdd(vals.T, bins=edges, weights=weights)
print(counts.shape)

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=NUM_REALS)

In [None]:
use_weights = np.copy(weights)
use_vals = np.copy(vals)
print(kale.utils.stats_str(use_weights))

# sel = (use_vals[2] < 0.5)
# use_vals[2][sel] *= 2.0

sel = slice(None)
# sel = (vals[2] > 0.5)
# sel = (use_weights > 1.0)
# sel = (use_weights < 1.0)
# sel = (use_weights != 1.0)
# sel = ~sel

use_weights = use_weights[sel]
# use_weights = np.random.poisson(use_weights)
print(kale.utils.stats_str(use_weights))

# print(kale.utils.stats_str(use_weights[sel]))
# print(kale.utils.stats_str(use_weights[~sel]))
# print(f"weights: {zmath.frac_str(sel)}")

gff, gwf, gwb = holo.sam._gws_from_samples(use_vals[:, sel], use_weights, fobs)
fig, ax = plot.figax()
ff = kale.utils.midpoints(fobs)
ax.plot(ff, gwb)
ax.plot(gff, gwf, 'rx', alpha=0.5)

ax.plot(ff, gwb_smooth, 'b:')
plot_gwb(ff, gwb_rough, label='roughed')

plt.show()

# Compare Properties of Sampled Population to pure-SAM

In [None]:
NUM_REALS = 10
vals = []
weights = []
for ii in holo.utils.tqdm(range(NUM_REALS)):
    _vals, _weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
            sam, holo.hardening.Hard_GW, fobs=fobs,
            sample_threshold=5.0, cut_below_mass=None, limit_merger_time=None,
    )
    vals.append(_vals)
    weights.append(_weights)

## Number of Sources vs. Frequency

In [None]:
num_fobs = fobs.size
MASS_MIN = 1.0e8 * MSOL
# MASS_MIN = 0.0/
num = np.zeros((num_fobs, NUM_REALS))

for ii in holo.utils.tqdm(range(NUM_REALS)):
    fo = vals[ii][-1]
    idx = np.digitize(fo, fobs) - 1
    for jj in range(num_fobs):
        sel = (idx == jj) & (vals[ii][0] > MASS_MIN)
        num[jj, ii] = np.sum(weights[ii][sel])
        
sel = (kale.utils.midpoints(edges[0]) > MASS_MIN)
idx = list(np.arange(mass.ndim))
idx.pop(mass.ndim - 1)
sam_num = np.sum(mass[sel], axis=tuple(idx))

In [None]:
fig, ax = plot.figax()
xx = fobs*YR
# ax.plot(xx, num)
ax.plot(kale.utils.midpoints(xx), num.mean(axis=-1)[:-1], 'b-')
ax.plot(kale.utils.midpoints(xx), sam_num, 'r--')

aa, bb = plot._get_hist_steps(xx, sam_num)
ax.plot(aa, bb, 'r:')
# ax.set(xlim=[0.95e-1, 2.5e-1], ylim=[2e9, 4e10])
plt.show()

# Chirp-mass Distribution

In [None]:
mchirp = []
for ii in range(NUM_REALS):
    vv = vals[ii]
    mc = utils.m1m2_from_mtmr(vv[0], vv[1])
    mc = utils.chirp_mass(*mc)
    print(mc.shape, utils.stats(mc/MSOL))
    break

## Distance/Redshift Distribution

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.hardening.Hard_GW
# shape = (150, 151, 152)
shape = (30, 31, 32)
# shape = (30, 31, 300)
# shape = None

sam = holo.sam.Semi_Analytic_Model(
        gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape,
        redz=[0.01, 6.0, 32]
)
print(f"{sam.shape=}")

fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)

In [None]:
z = 1.0e-3

zz = np.linspace(0.0, z, 100)
# dv = cosmo.differential_comoving_volume(zz) * 4*np.pi
dv = cosmo.dVcdz(zz, cgs=False)
v1 = np.sum(0.5 * (dv[1:] + dv[:-1]) * np.diff(zz))
v2 = cosmo.comoving_volume(zz)
print(v1, v2[-1], v1/v2[-1])
aa = v2[1]
bb = 0.5 * (dv[0] + dv[1]) * zz[1]
print(aa, bb, aa/bb)

In [None]:
vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
        sam, hard, fobs=fobs, cut_below_mass=None, limit_merger_time=None,
        sample_threshold=10.0, poisson_inside=True, poisson_outside=True,
)
print(f"{weights.size=:.4e}, {weights.sum()=:.8e}")

In [None]:
NBINS = 30
dc = cosmo.comoving_distance(vals[2]).to('Mpc').value
nsamp = weights.sum()
dcmax = dc.max()
dens = nsamp / ((4.0*np.pi*dcmax**3)/3.0)

fig, axes = plot.figax(nrows=2)

ax = axes[0]
xx = vals[2]
bins = kale.utils.spacing(xx, 'log', NBINS)
ax.hist(xx, bins=bins, histtype='step', weights=weights)
ax.hist(xx, bins=sam.edges[2], histtype='step', weights=weights)

ax = axes[1]
xx = dc
bins = kale.utils.spacing(xx, 'log', NBINS)
ax.hist(xx, bins=bins, histtype='step', weights=weights)
yy = ((4.0*np.pi*bins**3)/3.0) * dens
ax.plot(bins, yy, 'k--')

plt.show()

In [None]:
NBINS = 10
dc = cosmo.comoving_distance(vals[2]).to('Mpc').value
nsamp = weights.sum()
dcmax = dc.max()
dens = nsamp / ((4.0*np.pi*dcmax**3)/3.0)

fig, axes = plot.figax(nrows=2)

ax = axes[0]
xx = vals[2]
bins = kale.utils.spacing(xx, 'log', NBINS)
ax.hist(xx, bins=bins, histtype='step', weights=weights)
ax.hist(xx, bins=sam.edges[2], histtype='step', weights=weights)

ax = axes[1]
xx = dc
bins = kale.utils.spacing(xx, 'log', NBINS)
ax.hist(xx, bins=bins, histtype='step', weights=weights)
yy = ((4.0*np.pi*bins**3)/3.0) * dens
ax.plot(bins, yy, 'k--')

plt.show()

# Toy Failure

In [None]:
def log_norm_10(xx, mm, ss):  
    ss = np.log(10.0 ** ss)
    mm = np.log(mm)    
    amp = xx * ss * np.sqrt(2.0*np.pi)
    amp = np.log10(np.e) * np.log(10.0) / amp
    yy = np.square((np.log(xx) - mm) / ss) / 2.0
    yy = amp * np.exp(-yy)
    return yy

mm, ss = 10.0, 1.0
xx = zmath.log_normal_base_10(mm, ss, size=int(1e5))
xx = np.sort(xx)
fig, ax = plot.figax(yscale='lin')
tw = ax.twinx()

bins = zmath.spacing(xx, 'log', 100, log_stretch=0.1)
ax.hist(xx, bins=bins, density=True, color='k', histtype='step', alpha=0.5)

# ax.hist(xx, bins=zmath.spacing(xx, 'log', 100, log_stretch=0.1),
#         density=False, histtype='step', color='k', weights=np.ones_like(xx)/xx.size, ls=':', alpha=0.5)

tw.plot(xx, np.arange(xx.size)/(xx.size-1), 'k--', zorder=10, alpha=0.5)

pdf = log_norm_10(xx, mm, ss)
ax.plot(xx, pdf, 'r-', alpha=0.5)

pmf = 0.5*np.diff(xx)*(pdf[1:] + pdf[:-1])
pmf = np.concatenate([[0.0], pmf])
cdf = np.cumsum(pmf)
# ax.plot(xx, pmf, 'r:', alpha=0.5)

tw.plot(xx, cdf, 'r--', lw=2.0, alpha=0.5)

mm = xx.mean()
ss = np.log10(xx).std()
plt.show()

## Toy 1

In [None]:
NUM = 1e5
NUM = int(NUM)
# vals = np.random.lognormal(3.0, 1.0, size=NUM)
vals = zmath.log_normal_base_10(100.0, 1.0, NUM)
vv = 10.0**np.random.uniform(*np.log10([0.25*vals.min(), 4.0*vals.max()]), NUM//10)
vals = np.concatenate([vals, vv])

fig, ax = plot.figax(scale='log')
edges = zmath.spacing(vals, 'log', 100)
hist, edges, _ = ax.hist(vals, bins=edges)

plt.show()

In [None]:
thresh = 1000.0
ee = zmath.midpoints(edges, log=True)
dist = hist / np.diff(edges)
print(ee.shape, dist.shape, hist.shape)
xx, ww = kale.sample_outliers(ee, dist, thresh)
print(ww.size, zmath.stats_str(ww))

fig, ax = plot.figax()
kw = dict(histtype='step', alpha=0.5, lw=2.0)
*_, p1 = ax.hist(xx, bins=ee, weights=ww, **kw)
*_, p2 = ax.hist(vals, bins=edges, **kw)
c1 = p1[0].get_edgecolor()[:3]
c2 = p2[0].get_edgecolor()[:3]
kale.carpet(xx, weights=ww, ax=ax, yave=50, ystd=4, alpha=0.3, color=c1)
kale.carpet(vals, ax=ax, yave=30, ystd=3, alpha=0.3, color=c2)

power = 1.0
idx = np.argsort(xx)
x1 = xx[idx]
x2 = sorted(vals)
y1 = np.cumsum(ww[idx]*np.power(x1, power))
y2 = np.cumsum(np.power(x2, power))

# tw = ax.twinx()
# # tw.set(yscale='log')
# tw.plot(x1, y1)
# tw.plot(x2, y2)
# print(xx.max(), vals.max())

plt.show()

## Toy 2

In [None]:
def toy_pdf(xx, mm, ss):
    ss = np.log(10.0 ** ss)
    mm = np.log(mm)    
    amp = xx * ss * np.sqrt(2.0*np.pi)
    amp = np.log10(np.e) * np.log(10.0) / amp
    yy = np.square((np.log(xx) - mm) / ss) / 2.0
    yy = amp * np.exp(-yy)
    return yy

mm = 1.0e6
ss = 3.0e-1
NUM = int(1e5)
np.random.seed(1234)
xx = zmath.log_normal_base_10(mm, ss, size=NUM)

xedges = zmath.spacing(xx, 'log', 100, log_stretch=0.2)
pdf = toy_pdf(xedges, mm, ss)

hifac = 10
xedges_hi = kale.utils.subdivide(xedges, num=hifac, log=True)
pdf_hi = toy_pdf(xedges_hi, mm, ss)
pmf_hi = 0.5 * np.diff(xedges_hi) * (pdf_hi[:-1] + pdf_hi[1:])
pmf = np.zeros(xedges.size-1)
for ii in range(pmf.size):
    lo = (hifac+1) * ii
    hi = (hifac+1) * (ii + 1)
    pmf[ii] = np.sum(pmf_hi[lo:hi]) * NUM
    
hist, _ = np.histogram(xx, bins=xedges)
print(f"{pmf.sum()=:.4e}")

# yscale, ylim = 'lin', [0.0, pdf.max()*1.2]
yscale, ylim = 'log', [pdf.max() * 1e-4, 2.0*pdf.max()]
fig, ax = plot.figax(figsize=[10, 6], yscale=yscale, ylim=ylim)
tw = ax.twinx()

ax.plot(xedges, pdf, 'k--', alpha=0.5)
ax.hist(xx, bins=xedges, density=True, color='k', histtype='step', alpha=0.5)

l1 = zplot.plot_hist_line(tw, xedges, hist, lw=2.0, alpha=0.5)
# zplot.plot_hist_line(tw, xedges, pmf, ls=(0, [1, 1]), color='r', lw=2.0, alpha=0.5)
c1 = l1.get_color()

# ---- Sample

thresh = 1000.0
zz, ww = kale.sample_outliers(xedges, pdf, thresh, mass=pmf)
print(ww.size, ww.sum(), zmath.stats_str(ww))

# kw = dict(histtype='step', alpha=0.5, lw=2.0)
# *_, p1 = ax.hist(xx, bins=ee, weights=ww, **kw)
# *_, p2 = ax.hist(vals, bins=edges, **kw)
# c1 = p1[0].get_edgecolor()[:3]
# c2 = p2[0].get_edgecolor()[:3]
*_, p2 = tw.hist(zz, bins=xedges, density=False, histtype='step', alpha=0.5, lw=2.0, weights=ww)
c2 = p2[0].get_edgecolor()[:3]


kale.carpet(zz, weights=ww, ax=tw, yave=-300, ystd=60, alpha=0.3, color=c1)
kale.carpet(xx, ax=tw, yave=-800, ystd=50, alpha=0.3, color=c2)

# power = 3.0
# idx = np.argsort(zz)
# x1 = zz[idx]
# x2 = sorted(xx)
# y1 = np.cumsum(ww[idx]*np.power(x1, power))
# y2 = np.cumsum(np.power(x2, power))
# tw = ax.twinx()
# # tw.set(yscale='log')
# tw.plot(x1, y1)
# tw.plot(x2, y2)
# err = (y2[-1] - y1[-1]) / y1[-1]
# print(y1[-1], y2[-1], err)

fig.savefig('temp123.png')
plt.show()

# Fraction of GWB From Outliers

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.hardening.Hard_GW
shape = 90
threshold = 10

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape)

In [None]:
# fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)
fobs = utils.nyquist_freqs(10.0*YR, 1.0*YR)
gwb_smooth = sam.gwb(fobs, realize=False)

In [None]:
# gwf_freqs, gwf, gwb = holo.sam.sampled_gws_from_sam(
#     sam, fobs=fobs, hard=holo.hardening.Hard_GW, cut_below_mass=None, limit_merger_time=None,
#     sample_threshold=10, poisson_inside=True, poisson_outside=True,
# )

vals, weights, edges, dens, mass = holo.sam.sample_sam_with_hardening(
    sam, hard, fobs=fobs, sample_threshold=threshold)
# gff, gwf, gwb = _gws_from_samples(vals, weights, fobs)
mc = utils.chirp_mass(*utils.m1m2_from_mtmr(vals[0], vals[1]))
rz = vals[2]
fo = vals[3]
frst = utils.frst_from_fobs(fo, rz)
dc = cosmo.comoving_distance(rz).cgs.value
hs = utils.gw_strain_source(mc, dc, frst/2.0)

In [None]:
fbins = fobs.size - 1
gwb = np.zeros(fbins)
gwb_frac = np.zeros(fbins)
num_frac = np.zeros(fbins)

for ii in utils.tqdm(range(fbins)):
    lo, hi = fobs[ii], fobs[ii+1]
    ncycles = 0.5 * (lo + hi) / (hi - lo)
    idx = (lo <= fo) & (fo < hi)
    hh = hs[idx]
    ww = weights[idx]

    idx = (ww == 1.0)
    ww = ww * ncycles
    num_frac[ii] = np.count_nonzero(idx) / idx.size
    h1 = ww[idx] * (hh[idx]**2)
    gwb[ii] = np.sum(ww * (hh**2))
    gwb_frac[ii] = np.sum(h1) / gwb[ii]

gwb = np.sqrt(gwb)

In [None]:
fig, ax = plot.figax()
xx = zmath.midpoints(fobs) * YR

ax.plot(xx, gwb, 'k-')
ax.plot(xx, gwb_smooth, 'k--')
err = (gwb[0] - gwb_smooth[0]) / gwb_smooth[0]

tw = zplot.twin_axis(ax, color='b')

# xx = fobs
tw.plot(xx, num_frac, 'b-', label='num')
tw.plot(xx, gwb_frac, 'b--', label='pow')

ax.set_title(f"{sam.shape=}, {threshold=}, {err=:+.4e}", fontsize=10)
plt.legend()

fname = 'gwb-frac_by-outliers.png'
fname = zio.modify_exists(fname)
fig.savefig(fname)
print(f"saved to {fname=}")
plt.show()

In [None]:
NUM = 100
aa = np.arange(NUM)
# bb = np.ones_like(aa)
bb = np.linspace(0.0, 1.0, NUM)[::-1]

cc = np.cumsum(aa*bb)/np.sum(bb)
dd = np.cumsum(aa*bb)/np.cumsum(bb)

fig, ax = plt.subplots()
ax.plot(aa, cc, 'k-')
ax.plot(aa, dd, 'k--')

plt.show()
