In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim


In [None]:
SHAPE = None
NREALS = 100
NFREQS = 40
NLOUDEST = 10

CONSTRUCT = False
JUST_DETSTATS = False
SAVEFIG = True
TOL=0.01
MAXBADS=5

NVARS = 21

NPSRS = 40
NSKIES = 100
# TARGET = 'hard_time' # EDIT AS NEEDED
# TITLE = '$\\tau_\mathrm{hard}$'   # EDIT AS NEEDED


TARGET = 'gsmf_mchar0_log10' # EDIT AS NEEDED
TITLE = 'GSMF: log($M_\mathrm{char,0} / \mathrm{M}_\odot$)'   # EDIT AS NEEDED

In [None]:
path = '/Users/emigardiner/GWs/holodeck/output/anatomy_09B'
load_data_from_file = path+f'/{TARGET}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}.npz' 
# save_data_to_file = path+f'/{TARGET}_v{NVARS}vars_clbrt_pta.npz' 

if CONSTRUCT is False:
    import os
    if os.path.exists(load_data_from_file) is False:
        err = 'load data file does not exist, you need to construct it.'
        raise Exception(err)
print(load_data_from_file)
# print(save_data_to_file)

## Get Param Names

In [None]:
pspace = holo.param_spaces.PS_Uniform_09B(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)
param_names = pspace.param_names

## Load Data

In [None]:
file = np.load(load_data_from_file, allow_pickle=True)
print(file.files)
data = file['data']
params = file['params']
dsdat = file['dsdat']
file.close()


# Plot

In [None]:

def draw_skies_vs_bg(ax, skies_ss, dp_bg, label=None,
                     color='k', mean=True):
    xx = dp_bg # shape (R,)
    if mean:
        yy = np.mean(skies_ss, axis=-1) # shape (R,)
    else: 
        yy = np.mean(skies_ss, axis=-1) # shape (R,)
    yerr = np.std(skies_ss, axis=-1) # shape (R,)

    hh = ax.errorbar(xx, yy, yerr, color=color, label=label,
                linestyle='', capsize=3, marker='o', alpha=0.5)
    return hh

def plot_dpss_vs_dpbg(dsdat, params, target_param=TARGET, use_ev=True, title=TITLE):
    colors = cm.rainbow_r(np.linspace(0, 1, len(dsdat)))
    sslabel = 'Expected Number' if use_ev else 'Detection Probability'
    fig, ax = plot.figax(xlabel='Background Detection Probability', 
                         ylabel='$\langle$ Single Source %s} $\\rangle_\mathrm{skies}$' % sslabel,
    ) #xscale='linear', yscale='linear')
    handles = []
    for ii, ds in enumerate(dsdat):
        label = '%.2f' % params[ii][target_param]
        detss = ds['ev_ss'] if use_ev else ds['dp_ss']
        hh = draw_skies_vs_bg(ax, detss, ds['dp_bg'], color=colors[ii], label=label)
        handles.append(hh)
    ax.legend(handles=handles, loc = 'lower left', title=title,
              ncols=4, title_fontsize=12)

    fig.tight_layout()
    return fig

text = 'GSMF: $\psi_0=%.2f, m_{\phi,0}=%.2f$' % ((params[0]['gsmf_phi0']), (params[0]['gsmf_mchar0_log10']))
text = text+'\nMMB: $\mu = %.2f, \epsilon_\mu=%2f$ dex' % (params[0]['mmb_mamp_log10'], params[0]['mmb_scatter_dex'])
text = text+'\n$da/dt: \gamma_\mathrm{inner}=%.2f, \\tau_\mathrm{hard}=%.2f$' % (params[0]['hard_time'], 2.5)
print(text)




fig1 = plot_dpss_vs_dpbg(dsdat, params, use_ev=True)
fig1.axes[0].text(0.99,0.01, text, transform=fig1.axes[0].transAxes, verticalalignment='bottom', horizontalalignment='right')
fig1.tight_layout()
if SAVEFIG: 
    param_path = '/Users/emigardiner/GWs/holodeck/output/figures/params'
    fig1.savefig(param_path+f'/ss_vs_bg_{TARGET}_{NVARS}vars_clbrtd.png')


fig2 = plot_dpss_vs_dpbg(dsdat, params, use_ev=False)
fig2.axes[0].text(0.99,0.99, text, transform=fig1.axes[0].transAxes, verticalalignment='top', horizontalalignment='right')
fig2.tight_layout()

In [None]:

# xlabel=TITLE
# ylabel='$\gamma_\mathrm{SS}/\gamma_\mathrm{BG}$'
# colors = cm.rainbow_r(np.linspace(0, 1, len(params)))


# fig, ax = plot.figax(xscale='linear',)
# ax.set_ylabel(ylabel, fontsize=14)
# ax.set_xlabel(xlabel, fontsize=14)

# target=TARGET
# xx = []
# yy = []
# for ii, par in enumerate(params):
#     xx.append(params[ii][target])
#     dp_bg = dsdat[ii]['dp_bg']
#     dp_ss = np.mean(dsdat[ii]['dp_ss'], axis=-1)
#     yy.append(dp_ss/dp_bg)
#     for rr in range(len(yy[0])):
#         ax.scatter(xx[ii], yy[ii][rr], color=colors[ii], alpha=0.5)
# xx = np.array(xx)
# yy = np.array(yy)
# print(yy.shape)
# # for rr, yi in enumerate(yy):
# #     ax.scatter(xx,yy[:,rr])
# plot.draw_med_conf(ax, xx, yy)
# # for rr in range(len(yy[0])):
# #     ax.scatter(xx,yy[:,rr], alpha=0.5, color=colors[])

Could do the same, but vary another model parameter for color, or show some other parameter

# Get Mass Normalization

In [None]:
# # Get SS mass normalization
# use_snr=False
# # ssmtot = []
# # ss_dp = []
# if use_snr:
#     mt_hisnr= np.zeros((NVARS, NREALS))
#     for ii,dat in enumerate(data):
#         # ssmtot.append(dat['sspar'][0])
#         snr_ss = dsdat[ii]['snr_ss'] # (F,R,S,L)
#         # print(snr_ss.shape)
#         ssmtot = dat['sspar'][0]
#         # print(ssmtot.shape)
#         # for each realization, find the frequency with the loudest single source
#         for rr in range(len(snr_ss[0])):
#             argmax = np.argmax(snr_ss[:,rr,:,:])
#             fidx, sidx, lidx = np.unravel_index(argmax, (NFREQS, NSKIES, NLOUDEST))
#             # print(f"{fidx=}, {rr=}, {sidx=}, {lidx=}")
#             mt_hisnr[ii,rr] = ssmtot[fidx, rr, lidx]

#     mt_hisnr = np.log10(mt_hisnr/MSOL)
#     vmin = np.min(mt_hisnr)
#     vmax = np.max(mt_hisnr)
#     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

# # print(f"{vmin=}, {vmax=}")
# # print(mt_hisnr.shape)

# Any Normalization

In [None]:

#### use mass of max-DP single source for each realization
norm_by_maxDPmass = True
if norm_by_maxDPmass:
    normcol = np.zeros((NVARS, NREALS))
    for ii, dat in enumerate(data):
        gamma_ssi = dsdat[ii]['gamma_ssi'] # (F,R,S,L)
        ssmtot = dat['sspar'][0] # (F,R,L)
        # for each realization, find the frequency, sky, and loudest with the loudest single source
        for rr in range(NREALS):
            argmax = np.argmax(gamma_ssi[:,rr,:,:])
            ff, ss, ll = np.unravel_index(argmax, (NFREQS, NSKIES, NLOUDEST))
            normcol[ii,rr] = ssmtot[ff,rr,ll]
    normcol = np.log10(normcol/MSOL)

#### could include other colorbar options here, to set normcol values

norm = mpl.colors.Normalize(vmin=np.min(normcol), vmax=np.max(normcol))

In [None]:
print(dsdat[0]['dp_bg'].shape)

# Normalize with individual gamma ssi

In [None]:
print(f"{dsdat[0]['gamma_ssi'].shape=}")
print(f"{dsdat[0]['ev_ss'].shape=}")

In [None]:
print(dat['sspar'].shape)

In [None]:
cc = np.zeros((NVARS, NREALS, NSKIES))
for ii, dat in enumerate(data):
    for rr in range(NREALS):
        for ss in range(NSKIES):
            argmax = np.argmax(dsdat[ii]['gamma_ssi'][:,rr,ss,:])
            fidx, lidx =  np.unravel_index(argmax, shape=(NFREQS, NLOUDEST))
            cc[ii,rr,ss] = np.log10(dat['sspar'][0,fidx,rr,lidx]/MSOL)

norm = mpl.colors.Normalize(vmin=np.min(cc), vmax=np.max(cc))

# Plot with mass colorbar

In [None]:
print(xx[:,0,0].shape)

In [None]:
xlabel=TITLE
# ylabel='$\gamma_\mathrm{SS}/\gamma_\mathrm{BG}$'
ylabel = '$EV_\mathrm{SS} / DP_\mathrm{BG}$'
clabel = '$\log M / [\mathrm{M}_\odot] $)'
cmap = cm.rainbow

fig, ax = plot.figax(xscale='linear',)
ax.set_ylabel(ylabel, fontsize=14)
ax.set_xlabel(xlabel, fontsize=14)

target=TARGET
xx = []
yy = []
for ii, par in enumerate(params):
    xx.append(params[ii][target])
    dp_bg = np.repeat(dsdat[ii]['dp_bg'], NSKIES).reshape(NREALS, NSKIES)
    dp_ss = dsdat[ii]['ev_ss']
    yy.append(dp_ss/dp_bg)
    # for rr in range(len(yy[0])):

x1= np.array(xx)
yy = np.array(yy)

xx = np.repeat(x1, NREALS*NSKIES).reshape(NVARS, NREALS, NSKIES)
print(f"{xx.shape=}, {yy.shape=}, {cc.shape=}")
# sax = ax.scatter(xx, yy, c=cc, cmap=cmap, norm=norm, alpha=0.5,)
# cbar = plt.colorbar(sax, ax=ax, label=clabel)


print(f"{xx.shape=}, {yy.shape=}")

col = 'tab:blue'
for pp in [50, 95]:
    med, *conf = np.percentile(yy, [50, 50-pp/2, 50+pp/2], axis=(1,2))
    ax.plot(x1, med, alpha=0.5, color=col)
    ax.fill_between(x1, *conf, color=col, alpha=0.25)
# cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), label='log ($M(\mathrm{max\ EV_\mathrm{SS}}) / [\mathrm{M}_\odot]$)',)
# fig.tight_layout()

# figloc = '/Users/emigardiner/GWs/holodeck/output/figures/params'
# if SAVEFIG: fig.savefig(figloc+'/dpratio_vs_%s_w_mass_%dvars_clbrtd.png' % (TARGET, NVARS))