In [None]:
%reload_ext autoreload
%autoreload 2

# Builtin packages
from datetime import datetime
from importlib import reload
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
# import tqdm.notebook as tqdm

import kalepy as kale
# import kalepy.utils
# import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.librarian
from holodeck import cosmo, utils, plot
# from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG

# Silence annoying numpy errors
# np.seterr(divide='ignore', invalid='ignore', over='ignore')
# warnings.filterwarnings("ignore", category=UserWarning)

log = holo.log
log.setLevel(log.WARNING)

In [None]:
NSAMPLES = 10
SAM_SHAPE = (20, 21, 22)
PSPACE_CLASS = holo.librarian.PS_Classic_Phenom_Uniform

pspace = PSPACE_CLASS(holo.log, nsamples=NSAMPLES, sam_shape=SAM_SHAPE)
print(f"pspace '{pspace.name}'")
print(f"\tlibrary shape={pspace.lib_shape} (samples, dimensions)")
print(f"\tSAM grid shape={pspace.sam_shape}")
print(f"\tparameters:")
for pp, extr in zip(pspace.param_names, pspace.extrema):
    print(f"\t\t{pp}: {extr}")

In [None]:
kale.corner(
    pspace.param_samples.T, labels=pspace.param_names,
    dist2d=dict(scatter=True, hist=False, contour=False),
    dist1d=dict(carpet=True, hist=True, density=False),
)

# Manually run test library

## Run the components for a single simulation

In [None]:
args = holo.librarian._setup_argparse(
    comm,
    # ["PS_Broad_Uniform_02B", "../output/broad-uniform-02b_test"],
    ["PS_Uniform_07A", "../output/uniform-07a_test"],
    namespace=argparse.Namespace(nsamples=10, nreals=100, sam_shape=60, recreate=True),
)
args.log = log
print(args)

space = getattr(holo.param_spaces, args.param_space)
space = space(args.log, args.nsamples, args.sam_shape, args.seed)
np.random.seed(12345)
pnum = np.random.choice(args.nsamples)
print(f"{pnum=}")


In [None]:
# pta_dur = args.pta_dur * YR
# nfreqs = args.nfreqs
# hifr = nfreqs/pta_dur
# pta_cad = 1.0 / (2 * hifr)
# fobs_cents = holo.utils.nyquist_freqs(pta_dur, pta_cad)
# fobs_edges = holo.utils.nyquist_freqs_edges(pta_dur, pta_cad)

# print(space.param_dict(pnum))

# sam, hard = space(pnum)
# gwb = sam.gwb(fobs_edges, realize=args.nreals, hard=hard)
# print(f"{utils.stats(gwb)=}")


In [None]:
rv = holo.librarian.run_sam_at_pspace_num(args, space, pnum)

In [None]:
fname = ("/Users/lzkelley/Programs/nanograv/holodeck/output/uniform-07a_test/sims/"
         "lib-sams_gwb-ss__p000002.npz")
data = np.load(fname)
print(list(data.keys()))
bg = data['hc_bg']
ss = data['hc_ss']
gwb = data['gwb']
print(bg.shape, ss.shape, gwb.shape)
test = np.sum(ss**2, axis=-1)
test = np.sqrt(bg**2 + test)
print(test.shape, gwb.shape)
print(utils.stats(gwb))
print(utils.stats(test))

In [None]:
plaw_nbins, fit_plaw, fit_plaw_med = holo.librarian.fit_spectra_plaw(fobs_cents, gwb, holo.librarian.FITS_NBINS_PLAW)
turn_nbins, fit_turn, fit_turn_med = holo.librarian.fit_spectra_turn(fobs_cents, gwb, holo.librarian.FITS_NBINS_TURN)

fit_data = dict(
    fit_plaw_nbins=plaw_nbins, fit_plaw=fit_plaw, fit_plaw_med=fit_plaw_med,
    fit_turn_nbins=turn_nbins, fit_turn=fit_turn, fit_turn_med=fit_turn_med,
)


In [None]:
fig = holo.librarian.make_gwb_plot(fobs_cents, gwb, fit_data)

## run all sample points

In [None]:
args = _setup_argparse(
    ["PS_Uniform_07_GW", "../output/uniform-07_gw_test"],
    namespace=argparse.Namespace(nsamples=10, nreals=4, sam_shape=11),
    )
args.log = log
print(args)

space = getattr(holo.param_spaces, args.param_space)
space = space(args.log, args.nsamples, args.sam_shape, args.seed)
space_fname = space.save(args.output)
log.info(f"saved parameter space {space} to {space_fname}")

for pnum in tqdm.trange(args.nsamples):
    holo.librarian.run_sam_at_pspace_num(args, space, pnum)

In [None]:
log.setLevel(log.DEBUG)
lib_fname = holo.librarian.sam_lib_combine(args.output, log, path_sims=args.output_sims)

In [None]:
with h5py.File(lib_fname, 'r') as data:
    keys = list(data.keys())

    print("datasets:")
    for kk in keys:
        print("\t", kk, data[kk].shape)
        print("\t\t", utils.stats(data[kk]))

    print("attributes:")
    keys = list(data.attrs.keys())
    for kk in keys:
        print("\t", kk, data.attrs[kk])

In [None]:
with h5py.File(lib_fname, 'r') as data:
    xx = data['fobs'][()] * YR
    gwb = data['gwb'][()]
    hc_bg = data['hc_bg'][()]
    hc_ss = data['hc_ss'][()]
    print(gwb.shape, hc_bg.shape, hc_ss.shape)
    
    hc_gwb = np.sqrt(hc_bg **2 + np.sum(hc_ss**2, axis=-1))
    diff = (gwb - hc_gwb)/hc_gwb
    print(utils.stats(diff))
    
    for ii in range(10):
        fig, ax = plot.figax()
        aa = gwb[ii]
        bb = hc_gwb[ii]
        ax.plot(xx, np.median(aa, axis=-1))
        ax.plot(xx, np.median(bb, axis=-1))
    

In [None]:
breaker()

# Param_Space Class

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            gsmf_phi0=holo.librarian.PD_Normal(-2.77, 0.3),
        )
        return
    
    def model_for_number(self, num):
        params = self.param_dict(num)
        gsmf = holo.sam.GSMF_Schechter(phi0=params['gsmf_phi0'])
        # CONSTRUCT SAM and HARD models here
        sam = holo.sam.Semi_Analytic_Model(
            ZERO_DYNAMIC_STALLED_SYSTEMS = False,
            ZERO_GMT_STALLED_SYSTEMS = True,
        )
        hard = holo.hardening.Hard_GW
        return sam, hard
    
test = PS_Test(log, 4, 10, 12345)
fobs_edges = utils.nyquist_freqs_edges()
fobs_cents = utils.midpoints(fobs_edges, log=False)
sam, hard = test.model_for_number(0)
gwb = sam.gwb(fobs_edges, hard=hard, realize=20)
plot.plot_gwb(fobs_cents, gwb)
plt.show()

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            gsmf_phi0=holo.librarian.PD_Normal(-2.77, 0.3),
        )
        return
    
    def model_for_number(self, num):
        params = self.param_dict(num)
        gsmf = holo.sam.GSMF_Schechter(phi0=params['gsmf_phi0'])

        sam = holo.sam.Semi_Analytic_Model(
            ZERO_DYNAMIC_STALLED_SYSTEMS = True,
            ZERO_GMT_STALLED_SYSTEMS = False,
            shape=self.sam_shape
        )
        # hard = holo.hardening.Hard_GW
        hard = holo.hardening.Fixed_Time.from_sam(sam, 0.01*GYR)
        return sam, hard
    
test = PS_Test(log, 4, 10, 12345)
fobs_edges = utils.nyquist_freqs_edges()
fobs_cents = utils.midpoints(fobs_edges, log=False)
sam, hard = test.model_for_number(0)
gwb = sam.gwb(fobs_edges, hard=hard, realize=20)
plot.plot_gwb(fobs_cents, gwb)
plt.show()

In [None]:
class PS_Test(holo.librarian._Param_Space):
    
    def __init__(self, log, nsamples, sam_shape, seed):
        super().__init__(
            log, nsamples, sam_shape, seed,
            par1=holo.librarian.PD_Uniform(-1.0, 1.0),
            par2=holo.librarian.PD_Uniform_Log(10.0, 1000.0),
            par3=holo.librarian.PD_Normal(-3.0, 0.4),
        )
        return
    
    def model_for_number(self, num):
        # CONSTRUCT SAM and HARD models here
        sam = None
        hard = None
        return sam, hard
    
nsamps = 1000
test = PS_Test(log, nsamps, 50, 12345)

for ii in np.random.choice(nsamps, 4, replace=False):
    print(f"\ntest sample {ii:4d} :: {test.params(ii)}  \n    {test.param_dict(ii)}")

fig, axes = plt.subplots(figsize=[10, 5], ncols=test.ndims)
for ii, ax in enumerate(axes):
    ax.grid(True, alpha=0.25)
    kale.dist1d(test._params[:, ii], ax=ax, density=False)
    
plt.show()


## Saving a _Param_Space class

In [None]:
space = holo.param_spaces.PS_Broad_Uniform_02B(log, 100, (11, 12, 13), None)
output = Path('.').resolve()
print(output)
fname = space.save(output)

In [None]:
check = holo.param_spaces.PS_Broad_Uniform_02B.from_save(fname, log)

In [None]:
for kk in dir(space):
    if kk.startswith("__") or kk in ['_random_state']:
        continue
    v1 = getattr(space, kk)
    v2 = getattr(check, kk)
    print(kk, type(v1), type(v2))
    if callable(v1):
        continue
    test = (v1 == v2)
    print("\t", np.all(test))
    assert np.all(test)
    print(np.all(test))

## Generate new SAMs from existing PSpace

# Param_Dist Classes

## Normal

In [None]:
PD_Normal = holo.librarian.PD_Normal

## LinLog

In [None]:
test = holo.librarian.PD_Lin_Log(0.01, 100.0, 0.1, 0.5)
xx = np.linspace(0.0, 1.0, 10000)
yy = test(xx)
print(utils.minmax(yy))
plt.loglog(xx, yy)
ax = plt.gca()
ax.axhline(test._crit, color='r', ls=':')
ax.axvline(test._lofrac, color='r', ls=':')
plt.show()

### Change the fraction of population below/above cutoff

In [None]:
NUM = int(1e4)
crit = 0.1

BINS = 20
e1 = np.linspace(0.01, crit, BINS, endpoint=False)
e2 = np.logspace(*np.log10([crit, 100.0]), BINS)
edges = np.concatenate([e1, e2])

fig, ax = plot.figax(scale='log')
for frac in [0.2, 0.5, 0.8]:
    test = holo.librarian.PD_Lin_Log(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

### Change the location of the cutoff

In [None]:
NUM = int(1e4)
frac = 0.5

BINS = 20
edges = np.logspace(*np.log10([0.01, 100.0]), 2*BINS)

fig, ax = plot.figax(scale='log')
for crit in [0.1, 1.0, 10.0]:
    test = holo.librarian.PD_Lin_Log(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

## LogLin

In [None]:
test = holo.librarian.PD_Log_Lin(0.01, 100.0, 0.1, 0.5)
xx = np.linspace(0.0, 1.0, 10000)
yy = test(xx)
print(utils.minmax(yy))
plt.loglog(xx, yy)
ax = plt.gca()
ax.axhline(test._crit, color='r', ls=':')
ax.axvline(test._lofrac, color='r', ls=':')
plt.show()

### Change the fraction of population below/above cutoff

In [None]:
NUM = int(2e4)
crit = 0.1

BINS = 30
edges = np.logspace(*np.log10([0.01, 100.0]), BINS)

fig, ax = plot.figax(scale='log')
for frac in [0.2, 0.5, 0.8]:
    test = holo.librarian.PD_Log_Lin(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

### Change the location of the cutoff

In [None]:
NUM = int(2e4)
frac = 0.5

BINS = 20

edges = np.logspace(*np.log10([0.01, 100.0]), 2*BINS)

fig, ax = plot.figax(scale='log')
for crit in [0.1, 1.0, 10.0]:
    test = holo.librarian.PD_Log_Lin(0.01, 100.0, crit, frac)
    xx = test(np.random.uniform(0.0, 1.0, size=NUM))
    kale.dist1d(xx, ax=ax, edges=edges, density=True, probability=False)
    obs_frac = np.count_nonzero(xx < crit) / xx.size
    print(f"target:{frac:.2f}, result:{obs_frac:.4f}", 1.0/np.sqrt(NUM))
    assert np.isclose(frac, obs_frac, atol=2.0/np.sqrt(NUM))
    
plt.show()

## Piecewise Uniform

In [None]:
edges = [-1.0, 5.0, 6.0, 7.0]
test = holodeck.librarian.PD_Piecewise_Uniform_Mass(edges, [1.0, 2.0, 1.0])

xx = np.random.uniform(size=1000)
xx = np.sort(xx)
yy = test(xx)
print(utils.minmax(yy))
x, y, _ = plt.hist(yy, histtype='step', density=True, bins=edges)
ax = plt.gca()
plt.show()

In [None]:
edges = [-1.0, 5.0, 6.0, 7.0]
test = holodeck.librarian.PD_Piecewise_Uniform_Density(edges, [1.0, 2.0, 1.0])

xx = np.random.uniform(size=1000)
xx = np.sort(xx)
yy = test(xx)
print(utils.minmax(yy))
x, y, _ = plt.hist(yy, histtype='step', density=True, bins=edges)
ax = plt.gca()
plt.show()

In [None]:
edges = [0.1, 1.0, 2.0, 9.0, 11.0]
test = holodeck.librarian.PD_Piecewise_Uniform_Density(edges, [2.5, 1.5, 1.0, 1.5])

xx = np.random.uniform(size=2000)
xx = np.sort(xx)
yy = test(xx)
print(utils.minmax(yy))
ax = plt.gca()
scale = 'log'
scale = 'linear'
ax.set(xscale=scale)
xx = kale.utils.spacing(edges, scale, num=20)
ax.hist(yy, histtype='step', density=True, bins=xx)
# tw = ax.twinx()
# tw.hist(yy, histtype='step', density=True, bins=30)
plt.show()

In [None]:
# test = holodeck.librarian.PD_Piecewise_Uniform_Density([-3.5, -3.0, -2.0, -1.5], [2.0, 1.0, 2.0])   # gsmf_phi0
# test = holodeck.librarian.PD_Piecewise_Uniform_Density([10.5, 11.0, 12.0, 12.5], [2.0, 1.0, 2.0])   # gsmf_mchar0_log10
test = holodeck.librarian.PD_Piecewise_Uniform_Density([7.5, 8.0, 9.0, 9.5], [1.5, 1.0, 2.0])   # mmb_mamp_log10

xx = np.random.uniform(size=2000)
xx = np.sort(xx)
yy = test(xx)
print(utils.minmax(yy))
ax = plt.gca()
scale = 'log'
scale = 'linear'
ax.set(xscale=scale)
xx = kale.utils.spacing(yy, scale, num=20)
print(xx)
ax.hist(yy, histtype='step', density=True, bins=xx)
# tw = ax.twinx()
# tw.hist(yy, histtype='step', density=True, bins=30)
plt.show()

# Fit Spectra

In [None]:
lib_path = (
    "/Users/lzkelley/Programs/nanograv/15yr_astro_libraries/"
    # "uniform-05a_2023-05-02_n1000_r100_f40"
    # "uniform-07a_new_n500_r100_f40"
    "uniform-07a_new_n500_r100_f40"
    "/sam_lib.hdf5"
)
lib_path = Path(lib_path)
assert lib_path.exists()
print(lib_path.parent)

library = h5py.File(lib_path, 'r')
print(list(library.keys()))
gwb = library['gwb'][()]
fobs = library['fobs'][()]
print(f"{gwb.shape=} {utils.stats(gwb)=}")
params = library['sample_params'][()]
param_names = library.attrs['param_names'].astype('str')
print(param_names)
for ii, name in enumerate(param_names):
    print(f"{ii=}, {name=}, {params[ii].shape=}, {utils.minmax(params[ii])=}")

In [None]:
nsamps, nfreqs, nreals = gwb.shape
print(f"{nsamps=} {nfreqs=} {nreals=}")
# seed = np.random.randint(0, 999999)
# seed = 419587
# print(f"{seed=}")
# np.random.seed(seed)
# ss = np.random.choice(nsamps)
# rr = np.random.choice(nreals)
ss = 129
rr = 51
print(f"{ss=}, {rr=}")

xx = fobs
hc = gwb[ss, :, rr]

fig, ax = plot.figax()

plot.draw_gwb(ax, xx, gwb[ss, :, :])
ax.plot(xx, hc, 'k-', alpha=0.5)

plt.show()


In [None]:
rho = utils.char_strain_to_rho(fobs, hc, 1/fobs[0])
psd = utils.char_strain_to_psd(fobs, hc)

fig, axes = plot.figax(figsize=[12, 4], ncols=3, xscale='log')
values = [hc, rho, psd]
names = ["hc", "rho", "psd"]

for ii, ax in enumerate(axes):
    yy = values[ii]
    ax.plot(xx, yy, alpha=0.75)
    ax.set_title(names[ii], fontsize=10)

    fits, func = utils.fit_powerlaw_psd(xx, yy, 1.0/YR)
    zz = func(xx, *fits)
    ax.plot(xx, zz, ls='--', alpha=0.5)
    
    fits, func = utils.fit_turnover_psd(xx, yy, 1.0/YR)
    zz = func(xx, *fits)
    ax.plot(xx, zz, ls='--', alpha=0.5)    

    
plt.show()

In [None]:
psd = utils.char_strain_to_psd(fobs[np.newaxis, :, np.newaxis], library['gwb'][()])
nbins_list, fits_plaw = holo.librarian.fit_spectra_plaw(fobs, psd)

In [None]:
kale.plot.dist2d?

In [None]:
fits_plaw_med = np.median(fits_plaw, axis=1)

fig, axes = plt.subplots(figsize=[12, 4], ncols=5)
for ii in range(5):
    ax = axes[ii]
    temp = fits_plaw_med[:, ii, :].T
    ax.set_title(f"nbins={nbins_list[ii]}", fontsize=10)
    kale.plot.dist2d(temp, ax=ax)

plt.show()

In [None]:
fits_plaw_med = np.median(fits_plaw, axis=1)

fig, ax = plt.subplots(figsize=[6, 6])
ax.axhline(-13/3, ls='--', color='0.5')
ax.axvline(-15, ls='--', color='0.5')
for ii in range(5):
    temp = fits_plaw_med[:, ii, :].T
    kale.plot.dist2d(temp, ax=ax, hist=False, scatter=False, sigmas=[1, 2, 3], median=False)

plt.show()