In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import h5py
from tqdm import tqdm
import os
import kalepy as kale


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo

In [None]:
fobs_cents, fobs_edges = holo.utils.pta_freqs()
cad = 1.0/(2*fobs_cents[-1])
print(cad/YR)

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    '''
    https://stackoverflow.com/a/18926541
    '''
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

cmap_base = 'magma_r'
magma_r = truncate_colormap(cmap_base, 0, 0.85)
blacks = truncate_colormap('binary', 0.4, 1.0)

cmap_Blues = truncate_colormap('Blues', 0.4, 1)
cmap_PuBuGn = truncate_colormap('PuBuGn', 0.2, 1)
cmap_Greens = truncate_colormap('Greens', 0.4, 1)
cmap_Oranges = truncate_colormap('Oranges', 0.4, 1)
cmap_Purples = truncate_colormap('Purples', 0.4, 1)

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10
BGL = 1

SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
RED_GAMMA = None
RED2WHITE = None



In [None]:
def get_var_data( target, var=None, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'    , bgl=1 
):

    path = path + f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}' 
    load_data_from_file = path+f'/data_params.npz' 
    load_dets_from_file = path+f'/detstats_s{nskies}_bgl{bgl}' 
    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    if var is not None:
        data = file['data'][var]
        params = file['params'][var]
    else:
        data = file['data']
        params = file['params']

    file.close()

    file = np.load(load_dets_from_file, allow_pickle=True)
    if var is not None:
        dsdat = file['dsdat'][var]
    else:
        dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

# Histogram Functions

In [None]:
def hist_min_med_max(TARGET, mt_edges, dc_edges):

    # MINIMUM
    VAR = 0
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS, bgl=BGL,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz')
    par_min = params[TARGET]

    # single sources 
    snr = dsdat['snr_ss'] # (F,R,S,L)
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)

    hist_min, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 
    bghist_min, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())
    
    # MEAN
    VAR = 10
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_med = params[TARGET]

    # single sources 
    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)

    hist_med, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 

    bghist_med, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())
    
    # MAXIMUM
    VAR = -1
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_max = params[TARGET]

    # single sources
    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)


    hist_max, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # background
    bgsnr = np.repeat(dsdat['snr_bg'], NFREQS).reshape(NREALS, NFREQS) # (R,F)
    bgsnr = np.swapaxes(bgsnr, 0, 1).flatten() # (FxR)
    bgpar = data['bgpar'] # (4,F,R)
    bgmtt = bgpar[0]/MSOL
    bgdcm = bgpar[4]/MPC 

    bghist_max, dc_ed, mt_ed = np.histogram2d(bgdcm.flatten(), bgmtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=bgsnr.flatten())

    rv = dict(hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
              bghist_min=bghist_min, bghist_med=bghist_med, bghist_max=bghist_max,
              par_min=par_min, par_med=par_med, par_max=par_max,
              mt_edges=mt_edges, dc_edges=dc_edges)
    
    return rv


def draw_contours(ax, TARGET, mt_edges, dc_edges,
                  levels=np.linspace(3.5,5,8), colors=None, load_from=None):
    if load_from is None:
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
    else:
        rv = np.load(load_from)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        rv.close()

    mt_cents = holo.utils.midpoints(mt_edges)
    dc_cents= holo.utils.midpoints(dc_edges)

    if colors is None:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), cmap=cmap_Greens, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), cmap=cmap_Blues, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), cmap=cmap_Purples, levels=levels)
    else:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), colors=colors[0], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), colors=colors[1], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), colors=colors[2], levels=levels)
    # plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)
    
    if colors is None:
        colors = ['#1e8144', "#347ebb", '#6e56a6' ]
    handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min:.2f}", color=colors[0]),
        mpl.lines.Line2D([0], [0], label=f"{par_med:.2f}", color=colors[1]),
        mpl.lines.Line2D([0], [0], label=f"{par_max:.2f}", color=colors[2])
    ]
    return handles

# Save all hist data

In [None]:
BUILD_ARRAYS=True

In [None]:

NBINS = 40
TAKE = 8 # take 7
MT_IDX_MIN=18
MT_IDX_MAX=-9
DC_EDGE_MIN=2e1
DC_EDGE_MAX=8e3

if BUILD_ARRAYS:
    # get edges
    sam = holo.sams.Semi_Analytic_Model(shape=None)
    mt_edges = sam.mtot[MT_IDX_MIN:MT_IDX_MAX]/MSOL
    dc_edges = np.geomspace(DC_EDGE_MIN, DC_EDGE_MAX, NBINS)


    targets = [
        'gsmf_phi0', 'gsmf_mchar0_log10', 'mmb_mamp_log10', 'mmb_scatter_dex',
                'hard_time', 
                'hard_gamma_inner'
                ]
    for TARGET in tqdm(targets):
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        bghist_min, bghist_med, bghist_max = rv['bghist_min'], rv['bghist_med'], rv['bghist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        print(f'{holo.utils.stats(hist_max)},\n{holo.utils.stats(bghist_max)}')
        filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
        filename = filename+f'/mt_dc_hist_tk{TAKE}_{TARGET}_{NBINS}bins.npz'
        np.savez(filename,
                hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
                bghist_min=bghist_min, bghist_med=bghist_med, bghist_max=bghist_max, 
                    par_min=par_min, par_med=par_med, par_max=par_max,
                    mt_edges=rv['mt_edges'], dc_edges=rv['dc_edges'])

# Kale Contours

Using default sigmas: _DEF_SIGMAS = [0.5, 1.0, 1.5, 2.0]

In [None]:
quantiles, sigmas = kale.plot._default_quantiles(sigmas=[0.5,1.0,1.5])

### Single Column

In [None]:
def plot_all_targets_single(
        NBINS=NBINS, 
        quantiles=None, smooth=None):
    targets = [
    'gsmf_phi0', 'gsmf_mchar0_log10', 
    'mmb_mamp_log10', 'mmb_scatter_dex',
    'hard_time', 'hard_gamma_inner'
               ]

    green_colors = ['#98d594', '#2e984e', '#00441b']
    blue_colors = ['#94c4df', '#2e7ebc',  '#09306b']
    orange_colors = ['#fda762', '#e2540a', '#7f2704']

    catcolors = [green_colors, green_colors, 
                orange_colors, orange_colors, 
                blue_colors, blue_colors]
    

    cmap_idx = [0.4, 0.7, 1.0]
    cmaps = ['Greens', 'Greens', 
             'Oranges', 'Oranges',
             'Blues', 'Blues',]

    # make figure
    xlabel='Mass [$\mathrm{M}_\odot$]'
    ylabel='Distance [$\mathrm{Mpc}$]'
    fig, axs = plot.figax_single(nrows=3, ncols=2, sharey=True, sharex=True,
                                 height=7)
    fig.text(0.55, 0.075, xlabel, ha='center', va='bottom', )
    plt.subplots_adjust(wspace=0, hspace=0)
    for ii, ax in enumerate(tqdm(axs.flatten())):
        colors = catcolors[ii]
        if ii == 2:
            ax.set_ylabel(ylabel)
        title = plot.PARAM_KEYS[targets[ii]]

        # load histogram data
        filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
        filename = filename+f'/mt_dc_hist_tk{TAKE}_{targets[ii]}_{NBINS}bins.npz'
        rv = np.load(filename)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        bghist_min, bghist_med, bghist_max = rv['bghist_min'], rv['bghist_med'], rv['bghist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        mt_edges, dc_edges = rv['mt_edges'], rv['dc_edges']
        rv.close()

        # plot histogram
        bghists = [bghist_min, bghist_med, bghist_max]
        for hh, hist in enumerate([hist_min, hist_med, hist_max]):
            # plot single source contours
            cmap = truncate_colormap(cmaps[ii], cmap_idx[hh], cmap_idx[hh])
            kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
                                np.swapaxes(hist,0,1), cmap=cmap, 
                                outline=False, quantiles=quantiles, smooth=smooth)
            
            # plot background contours
            cmap = truncate_colormap(cm.Greys, cmap_idx[hh], cmap_idx[hh])
            kale.plot.draw_contour2d(ax, [mt_edges, dc_edges], 
                                     np.swapaxes(bghists[hh],0,1), cmap=cmap,
                                     outline=False, quantiles=quantiles, smooth=smooth,
                                     alpha=0.5, linestyles='--', linewidth=1)
            

        # make legend
        alpha = 1- np.sqrt(1-0.8)
        handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min}", color=colors[0], alpha=alpha),
        mpl.lines.Line2D([0], [0], label=f"{par_med}", color=colors[1], alpha=alpha),
        mpl.lines.Line2D([0], [0], label=f"{par_max}", color=colors[2], alpha=alpha),
        ]
        # ax.text(0.03, 0.97, title, ha='left', va='top', transform=ax.transAxes)
        leg = ax.legend(handles=handles, title=title, bbox_to_anchor=(0.02,0), loc='lower left', 
                  handletextpad=0.25, borderpad=0.25, labelspacing=0.25, frameon=0, handlelength=1.0,
                  borderaxespad=0.25, alignment='left', ncols=3, columnspacing=0.3)
        # leg._legend_box.align = "left"
    
        ax.set_xticks([10**6, 10**7, 10**8, 10**9, 10**10, 10**11])
        ax.tick_params(axis='y', which='both', right=True, direction='in')
        ax.tick_params(axis='y', which='minor', right=True, direction='in')
        ax.tick_params(axis='x', which='both', top=True, direction='in')
        
        ## set y ticks
        # x_major = mpl.ticker.LogLocator(base = 10.0, numticks = 5)
        # ax.xaxis.set_major_locator(x_major)
        x_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1, 10.0) * 0.1, numticks = 10)
        ax.xaxis.set_minor_locator(x_minor)
        ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())

    # plt.subplots_adjust(wspace=0, hspace=0)
    return fig

fig = plot_all_targets_single(smooth=True)

savepath = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/snr_contours'
savename = f"{savepath}/snr_bgkale_contours_single.png"
fig.savefig(savename, dpi=300, bbox_inches='tight')