In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
import os

from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim


In [None]:
TARGET = 'hard_time' # EDIT AS NEEDED
TITLE = '$\\tau_\mathrm{hard}$'   # EDIT AS NEEDED
NVARS = 21

In [None]:
path = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_09B'
load_data_from_file = path+f'/{TARGET}_{NVARS}vars_clbrt_reals.npz' 

if os.path.exists(load_data_from_file) is False:
    err = 'load data file does not exist, you need to construct it.'
    raise Exception(err)
print(load_data_from_file)
# print(save_data_to_file)

# Load Data from File

In [None]:
file = np.load(load_data_from_file, allow_pickle=True)
print(file.files)
data = file['data']
params = file['params']
dsdat = file['dsdat']
file.close()

# Get Shapes, Edges, Frequency Info

In [None]:
nfreqs, nreals, nloudest = [*data[0]['hc_ss'].shape]
fobs_cents = data[0]['fobs_cents']

In [None]:
# mass edges
sam = holo.sams.Semi_Analytic_Model()
mm_edges = sam.mtot/MSOL

# frequency bin edges
ff_edges = data[0]['fobs_edges']/10**9

# get frequencies
ssfrq = np.repeat(fobs_cents, nreals*nloudest)/10**9 # nHz
bgfrq = np.repeat(fobs_cents, nreals)/10**9 # nHz

In [None]:
print(np.log10(28))

# Plot Mass vs Frequency

In [None]:
saveloc = '/Users/emigardiner/GWs/holodeck/output/figures/params/mass_vs_freq_hard_time_21vars'
vmin = 0
vmax_bg = 0
vmax_ss = 0
# find maxes
for ii, dat in enumerate(data):
    parm = params[ii]
    # get masses
    ssmtt = dat['sspar'][0].flatten() /MSOL # msol
    bgmtt = dat['bgpar'][0].flatten() /MSOL # msol

    # get histograms
    hist_ss, mme, ffe = np.histogram2d(ssmtt, ssfrq, bins=(mm_edges, ff_edges))
    hist_bg, mme, ffe = np.histogram2d(bgmtt, bgfrq, bins=(mm_edges, ff_edges))

    print(f"{np.max(hist_ss)=}, {np.max(hist_bg)=}")
    if np.max(np.log10(hist_ss)) > vmax_ss:
        vmax_ss = np.max(np.log10(hist_ss))
    if np.max(np.log10(hist_bg)) > vmax_bg:
        vmax_bg = np.max(np.log10(hist_bg))

# plot
for ii, dat in enumerate(data):
    parm = params[ii]
    # get masses
    ssmtt = dat['sspar'][0].flatten() /MSOL # msol
    bgmtt = dat['bgpar'][0].flatten() /MSOL # msol

    # get histograms
    hist_ss, mme, ffe = np.histogram2d(ssmtt, ssfrq, bins=(mm_edges, ff_edges))
    hist_bg, mme, ffe = np.histogram2d(bgmtt, bgfrq, bins=(mm_edges, ff_edges))

    fig, axs = plot.figax(ncols=2, ylabel='$M$ [M$_\odot$]', xlabel = '$f$ [nHz]', figsize=(8,4))
    ffgrid, mtgrid, = np.meshgrid(ffe, mme)
    # print(f"{mtgrid.shape=}, {ffgrid.shape=}")

    ax = axs[0]
    ax.set_title('$h_c^2$-weighted Background')
    ax.set_facecolor('k')
    im = ax.pcolormesh(ffgrid, mtgrid, np.log10(hist_bg), cmap='viridis', vmin=vmin, vmax=vmax_ss)
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

    ax = axs[1]
    ax.set_title('Single Sources')
    ax.set_facecolor('k')
    ax.set_ylabel(None)
    im = ax.pcolormesh(ffgrid, mtgrid, np.log10(hist_ss), cmap='inferno', vmin=vmin, vmax=vmax_ss)
    cbar = plt.colorbar(im, ax=ax, label='$\log N$', orientation='horizontal', pad=0.2)

    fig.suptitle((f"{TARGET}={parm[TARGET]:.2f}"))
    fig.tight_layout()
    fig.savefig(saveloc+f'/hard_time_mvsf_{ii:02d}.png')

# Plot Detection Probability for masses

In [None]:
dat = data[-1]
parm = params[-1]
dsd = dsdat[-1]

ssmtt = dat['sspar'][0]/MSOL
ssfrq = np.repeat(fobs_cents, nreals*nloudest).reshape(nfreqs, nreals, nloudest)/10**9


# gamma_ssi = dsd['gamma_ssi']
gamma_ssi = np.mean(dsd['gamma_ssi'], axis=2)
print(gamma_ssi.shape)
dp_bg = dsd['dp_bg']
print(dp_bg.shape)

ratio = gamma_ssi/dp_bg[np.newaxis,:,np.newaxis]
print(ratio.shape, ssmtt.shape, ssfrq.shape)

## Normalize for detstats

Individual detection probability of each single source

In [None]:
colval = np.log10(ratio)

vmin=np.min(colval)
# vmax=np.max(colval)
vmax=np.max(colval)
# vmax=-5
norm = mpl.colors.TwoSlopeNorm(vmin=vmin,  vcenter=0, vmax=vmax,)

cmap = cm.rainbow
fig, ax = plot.figax(xlabel = '$f$ [nHz]', ylabel = '$M$ [M$_\odot$]')
# for ff in range(30,40):
#     for rr in range(10):
#         for ll in range(5):
#             im = ax.scatter(ssfrq[ff,rr,ll], ssmtt[ff,rr,ll], color=cmap(norm(ratio[ff,rr,ll])), alpha=0.5)
#             print(ratio[ff,rr,ll], norm(ratio[ff,rr,ll]))
# im = ax.scatter(ssfrq.flatten(), ssmtt.flatten(), color=cmap(norm(colval.flatten())), alpha=0.01)
im = ax.pcolormesh(ssfrq, ssmtt, colval)
cbar = plt.colorbar(im) #, vmin=vmin, vmax=vmax)

In [None]:
print(holo.utils.stats(colval))
print(vmin, vmax)

In [None]:
print(norm(ratio.flatten()
           ))