In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import h5py
from tqdm import tqdm
import os
import kalepy as kale


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo

In [None]:
fobs_cents, fobs_edges = holo.utils.pta_freqs()
cad = 1.0/(2*fobs_cents[-1])
print(cad/YR)

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    '''
    https://stackoverflow.com/a/18926541
    '''
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

cmap_base = 'magma_r'
magma_r = truncate_colormap(cmap_base, 0, 0.85)
blacks = truncate_colormap('binary', 0.4, 1.0)

cmap_Blues = truncate_colormap('Blues', 0.4, 1)
cmap_PuBuGn = truncate_colormap('PuBuGn', 0.2, 1)
cmap_Greens = truncate_colormap('Greens', 0.4, 1)
cmap_Oranges = truncate_colormap('Oranges', 0.4, 1)
cmap_Purples = truncate_colormap('Purples', 0.4, 1)

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
RED_GAMMA = None
RED2WHITE = None



In [None]:
def get_var_data( target, var=None, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    path = path + f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}' 
    load_data_from_file = path+f'/data_params.npz' 
    load_dets_from_file = path+f'/detstats_s{nskies}_ssn' 
    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    if var is not None:
        data = file['data'][var]
        params = file['params'][var]
    else:
        data = file['data']
        params = file['params']

    file.close()

    file = np.load(load_dets_from_file, allow_pickle=True)
    if var is not None:
        dsdat = file['dsdat'][var]
    else:
        dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

# Get Edges

In [None]:
# # get edges
# NBINS = 40
# sam = holo.sams.Semi_Analytic_Model()
# mt_edges = sam.mtot[25:-15]/MSOL
# dc_edges = np.geomspace(1e1, 1e4, NBINS)
# sn_edges = np.geomspace(2.e-6,515, NBINS)

# Get Hist Data

In [None]:
# TARGET = 'hard_time'
# var = None
# data, params, dsdat = get_var_data(target=TARGET, var=0, nskies=NSKIES, nvars=NVARS,
#                         path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# print(params)

In [None]:
# TARGET = 'hard_time'
# hist_mt = []
# hist_dc = []
# bghist_mt = []
# bghist_dc = []

# for var in [0,-1]:
#     data, params, dsdat = get_var_data(target=TARGET, var=var, nskies=NSKIES, nvars=NVARS,
#                         path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )

#     sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
#     bgpar = data['bgpar']

#     # total mass
#     _ssmtt = sspar[0]/MSOL
#     _ssmtt = np.repeat(_ssmtt, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
#     _ssmtt = np.swapaxes(_ssmtt, -2, -1)

#     _bgmtt = bgpar[0]/MSOL # F,R
#     print(f"{_bgmtt.shape=}")

#     # comoving distance
#     _ssdcm = sspar[4]/MPC
#     _ssdcm = np.repeat(_ssdcm, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
#     _ssdcm = np.swapaxes(_ssdcm, -2, -1)

#     _bgdcm = bgpar[4]/MPC

#     # snr
#     _snssi = dsdat['snr_ss']
#     _snrbg = dsdat['snr_bg']
#     _snrbg = np.repeat(_snrbg, NFREQS).reshape(NREALS, NFREQS)
#     _snrbg = np.swapaxes(_snrbg, 0, 1)
    

#     # get total mass histograms
#     _hist_mt, sne, mte, = np.histogram2d(
#         _snssi.flatten(), _ssmtt.flatten(), bins=(sn_edges, mt_edges))
#     hist_mt.append(_hist_mt)

#     _bghist_mt, sne, mte, = np.histogram2d(
#         _snrbg.flatten(), _bgmtt.flatten(), bins=(sn_edges, mt_edges))
#     bghist_mt.append(_bghist_mt)

#     # get comoving distance histograms
#     _hist_dc, sne, dce, = np.histogram2d(
#         _snssi.flatten(), _ssdcm.flatten(), bins=(sn_edges, dc_edges))
#     hist_dc.append(_hist_dc)

#     _bghist_dc, sne, dce = np.histogram2d(
#         _snrbg.flatten(), _bgdcm.flatten(), bins=(sn_edges, dc_edges))
#     bghist_dc.append(_bghist_dc)


# Try AGAIN

In [None]:
# # get edges
# NBINS = 20
# sam = holo.sams.Semi_Analytic_Model()
# mt_edges = sam.mtot[25:-15]/MSOL
# dc_edges = np.geomspace(1e1, 1e4, NBINS)
# sn_edges = np.geomspace(2.e-6,515, NBINS)

# dcgrid, mtgrid, = np.meshgrid(dc_edges, mt_edges, )

In [None]:
# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)


In [None]:
# snr_grid = np.zeros((len(mt_edges)-1, len(dc_edges)-1))
# sum_grid = np.zeros_like(snr_grid)

In [None]:
# print(snr_grid.shape)
# print(mtt.shape, dcm.shape, snr.shape)

In [None]:
# for mm in range(len(mt_edges)-1):
#     # inmass = np.where(np.logical_and(mtt>mt_edges[0], mtt<mt_edges[mm+1])
#     for dd in range(len(dc_edges)-1):
#         inbin = np.where(
#             np.logical_and(
#                 np.logical_and(mtt>mt_edges[mm], mtt<mt_edges[mm+1]),
#                 np.logical_and(dcm>dc_edges[dd], dcm<dc_edges[dd+1])
#             )
#         )
#         snr_grid[mm,dd] = np.mean(snr[inbin])


# min grid

In [None]:
# VAR = 0
# data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
#                     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# ht_min = params[TARGET]

# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)


# min_grid = np.zeros_like(snr_grid)
# for mm in range(len(mt_edges)-1):
#     # inmass = np.where(np.logical_and(mtt>mt_edges[0], mtt<mt_edges[mm+1])
#     for dd in range(len(dc_edges)-1):
#         inbin = np.where(
#             np.logical_and(
#                 np.logical_and(mtt>mt_edges[mm], mtt<mt_edges[mm+1]),
#                 np.logical_and(dcm>dc_edges[dd], dcm<dc_edges[dd+1])
#             )
#         )
#         min_grid[mm,dd] = np.sum(snr[inbin])


# max grid

In [None]:
# VAR = -1
# data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
#                     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# ht_max = params[TARGET]

# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)


# max_grid = np.zeros_like(snr_grid)
# for mm in range(len(mt_edges)-1):
#     # inmass = np.where(np.logical_and(mtt>mt_edges[0], mtt<mt_edges[mm+1])
#     for dd in range(len(dc_edges)-1):
#         inbin = np.where(
#             np.logical_and(
#                 np.logical_and(mtt>mt_edges[mm], mtt<mt_edges[mm+1]),
#                 np.logical_and(dcm>dc_edges[dd], dcm<dc_edges[dd+1])
#             )
#         )
#         max_grid[mm,dd] = np.sum(snr[inbin])

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, max_grid)
# plt.colorbar(im, ax=ax, label='$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, min_grid)
# plt.colorbar(im, ax=ax, label='$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_min)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, np.log10(max_grid))
# plt.colorbar(im, ax=ax, label='$\log \sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, np.log10(min_grid))
# plt.colorbar(im, ax=ax, label='$\log \sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_min)

In [None]:
# # ratio
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, (max_grid/min_grid))
# plt.colorbar(im, ax=ax, label='[$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$] / [$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$]' 
#              % (ht_max, ht_min))

In [None]:
# # ratio
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mtgrid, dcgrid, np.log10(max_grid/min_grid))
# plt.colorbar(im, ax=ax, label='log [$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$] / [$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$]' 
#              % (ht_max, ht_min))

In [None]:
# fig, ax = plot.figax()
# ax.pcolormesh(mtgrid, dcgrid, snr_grid)
# ax.set_xscale('log')
# ax.set_yscale('log')

# min hist

In [None]:
# mt_grid, dc_grid, = np.meshgrid(mt_edges, dc_edges, )

In [None]:
# VAR = 0
# data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
#                     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# ht_min = params[TARGET]

# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)

# hist_min, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
#                                         bins=(dc_edges, mt_edges), weights=snr.flatten())

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mt_grid, dc_grid, (hist_min))
# plt.colorbar(im, ax=ax, label='$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_min)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mt_grid, dc_grid, np.log10(hist_min))
# plt.colorbar(im, ax=ax, label='$\log \sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_min)

# med grid

In [None]:
# VAR = 10
# data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
#                     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# ht_med = params[TARGET]

# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)


# hist_med, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
#                                         bins=(dc_edges, mt_edges), weights=snr.flatten())

# max grid

In [None]:
# VAR = -1
# data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
#                     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
# ht_max = params[TARGET]

# snr = dsdat['snr_ss']
# sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
# mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# mtt = np.swapaxes(mtt, -1, -2)
# dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
# dcm = np.swapaxes(dcm, -1, -2)


# hist_max, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
#                                         bins=(dc_edges, mt_edges), weights=snr.flatten())

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mt_grid, dc_grid, (hist_max))
# plt.colorbar(im, ax=ax, label='$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# im = ax.pcolormesh(mt_grid, dc_grid, np.log10(hist_max))
# plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)

In [None]:
# mt_cents = holo.utils.midpoints(mt_edges)
# dc_cents= holo.utils.midpoints(dc_edges)

In [None]:
# # sum
# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# ax.contour(mt_cents, dc_cents, (hist_max), cmap=cm.Purples)
# ax.contour(mt_cents, dc_cents, (hist_med), cmap=cm.Blues)
# ax.contour(mt_cents, dc_cents, (hist_min), cmap=cm.Greens)
# # plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)

In [None]:
# # sum
# levels = np.linspace(3.5,5,8)
# # levels=None

# fig, ax = plot.figax(xlabel='Mass', ylabel='Distance',)
# ax.contour(mt_cents, dc_cents, np.log10(hist_min), cmap=cmap_Greens, levels=levels)
# ax.contour(mt_cents, dc_cents, np.log10(hist_med), cmap=cmap_Blues, levels=levels)
# ax.contour(mt_cents, dc_cents, np.log10(hist_max), cmap=cmap_Purples, levels=levels)
# # plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)
# handles = [
#     mpl.lines.Line2D([0], [0], label=f"{ht_min:.2f}", color='#1e8144'),
#     mpl.lines.Line2D([0], [0], label=f"{ht_med:.2f}", color="#347ebb"),
#     mpl.lines.Line2D([0], [0], label=f"{ht_max:.2f}", color="#6e56a6")
# ]
# # labels = [f"{ht_min:.2f}", f"{ht_med:.2f}", labelf"{ht_max:.2f}", color='#1e8144']
# ax.legend(handles=handles, title='$\\tau_\mathrm{hard}$', loc='upper left')

# Histogram Functions

In [None]:
def hist_min_med_max(TARGET, mt_edges, dc_edges):

    # MINIMUM
    VAR = 0
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_min = params[TARGET]

    # single sources 
    snr = dsdat['snr_ss'] # (F,R,S,L)
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)

    hist_min, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 
    bghist_min, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())
    
    # MEAN
    VAR = 10
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_med = params[TARGET]

    # single sources 
    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)

    hist_med, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 

    bghist_med, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())
    
    # MAXIMUM
    VAR = -1
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_max = params[TARGET]

    # single sources
    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)


    hist_max, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 

    bghist_max, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())

    rv = dict(hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
              bghist_min=bghist_min, bghist_med=bghist_med, bghist_max=bghist_max,
              par_min=par_min, par_med=par_med, par_max=par_max)
    
    return rv


def draw_contours(ax, TARGET, mt_edges, dc_edges,
                  levels=np.linspace(3.5,5,8), colors=None, load_from=None):
    if load_from is None:
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
    else:
        rv = np.load(load_from)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        rv.close()

    mt_cents = holo.utils.midpoints(mt_edges)
    dc_cents= holo.utils.midpoints(dc_edges)

    if colors is None:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), cmap=cmap_Greens, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), cmap=cmap_Blues, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), cmap=cmap_Purples, levels=levels)
    else:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), colors=colors[0], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), colors=colors[1], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), colors=colors[2], levels=levels)
    # plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)
    
    if colors is None:
        colors = ['#1e8144', "#347ebb", '#6e56a6' ]
    handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min:.2f}", color=colors[0]),
        mpl.lines.Line2D([0], [0], label=f"{par_med:.2f}", color=colors[1]),
        mpl.lines.Line2D([0], [0], label=f"{par_max:.2f}", color=colors[2])
    ]
    return handles

# Save all hist data

In [None]:
BUILD_ARRAYS=False

In [None]:

NBINS = 40
TAKE = 6 # take 5: ssn
MT_IDX_MIN=30 
MT_IDX_MAX=-1
DC_EDGE_MIN=3e1
DC_EDGE_MAX=1e4


# get edges
sam = holo.sams.Semi_Analytic_Model()
mt_edges = sam.mtot[MT_IDX_MIN:MT_IDX_MAX]/MSOL
dc_edges = np.geomspace(DC_EDGE_MIN, DC_EDGE_MAX, NBINS)


targets = [
    # 'gsmf_phi0', 'gsmf_mchar0_log10', 'mmb_mamp_log10', 'mmb_scatter_dex',
    #            'hard_time', 
               'hard_gamma_inner'
               ]
if BUILD_ARRAYS:
    for TARGET in tqdm(targets):
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        bghist_min, bghist_med, bghist_max = rv['bghist_min'], rv['bghist_med'], rv['bghist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        print(file=/'')
        filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
        filename = filename+f'/mt_dc_hist_tk{TAKE}_{TARGET}_{NBINS}bins.npz'
        np.savez(filename,
                hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
                bghist_min=bghist_min, bghist_med=bghist_med, bghist_max=bghist_max, 
                    par_min=par_min, par_med=par_med, par_max=par_max)

### Plot individual targets from saved hist data

In [None]:

def plot_one_target(TARGET, title, NBINS=NBINS,
                    MT_IDX_MIN=MT_IDX_MIN, MT_IDX_MAX=MT_IDX_MAX,
                    DC_EDGE_MIN=DC_EDGE_MIN, DC_EDGE_MAX=DC_EDGE_MAX,
                    levels=np.linspace(3.5,5,8), load_from=None):

    # get edges
    sam = holo.sams.Semi_Analytic_Model()
    mt_edges = sam.mtot[MT_IDX_MIN:MT_IDX_MAX]/MSOL
    dc_edges = np.geomspace(DC_EDGE_MIN, DC_EDGE_MAX, NBINS)

    # make figure
    fig, ax = plot.figax_single(xlabel='Mass', ylabel='Distance',)
    handles = draw_contours(ax, TARGET, mt_edges=mt_edges, dc_edges=dc_edges, levels=levels, load_from=load_from)
    ax.legend(handles=handles, title=title, loc='upper left')
    return fig


targets = ['hard_gamma_inner']


for tt, TARGET in enumerate(tqdm(targets)):
    title = plot.PARAM_KEYS[TARGET]

    filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
    filename = filename+f'/mt_dc_hist_tk{TAKE}_{TARGET}_{NBINS}bins.npz'
    print(filename)
    print(title)
    fig = plot_one_target(TARGET, title, load_from=filename)

# Kale Contours

Using default sigmas: _DEF_SIGMAS = [0.5, 1.0, 1.5, 2.0]

In [None]:
print(f"{sam.mtot[35]/MSOL:.2e}")

In [None]:
# TAKE = 3
# MT_IDX_MIN=30 
# MT_IDX_MAX=-8
# DC_EDGE_MIN=3e1
# DC_EDGE_MAX=1e4

# # get edges
# sam = holo.sams.Semi_Analytic_Model()
# mt_edges = sam.mtot[MT_IDX_MIN:MT_IDX_MAX]/MSOL
# dc_edges = np.geomspace(DC_EDGE_MIN, DC_EDGE_MAX, NBINS)

In [None]:
# target = 'hard_time'

# green_colors = ['#98d594', '#2e984e', '#00441b']
# blue_colors = ['#94c4df', '#2e7ebc',  '#09306b']
# orange_colors = ['#fda762', '#e2540a', '#7f2704']

# filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
# load_from = filename+f'/mt_dc_hist_tk{TAKE}_{target}_{NBINS}bins.npz'

# rv = np.load(load_from)
# hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
# par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
# rv.close()

# mt_cents = holo.utils.midpoints(mt_edges)
# dc_cents= holo.utils.midpoints(dc_edges)

In [None]:
# fig, ax = plot.figax(xlabel='Mass [M$_\odot$]', ylabel='Distance [Mpc]')
# cmap_idx = [0.4, 0.7, 1.0]
# for ii, hist in enumerate([hist_min, hist_med, hist_max]):
#     cmap = truncate_colormap('Blues', cmap_idx[ii], cmap_idx[ii])
#     kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
#                          np.swapaxes(hist,0,1), cmap=cmap, outline=False)

In [None]:
# fig, ax = plot.figax(xlabel='Mass [M$_\odot$]', ylabel='Distance [Mpc]')
# cmap_idx = [0.4, 0.7, 1.0]
# for ii, hist in enumerate([hist_min, hist_med, hist_max]):
#     cmap = truncate_colormap('Blues', cmap_idx[ii], cmap_idx[ii])
#     kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
#                          np.swapaxes(hist,0,1), cmap=cmap, outline=True)

In [None]:
quantiles, sigmas = kale.plot._default_quantiles(sigmas=[0.5,1.0,1.5])

In [None]:
# sigmas = [1.0, 2.0]
# quantiles = 1.0 - np.exp(-0.5 * np.square(sigmas))
# print(quantiles)

# fig = plot_all_targets_catcolors(targets, NBINS=NBINS,
#                     MT_IDX_MIN=MT_IDX_MIN, MT_IDX_MAX=MT_IDX_MAX,
#                     DC_EDGE_MIN=DC_EDGE_MIN, DC_EDGE_MAX=DC_EDGE_MAX,
#                     smooth=True)
# savepath = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/snr_contours'
# savename = f"{savepath}/snr_smooth_contours.png"
# fig.savefig(savename, dpi=300)

In [None]:
# sigmas = [1.0, 2.0]
# quantiles = 1.0 - np.exp(-0.5 * np.square(sigmas))
# print(quantiles)

# fig = plot_all_targets_catcolors(targets, NBINS=NBINS,
#                     MT_IDX_MIN=MT_IDX_MIN, MT_IDX_MAX=MT_IDX_MAX,
#                     DC_EDGE_MIN=DC_EDGE_MIN, DC_EDGE_MAX=DC_EDGE_MAX,
#                      quantiles=quantiles)

# Single

In [None]:
def plot_all_targets_single(
        targets, NBINS=NBINS, 
        MT_IDX_MIN=MT_IDX_MIN, MT_IDX_MAX=MT_IDX_MAX, DC_EDGE_MIN=DC_EDGE_MIN, DC_EDGE_MAX=DC_EDGE_MAX,
        quantiles=None, smooth=None):
    green_colors = ['#98d594', '#2e984e', '#00441b']
    blue_colors = ['#94c4df', '#2e7ebc',  '#09306b']
    orange_colors = ['#fda762', '#e2540a', '#7f2704']

    catcolors = [green_colors, green_colors, 
                orange_colors, orange_colors, 
                blue_colors, blue_colors]
    

    cmap_idx = [0.4, 0.7, 1.0]
    cmaps = ['Greens', 'Greens', 
             'Oranges', 'Oranges',
             'Blues', 'Blues',]

    # get edges
    sam = holo.sams.Semi_Analytic_Model()
    mt_edges = sam.mtot[MT_IDX_MIN:MT_IDX_MAX]/MSOL
    dc_edges = np.geomspace(DC_EDGE_MIN, DC_EDGE_MAX, NBINS)

    # make figure
    xlabel='Mass [$\mathrm{M}_\odot$]'
    ylabel='Distance [$\mathrm{Mpc}$]'
    fig, axs = plot.figax_single(nrows=3, ncols=2, sharey=True, sharex=True,
                                 height=7)
    fig.text(0.55, 0.075, xlabel, ha='center', va='bottom', )
    plt.subplots_adjust(wspace=0, hspace=0)
    for ii, ax in enumerate(tqdm(axs.flatten())):
        colors = catcolors[ii]
        if ii == 2:
            ax.set_ylabel(ylabel)
        title = plot.PARAM_KEYS[targets[ii]]

        # load histogram data
        filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
        filename = filename+f'/mt_dc_hist_tk{TAKE}_{targets[ii]}_{NBINS}bins.npz'
        rv = np.load(filename)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        bghist_min, bghist_med, bghist_max = rv['bghist_min'], rv['bghist_med'], rv['bghist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        rv.close()

        # plot histogram
        bghists = [bghist_min, bghist_med, bghist_max]
        for hh, hist in enumerate([hist_min, hist_med, hist_max]):
            # plot single source contours
            cmap = truncate_colormap(cmaps[ii], cmap_idx[hh], cmap_idx[hh])
            kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
                                np.swapaxes(hist,0,1), cmap=cmap, 
                                outline=False, quantiles=quantiles, smooth=smooth)
            
            # plot background contours
            cmap = truncate_colormap(cm.Greys, cmap_idx[hh], cmap_idx[hh])
            kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
                                     np.swapaxes(bghists[hh],0,1), cmap=cmap,
                                     outline=False, quantiles=quantiles, smooth=smooth,
                                     alpha=0.5, linestyles='--', linewidth=1)
            

        # make legend
        alpha = 1- np.sqrt(1-0.8)
        handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min:.2f}", color=colors[0], alpha=alpha),
        mpl.lines.Line2D([0], [0], label=f"{par_med:.2f}", color=colors[1], alpha=alpha),
        mpl.lines.Line2D([0], [0], label=f"{par_max:.2f}", color=colors[2], alpha=alpha),
        ]
        leg = ax.legend(handles=handles, title=title, loc='upper left', 
                  handletextpad=0.25, borderpad=0.25, labelspacing=0.25, frameon=0,
                  borderaxespad=0.25)
        leg._legend_box.align = "left"
    
    # plt.subplots_adjust(wspace=0, hspace=0)
    return fig

fig = plot_all_targets_single(targets, smooth=True)

savepath = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/snr_contours'
savename = f"{savepath}/snr_bgkale_contours_single.png"
fig.savefig(savename, dpi=300)