In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from datetime import datetime
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
obs_amp_med = 2.4e-15
obs_amp_5_95 = [1.8e-15, 3.1e-15]
_obs_pow_plaw = 3.2  # +0.6 - 0.5
obs_plaw = (_obs_pow_plaw - 3) / -2.0
print(obs_plaw)

# Compare Libraries

## Power-Law Fits

In [None]:
path_output = Path(holo._PATH_OUTPUT).joinpath("share")
paths = [
    "astro-01-gw_2023-03-03_n10000_s61-81-101_r100_f40_SHARE",
    "astro-02-gw_2023-03-03_n10000_s61-81-101_r100_f40_SHARE",
    "astro-tight-02-gw_2023-03-03_n10000_s61-81-101_r100_f40_SHARE",
    "broad-uniform-01-gw_2023-03-03_n10000_s61-81-101_r100_f40_SHARE"
]
fname = "sam_lib.hdf5"
files = []
for path in paths:
    file = path_output.joinpath(path, fname)
    files.append(file)
    print(f"{file.parts[-2]} :: exists = {file.exists()}")

In [None]:
skip_reals = 6
skip_samps = 6
num_bins = 10

handles = []
labels = []
axes = None
for fil in files:
    lab = fil.parts[-2].split("_")[0]
    print(lab)
    with h5py.File(fil, 'r') as data:
        lamp = data['fit_lamp']
        plaw = data['fit_plaw']
        nbins = data.attrs['fit_nbins'].tolist()
        lamp = lamp[::skip_samps, ::skip_reals, :].reshape(-1, len(nbins))
        plaw = plaw[::skip_samps, ::skip_reals, :].reshape(-1, len(nbins))

    nb_idx = nbins.index(num_bins)
    aa = lamp[:, nb_idx]
    bb = plaw[:, nb_idx]

    idx = np.isfinite(aa) & np.isfinite(bb)
    corner = kale.Corner(
        [aa[idx], bb[idx]],
        limits=[[-18, -13], [-2, 0]],
        labels=['$\log10(A_{\mathrm{yr}^{-1}})$', '$\gamma_\mathrm{h}$'],
        axes=axes
    )
    hand = corner.plot(
        dist1d=dict(carpet=False, confidence=False),
        dist2d=dict(scatter=False, hist=False, median=False, sigmas=[1, 2]),
    )
    handles.append(hand)
    labels.append(lab)
    fig = corner.fig
    axes = corner.axes

for ax in axes.flatten():
    ax.grid(True, alpha=0.25)

# TL    
axes[1, 0].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 0].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
axes[1, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[1, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# BL
axes[0, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[0, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# TR
axes[1, 1].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 1].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
    
corner.legend(handles, labels, title='n bins')
plt.show()
fname = f"comparison_all-plaw-fits_skip-s{skip_samps}-r{skip_reals}.png"
fname = Path(holo._PATH_OUTPUT).joinpath(fname)
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")


## General

In [None]:
path_output = Path(holo._PATH_OUTPUT).resolve()
paths = [
    "ps-test-uniform_2023-03-01_n200_s50_r100_f40_SHARE",
    "ps-test-normal_2023-03-01_n200_s50_r100_f40_SHARE",
]
fname = "sam_lib.hdf5"
files = []
for path in paths:
    file = path_output.joinpath(path, fname)
    files.append(file)
    print(file.name, file.exists())

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2, sharey=True, sharex=True)
for ax, file in zip(axes, files):
    ax.set(xscale='log', yscale='log')
    ax.grid(True, alpha=0.25)
    temp = file.parts[-2].split("_")[0].split("-")[-1]
    ax.set_title(temp, fontsize=12)
    data = h5py.File(file, 'r')

    fobs = data['fobs'][()]
    gwb = data['gwb'][()]
    nsamp, _, nreal = gwb.shape
    for ii in range(nsamp):
        plot.draw_gwb(ax, fobs, gwb[ii])
    
plt.show()


In [None]:
nbins = [5, 10, 15, 0]

fit_keys = ['fit_lamp', 'fit_plaw']

for file in files:
    label = file.parts[-2].split("_")[0].split("-")[-1]
    data = h5py.File(file, 'r')
    lamp = data['fit_lamp'][:, :, 1]
    plaw = data['fit_plaw'][:, :, 1]
    vals = [lamp.flatten(), plaw.flatten()]
    vals = [vv[::10] for vv in vals]
    idx = np.ones_like(vals[0], dtype=bool)
    for vv in vals:
        lo, hi = np.percentile(vv, [2, 98])
        idx = idx & (lo < vv) & (vv < hi)
        
    vals = [vv[idx] for vv in vals]
    corner, _ = kale.corner(vals)
    corner.fig.text(0.5, 0.5, label, va='center', ha='center')
    
plt.show()


In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2, nrows=2, sharex='col')
nbins = [5, 10, 15, 0]

fit_keys = ['fit_lamp', 'fit_plaw']

for (ii, jj), ax in np.ndenumerate(axes):
    file = files[ii]
    ax.set(xscale='linear', yscale='linear')
    ax.grid(True, alpha=0.25)
    temp = file.parts[-2].split("_")[0].split("-")[-1]
    ax.set_title(temp, fontsize=12)
    data = h5py.File(file, 'r')

    key = fit_keys[jj]
    temp = data[key][:, :, 1]
    ave = np.mean(temp, axis=-1)
    print(utils.stats(ave))
    # kale.dist1d(ave, ax=ax)
    kale.dist1d(temp.flatten(), ax=ax)
    
    
plt.show()


## Single File

In [None]:
# fname = (
#     "/Users/lzkelley/programs/nanograv/holodeck/output/"
#     "hard04b_2023-01-23_01_n1000_g100_s40_r50_f40/"
#     "sam-lib_hard04b_2023-01-23_01_n1000_g100_s40_r50_f40.hdf5"
# )
# fname = "/Users/lzkelley/Programs/nanograv/holodeck/output/hard04b_2023-01-23_01_n1000_g100_s40_r50_f40/sam-lib_hard04b_2023-01-23_01_n1000_g100_s40_r50_f40.hdf5"

# fname = "/Users/lzkelley/programs/nanograv/holodeck/output/eccen-02_2023-01-31_02_n100_s30_r100_f20/sam_lib.hdf5"
# FNAME = "/Users/lzkelley/programs/nanograv/holodeck/output/big-circ-01_2023-02-02_01_n2000_s40_r100_f40/sam_lib.hdf5"
FNAME = (
    "/Users/lzkelley/programs/nanograv/holodeck/output/savio-runs/"
    "astro-01_2023-03-03_n10000_s61-81-101_r100_f40_SHARE/sam_lib.hdf5"
)

path = Path(FNAME)

# with h5py.File(FNAME, 'r') as data:
data = h5py.File(path, 'r')
print("Keys:", data.keys())
for kk, vv in data.items():
    try:
        print("\t", kk, vv.shape)
    except AttributeError:
        continue
    
print("\nAttributes:")
for kk, vv in data.attrs.items():
    print("\t", kk, vv)
    
# print("\nParameters:")
# for kk, vv in data['parameters'].items():
#     print("\t", kk, vv[:])


# Plot Strain Amplitude at 1/yr

In [None]:
fobs = data['fobs'][:]
xx = fobs * YR
fig, ax = plot.figax(xlabel='Frequency $[\\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

df = np.fabs(xx - 1.0)
ff = np.argmin(df)
title = f"Characteristic Strain at ${xx[ff]:.2f} \, \\mathrm{{ yr}}^{{-1}}$"
ax.set(xscale='linear', yscale='linear', xlim=[-18, -13], title=title)
hcyr = data['gwb'][:, ff, :]
idx = (hcyr.flatten() > 0.0)
yy = np.log10(hcyr.flatten()[idx])
kale.dist1d(yy, ax=ax, density=False, carpet=True, confidence=True, quantiles=[0.68, 0.98])
med = np.median(yy)
std = np.std(yy)
fig.text(0.99, 0.99, f"med: {med:.2f}, std: {std:.2f}", ha='right', va='top', fontsize=8)

xx = np.linspace(-18, -13, 100)
xx, yy = kale.density(yy, points=xx, probability=True)
ax.plot(xx, yy, 'b--')

popt, pcov = utils.fit_gaussian(xx, yy, [1.0, -15, 2.0])
print(popt)
yy = utils._func_gaussian(xx, *popt)
ax.plot(xx, yy, 'k:', label=f"$A_{{{{yr}}^{{-1}}}} = {popt[1]:.2f} \pm {popt[2]:.2f}$")

lab = f"${plot.scientific_notation(obs_amp_med, dollar=False)} = 10^{{{np.log10(obs_amp_med):.2f}}}$"

ax.axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5, label=lab)
ax.axvspan(*np.log10(obs_amp_5_95), color='r', alpha=0.1, label="[5,95]%")

ax.legend(fontsize=8)
plt.show()
fname = path.parent.joinpath("hc-amp-inv-yr.png")
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")

In [None]:
pnames = [pn.decode() for pn in data.attrs['param_names']]
print(pnames)
npar = len(pnames)
fig, axes = plot.figax(
    figsize=[5*np.sqrt(npar), 5], ncols=npar,
    sharey=True, xscale='linear',
    left=0.03, right=0.99
)
for ii, (pn, ax) in enumerate(zip(pnames, axes)):
    ax.set(xlabel=pn)
    if ii == 0:
        ax.set_ylabel('Characteristic Strain')
    vals = data['sample_params'][:, ii]
    idx = np.argsort(vals)
    xx = vals[idx]
    ax.scatter(xx[:, np.newaxis] * np.ones_like(hcyr), hcyr[idx], alpha=0.15, marker='.', s=4)
    ax.scatter(xx, np.median(hcyr, axis=-1)[idx], alpha=0.5, marker='.', s=4)

    ax.axhline(obs_amp_med, color='r', ls='--', alpha=0.5)
    ax.axhspan(*obs_amp_5_95, color='r', alpha=0.1)

fname = path.parent.joinpath("hc-amp-inv-yr_parameters.png")
fig.savefig(fname, dpi=200)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")
plt.show()

# Power-Law Fits

In [None]:
lamp = data['fit_med_lamp']
plaw = data['fit_med_plaw']
nbins = data.attrs['fit_nbins']
print(lamp.shape, nbins)

handles = []
labels = []
axes = None
for ii in range(len(nbins)):
    # aa = lamp[::40, ii]
    # bb = plaw[::40, ii]
    aa = lamp[:, ii]
    bb = plaw[:, ii]
    idx = np.isfinite(aa) & np.isfinite(bb)
    corner = kale.Corner(
        [aa[idx], bb[idx]],
        limits=[[-20, -12], [-3, 2]],
        labels=['$\log10(A_{\mathrm{yr}^{-1}})$', '$\gamma_\mathrm{h}$'],
        axes=axes
    )
    hand = corner.plot(
        dist1d=dict(carpet=False, confidence=False),
        dist2d=dict(scatter=False, hist=False, median=False, sigmas=[1, 2]),
    )
    handles.append(hand)
    lab = nbins[ii] 
    lab = f"{lab}" if (lab is not None) and (lab > 0) else "all"
    labels.append(lab)
    fig = corner.fig
    axes = corner.axes

for ax in axes.flatten():
    ax.grid(True, alpha=0.25)

# TL    
axes[1, 0].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 0].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
axes[1, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[1, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# BL
axes[0, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[0, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# TR
axes[1, 1].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 1].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
    
corner.legend(handles, labels, title='n bins')
fname = path.parent.joinpath("med-plaw-fits.png")
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")
plt.show()


In [None]:
lamp = data['fit_lamp']
print(lamp.shape)

In [None]:
skip_reals = 10
skip_samps = 10

lamp = data['fit_lamp']
plaw = data['fit_plaw']
print(lamp.shape, lamp.size)

nbins = data.attrs['fit_nbins']
lamp = lamp[::skip_samps, ::skip_reals, :].reshape(-1, len(nbins))
plaw = plaw[::skip_samps, ::skip_reals, :].reshape(-1, len(nbins))
print(lamp.shape, lamp.size)

handles = []
labels = []
axes = None
for ii in range(len(nbins)):
    aa = lamp[:, ii]
    bb = plaw[:, ii]
    # aa = lamp[:, ii]
    # bb = plaw[:, ii]
    idx = np.isfinite(aa) & np.isfinite(bb)
    corner = kale.Corner(
        [aa[idx], bb[idx]],
        limits=[[-20, -12], [-3, 2]],
        labels=['$\log10(A_{\mathrm{yr}^{-1}})$', '$\gamma_\mathrm{h}$'],
        axes=axes
    )
    hand = corner.plot(
        dist1d=dict(carpet=False, confidence=False),
        dist2d=dict(scatter=False, hist=False, median=False, sigmas=[1, 2]),
    )
    handles.append(hand)
    lab = nbins[ii] 
    lab = f"{lab}" if (lab is not None) and (lab > 0) else "all"
    labels.append(lab)
    fig = corner.fig
    axes = corner.axes

for ax in axes.flatten():
    ax.grid(True, alpha=0.25)

# TL    
axes[1, 0].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 0].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
axes[1, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[1, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# BL
axes[0, 0].axvline(np.log10(obs_amp_med), color='r', ls='--', alpha=0.5)
axes[0, 0].axvspan(*np.log10(obs_amp_5_95), color='r', ls='--', alpha=0.15)

# TR
axes[1, 1].axhline(-2.0/3.0, color='k', ls='--', alpha=0.5)
axes[1, 1].axhline(obs_plaw, color='r', ls='--', alpha=0.5)
    
corner.legend(handles, labels, title='n bins')
fname = f"all-plaw-fits_skip-s{skip_samps}-r{skip_reals}.png"
fname = path.parent.joinpath(fname)
fig.savefig(fname, dpi=400)
print(f"Saved to {fname}, size {utils.get_file_size(fname)}")
plt.show()


# Plot some random spectra

In [None]:
num_label = 5
# data['gwb'].shape

NUM = 100
# percs = [25, 75]
percs = [10, 90]

fobs = data['fobs'][:]
xx = fobs * YR
fig, ax = plot.figax(xlabel='Frequency $[\\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)
ax.set(ylim=[1e-18, 3e-13])

# (nsamples, nfreqs, nreals)
yy = data['gwb'][:]

nsamp = yy.shape[0]
choose = np.min([NUM, nsamp])
sel = np.random.choice(nsamp, choose, replace=False)
yy = yy[sel]

# ==> (nfreqs, nsamples, nreals)
yy = np.moveaxis(yy, 1, 0)

med = np.median(yy, axis=-1)

use_num = np.min([NUM, nsamp])
print(f"{use_num=}")

for ii in range(use_num):
    cc, = ax.plot(xx, med[:, ii], alpha=0.5, lw=1) # , color='b')
    cc = cc.get_color()

    conf = np.percentile(yy[:, ii, :], percs, axis=-1)
    ax.fill_between(xx, *conf, alpha=0.15, color=cc, lw=1)

    zz = yy[:, ii, :5]
    zz = zz.reshape(zz.shape[0], -1)
    ax.plot(xx, zz, alpha=0.25, color=cc, lw=1)

for amp in [1e-15, 2e-15, 3e-15]:
    hh = plot._draw_plaw(ax, xx, amp, 1.0, color='0.95', alpha=0.70, zorder=99, lw=2.0, ls='-')
    plot._draw_plaw(ax, xx, amp, 1.0, color='r', alpha=0.75, zorder=100, lw=1.0)

plt.show()

# Plot parameter space (corner plot)

In [None]:
spars = data['sample_params'][:].copy()
idx = np.ones_like(spars[:, 0], dtype=bool)
for sp in spars.T:
    extr = np.percentile(sp, [1, 99])
    idx = idx & (extr[0] < sp) & (sp < extr[1])

names = [dd.decode() for dd in data.attrs['param_names']]
corner, _ = kale.corner(spars[idx].T, dist2d=dict(contour=False), dist1d=dict(probability=True, density=False, hist=True), labels=names)
fig = corner.fig
plt.subplots_adjust(left=0.04, right=0.98, bottom=0.02, top=0.96)
fname = path.parent.joinpath("parameter-samples.png")
fig.savefig(fname, dpi=200)
plt.show()

# Plot spectra versus different parameters

In [None]:
percs = [25, 75]
xx = data['fobs'][:] * YR
gwb = data['gwb'][:]
spars = data['sample_params'][:]
print(gwb.shape, spars.shape)
# print(f"{len(data.attrs['param_names'])=}")

spars = spars.T
for ii, pname in enumerate(data.attrs['param_names']):
    pname = pname.decode()
    fig, ax = plot.figax(
        figsize=[6, 4], ylim=[0.5e-16, 3e-13], left=0.13, bottom=0.13, right=0.98, top=0.88,
        xlabel="Frequency [1/yr]", ylabel="Characteristic Strain",
    )
    plot._twin_hz(ax)

    vals = spars[ii].copy()
    sort_idx = np.argsort(vals)
    sv = vals[sort_idx]
    gwb = data['gwb'][:].copy()[sort_idx]

    num = gwb.shape[0]
    colors = mpl.cm.get_cmap('Spectral')(np.linspace(0.05, 0.95, num))

    skip = 10
    for jj, pv in enumerate(sv):
        if jj % skip != 0:
            continue
        # Make sure there are only `num_label` labeled lines on the plot
        label = (jj % (num//num_label) == 0) or (num <= num_label) or (jj == num - 1)
        cc = colors[jj]

        yy = gwb[jj]
        med = np.median(yy, axis=-1)
        conf = np.percentile(yy, percs, axis=-1)
        label = f"{pv:.3e}" if label else ""
        ax.plot(xx, med, color=colors[jj], label=label, lw=1.0, alpha=0.5)
        # ax.fill_between(xx, *conf, color=colors[jj], alpha=0.2, lw=0.5)
        
    ax.legend(title=pname, fontsize=8, loc='upper right')
    output_path = Path(holo._PATH_OUTPUT, f'eccen-lib_{pname}_skip-{skip}.png')
    fig.savefig(output_path, dpi=300)
    
plt.show()