In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim


In [None]:
SHAPE = None
NREALS = 30
NFREQS = 40
NLOUDEST = 10

CONSTRUCT = True
JUST_DETSTATS = True

NPARS = 6

NPSRS = 40
NSKIES = 25
TARGET = 'hard_time'
TITLE = '$\log( M_\mathrm{char,0} / \mathrm{M}_\odot )$'

In [None]:
path = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_09B'
load_data_from_file = path+f'/{TARGET}_{NPARS}vars_clbrt_reals.npz' 
save_data_to_file = path+f'/{TARGET}_{NPARS}vars_clbrt_reals.npz' 

if CONSTRUCT is False:
    import os
    if os.path.exists(load_data_from_file) is False:
        err = 'load data file does not exist, you need to construct it.'
        raise Exception(err)

# Construct Parameter Space

In [None]:
pspace = holo.param_spaces.PS_Uniform_09B(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)

In [None]:
param_names = pspace.param_names
num_pars = len(param_names)
pars = 0.5*np.ones(num_pars)

In [None]:
def vary_parameter(
        target_param,    # the name of the parameter, has to exist in `param_names`
        params_list = [0.0,  0.5, 1.0],  # the values we'll check
        pspace = holo.param_spaces.PS_Uniform_09B(holo.log, nsamples=1, sam_shape=SHAPE, seed=None),
        nreals=NREALS, nfreqs=NFREQS,
        pars=None, save_dir=None, 
        get_ds=False, #npsrs=NPSRS, sigma=SIGMA, nskies=NSKIES,
        ):
    # get the parameter names from this library-space
    param_names = pspace.param_names
    num_pars = len(pspace.param_names)
    print(f"{num_pars=} :: {param_names=}")

    # choose each parameter to be half-way across the range provided by the library
    if pars is None:
        pars = 0.5 * np.ones(num_pars) 
    str_pars = str(pars).replace(" ", "_").replace("[", "").replace("]", "")
    # Choose parameter to vary
    param_idx = param_names.index(target_param)

    data = []
    params = []
    # dsdat = []
    for ii, par in enumerate(params_list):
        pars[param_idx] = par
        print(f"{ii=}, {pars=}")
        # _params = pspace.param_samples[0]*pars
        _params = pspace.normalized_params(pars)
        params.append(_params)
        # construct `sam` and `hard` instances based on these parameters
        sam, hard = pspace.model_for_params(_params, pspace.sam_shape)
        if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
            hard_name = 'Fixed Time'
        elif isinstance(hard, holo.hardening.Hard_GW):
            hard_name = 'GW Only'
        # run this model, retrieving binary parameters and the GWB
        _data = holo.librarian.run_model(sam, hard, nreals, nfreqs, nloudest=NLOUDEST,
                                        gwb_flag=False, singles_flag=True, params_flag=True, details_flag=True)
        data.append(_data)
        # if get_ds:
        #     _dsdat = detstats.detect_pspace_model(
        #         _data['fobs_cents'], _data['hc_ss'], _data['hc_bg'], NPSRS, SIGMA, NSKIES)
        #     dsdat.append(_dsdat)
    if save_dir is not None:
        str_shape = str(sam.shape).replace(", ", "_").replace("(", "").replace(")", "")
        filename = save_dir+'/%s_p%s_s%s.npz' % (target_param, str_pars, str_shape)
        np.savez(filename, data=data, params=params, hard_name=hard_name, shape=sam.shape, target_param=target_param )
        print('saved to %s' % filename)

    return (data, params) #, dsdat)

## Calculate Data (if CONSTRUCT) and Individual DetStats (if CONSTRUCT or JUST_DETSTAATS)

In [None]:
NPSRS = 40

JUST_DETSTATS = True
if JUST_DETSTATS or CONSTRUCT:
    # get data (not dsdat_)
    if CONSTRUCT:
        params_list = np.linspace(0,1,NPARS)
        data, params, temp = vary_parameter(target_param=TARGET, params_list=params_list, get_ds=False)
    else:
        file = np.load(load_data_from_file, allow_pickle=True)
        print(file.files)
        data = file['data']
        params = file['params']
        file.close()

    fobs = data[1]['fobs_cents'] 
    dur = 1.0/fobs[0]
    cad = 1.0/(2*fobs[-1])
    
    # get dsdat for each data/param
    dsdat = []
    for ii, _data in enumerate(data):
        print(f"{ii=}")

        # get strain info
        hc_bg = _data['hc_bg']
        hc_ss = _data['hc_ss']

        # get calibrated sigmas for each realization
        sigmas, avg_dps, std_dps = detstats.calibrate_every_real(hc_bg, fobs, NPSRS, maxtrials=1)
        
        # for each realization, get individual calibrated detstats   
        dp_ss = np.zeros((NREALS, NSKIES))     
        dp_bg = np.zeros(NREALS)
        snr_ss = np.zeros((NFREQS, NREALS, NSKIES, NLOUDEST))
        snr_bg = np.zeros((NREALS))
        gamma_ssi = np.zeros((NFREQS, NREALS, NSKIES, NLOUDEST))
        # ev_ss = np.zeros
        
        for rr in range(NREALS):
            # get psrs for the given calibrated realizations
            psrs = detstats._build_pta(NPSRS, sigmas[rr], dur, cad)
            # use those psrs to calculate realization detstats
            _dp_bg, _snr_bg = detstats.detect_bg_pta(psrs, fobs, hc_bg[:,rr:rr+1], ret_snr=True)
            print(f"{_dp_bg.shape=}, {_snr_bg.shape=}")
            dp_bg[rr], snr_bg[rr] = _dp_bg.squeeze(), _snr_bg.squeeze()
            _dp_ss, _snr_ss, _gamma_ssi = detstats.detect_ss_pta(
                psrs, fobs, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], ret_snr=True)
            print(f"{_dp_ss.shape=}, {_snr_ss.shape=}, {_gamma_ssi.shape=} ")
            dp_ss[rr], snr_ss[:,rr], gamma_ssi[:,rr] = _dp_ss.squeeze(), _snr_ss.squeeze(), _gamma_ssi.squeeze()
        ev_ss = detstats.expval_of_ss(gamma_ssi)
        df_ss, df_bg = detstats.detfrac_of_reals(dp_ss, dp_bg)
        _dsdat = {
            'dp_ss':dp_ss, 'snr_ss':snr_ss, 'gamma_ssi':gamma_ssi, 
            'dp_bg':dp_bg, 'snr_bg':snr_bg,
            'df_bg':df_bg, 'ev_ss':ev_ss,
        }
        dsdat.append(_dsdat)
    np.savez(save_data_to_file,
             data = data, dsdat=dsdat, params=params)

In [None]:
for ds in dsdat[:3]:
    print(f"{holo.utils.stats(ds['dp_bg'])=}")
    print(f"{holo.utils.stats(ds['dp_ss'])=}")
    print(f"{holo.utils.stats(ds['ev_ss'])=}")

In [None]:
# if CONSTRUCT:
#     params_list = np.linspace(0,1,NPARS)
#     print(params_list)
#     data, params, dsdat = vary_parameter('gsmf_mchar0_log10', params_list=params_list, get_ds=True)
# else:
#     file = np.load(filename, allow_pickle=True)
#     print(file.files)
#     data = file['data']
#     params = file['params']
#     dsdat = file['dsdat']

In [None]:
# def_draw_ev(ax):
# plt.rcParams['mathtext.fontset'] = "cm"
# plt.rcParams["font.family"] = "serif"

def draw_skies_vs_bg(ax, skies_ss, dp_bg, label=None,
                     color='k', mean=True):
    xx = dp_bg # shape (R,)
    if mean:
        yy = np.mean(skies_ss, axis=-1) # shape (R,)
    else: 
        yy = np.mean(skies_ss, axis=-1) # shape (R,)
    yerr = np.std(skies_ss, axis=-1) # shape (R,)

    hh = ax.errorbar(xx, yy, yerr, color=color, label=label,
                linestyle='', capsize=3, marker='o', alpha=0.5)
    return hh


# def plot_evss_vs_dpbg(dsdat):
#     colors = cm.rainbow_r(np.linspace(0, 1, len(dsdat)))

#     fig, ax = plot.figax(xlabel='Background Detection Probability', 
#                          ylabel='$\langle$ Single Source Detections $\\rangle_\mathrm{skies}$')
#     for ii, ds in enumerate(dsdat):
#         draw_skies_vs_bg(ax, ds['ev_ss'], ds['dp_bg'], color=colors[ii])
#     # xx = fobs_cents*YR
#     # yy = dsdat['ev_ss']

#     return fig

def plot_dpss_vs_dpbg(dsdat, params, target_param=TARGET, use_ev=True, title=TITLE):
    colors = cm.rainbow_r(np.linspace(0, 1, len(dsdat)))
    sslabel = 'Expected Number' if use_ev else 'Detection Probability'
    fig, ax = plot.figax(xlabel='Background Detection Probability', 
                         ylabel='$\langle$ Single Source %s} $\\rangle_\mathrm{skies}$' % sslabel,
    ) #xscale='linear', yscale='linear')
    handles = []
    for ii, ds in enumerate(dsdat):
        label = '%.2f' % params[ii][target_param]
        detss = ds['ev_ss'] if use_ev else ds['dp_ss']
        hh = draw_skies_vs_bg(ax, detss, ds['dp_bg'], color=colors[ii], label=label)
        handles.append(hh)
    ax.legend(handles=handles, loc = 'lower left', title=title,
              ncols=4, title_fontsize=12)
    # ax.set_aspect(1)
    # xx = fobs_cents*YR
    # yy = dsdat['ev_ss']

    fig.tight_layout()
    return fig

text = 'GSMF: $\psi_0=%.2f, m_{\phi,0}=%.2f$' % ((params[0]['gsmf_phi0']), (params[0]['gsmf_mchar0_log10']))
text = text+'\nMMB: $\mu = %.2f, \epsilon_\mu=%2f$ dex' % (params[0]['mmb_mamp_log10'], params[0]['mmb_scatter_dex'])
text = text+'\n$da/dt: \gamma_\mathrm{inner}=%.2f, \\tau_\mathrm{hard}=%.2f$' % (params[0]['hard_time'], 2.5)
print(text)

param_path = '/Users/emigardiner/GWs/holodeck/output/figures/params'


fig1 = plot_dpss_vs_dpbg(dsdat, params, use_ev=True)
fig1.axes[0].text(0.99,0.01, text, transform=fig1.axes[0].transAxes, verticalalignment='bottom', horizontalalignment='right')
fig1.tight_layout()
# fig1.savefig(param_path+f'/ss_vs_bg_w_{TARGET}_21vars.png')


fig2 = plot_dpss_vs_dpbg(dsdat, params, use_ev=False)
fig2.axes[0].text(0.99,0.99, text, transform=fig1.axes[0].transAxes, verticalalignment='top', horizontalalignment='right')
fig2.tight_layout()

In [None]:

xlabel=TITLE
ylabel='$\gamma_\mathrm{SS}/\gamma_\mathrm{BG}$'
colors = cm.rainbow_r(np.linspace(0, 1, len(params)))


fig, ax = plot.figax(xscale='linear',)
ax.set_ylabel(ylabel, fontsize=14)
ax.set_xlabel(xlabel, fontsize=14)

target=TARGET
xx = []
yy = []
for ii, par in enumerate(params):
    xx.append(params[ii][target])
    dp_bg = dsdat[ii]['dp_bg']
    dp_ss = np.mean(dsdat[ii]['dp_ss'], axis=-1)
    yy.append(dp_ss/dp_bg)
    for rr in range(len(yy[0])):
        ax.scatter(xx[ii], yy[ii][rr], color=colors[ii], alpha=0.5)
xx = np.array(xx)
yy = np.array(yy)
print(yy.shape)
# for rr, yi in enumerate(yy):
#     ax.scatter(xx,yy[:,rr])
plot.draw_med_conf(ax, xx, yy)
# for rr in range(len(yy[0])):
#     ax.scatter(xx,yy[:,rr], alpha=0.5, color=colors[])

Could do the same, but vary another model parameter for color, or show some other parameter

In [None]:
print(NREALS)
print(dsdat[0]['snr_ss'].shape)

# Get Mass Normalization

In [None]:
# Get SS mass normalization

ssmtot = []
ss_dp = []
mt_hisnr= np.zeros((NPARS, NREALS))
for ii,dat in enumerate(data):
    # ssmtot.append(dat['sspar'][0])
    snr_ss = dsdat[ii]['snr_ss'] # (F,R,S,L)
    # print(snr_ss.shape)
    ssmtot = dat['sspar'][0]
    # print(ssmtot.shape)
    # for each realization, find the frequency with the loudest single source
    for rr in range(len(snr_ss[0])):
        argmax = np.argmax(snr_ss[:,rr,:,:])
        fidx, sidx, lidx = np.unravel_index(argmax, (NFREQS, NSKIES, NLOUDEST))
        # print(f"{fidx=}, {rr=}, {sidx=}, {lidx=}")
        mt_hisnr[ii,rr] = ssmtot[fidx, rr, lidx]

mt_hisnr = np.log10(mt_hisnr/MSOL)
vmin = np.min(mt_hisnr)
vmax = np.max(mt_hisnr)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

print(f"{vmin=}, {vmax=}")
print(mt_hisnr.shape)

# Plot with mass colorbar

In [None]:
xlabel=TITLE
# ylabel='$\gamma_\mathrm{SS}/\gamma_\mathrm{BG}$'
ylabel = '$EV_\mathrm{SS} / DP_\mathrm{BG}$'
cmap = cm.rainbow

fig, ax = plot.figax(xscale='linear',)
ax.set_ylabel(ylabel, fontsize=14)
ax.set_xlabel(xlabel, fontsize=14)

target=TARGET
xx = []
yy = []
for ii, par in enumerate(params):
    xx.append(params[ii][target])
    dp_bg = dsdat[ii]['dp_bg']
    dp_ss = np.mean(dsdat[ii]['ev_ss'], axis=-1)
    yy.append(dp_ss/dp_bg)
    for rr in range(len(yy[0])):
        ax.scatter(xx[ii], yy[ii][rr], color=cmap(norm(mt_hisnr[ii,rr])),  alpha=0.5,)
xx = np.array(xx)
yy = np.array(yy)
print(yy.shape)
# for rr, yi in enumerate(yy):
#     ax.scatter(xx,yy[:,rr])
plot.draw_med_conf(ax, xx, yy)
# for rr in range(len(yy[0])):
#     ax.scatter(xx,yy[:,rr], alpha=0.5, color=colors[])
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), label='log ($M_\mathrm{max\ SNR} / [\mathrm{M}_\odot]$',)
fig.tight_layout()

figloc = '/Users/emigardiner/GWs/holodeck/output/figures/params'
# fig.savefig(figloc+'/dpratio_vs_%s_w_mass_%dvars.png' % (TARGET, NPARS))