In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import tqdm


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo
from holodeck.sams import sam

import hasasia.sim as hsim

import sys
sys.path.append('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation')
import anatomy as anat

In [None]:
file_path = '/Users/emigardiner/GWs/holodeck/output/brc_output/ss63_09Bparams_2023-06-22_uniform-09b_n500_r100_f40_l10'
hdf = h5py.File(file_path+'/sam_lib.hdf5', 'r')

print(hdf.keys())
sample_params = hdf['sample_params'][...]
hc_ss = hdf['hc_ss'][...]
hc_bg = hdf['hc_bg'][...]
sspar = hdf['sspar'][...]
bgpar = hdf['bgpar'][...]
fobs_gw_cents = hdf['fobs'][:]
hdf.close()

dur = 1/fobs_gw_cents[0]
fobs_gw_cents, fobs_gw_edges = holo.utils.pta_freqs(dur, num=len(fobs_gw_cents))
nsamps, nfreqs, nreals, nloudest = [*hc_ss.shape]


In [None]:
npz = np.load(file_path+'/PS_Uniform_09B.pspace.npz', allow_pickle=True)
print(npz.files)
param_names = npz['param_names']
print(param_names)
print(npz['class_name'])
lib_name = '%s v%s, %d samples' % (npz['class_name'], npz['librarian_version'], nsamps)
print(lib_name)
# print(npz['sam_shape'])
npz.close()

In [None]:
print(sample_params[0])

In [None]:
ssfrq = np.repeat(fobs_gw_cents, nsamps*nreals*nloudest).reshape(nsamps, nfreqs, nreals, nloudest)
bgfrq = np.repeat(fobs_gw_cents, nsamps*nreals).reshape(nsamps, nfreqs, nreals)
print(f"{ssfrq.shape=}, {sspar[:,0].shape=}")

# Mass vs Frequency Marginalized Over Library

In [None]:
sam = holo.sams.Semi_Analytic_Model()
mt_edges = sam.mtot/MSOL
ff_edges = fobs_gw_edges /10**9

ssmtt = sspar[:,0,...].flatten()/MSOL
ssfrq = np.repeat(fobs_gw_cents, nsamps*nreals*nloudest)/10**9
print(f"{ssmtt.shape=}, {ssfrq.shape=}")

hist, ffe, mte, = np.histogram2d(ssfrq, ssmtt, bins=(ff_edges, mt_edges))
print(f"{ffe.shape=}, {mte.shape=}")

cmap = 'inferno'
fig, ax = plot.figax(xlabel='$M$ [M$_\odot$]', ylabel = '$f$ [nHz]')
mtgrid, ffgrid = np.meshgrid(mt_edges, ff_edges)
print(f"{mtgrid.shape=}, {ffgrid.shape=}")
ax.pcolormesh(mtgrid, ffgrid, np.log10(hist), cmap=cmap)

ax.set_title('All Uniform 09B')

# For a single sample

In [None]:
nsort = holo.detstats.rank_samples(hc_ss, hc_bg, fobs_gw_cents, fidx=1,
                                   )
print(nsort[:5])

In [None]:
def plot_best_index(best):
    if best is None:
        ssmtt = sspar[:,0,...].flatten()/MSOL
        ssfrq = np.repeat(fobs_gw_cents, nsamps*nreals*nloudest)/10**9

        bgmtt = bgpar[:,0,...].flatten()/MSOL
        bgfrq = np.repeat(fobs_gw_cents, nsamps*nreals)/10**9
    else:
        ssmtt = sspar[best,0,...].flatten()/MSOL
        ssfrq = np.repeat(fobs_gw_cents, nreals*nloudest)/10**9

        bgmtt = bgpar[best,0,...].flatten()/MSOL
        bgfrq = np.repeat(fobs_gw_cents, nreals)/10**9
    # print(f"{ssmtt.shape=}, {ssfrq.shape=}")

    hist_ss, ffe, mte, = np.histogram2d(ssfrq, ssmtt, bins=(ff_edges, mt_edges))
    hist_bg, ffe, mte, = np.histogram2d(bgfrq, bgmtt, bins=(ff_edges, mt_edges))
    
    # print(f"{ffe.shape=}, {mte.shape=}")

    fig, axs = plot.figax(ncols=2, xlabel='$M$ [M$_\odot$]', ylabel = '$f$ [nHz]', figsize=(8,4))
    mtgrid, ffgrid = np.meshgrid(mte, ffe)
    # print(f"{mtgrid.shape=}, {ffgrid.shape=}")

    ax = axs[0]
    ax.set_title('$h_c^2$-weighted Background')
    ax.set_facecolor('k')
    im = ax.pcolormesh(mtgrid, ffgrid, np.log10(hist_bg), cmap='viridis')
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

    ax = axs[1]
    ax.set_title('Single Sources')
    ax.set_facecolor('k')
    ax.set_ylabel(None)
    im = ax.pcolormesh(mtgrid, ffgrid, np.log10(hist_ss), cmap='inferno')
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

    if best is None:
        fig.suptitle(lib_name)
    else:
        fig.suptitle(
            'sample %d:, $t_\mathrm{hard}$=%.1fGyr, $\gamma_\mathrm{inner}$=%.1f, $\Phi_0$=%.1f, '
            % (best, sample_params[best,0], sample_params[best,-1], sample_params[best,1], ) 
            + '$\log M_\mathrm{char,0}={%.1f}$, $\log \mu_\mathrm{MMB}=%.1f$, $\sigma_\mathrm{MMB,dex}$=%.1f' 
            % (sample_params[best,2], sample_params[best,3], sample_params[best,4]))
        fig.tight_layout()

## plot best

In [None]:
for best in nsort[:5]:
    plot_best_index(best)

## plot worst

In [None]:
for best in nsort[-5:]:
    plot_best_index(best)

# Marginalized

In [None]:
plot_best_index(None)