In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
from tqdm import tqdm

from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 1
NVARS = 5

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5


# NVARS = 6

NPSRS = 40
NSKIES = 500
# NSKIES = 15

PARAM_NAMES = [
    'hard_time', 'gsmf_phi0', 'gsmf_mchar0_log10',
    'mmb_mamp_log10', 'mmb_scatter_dex', 'hard_gamma_inner'
]
PARVARS = [0,1,2,3,4]

### truncate colormaps

In [None]:
cmap_Blues = plot.truncate_colormap('Blues', 0.4, 1)
cmap_Greens = plot.truncate_colormap('Greens', 0.4, 1)
cmap_Oranges = plot.truncate_colormap('Oranges', 0.4, 1)
cmap_Greys = plot.truncate_colormap('Greys', 0.4, 0.9)

# Vary Parameter


In [None]:
# def get_data(
#         target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, nloudest=NLOUDEST,
#         red_gamma = None, red2white=None,
#     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
# ):

#     load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

#     if os.path.exists(load_data_from_file) is False:
#         err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
#         raise Exception(err)
#     # if os.path.exists(load_dets_from_file) is False:
#     #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
#     #     raise Exception(err)
#     file = np.load(load_data_from_file, allow_pickle=True)
#     data = file['data']
#     params = file['params']
#     file.close()

#     # file = np.load(load_dets_from_file, allow_pickle=True)
#     # dsdat = file['dsdat']
#     # file.close()

#     return data, params

In [None]:
parvars = [0,2,4]
nvars = 5
nreals = 500
nloudest=1


color_maps = [
    cmap_Greens(np.linspace(0, 1, len(parvars))), 
    cmap_Oranges(np.linspace(0, 1, len(parvars))), 
    cmap_Blues(np.linspace(0, 1, len(parvars))), 
]

grey_map = cmap_Greys(np.linspace(0, 1, nvars))
bgcolors=cmap_Greys(np.linspace(0,1,nvars))

# ylim = [
#     (2e-19, 2e-12), # hc
#     (6.5e5, 2e11), # mass in Msun
#     (3e1, 5e3)
# ]


ylabels = ['$h_c$', '$M\ [\mathrm{M}_\odot]$', '$d_c\ [\mathrm{Mpc}]$',]
names = ['gsmf', 'mmb', 'hard']

for tt, targets in enumerate(
    [['gsmf_phi0', 'gsmf_mchar0_log10'],
     ['mmb_mamp_log10', 'mmb_scatter_dex'],
     ['hard_time', 'hard_gamma_inner']]):

    colors=color_maps[tt]


    fig, axs = plot.figax_single(ncols=2, nrows=3, sharex=True, height=5.5)
    fig.text(0.5,0.06, plot.LABEL_GW_FREQUENCY_YR, ha='center')
    for ii, ax in enumerate(axs[:,0]):
        ax.set_ylabel(ylabels[ii])
    for cc, target in enumerate(targets):
        xx, yy_ss, yy_bg, labels = detstats.get_hcpar_arrays(
            targets[cc], nloudest=nloudest, parvars=[0,2,4],
        )# nvars=nvars, nreals=nreals)
        print(labels)
        print(yy_ss.shape)
        for rr, ax in enumerate(axs[:,cc]):
            # Plot the background median 
            handles=[]
            for vv, var in enumerate(parvars): 
                hh, = ax.plot(xx, np.median(yy_bg[vv][rr], axis=-1), color=bgcolors[vv], lw=1, linestyle='--')
                # for aa, nn in enumerate(idx):
                #     ax.plot(xx, yy_bg[rr][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
                handles.append(mpl.lines.Line2D([0], [0], color=colors[vv]))

            # Plot the single sources extrema
            for vv, var in enumerate(parvars):
                if rr == 2:
                    # ymin = np.min(yy_ss[vv][rr], axis=-1)
                    ymin = np.percentile(yy_ss[vv][rr], 50-95/2, axis=-1)
                    ax.scatter(xx, ymin, color=colors[vv], alpha=0.75, 
                            marker='2', s=10, linestyle='')
                    # pass
                else:
                    # ymax = np.max(yy_ss[vv][rr], axis=-1)
                    ymax = np.percentile(yy_ss[vv][rr], 50+95/2, axis=-1)
                    ax.scatter(xx, ymax, color=colors[vv], alpha=0.75, 
                            marker='1', s=10, linestyle='')
                    # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)

            # Plot the single sources cis           
            for vv, var in enumerate(parvars):
                # Plot the loudest single sources confidence intervals
                for pp in [68,]:
                    percs = pp / 2
                    percs = [50 - percs, 50 + percs]
                    ax.fill_between(xx, *np.percentile(yy_ss[vv][rr], percs, axis=-1), alpha=0.25, color=colors[vv])


            ## set y ticks
            y_major = mpl.ticker.LogLocator(base = 10.0, numticks = 5)
            ax.yaxis.set_major_locator(y_major)
            y_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1, 10.0) * 0.1, numticks = 10)
            ax.yaxis.set_minor_locator(y_minor)
            ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

            ax.tick_params(axis='y', which='both', right=True, left=True, direction='in')
            ax.tick_params(axis='x', which='both', top=True, direction='in')

            # ax.set_ylim(ylim[rr])
            if cc>0:
                ylim0 = axs[rr,0].get_ylim()
                ylim1 = axs[rr,1].get_ylim()
                ylim = (np.min([ylim0[0], ylim1[0]]), np.max([ylim0[1], ylim1[1]]))
                axs[rr,0].set_ylim(ylim)
                ax.set_ylim(ylim)
                ax.set_yticklabels([])


        # set legend
        labels = [labels[0], labels[1], labels[2]]
        axs[0,cc].legend(
             handles=handles, labels=labels, 
             bbox_to_anchor=(0.5,1.02), bbox_transform=axs[0,cc].transAxes, 
                loc='lower center', ncols=3, title=plot.PARAM_KEYS[target],
                borderpad=0.1, title_fontsize=12, labelspacing=0.25, handlelength=1.5,   
                handletextpad=0.3, columnspacing=0.5, frameon=False )
                # ax.sharey(axs[rr,0])
    plt.subplots_adjust(hspace=0, wspace=0)

    save_path = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/new_spectra'
    save_name = save_path+f'/uplims_greybg_l{NLOUDEST}_{names[tt]}.png'
    fig.savefig(save_name, dpi=300, bbox_inches='tight')

In [None]:
if BUILD_ARRAYS:
    for target in tqdm(PARAM_NAMES):
        detstats.build_hcpar_arrays(target, nvars=NVARS, nreals=NREALS, nloudest=NLOUDEST,
        
        parvars=PARVARS, gw_only=False,)

# Draw Functions

### draw_95ci()

In [None]:
def draw_95ci(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=PARVARS):

    if bgcolors is None:
        bgcolors = colors
    handles=[]

    # plot the background
    for vv, var in enumerate(parvars): 
        if var != 2 and var != 4:
            hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
            # for aa, nn in enumerate(idx):
            #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
            handles.append(hh)

    # plot the single sources
    for vv, var in enumerate(parvars):
        if var != 2 and var != 4:
        # Plot the loudest single sources confidence intervals
            for pp in [95,]:
                percs = pp / 2
                percs = [50 - percs, 50 + percs]
                ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.15, color=colors[var])
            
    return handles
        


### draw_lims()

In [None]:
def draw_lims(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=PARVARS):

    if bgcolors is None:
        bgcolors = colors 

    # Plot the background median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(mpl.lines.Line2D([0], [0], color=colors[var]))

    # Plot the single sources
    for vv, var in enumerate(parvars):
        # for aa, nn in enumerate(idx):
            # for ll in range(3):
            # edgecolor = 'k' if ll==0 else None
        ymed = np.median(yy_ss[vv][ii], axis=-1)
        ymax = np.max(yy_ss[vv][ii], axis=-1)
        ymin = np.min(yy_ss[vv][ii], axis=-1)
        if ii == 2:
            ax.errorbar(xx, ymin, yerr=(ymin-ymin, ymed-ymin), color=colors[var], alpha=0.15, 
                    capsize=0.5, lolims=True, marker='o', markersize=1, linestyle='')
        else:
            ax.errorbar(xx, ymax, yerr=(ymax-ymed, ymax-ymax), color=colors[var], alpha=0.15, 
                    capsize=0.5, uplims=True, marker='o', markersize=1, linestyle='')
            # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)
            
    for vv, var in enumerate(parvars):
        # Plot the loudest single sources confidence intervals
        for pp in [68,]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.15, color=colors[var])
    
    return handles 

### draw_tris()

In [None]:
def draw_tris(ax, xx, yy_ss, yy_bg, ii, colors, bgcolors=None,
                parvars=PARVARS):
    """ Plot mins/maxes with tri symbols and 68% ci
    """
    if bgcolors is None:
        bgcolors = colors


    # Plot the background median 
    handles=[]
    for vv, var in enumerate(parvars): 
        if var != 2 and var != 4:

            hh, = ax.plot(xx, np.median(yy_bg[var][ii], axis=-1), color=bgcolors[var], lw=1, linestyle='--')
            # for aa, nn in enumerate(idx):
            #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
            handles.append(mpl.lines.Line2D([0], [0], color=colors[var]))

    # Plot the single sources extrema
    for vv, var in enumerate(parvars):
        if var != 2 and var != 4:

            if ii == 2:
                ymin = np.min(yy_ss[var][ii], axis=-1)
                ax.scatter(xx, ymin, color=colors[var], alpha=0.75, 
                        marker='2', s=10, linestyle='')
            else:
                ymax = np.max(yy_ss[var][ii], axis=-1)
                ax.scatter(xx, ymax, color=colors[var], alpha=0.75, 
                        marker='1', s=10, linestyle='')
                # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)

    # Plot the single sources cis           
    for vv, var in enumerate(parvars):
        if var != 2 and var != 4:

            # Plot the loudest single sources confidence intervals
            for pp in [68,]:
                percs = pp / 2
                percs = [50 - percs, 50 + percs]
                ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])
    
    return handles 

# 3 Separate Singles

In [None]:

color_maps = [
    cmap_Greens(np.linspace(0, 1, NVARS)), 
    cmap_Oranges(np.linspace(0, 1, NVARS)), 
    cmap_Blues(np.linspace(0, 1, NVARS)), 
]

grey_map = cmap_Greys(np.linspace(0, 1, NVARS))

# ylim = [
#     (2e-19, 2e-12), # hc
#     (6.5e5, 2e11), # mass in Msun
#     (3e1, 5e3)
# ]


ylabels = ['$h_c$', '$M\ [\mathrm{M}_\odot]$', '$d_c\ [\mathrm{Mpc}]$',]
names = ['gsmf', 'mmb', 'hard']

for tt, targets in enumerate(
    [['gsmf_phi0', 'gsmf_mchar0_log10'],
     ['mmb_mamp_log10', 'mmb_scatter_dex'],
     ['hard_time', 'hard_gamma_inner']]):


    fig, axs = plot.figax_single(ncols=2, nrows=3, sharex=True, height=5.5)
    fig.text(0.5,0.06, plot.LABEL_GW_FREQUENCY_YR, ha='center')
    for ii, ax in enumerate(axs[:,0]):
        ax.set_ylabel(ylabels[ii])
    for cc, target in enumerate(targets):
        xx, yy_ss, yy_bg, labels = detstats.get_hcpar_arrays(
            targets[cc], nloudest=NLOUDEST, parvars=PARVARS)
        print(labels)
        print(yy_ss.shape)
        for rr, ax in enumerate(axs[:,cc]):
            handles = draw_tris(ax, xx, yy_ss, yy_bg, ii=rr, colors=color_maps[tt], bgcolors=grey_map)
        
            ## set y ticks
            y_major = mpl.ticker.LogLocator(base = 10.0, numticks = 5)
            ax.yaxis.set_major_locator(y_major)
            y_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1, 10.0) * 0.1, numticks = 10)
            ax.yaxis.set_minor_locator(y_minor)
            ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

            ax.tick_params(axis='y', which='both', right=True, left=True, direction='in')
            ax.tick_params(axis='x', which='both', top=True, direction='in')

            # ax.set_ylim(ylim[rr])
            if cc>0:
                ylim0 = axs[rr,0].get_ylim()
                ylim1 = axs[rr,1].get_ylim()
                ylim = (np.min([ylim0[0], ylim1[0]]), np.max([ylim0[1], ylim1[1]]))
                axs[rr,0].set_ylim(ylim)
                ax.set_ylim(ylim)
                ax.set_yticklabels([])


        # set legend
        labels = [labels[0], labels[2], labels[4]]
        axs[0,cc].legend(
             handles=handles, labels=labels, 
             bbox_to_anchor=(0.5,1.02), bbox_transform=axs[0,cc].transAxes, 
                loc='lower center', ncols=3, title=plot.PARAM_KEYS[target],
                borderpad=0.1, title_fontsize=12, labelspacing=0.25, handlelength=1.5,   
                handletextpad=0.3, columnspacing=0.5, frameon=False )
                # ax.sharey(axs[rr,0])
    plt.subplots_adjust(hspace=0, wspace=0)

    save_path = '/Users/emigardiner/GWs/holodeck/output/figures/bigplots/new_spectra'
    save_name = save_path+f'/trilims_greybg_l{NLOUDEST}_{names[tt]}.png'
    # fig.savefig(save_name, dpi=300, bbox_inches='tight')

# l5

In [None]:
target = 'hard_gamma_inner'
parvars=[0,1,2]
nloudest=1
nreals=100
nvars=3
xx, yy_ss, yy_bg, labels = detstats.build_hcpar_arrays(target, nvars=nvars, nreals=nreals, nloudest=nloudest,
        parvars=parvars, gw_only=False,)

In [None]:
which_loudest=0

xlabel=plot.LABEL_GW_FREQUENCY_YR
ylabels = ['$h_c$', '$M\ [\mathrm{M}_\odot]$', '$d_c\ [\mathrm{Mpc}]$',]
colors=cmap_Blues(np.linspace(0,1,nvars)) 
bgcolors=cmap_Greys(np.linspace(0,1,nvars))

fig, axs = plot.figax(ncols=3, nrows=3, figsize=(6,7), )

for ii, ax in enumerate(axs[:,0]):
    ax.set_ylabel(ylabels[ii])
axs[2,1].set_xlabel(xlabel)



for ll, nloudest in enumerate([1,5,10]):
    axs_col = axs[:,ll]
    axs_col[0].set_title(f"L={nloudest}, {which_loudest}th source", fontsize=12)

    parvars=[0,1,2] if nloudest>1 else [0,2,4]
    nvars=3 if nloudest>1 else 5
    nreals=100 if nloudest>1 else 500

    xx, yy_ss, yy_bg, labels = detstats.build_hcpar_arrays(
            target, nloudest=nloudest, nreals=nreals, parvars=parvars, nvars=nvars,
            ss_zero=False)
    print(f"{yy_ss.shape=}")
    yy_ss=yy_ss[...,which_loudest]
    print(labels)
    print(yy_ss.shape)
    for rr, ax in enumerate(axs_col):

        # Plot the background median 
        handles=[]
        for vv, var in enumerate(parvars): 
            hh, = ax.plot(xx, np.median(yy_bg[vv][rr], axis=-1), color=bgcolors[vv], lw=1, linestyle='--')
            # for aa, nn in enumerate(idx):
            #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
            handles.append(mpl.lines.Line2D([0], [0], color=colors[vv]))

        # Plot the single sources extrema
        for vv, var in enumerate(parvars):
            if ii == 2:
                ymin = np.min(yy_ss[vv][rr], axis=-1)
                ax.scatter(xx, ymin, color=colors[vv], alpha=0.75, 
                        marker='2', s=10, linestyle='')
            else:
                ymax = np.max(yy_ss[vv][rr], axis=-1)
                ax.scatter(xx, ymax, color=colors[vv], alpha=0.75, 
                        marker='1', s=10, linestyle='')
                # ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.5, s=20)

        # Plot the single sources cis           
        for vv, var in enumerate(parvars):
            # Plot the loudest single sources confidence intervals
            for pp in [68,]:
                percs = pp / 2
                percs = [50 - percs, 50 + percs]
                ax.fill_between(xx, *np.percentile(yy_ss[vv][rr], percs, axis=-1), alpha=0.25, color=colors[vv])

        if ll>0:
            ax.set_yticklabels([])

    # set legend
    labels = [labels[0], labels[1], labels[2]]
    plt.subplots_adjust(hspace=0, wspace=0)