In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
# fname = (
#     "/Users/lzkelley/programs/nanograv/holodeck/output/"
#     "hard04b_2023-01-23_01_n1000_g100_s40_r50_f40/"
#     "sam-lib_hard04b_2023-01-23_01_n1000_g100_s40_r50_f40.hdf5"
# )
# fname = "/Users/lzkelley/Programs/nanograv/holodeck/output/hard04b_2023-01-23_01_n1000_g100_s40_r50_f40/sam-lib_hard04b_2023-01-23_01_n1000_g100_s40_r50_f40.hdf5"

fname = "/Users/lzkelley/programs/nanograv/holodeck/output/eccen-02_2023-01-31_02_n100_s30_r100_f20/sam_lib.hdf5"

print(fname)
data = h5py.File(fname, 'r')
print("Keys:", data.keys())
for kk, vv in data.items():
    try:
        print("\t", kk, vv.shape)
    except AttributeError:
        continue
    
print("\nAttributes:")
for kk, vv in data.attrs.items():
    print("\t", kk, vv)
    
# print("\nParameters:")
# for kk, vv in data['parameters'].items():
#     print("\t", kk, vv[:])


In [None]:
fobs = data['fobs'][:]
xx = fobs * YR
fig, ax = plot.figax(xlabel='Frequency $[\\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)

if False:
    # (nsamples, nfreqs, nreals)
    yy = data['gwb'][:]
    # ==> (nfreqs, nsamples, nreals)
    yy = np.moveaxis(yy, 1, 0)
    yy = np.median(yy, axis=-1)
    yy = yy.reshape(yy.shape[0], -1)
    ax.set(ylim=[1e-16, 3e-14])
    ax.plot(xx, yy, alpha=0.1, color='k', lw=1)

    for amp in [1e-15, 2e-15, 3e-15]:
        plot._draw_plaw(ax, xx, amp, 1.0, color='r', alpha=1.0, zorder=100)

if True:
    df = np.fabs(xx - 1.0)
    ff = np.argmin(df)
    title = f"Characteristic Strain at ${xx[ff]:.2f} \, \\mathrm{{ yr}}^{{-1}}$"
    ax.set(xscale='linear', yscale='linear', xlim=[-17, -13], title=title)
    yy = data['gwb'][:, ff, :].flatten()
    yy = np.log10(yy[yy > 0.0])
    kale.dist1d(yy, ax=ax, density=True, carpet=True)

plt.show()

In [None]:
num_label = 5
# data['gwb'].shape

NUM = 60
# percs = [25, 75]
percs = [10, 90]

fobs = data['fobs'][:]
xx = fobs * YR
fig, ax = plot.figax(xlabel='Frequency $[\\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain')
plot._twin_hz(ax)
ax.set(ylim=[1e-16, 3e-14])

# (nsamples, nfreqs, nreals)
yy = data['gwb'][:]

nsamp = yy.shape[0]
choose = np.min([NUM, nsamp])
sel = np.random.choice(nsamp, choose, replace=False)
yy = yy[sel]

# ==> (nfreqs, nsamples, nreals)
yy = np.moveaxis(yy, 1, 0)

med = np.median(yy, axis=-1)

for ii in range(NUM):
    cc, = ax.plot(xx, med[:, ii], alpha=0.5, lw=1) # , color='b')
    cc = cc.get_color()

    conf = np.percentile(yy[:, ii, :], percs, axis=-1)
    ax.fill_between(xx, *conf, alpha=0.15, color=cc, lw=1)

    zz = yy[:, ii, :5]
    zz = zz.reshape(zz.shape[0], -1)
    ax.plot(xx, zz, alpha=0.25, color=cc, lw=1)

for amp in [1e-15, 2e-15, 3e-15]:
    hh = plot._draw_plaw(ax, xx, amp, 1.0, color='0.95', alpha=0.70, zorder=99, lw=2.0, ls='-')
    # zplot.label_line(ax, hh[0], f"{zplot.scientific_notation(amp)}", x=0.85, dy=0.01, alpha=0.65, fontsize=10)
    plot._draw_plaw(ax, xx, amp, 1.0, color='r', alpha=0.75, zorder=100, lw=1.0)

plt.show()

In [None]:
for pp in spars:
    # kale.dist1d(pp)
    print(utils.stats(pp))
    
# plt.show()

In [None]:
spars = data['sample_params'][:].copy()
names = [dd.decode() for dd in data.attrs['param_names']]
print(names, spars.shape)
idx = names.index('mmb_amp')
print(idx, names[idx])
spars[:, idx] = np.log10(spars[:, idx])
fig = plt.figure()
corner, _ = kale.corner(spars.T, dist2d=dict(contour=False), dist1d=dict(probability=True, density=False, hist=True), labels=names)
fig = corner.fig
fig.savefig('temp_params.png')
plt.show()

In [None]:
percs = [25, 75]
xx = data['fobs'][:] * YR
gwb = data['gwb'][:]
spars = data['sample_params'][:]
print(gwb.shape, spars.shape)
# print(f"{len(data.attrs['param_names'])=}")

spars = spars.T
for ii, pname in enumerate(data.attrs['param_names']):
    pname = pname.decode()
    fig, ax = plot.figax(
        figsize=[6, 4], ylim=[0.5e-16, 3e-13], left=0.13, bottom=0.13, right=0.98, top=0.88,
        xlabel="Frequency [1/yr]", ylabel="Characteristic Strain",
    )
    plot._twin_hz(ax)

    vals = spars[ii].copy()
    sort_idx = np.argsort(vals)
    sv = vals[sort_idx]
    gwb = data['gwb'][:].copy()[sort_idx]

    num = gwb.shape[0]
    colors = mpl.cm.get_cmap('Spectral')(np.linspace(0.05, 0.95, num))

    for jj, pv in enumerate(sv):
        # Make sure there are only `num_label` labeled lines on the plot
        label = (jj % (num//num_label) == 0) or (num <= num_label) or (jj == num - 1)
        cc = colors[jj]

        yy = gwb[jj]
        med = np.median(yy, axis=-1)
        conf = np.percentile(yy, percs, axis=-1)
        label = f"{pv:.3e}" if label else ""
        ax.plot(xx, med, color=colors[jj], label=label, lw=1.0)
        ax.fill_between(xx, *conf, color=colors[jj], alpha=0.2, lw=0.5)
        
    ax.legend(title=pname, fontsize=8, loc='upper right')
    output_path = Path(holo._PATH_OUTPUT, f'eccen_lib_{pname}.png')
    fig.savefig(output_path, dpi=300)
    
plt.show()

In [None]:
xx = data['fobs'][:] * YR
gwb = data['gwb'][:]
spars = data['sample_params'][:]
print(gwb.shape, spars.shape)
# print(f"{len(data.attrs['param_names'])=}")

spars = spars.T
for ii, pname in enumerate(data.attrs['param_names']):
    fig, ax = plot.figax(xlabel="Frequency [1/yr]", ylabel="Characteristic Strain")
    plot._twin_hz(ax)

    pvals = data['parameters'][pname][:]
    print(ii, pname, pvals[:3], "...", pvals[-3:])
    sv = spars[ii].copy()
    sort_idx = np.argsort(sv)
    sv = sv[sort_idx]
    lo = 0

    num_pvals = len(pvals)
    colors = mpl.cm.get_cmap('Spectral')(np.linspace(0.1, 0.9, num_pvals))

    for jj, pv in enumerate(pvals):
        # Make sure there are only `num_label` labeled lines on the plot
        label = (jj % (num_pvals//num_label) == 0) or (num_pvals <= num_label)

        # This boolean array is False until the first value ABOVE this bin
        hi = (sv > pv)
        # Find the index of the first value ABOVE of this bin
        hi = np.argmax(hi)
        # Make sure ANY of the values are above this bin
        hi = hi if (sv[hi] > pv) else None
        # Select the values since the last bin, up to the next bin
        cut = slice(lo, hi)
        # Select the GWB spectra in this bin, shape (samples, freqs, reals)
        yy = gwb[sort_idx][cut]
        # convert to shape (freqs, samples, reals)
        yy = np.moveaxis(yy, 1, 0)
        # convert to shape (freqs, samples*reals)
        yy = yy.reshape(yy.shape[0], -1)

        med = np.median(yy, axis=-1)
        conf = np.percentile(yy, [25, 75], axis=-1)
        label = f"{pv:.1e}" if label else ""
        ax.plot(xx, med, color=colors[jj], label=label)
        ax.fill_between(xx, *conf, color=colors[jj], alpha=0.2)
        
        # set the lo-end of the next bin, to be the hi-end of this bin
        lo = hi

    ax.legend(title=pname)    
    

In [None]:
# kale.corner(data['params'][:].T)

In [None]:
data = h5py.File(fname, 'r')
print("Keys:", data.keys())
for kk, vv in data.items():
    try:
        print("\t", kk, vv.shape)
    except AttributeError:
        continue

In [None]:
xx = data['fobs'][:] * YR
print(xx)

In [None]:
# gwb_mean = data['gwb'][:].mean(axis=-1)
fidx = 0
gwb = data['gwb'][:, 0, :]

mmb_idx = data['lhs_grid'][:, -1]
mmb_sort_idx = np.argsort(mmb_idx)
gwb = gwb[mmb_sort_idx].reshape(100, -1)
gwb_mean = np.mean(gwb, axis=-1)
gwb_std = np.std(gwb, axis=-1)

fig, ax = plt.subplots()
ax.set(xlabel='mmb_amp index', ylabel=f'strain at frequency {fidx+1:d}/T')
ax.grid(True)
xx = np.arange(100)
ax.plot(xx, gwb_mean, 'r.-', alpha=0.75)
ax.errorbar(xx, gwb_mean, gwb_std, fmt='|', alpha=0.5)
# ax.plot(gwb_mean)
# tw = ax.twinx()
ax.plot(xx, gwb_std, 'g--')

plt.show()
fname = Path('~/hard04b_001.png').expanduser()
fig.savefig(fname)
print(fname)

In [None]:


# data['sample_params'][:]
fig, ax = plot.figax()
plot._twin_hz(ax)

ax.plot(xx, data['gwb'][0, :, :].mean(axis=-1))

plt.show()