In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
PATH = "test_2022-11-15"
regex = "lib_sams__p*.npz"
PATH = Path(holo._PATH_OUTPUT).resolve().joinpath(PATH)
files = sorted(PATH.glob(regex))
num_files = len(files)
print(PATH, f"\n\texists={PATH.exists()}", f"\n\tfound {num_files} files")

all_exist = True
for ii in range(num_files):
    temp = PATH.joinpath(regex.replace('*', f"{ii:06d}"))
    exists = temp.exists()
    # print(f"{ii:4d}, {temp.name}, {exists=}")
    if not exists:
        all_exist = False
        break
    
print(f"All files exist?  {all_exist}")

In [None]:
class Parameter_Space:

    def __init__(
        self,
        # gsmf_phi0=[-3.35, -2.23, 7],
        # gsmf_phi0=[-3.61, -1.93, 7],
        times=[1e-2, 10.0, 7],   # [Gyr]
        # gsmf_alpha0=[-1.56, -0.92, 5],
        # mmb_amp=[0.39e9, 0.61e9, 9], mmb_plaw=[1.01, 1.33, 11]
        mmb_amp=[0.1e9, 1.0e9, 9], mmb_plaw=[0.8, 1.5, 11]
    ):

        # self.gsmf_phi0 = np.linspace(*gsmf_phi0)
        self.times = np.logspace(*np.log10(times[:2]), times[2])
        # self.gsmf_alpha0 = np.linspace(*gsmf_alpha0)
        self.mmb_amp = np.linspace(*mmb_amp)
        self.mmb_plaw = np.linspace(*mmb_plaw)
        pars = [
            self.times,   # [Gyr]
            # self.gsmf_phi0,
            # self.gsmf_alpha0,
            self.mmb_amp,
            self.mmb_plaw
        ]
        self.names = [
            'times',
            # 'gsmf_phi0',
            # 'gsmf_alpha0',
            'mmb_amp',
            'mmb_plaw'
        ]
        self.params = np.meshgrid(*pars, indexing='ij')
        self.shape = self.params[0].shape
        self.size = np.product(self.shape)
        self.params = np.moveaxis(self.params, 0, -1)

        pass

    def number_to_index(self, num):
        idx = np.unravel_index(num, self.shape)
        return idx

    def index_to_number(self, idx):
        num = np.ravel_multi_index(idx, self.shape)
        return num

    def param_dict_for_number(self, num):
        idx = self.number_to_index(num)
        pars = self.params[idx]
        rv = {nn: pp for nn, pp in zip(self.names, pars)}
        return rv

    def params_for_number(self, num):
        idx = self.number_to_index(num)
        pars = self.params[idx]
        return pars

    def sam_for_number(self, num):
        params = self.params_for_number(num)

        # gsmf_phi0, mmb_amp, mmb_plaw = params
        time, mmb_amp, mmb_plaw = params

        gsmf = holo.sam.GSMF_Schechter()
        gpf = holo.sam.GPF_Power_Law()
        gmt = holo.sam.GMT_Power_Law()
        mmbulge = holo.relations.MMBulge_KH2013(mamp=mmb_amp*MSOL, mplaw=mmb_plaw)

        sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)
        hard = holo.evolution.Fixed_Time.from_sam(sam, time*GYR, exact=True, progress=False)
        return sam, hard

In [None]:

space = Parameter_Space()
# space.shape, space.number_to_index(100)
shape = space.shape

temp = files[0]
data = np.load(temp)

fobs = data['fobs']
fobs_edges = data['fobs_edges']
nreals = data['nreals'][()]
nfreqs = fobs.size

temp_gwb = data['gwbspec']

assert np.ndim(temp_gwb) == 2
assert temp_gwb.shape[0] == nfreqs
assert temp_gwb.shape[1] == nreals

gwb_shape = list(shape) + [nfreqs, nreals,]
names = space.names + ['freqs', 'reals']
gwb = np.zeros(gwb_shape)

for ii, fil in enumerate(files):
    temp = np.load(fil)
    assert np.allclose(fobs, temp['fobs'])
    assert np.allclose(fobs_edges, temp['fobs_edges'])
    pars = [pp[()] for pp in [temp['times'], temp['mmb_amp'], temp['mmb_plaw']]]

    idx = space.number_to_index(ii)
    # print(ii, idx)
    space_pars = space.params[idx]
    assert np.allclose(pars, space_pars)

    gwb[idx] = temp['gwbspec'][:]
    
print(utils.stats(gwb))
print(utils.frac_str(gwb > 0.0))

out_filename = PATH.joinpath('sam_lib.hdf5')
with h5py.File(out_filename, 'w') as h5:
    h5.create_dataset('params', data=space.params)
    h5.create_dataset('fobs', data=fobs)
    h5.create_dataset('fobs_edges', data=fobs_edges)
    h5.create_dataset('gwb', data=gwb)
    h5.create_dataset('names', data=names)

print(f"Saved to {out_filename}, size: {utils.get_file_size(out_filename)}")

In [None]:
temp = files[-1]
data = np.load(temp)
print(temp)
for kk, vv in data.items():
    print(f"\t{kk}: {vv.shape}")

fobs = data['fobs']
gwb = data['gwbspec']

xx = fobs * YR
fig, ax = plot.figax()
ax.plot(xx, np.median(gwb, axis=1))
ax.fill_between(xx, *np.percentile(gwb, [25, 75], axis=1), alpha=0.1)

plt.show()