In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from datetime import datetime
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

# Param_Space Class

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            gsmf_phi0=holo.librarian.PD_Normal(-2.77, 0.3),
        )
        return
    
    def model_for_number(self, num):
        params = self.param_dict(num)
        gsmf = holo.sam.GSMF_Schechter(phi0=params['gsmf_phi0'])
        # CONSTRUCT SAM and HARD models here
        sam = holo.sam.Semi_Analytic_Model(
            ZERO_DYNAMIC_STALLED_SYSTEMS = False,
            ZERO_GMT_STALLED_SYSTEMS = True,
        )
        hard = holo.hardening.Hard_GW
        return sam, hard
    
test = PS_Test(log, 4, 10, 12345)
fobs_edges = utils.nyquist_freqs_edges()
fobs_cents = utils.midpoints(fobs_edges, log=False)
sam, hard = test.model_for_number(0)
gwb = sam.gwb(fobs_edges, hard=hard, realize=20)
plot.plot_gwb(fobs_cents, gwb)
plt.show()

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            gsmf_phi0=holo.librarian.PD_Normal(-2.77, 0.3),
        )
        return
    
    def model_for_number(self, num):
        params = self.param_dict(num)
        gsmf = holo.sam.GSMF_Schechter(phi0=params['gsmf_phi0'])

        sam = holo.sam.Semi_Analytic_Model(
            ZERO_DYNAMIC_STALLED_SYSTEMS = True,
            ZERO_GMT_STALLED_SYSTEMS = False,
            shape=self.sam_shape
        )
        # hard = holo.hardening.Hard_GW
        hard = holo.hardening.Fixed_Time.from_sam(sam, 0.01*GYR)
        return sam, hard
    
test = PS_Test(log, 4, 10, 12345)
fobs_edges = utils.nyquist_freqs_edges()
fobs_cents = utils.midpoints(fobs_edges, log=False)
sam, hard = test.model_for_number(0)
gwb = sam.gwb(fobs_edges, hard=hard, realize=20)
plot.plot_gwb(fobs_cents, gwb)
plt.show()

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            par1=holo.librarian.PD_Uniform(-1.0, 1.0),
            par2=holo.librarian.PD_Uniform_Log(10.0, 1000.0),
            par3=holo.librarian.PD_Normal(-3.0, 0.4),
        )
        return
    
    def model_for_number(self, num):
        # CONSTRUCT SAM and HARD models here
        sam = None
        hard = None
        return sam, hard
    
nsamps = 1000
test = PS_Test(log, nsamps, 50, 12345)

for ii in np.random.choice(nsamps, 4, replace=False):
    print(f"\ntest sample {ii:4d} :: {test.params(ii)}  \n    {test.param_dict(ii)}")

fig, axes = plt.subplots(figsize=[10, 5], ncols=test.ndims)
for ii, ax in enumerate(axes):
    ax.grid(True, alpha=0.25)
    kale.dist1d(test._params[:, ii], ax=ax, density=False)
    
plt.show()


## Saving a _Param_Space class

In [None]:
space = holo.param_spaces.PS_Broad_Uniform_02B(log, 100, (11, 12, 13), None)
output = Path('.').resolve()
print(output)
fname = space.save(output)

In [None]:
check = holo.param_spaces.PS_Broad_Uniform_02B.from_save(fname, log)

In [None]:
for kk in dir(space):
    if kk.startswith("__") or kk in ['_random_state']:
        continue
    v1 = getattr(space, kk)
    v2 = getattr(check, kk)
    print(kk, type(v1), type(v2))
    if callable(v1):
        continue
    test = (v1 == v2)
    print("\t", np.all(test))
    assert np.all(test)
    print(np.all(test))

# Param_Dist Classes

## Normal

In [None]:
PD_Normal = holo.librarian.PD_Normal

# LinLog

In [None]:
test = holo.librarian.PD_Lin_Log(0.01, 100.0, 0.1, 0.5)
xx = np.linspace(0.0, 1.0, 10000)
yy = test(xx)
print(utils.minmax(yy))
plt.loglog(xx, yy)
ax = plt.gca()
ax.axhline(test._crit, color='r', ls=':')
ax.axvline(test._lofrac, color='r', ls=':')
plt.show()

### Change the fraction of population below/above cutoff

In [None]:
NUM = int(1e4)
crit = 0.1

BINS = 20
e1 = np.linspace(0.01, crit, BINS, endpoint=False)
e2 = np.logspace(*np.log10([crit, 100.0]), BINS)
edges = np.concatenate([e1, e2])

fig, ax = plot.figax(scale='log')
for frac in [0.2, 0.5, 0.8]:
    test = holo.librarian.PD_Lin_Log(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

### Change the location of the cutoff

In [None]:
NUM = int(1e4)
frac = 0.5

BINS = 20
edges = np.logspace(*np.log10([0.01, 100.0]), 2*BINS)

fig, ax = plot.figax(scale='log')
for crit in [0.1, 1.0, 10.0]:
    test = holo.librarian.PD_Lin_Log(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

# LogLin

In [None]:
test = holo.librarian.PD_Log_Lin(0.01, 100.0, 0.1, 0.5)
xx = np.linspace(0.0, 1.0, 10000)
yy = test(xx)
print(utils.minmax(yy))
plt.loglog(xx, yy)
ax = plt.gca()
ax.axhline(test._crit, color='r', ls=':')
ax.axvline(test._lofrac, color='r', ls=':')
plt.show()

### Change the fraction of population below/above cutoff

In [None]:
NUM = int(2e4)
crit = 0.1

BINS = 30
edges = np.logspace(*np.log10([0.01, 100.0]), BINS)

fig, ax = plot.figax(scale='log')
for frac in [0.2, 0.5, 0.8]:
    test = holo.librarian.PD_Log_Lin(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

### Change the location of the cutoff

In [None]:
NUM = int(2e4)
frac = 0.5

BINS = 20

edges = np.logspace(*np.log10([0.01, 100.0]), 2*BINS)

fig, ax = plot.figax(scale='log')
for crit in [0.1, 1.0, 10.0]:
    test = holo.librarian.PD_Log_Lin(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()