In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
from tqdm import tqdm

from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 1

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

PARAM_NAMES = [
    'hard_time', 'gsmf_phi0', 'gsmf_mchar0_log10',
    'mmb_mamp_log10', 'mmb_scatter_dex', 'hard_gamma_inner'
]

### truncate colormaps

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    '''
    https://stackoverflow.com/a/18926541
    '''
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap


cmap_Blues = truncate_colormap('Blues', 0.4, 1)
cmap_PuBuGn = truncate_colormap('PuBuGn', 0.2, 1)
cmap_Greens = truncate_colormap('Greens', 0.4, 1)
cmap_Oranges = truncate_colormap('Oranges', 0.4, 1)
cmap_Greys = truncate_colormap('Greys', 0.4, 0.8)

# Vary Parameter


In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, nloudest=NLOUDEST,
        red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

# Calculate and save arrays

In [None]:
targets = [
        'hard_time', 'gsmf_phi0', 'gsmf_mchar0_log10',  'mmb_mamp_log10', 
                'mmb_scatter_dex', 
                'hard_gamma_inner'
                ]

In [None]:
def resample_loudest(hc_ss, hc_bg, nloudest):
    if nloudest > hc_ss.shape[-1]: # check for valid nloudest
        err = f"{nloudest=} for detstats must be <= nloudest of hc data"
        raise ValueError(err)
    
    # recalculate new hc_bg and hc_ss
    new_hc_bg = np.sqrt(hc_bg**2 + np.sum(hc_ss[...,nloudest:-1]**2, axis=-1))
    new_hc_ss = hc_ss[...,0:nloudest]

    return new_hc_ss, new_hc_bg

In [None]:
def resample_par(hc_ss, hc_bg, sspar, bgpar, nloudest):
    """
    hc_ss [F,R,L]
    hc_bg [F,R]
    sspar [4,F,R,L]
    bgpar [7,F,R]
    """
    if nloudest > sspar.shape[-1]: # check for valid nloudest
        err = f"{nloudest=} for detstats must be <= nloudest of hc data"
        raise ValueError(err)
    
    new_sspar = sspar[...,0:nloudest]

    # new avg
    new_par = (
        hc_bg**2 * bgpar
        + np.sum(hc_ss[...,nloudest:-1]**2 * sspar[...,nloudest:-1], axis=-1)
    )
    new_sum = hc_bg**2 + np.sum(hc_ss[...,nloudest:-1]**2, axis=-1)
    print(f"{new_sum.shape=}, {new_par.shape=}")
    new_bgpar = new_par/new_sum[np.newaxis,...]

    return new_sspar, new_bgpar

### check resample works

In [None]:
data, params = get_data('hard_time')
hc_ss = data[10]['hc_ss']
print(hc_ss.shape)
hc_bg = data[10]['hc_bg']

hc_tot = np.sqrt(hc_bg**2 + np.sum(hc_ss**2, axis=-1))
print(holo.utils.stats(hc_tot))

hc_ss_new, hc_bg_new = resample_loudest(hc_ss, hc_bg, nloudest=1)
print(hc_ss_new.shape)
hc_tot_new = np.sqrt(hc_bg_new**2 + np.sum(hc_ss_new**2, axis=-1))
print(holo.utils.stats(hc_tot_new))

In [None]:
def build_hcpar_arrays(target, nloudest=NLOUDEST,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz',
                         ):

    path='/Users/emigardiner/GWs/holodeck/output/anatomy_redz'

    parvars = [0,5,10,15,20] 
    labels = []
    yy_ss = []
    yy_bg = []
    data, params = get_data(target, nvars=NVARS, nskies=NSKIES, nreals=NREALS,
                            path=path)
    fobs_cents = data[0]['fobs_cents']
    xx = fobs_cents * YR

    for vv, var in enumerate(parvars):
        labels.append(f"{params[var][target]:.2f}")

        hc_ss_old = data[var]['hc_ss']
        hc_bg_old = data[var]['hc_bg']
        hc_ss, hc_bg = resample_loudest(hc_ss_old, hc_bg_old, nloudest)

        sspar = data[var]['sspar']
        bgpar = data[var]['bgpar']

        sspar = sings.all_sspars(fobs_cents, sspar)
        bgpar = bgpar*sings.par_units[:,np.newaxis,np.newaxis]
        sspar = sspar*sings.par_units[:,np.newaxis,np.newaxis,np.newaxis]
        
        sspar, bgpar = resample_par(hc_ss_old, hc_bg_old, sspar, bgpar, nloudest)
        print(f"{sspar.shape=}, {bgpar.shape=}")

    # parameters to plot
        _yy_ss = [hc_ss[...,0], sspar[0,...,0], #sspar[1,...,0], # sspar[2,],  # strain, mass, mass ratio,
                sspar[4,...,0]] # final comoving distance, single loudest only

        _yy_bg = [hc_bg, bgpar[0], #bgpar[1],  # strain, mass, mass ratio, initial redshift, final com distance
                bgpar[4],]
        yy_ss.append(_yy_ss)
        yy_bg.append(_yy_bg)


    save_name = path+f'/figdata/hcpar/hcpar_arrays_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}'
    if nloudest != 10: save_name += f"_l{nloudest}"
    save_name += '.npz'
    np.savez(save_name, xx=xx, yy_ss=yy_ss, yy_bg=yy_bg, labels=labels)

In [None]:
for target in tqdm(PARAM_NAMES):
    build_hcpar_arrays(target)

# Load arrays

In [None]:
def load_hcpar_arrays(target, nloudest=NLOUDEST,
    path='/Users/emigardiner/GWs/holodeck/output/anatomy_redz'):
    
    load_name = path+f'/figdata/hcpar/hcpar_arrays_{target}_v{NVARS}_r{NREALS}_s{NSKIES}_shape{str(SHAPE)}' 
    if nloudest != 10: load_name += f"_l{nloudest}"
 
    load_name += '.npz'
    file = np.load(load_name)
    xx = file['xx']
    yy_ss = file['yy_ss']
    yy_bg = file['yy_bg']
    labels = file['labels']
    file.close()
    return xx, yy_ss, yy_bg, labels

# Draw Functions

### draw_95ci()

In [None]:
def draw_95ci(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=[0,5,10,15,20]):

    if bgcolors is None:
        bgcolors = colors
    handles=[]

    # plot the background
    for vv, var in enumerate(parvars): 
        if var != 5 and var != 15:
            hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
            # for aa, nn in enumerate(idx):
            #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
            handles.append(hh)

    # plot the single sources
    for vv, var in enumerate(parvars):
        if var != 5 and var != 15:
        # Plot the loudest single sources confidence intervals
            for pp in [95,]:
                percs = pp / 2
                percs = [50 - percs, 50 + percs]
                ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.15, color=colors[var])
            
    return handles
        


### draw_lims()

In [None]:
def draw_lims(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=[0,10,20]):

    if bgcolors is None:
        bgcolors = colors 

    # Plot the background median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(mpl.lines.Line2D([0], [0], color=colors[var]))

    # Plot the single sources
    for vv, var in enumerate(parvars):
        # for aa, nn in enumerate(idx):
            # for ll in range(3):
            # edgecolor = 'k' if ll==0 else None
        ymed = np.median(yy_ss[vv][ii], axis=-1)
        ymax = np.max(yy_ss[vv][ii], axis=-1)
        ymin = np.min(yy_ss[vv][ii], axis=-1)
        if ii == 2:
            ax.errorbar(xx, ymin, yerr=(ymin-ymin, ymed-ymin), color=colors[var], alpha=0.15, 
                    capsize=0.5, lolims=True, marker='o', markersize=1, linestyle='')
        else:
            ax.errorbar(xx, ymax, yerr=(ymax-ymed, ymax-ymax), color=colors[var], alpha=0.15, 
                    capsize=0.5, uplims=True, marker='o', markersize=1, linestyle='')
            # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)
            
    for vv, var in enumerate(parvars):
        # Plot the loudest single sources confidence intervals
        for pp in [68,]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.15, color=colors[var])
    
    return handles 

### draw_tris()

In [None]:
def draw_tris(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=[0,5,10,15,20]):
    """ Plot mins/maxes with tri symbols and 68% ci
    """
    if bgcolors is None:
        bgcolors = colors


    # Plot the background median 
    handles=[]
    for vv, var in enumerate(parvars): 
        if var != 5 and var != 15:

            hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
            # for aa, nn in enumerate(idx):
            #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
            handles.append(mpl.lines.Line2D([0], [0], color=colors[var]))

    # Plot the single sources extrema
    for vv, var in enumerate(parvars):
        if var != 5 and var != 15:

            if ii == 2:
                ymin = np.min(yy_ss[vv][ii], axis=-1)
                ax.scatter(xx, ymin, color=colors[var], alpha=0.75, 
                        marker='2', s=10, linestyle='')
            else:
                ymax = np.max(yy_ss[vv][ii], axis=-1)
                ax.scatter(xx, ymax, color=colors[var], alpha=0.75, 
                        marker='1', s=10, linestyle='')
                # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)

    # Plot the single sources cis           
    for vv, var in enumerate(parvars):
        if var != 5 and var != 15:

            # Plot the loudest single sources confidence intervals
            for pp in [68,]:
                percs = pp / 2
                percs = [50 - percs, 50 + percs]
                ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])
    
    return handles 

# 3 Separate Singles

In [None]:

color_maps = [
    cmap_Greens(np.linspace(0, 1, NVARS)), 
    cmap_Oranges(np.linspace(0, 1, NVARS)), 
    cmap_Blues(np.linspace(0, 1, NVARS)), 
]

grey_map = cmap_Greys(np.linspace(0, 1, NVARS))

# ylim = [
#     (2e-19, 2e-12), # hc
#     (6.5e5, 2e11), # mass in Msun
#     (3e1, 5e3)
# ]


ylabels = ['$h_c$', '$M\ [\mathrm{M}_\odot]$', '$d_c\ [\mathrm{Mpc}]$',]
names = ['gsmf', 'mmb', 'hard']

for tt, targets in enumerate(
    [['gsmf_phi0', 'gsmf_mchar0_log10'],
     ['mmb_mamp_log10', 'mmb_scatter_dex'],
     ['hard_time', 'hard_gamma_inner']]):


    fig, axs = plot.figax_single(ncols=2, nrows=3, sharex=True, height=5.5)
    fig.text(0.5,0.06, plot.LABEL_GW_FREQUENCY_YR, ha='center')
    for ii, ax in enumerate(axs[:,0]):
        ax.set_ylabel(ylabels[ii])
    for cc, target in enumerate(targets):
        xx, yy_ss, yy_bg, labels = load_hcpar_arrays(targets[cc], nloudest=NLOUDEST)
        print(labels)
        for rr, ax in enumerate(axs[:,cc]):
            handles = draw_tris(ax, xx, yy_ss, yy_bg, ii=rr, colors=color_maps[tt], bgcolors=grey_map)
        
            ## set y ticks
            y_major = mpl.ticker.LogLocator(base = 10.0, numticks = 5)
            ax.yaxis.set_major_locator(y_major)
            y_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1, 10.0) * 0.1, numticks = 10)
            ax.yaxis.set_minor_locator(y_minor)
            ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

            ax.tick_params(axis='y', which='both', right=True, left=True, direction='in')
            ax.tick_params(axis='x', which='both', top=True, direction='in')

            # ax.set_ylim(ylim[rr])
            if cc>0:
                ylim0 = axs[rr,0].get_ylim()
                ylim1 = axs[rr,1].get_ylim()
                ylim = (np.min([ylim0[0], ylim1[0]]), np.max([ylim0[1], ylim1[1]]))
                axs[rr,0].set_ylim(ylim)
                ax.set_ylim(ylim)
                ax.set_yticklabels([])


        # set legend
        labels = [labels[0], labels[2], labels[4]]
        axs[0,cc].legend(
             handles=handles, labels=labels, 
             bbox_to_anchor=(0.5,1.02), bbox_transform=axs[0,cc].transAxes, 
                loc='lower center', ncols=3, title=plot.PARAM_KEYS[target],
                borderpad=0.1, title_fontsize=12, labelspacing=0.25, handlelength=1.5,   
                handletextpad=0.3, columnspacing=0.5, frameon=False )
                # ax.sharey(axs[rr,0])
    plt.subplots_adjust(hspace=0, wspace=0)

    save_path = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/new_spectra'
    save_name = save_path+f'/trilims_greybg_{names[tt]}.png'
    fig.savefig(save_name, dpi=300, bbox_inches='tight')