In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import h5py
from tqdm import tqdm
import os


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo
from holodeck.sams import sam

import hasasia.sim as hsim

import sys
sys.path.append('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation')
import anatomy as anat

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    '''
    https://stackoverflow.com/a/18926541
    '''
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

cmap_base = 'magma_r'
magma_r = truncate_colormap(cmap_base, 0, 0.85)
blacks = truncate_colormap('binary', 0.4, 1.0)

cmap_Blues = truncate_colormap('Blues', 0.4, 1)
cmap_PuBuGn = truncate_colormap('PuBuGn', 0.2, 1)
cmap_Greens = truncate_colormap('Greens', 0.4, 1)
cmap_Oranges = truncate_colormap('Oranges', 0.4, 1)
cmap_Purples = truncate_colormap('Purples', 0.4, 1)

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
RED_GAMMA = None
RED2WHITE = None



In [None]:
def get_var_data( target, var=None, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    path = path + f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}' 
    load_data_from_file = path+f'/data_params.npz' 
    load_dets_from_file = path+f'/detstats_s{nskies}' 
    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    if var is not None:
        data = file['data'][var]
        params = file['params'][var]
    else:
        data = file['data']
        params = file['params']

    file.close()

    file = np.load(load_dets_from_file, allow_pickle=True)
    if var is not None:
        dsdat = file['dsdat'][var]
    else:
        dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

# Histogram Functions

In [None]:
def hist_min_med_max(TARGET, mt_edges, dc_edges):

    # get minimum weighted histogram
    VAR = 0
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_min = params[TARGET]

    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)

    hist_min, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # get median weighted histogram
    VAR = 10
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_med = params[TARGET]

    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)


    hist_med, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    # get maximum weighted histogram
    VAR = -1
    data, params, dsdat = get_var_data(target=TARGET, var=VAR, nskies=NSKIES, nvars=NVARS,
                        path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'  )
    par_max = params[TARGET]

    snr = dsdat['snr_ss']
    sspar = sings.all_sspars(data['fobs_cents'], data['sspar'])
    mtt = np.repeat(sspar[0]/MSOL, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    mtt = np.swapaxes(mtt, -1, -2)
    dcm = np.repeat(sspar[4]/MPC, NSKIES).reshape(NFREQS, NREALS, NLOUDEST, NSKIES)
    dcm = np.swapaxes(dcm, -1, -2)


    hist_max, dc_ed, mt_ed = np.histogram2d(dcm.flatten(), mtt.flatten(), 
                                            bins=(dc_edges, mt_edges), weights=snr.flatten())
    
    rv = dict(hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
              par_min=par_min, par_med=par_med, par_max=par_max)
    
    return rv


def draw_contours(ax, TARGET, title, mt_edges=mt_edges, dc_edges=dc_edges,
                  levels=np.linspace(3.5,5,8), colors=None, load_from=None):
    if load_from is None:
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
    else:
        rv = np.load(load_from)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        rv.close()

    mt_cents = holo.utils.midpoints(mt_edges)
    dc_cents= holo.utils.midpoints(dc_edges)

    if colors is None:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), cmap=cmap_Greens, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), cmap=cmap_Blues, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), cmap=cmap_Purples, levels=levels)
    else:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), colors=colors[0], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), colors=colors[1], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), colors=colors[2], levels=levels)
    # plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)
    
    if colors is None:
        colors = ['#1e8144', "#347ebb", '#6e56a6' ]
    handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min:.2f}", color=colors[0]),
        mpl.lines.Line2D([0], [0], label=f"{par_med:.2f}", color=colors[1]),
        mpl.lines.Line2D([0], [0], label=f"{par_max:.2f}", color=colors[2])
    ]
    return handles


def draw_kale_contours2d(ax, TARGET, title, mt_edges=mt_edges, dc_edges=dc_edges,
                  levels=np.linspace(3.5,5,8), colors=None, load_from=None):
    if load_from is None:
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
    else:
        rv = np.load(load_from)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        rv.close()

    mt_cents = holo.utils.midpoints(mt_edges)
    dc_cents= holo.utils.midpoints(dc_edges)

    if colors is None:
        ax.contour(mt_cents, dc_cents, np.log10(hist_min), cmap=cmap_Greens, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), cmap=cmap_Blues, levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), cmap=cmap_Purples, levels=levels)
    else:
        ax.contour(mt_cents, dc_cents, np.log10(hist_med), colors=colors[1], levels=levels)
        ax.contour(mt_cents, dc_cents, np.log10(hist_max), colors=colors[2], levels=levels)
        plt.colorbar(im, ax=ax, label='log$\sum \mathrm{SNR} (\\tau_\mathrm{hard}=%.2f \mathrm{Gyr})$' % ht_max)
    
    if colors is None:
        colors = ['#1e8144', "#347ebb", '#6e56a6' ]
    handles = [
        mpl.lines.Line2D([0], [0], label=f"{par_min:.2f}", color=colors[0]),
        mpl.lines.Line2D([0], [0], label=f"{par_med:.2f}", color=colors[1]),
        mpl.lines.Line2D([0], [0], label=f"{par_max:.2f}", color=colors[2])
    ]
    return handles

# Save all hist data

In [None]:
# get edges
NBINS = 40
TAKE = 4
mt_idx_min=30 
mt_idx_max=-8
dc_edge_min=3e1
dc_edge_max=1e4

# get edges
sam = holo.sams.Semi_Analytic_Model()
mt_edges = sam.mtot[mt_idx_min:mt_idx_max]/MSOL
dc_edges = np.geomspace(dc_edge_min, dc_edge_max, NBINS)

targets = ['gsmf_phi0', 'gsmf_mchar0_log10', 'mmb_mamp_log10', 'mmb_scatter_dex',
               'hard_time', 'hard_gamma_inner']

if BUILD_ARRAYS:
    for TARGET in tqdm(targets):
        rv = hist_min_med_max(TARGET, mt_edges=mt_edges, dc_edges=dc_edges,)
        hist_min, hist_med, hist_max = rv['hist_min'], rv['hist_med'], rv['hist_max']
        par_min, par_med, par_max = rv['par_min'], rv['par_med'], rv['par_max']
        filename = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz/figdata'
        filename = filename+f'mt_dc_hist_tk{TAKE}_{TARGET}_{NBINS}bins_tk1.npz'
        np.savez(filename,
                hist_min=hist_min, hist_med=hist_med, hist_max=hist_max, 
                    par_min=par_min, par_med=par_med, par_max=par_max)