In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

# Reproduce Error

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.sam.MMBulge_Simple()     # M-MBulge Relation            (MMB)

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)

In [None]:
# Choose observed GW-Frequency bins based on nyquist sampling
fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)
# idx = (1.0/YR < fobs) & (fobs < 2.0/YR)
# idx = (fobs > 2.0/YR) & (fobs < 3.0/YR)
idx = (fobs > 4.0/YR) & (fobs < 5.0/YR)
print(np.count_nonzero(idx))
fobs = fobs[idx]
hard = holo.evolution.Hard_GW

In [None]:
vals, weights, edges, dens = holo.sam.sample_sam_with_hardening(
        sam, hard, fobs=fobs, cut_below_mass=3e7, limit_merger_time=4*GYR, 
        # sample_threshold=1.0e2,   # 2.5s ==> 4.1e6, 3.5e9   [fobs < 1/YR]
        # sample_threshold=1.0e3,   # 22.8s ==> 4.0e7, 3.5e9   [fobs < 1/YR]
        # sample_threshold=1.0e4,   # 257s ==> 4.0e8, 3.5e9   [fobs < 1/YR]
        sample_threshold=None,   # 257s ==> 4.0e8, 3.5e9   [fobs < 0.5/YR]
        interpolate=False
)

In [None]:
mt, mr, rz, fo = vals
print(f"{fo.size=:.4e}, {weights.size=:.4e}, {np.sum(weights)=:.4e}")

In [None]:
# None
fig = plot_error(fo, fobs)
# fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
# 1e4
fig = plot_error(fo, fobs)
fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
# 1e3
fig = plot_error(fo, fobs)
fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
# 1e3
fig = plot_error(fo, fobs)
fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
# 1e3
fig = plot_error(fo, fobs)
fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
def double_bins(xx):
    xx = np.concatenate([xx, kale.utils.midpoints(xx)])
    xx = np.sort(np.unique(xx))
    return xx

def plot_error(fo, bins, weights=None, **kw):
    fig, ax = plot.figax(figsize=[8, 2], xscale='log') # , xlim=[0.9*fo.min(), fo.min()*20])
    plt.tight_layout()

    kw = dict(weights=weights, rwidth=0.9, histtype='step', density=True)
    xx = bins*YR
    ax.hist(fo*YR, bins=xx, **kw)

    for ii in range(2):
        xx = double_bins(xx)
        ax.hist(fo*YR, bins=xx, **kw)

    return fig

fig = plot_error(fo, fobs)
# fig = plot_error(fo, fobs, weights=weights)
plt.show()

In [None]:
fig, ax = plot.figax()
ax.hist(fo, bins=200, weights=weights, rwidth=0.9, histtype='step')
tw = ax.twinx()
tw.set(yscale='log')
tw.hist(fo, bins=200, rwidth=0.9, histtype='step', color='r', ls='--')

plt.show()

# Reproduce with kalepy directly in 2D

In [None]:
sample_threshold = 5.0e4
edges, dnum = sam.number_from_hardening(hard, fobs=fobs)
log_edges = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]

# integrate each bin to convert from probability- density to mass
# NOTE: _integrate_differential_number() has log-vs-lin spacings hardcoded! use `edges` as is
mass = holo.sam._integrate_differential_number(edges, dnum, freq=True)
# sample binaries from distribution, using appropriate spacing as needed
# BUG: should the density used for proportional sampling `dnum` be log(density) ?!
# vals, weights = kale.sample_outliers(edges, dnum, sample_threshold, mass=mass)

vals, weights = kale.sample_outliers(log_edges, dnum, sample_threshold, mass=mass)
vals[0] = 10.0 ** vals[0]
vals[3] = np.e ** vals[3]

In [None]:
vals.size/1e9, np.count_nonzero(weights == 1.0)/1e9

In [None]:
mass.shape

In [None]:
fig = plot_error(vals[3], fobs, weights=weights)
plt.show()

In [None]:
thresh = 10.0e6
test = mass.sum(axis=(0, 1, 2))
dd = holo.utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
dd = holo.utils.trapz(dd, edges[1], axis=1, cumsum=False)
dd = holo.utils.trapz(dd, edges[2], axis=2, cumsum=False)
dd = dd.sum(axis=(0, 1, 2))

print(test.shape, dd.shape, test.sum()/1e9)

vv, ww = kale.sample_outliers([log_edges[-1]], dd, thresh, mass=test)


In [None]:
vv.shape, ww.shape, np.count_nonzero(ww == 1.0)

In [None]:
le = log_edges[-1]
print(le.shape, dd.shape, test.shape)
fig, ax = plot.figax(xscale='lin', xlim=[-16.5, -15.5])
# ax.plot(le, dd)
# kale.plot.draw_hist1d(ax, le, test)

kw = dict(
    # ax=ax, ls='--', 
    ax=ax, ls='--', weights=ww, probability=True, density=True,
)

h1a, e1a, _ = kale.plot.hist1d(vv[0], edges=le, **kw)
oo = double_bins(le)
h2a, e2a, _ = kale.plot.hist1d(vv[0], edges=oo, **kw)

plt.show()