In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import tqdm


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo
from holodeck.sams import sam

import hasasia.sim as hsim

import sys
sys.path.append('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation')
import anatomy as anat

# Load Data

In [None]:
# use one file to get the shape
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/hard_time_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
                        allow_pickle=True)             
print(f"{npz.files=}")
data = npz['data']
print(f"{data[0].keys()=}")

# params = npz['params']
# hard_name = npz['hard_name']
shape = npz['shape']
print(f"{shape=}")
# target_param = npz['target_param']
fobs_gw_cents = npz['data'][0]['fobs_cents']
fobs_gw_edges = npz['data'][0]['fobs_edges']

npz.close()

# get param names
pspace = holo.param_spaces.PS_Uniform_09A(holo.log, nsamples=1, sam_shape=shape, seed=None)
param_names = pspace.param_names
print(f"{param_names=}")


hc_ss = data[1]['hc_ss']
hc_bg = data[1]['hc_bg']
bgpar = data[1]['bgpar']
sspar = data[1]['sspar']
sspar = sings.all_sspars(fobs_gw_cents=fobs_gw_cents, sspar=sspar)
print(f"{sings.par_names=}")
nfreqs, nreals, nloudest = [*hc_ss.shape]
print(f"{nfreqs=}, {nreals=}, {nloudest=},")

# # set directory path
# sam_loc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
# save_dir=sam_loc+'/figures'       

# Dev

In [None]:
mtot = bgpar[0].flatten()/MSOL
dcom = bgpar[np.where(sings.par_names=='dcom_final')].flatten()/MPC

nbins=30
mt_edges = np.logspace(np.log10(np.nanmin(mtot)), np.log10(np.nanmax(mtot)), nbins)
print(np.min(mtot), np.max(mtot))
print(mt_edges)
dc_edges = np.logspace(np.log10(np.min(dcom)), np.log10(np.max(dcom)), nbins)
hist, mtbins, dcbins = np.histogram2d(mtot, dcom, bins=(mt_edges, dc_edges))

MT, DC = np.meshgrid(mtbins, dcbins)

fig, ax = plot.figax(xlabel=sings.par_labels[0],
                     ylabel=sings.par_labels[4])
ax.pcolormesh(MT, DC, hist)
fig.tight_layout()

# Plotting function

In [None]:
def _append_freqs_to_pars(data=data[1], debug=False, short_labels=True):
    hc_ss=data['hc_ss']
    hc_bg=data['hc_bg']

    sspar = data['sspar']
    sspar = sings.all_sspars(fobs_gw_cents, sspar)
    # sspar[sspar<0]=0
    bgpar = data['bgpar']


    # Add frequency to pars
    _freqs = np.repeat(fobs_gw_cents, hc_ss[0].size).reshape(hc_ss.shape)
    sspar = np.append(sspar, _freqs).reshape([8,]+[*sspar.shape[1:]])
    _freqs = np.repeat(fobs_gw_cents, hc_bg[0].size).reshape(hc_bg.shape)
    bgpar = np.append(bgpar, _freqs).reshape([8,]+[*bgpar.shape[1:]])

    par_names = np.append(sings.par_names, 'freqs')
    par_units = np.append(sings.par_units, YR)
    par_labels = np.append(sings.par_labels, 'GW Frequency $f_\mathrm{obs}\ \mathrm{yr}^{-1}$')

    if short_labels:
        par_labels[4] = '$d_\mathrm{com}$ (Mpc)'
        par_labels[7] =  'GW$ f_\mathrm{obs}$ (nHz)'

    if debug: print(f"{par_names=}\n{par_units=}\n{par_labels=}")

    pardat= dict(sspar=sspar, bgpar=bgpar, par_names=par_names, par_units=par_units, par_labels=par_labels)

    return pardat

In [None]:
print(sspar.shape)

# Get Edges

In [None]:
mtot_edges = holo.sams.Semi_Analytic_Model().mtot /MSOL
freq_edges = fobs_gw_edges*10**9
print(freq_edges)

dcom_min = np.min([np.nanmin(bgpar[4]), np.nanmin(sspar[4])])
dcom_max = np.max([np.nanmax(bgpar[4]), np.nanmax(sspar[4])])
print(f"{dcom_min/MPC=}, {dcom_max/MPC=}")
dcom_edges = np.geomspace(dcom_min, dcom_max, 30)/MPC

In [None]:
def draw_2D_hist(ax, xx, yy, xx_edges, yy_edges, cmap='viridis'):
    # xx_edges = np.logspace(np.log10(np.nanmin(xx)), np.log10(np.nanmax(xx)), nbins)
    # yy_edges = np.logspace(np.log10(np.nanmin(yy)), np.log10(np.nanmax(yy)), nbins)
    # print(f"{holo.utils.stats(yy_edges)=}")
    hist, yy_edges, xx_edges,  = np.histogram2d(yy.flatten(), xx.flatten(), bins = (yy_edges, xx_edges))
    # print(f"{holo.utils.stats(xx_edges)=}, {holo.utils.stats(yy_edges)=}")

    xgrid, ygrid = np.meshgrid(xx_edges, yy_edges)
    # print(f"{holo.utils.stats(xgrid)=}, {holo.utils.stats(ygrid)=}")
    im = ax.pcolormesh(xgrid, ygrid, np.log10(hist), cmap=cmap)
    cbar = plt.colorbar(im, ax=ax)

def plot_number_densities(data, xx_idx = [0], yy_idx=[4,7], nbins=25, 
                          ylim0=None, ylim1=None, xlim=None):
                        # ylim0=(5*10**1, 5*10**3), ylim1=(5*10**-2, 4E0),
                        # xlim = (2*10**7, 10**11)):
    
    # add frequencies to par arrays
    pardat = _append_freqs_to_pars(data)
    par_names=pardat['par_names']
    par_units=pardat['par_units']
    par_labels=pardat['par_labels']
    sspar=pardat['sspar']
    bgpar=pardat['bgpar']

    ncols = 2
    nrows = len(yy_idx)

    xlabel = par_labels[xx_idx]
    ylabel = par_labels[yy_idx]

    xx_edges = mtot_edges
    yy_edges = [dcom_edges, freq_edges]

    fig, axs = plot.figax_double(nrows=nrows, ncols=ncols, 
                                 sharex=True, figsize=(7,6))
    axs[0,0].set_title('$h_c^2$-weighted background', fontsize=10)
    axs[0,1].set_title('single sources', fontsize=10)
    axs[0,0].set_ylabel(ylabel[0], fontsize=10)
    axs[1,0].set_ylabel(ylabel[1], fontsize=10)
    if ylim0 is not None: axs[0,0].set_ylim(ylim0)
    if ylim1 is not None: axs[1,0].set_ylim(ylim1)
    for ii,ax in enumerate(axs[1,:]):
        ax.set_xlabel(xlabel[0], fontsize=10)
        if xlim is not None: ax.set_xlim(xlim)
    for ii,ax in enumerate(axs[:,1]):
        ax.sharey(axs[ii,0])

    # draw background
    xx = (bgpar[xx_idx]*par_units[xx_idx])
    for ii, yy in enumerate(bgpar[yy_idx]):
        yy = yy*par_units[yy_idx[ii]]
        draw_2D_hist(axs[ii,0], xx, yy, xx_edges = xx_edges, yy_edges=yy_edges[ii])

    # draw single sources
    xx = (sspar[xx_idx]*par_units[xx_idx])
    for ii, yy in enumerate(sspar[yy_idx]):
        yy = yy*par_units[yy_idx[ii]]
        yy[yy<0] = np.min(yy[yy>0]) # patch
        draw_2D_hist(axs[ii,1], xx, yy, 
                     xx_edges = xx_edges, yy_edges=yy_edges[ii], cmap='inferno')

    fig.tight_layout()
    return fig

fig = plot_number_densities(data[1])

# Problem with negative final comoving distances and redshifts

In [None]:


idx = np.where(sings.par_names=='redz_final')
print(np.sum(sspar[idx]<0), 'out of %d sources have z_final<0' % sspar[0].size)
print((sspar[idx])[sspar[idx]<0])

idx = np.where(sings.par_names=='dcom_final')
print(np.sum(sspar[idx]<0), 'out of %d sources have dcom_final<0' % sspar[0].size)
print((sspar[idx])[sspar[idx]<0]/MPC)

In [None]:
print(holo.utils.stats(sspar[idx]/MPC))

In [None]:
idx = np.where(sings.par_names=='mtot')
print(np.sum(sspar[idx]<0), 'out of %d sources have mtot<0' % sspar[0].size)
print((sspar[idx])[sspar[idx]<0]/MPC)

# Plots

# hard_time

In [None]:
TARGET='hard_time'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    plot_mass_freq(data[ii])
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()
# # set directory path
# sam_loc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
# save_dir=sam_loc+'/figures'       

# gsmf_phi0

In [None]:
TARGET='gsmf_phi0'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()

# gsmf_mchar0_log10

In [None]:
TARGET='gsmf_mchar0_log10'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()

## mmbulge_mamp_log10

In [None]:
TARGET='mmb_mamp_log10'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()

# mmb_scatter_dex

In [None]:
TARGET='mmb_scatter_dex'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()

# hard_gamma_inner

In [None]:
TARGET='hard_gamma_inner'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()


for ii in range(len(data)):
    fig = plot_number_densities(data[ii])
    fig.suptitle('%s = %.2e' % (TARGET, params[ii][TARGET]), fontsize=10)
    fig.tight_layout()

# Just the Mass vs Frequency

In [None]:
TARGET='hard_time'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz' 
              % TARGET, allow_pickle=True)             
data = npz['data']
params = npz['params']
npz.close()

def plot_mass_freq(data=data[1]):
    # get edges
    ff_edges=data['fobs_edges']/10**9
    mt_edges=holo.sams.Semi_Analytic_Model().mtot /MSOL

    # get masses and frequencies
    sspar = data['sspar']
    bgpar = data['bgpar']
    sspar = sings.all_sspars(fobs_gw_cents, sspar)
    print(np.sum(sspar[3]<0))

    ssmtt = sspar[0,...].flatten()/MSOL
    ssfrq = np.repeat(fobs_gw_cents, nreals*nloudest)/10**9

    bgmtt = bgpar[0,...].flatten()/MSOL
    bgfrq = np.repeat(fobs_gw_cents, nreals)/10**9

    # calculate histograms
    hist_ss, ffe, mte, = np.histogram2d(ssfrq, ssmtt, bins=(ff_edges, mt_edges))
    hist_bg, ffe, mte, = np.histogram2d(bgfrq, bgmtt, bins=(ff_edges, mt_edges))

    # plot
    fig, axs = plot.figax(ncols=2, xlabel='$M$ [M$_\odot$]', ylabel = '$f$ [nHz]', figsize=(8,4))
    mtgrid, ffgrid = np.meshgrid(mte, ffe)

    ax = axs[0]
    ax.set_title('$h_c^2$-weighted Background')
    im = ax.pcolormesh(mtgrid, ffgrid, np.log10(hist_bg), cmap='viridis')
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

    ax = axs[1]
    ax.set_title('Single Sources')
    ax.set_ylabel(None)
    im = ax.pcolormesh(mtgrid, ffgrid, np.log10(hist_ss), cmap='inferno')
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

plot_mass_freq(data[1])

# if best is None:
#     fig.suptitle(lib_name)
# else:
#     fig.suptitle(
#         'sample %d:, $t_\mathrm{hard}$=%.1fGyr, $\gamma_\mathrm{inner}$=%.1f, $\Phi_0$=%.1f, '
#         % (best, sample_params[best,0], sample_params[best,-1], sample_params[best,1], ) 
#         + '$\log M_\mathrm{char,0}={%.1f}$, $\log \mu_\mathrm{MMB}=%.1f$, $\sigma_\mathrm{MMB,dex}$=%.1f' 
#         % (sample_params[best,2], sample_params[best,3], sample_params[best,4]))
#     fig.tight_layout()