In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 3
# NVARS = 6

NPSRS = 40
NSKIES = 25
# NSKIES = 15

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/e migardiner/GWs/holodeck/output/anatomy_09B'     
):
    if path == '/Users/e migardiner/GWs/holodeck/output/anatomy_09':
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz' 
    else:
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}.npz' 
    # load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}_ds' 
    # if red_gamma is not None and red2white is not None:
    #     load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    # load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data'][int(NVARS/2)]
    params = file['params'][int(NVARS/2)]
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

In [None]:
data, params = get_data('hard_time')

In [None]:
hc_ss = data['hc_ss']
hc_bg = data['hc_bg']
sspar = data['sspar']
bgpar = data['bgpar']
fobs_cents = data['fobs_cents']

In [None]:
edgecolors = np.repeat(None, NLOUDEST)
edgecolors[0] = 'k'
print(edgecolors)

In [None]:
fig, ax = plot.figax_single(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN)


nsamp = 5    # number of sample GWB spectra to plot

xx = fobs_cents * YR
xx_ss = np.repeat(xx, NLOUDEST).reshape(NFREQS, NLOUDEST)

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0)

# Plot the median GWB spectrum
ax.plot(xx, np.median(hc_bg, axis=-1), 'k-')



# Plot `nsamp` random spectra 
seed = 67233 # more spread out
# seed = 98068 # includes random super high
# seed = np.random.randint(99999)   # get a random number
print(seed)                       # print it out so we can reuse it if desired
np.random.seed(seed)              # set the random seed

nsamp = np.min([nsamp, NREALS])

colors = [
    '#ff7f0f', # orange
    '#6a3d9a', # purple
    '#f0027f', # pink
    '#a6d853', # green
    '#15becf', # teal
]

# select random realizations to plot
idx = np.random.choice(NREALS, nsamp, replace=False)
for aa, ii in enumerate(idx):

    # edgecolors = np.repeat(colors[aa], NLOUDEST).reshape(4, NLOUDEST) # idk why this isnt working
    # edgecolors = np.swapaxes(edgecolors, 0,1)
    # edgecolors[0,:] = np.array([1, 1, 1, 1])

    for ll in range(5):
        edgecolor = 'k' if ll==0 else None
        ax.scatter(xx, hc_ss[:,ii,ll], color=colors[aa], alpha=0.3, edgecolor=edgecolor,
                   s=20)

for aa, ii in enumerate(idx):
    ax.plot(xx, hc_bg[:,ii], linestyle='-', alpha=0.75, color=colors[aa])  

fig.tight_layout()
# fig.savefig(f'/Users/emigardiner/GWs/holodeck/output/figures/bigplots/hc_midvars_{seed}.png', dpi=100)

# plot contours at 50% and 98% confidence intervals
for pp in [50, 98]:
    percs = pp / 2
    percs = [50 - percs, 50 + percs]
    ax.fill_between(xx, *np.percentile(hc_bg, percs, axis=-1), alpha=0.25, color='k')
    
# plt.show()

# all parameters

In [None]:
print(sings.par_labels)

bgpar = data['bgpar']
sspar = data['sspar']
sspar = sings.all_sspars(fobs_cents, sspar)

bgpar = bgpar*sings.par_units[:,np.newaxis,np.newaxis]
sspar = sspar*sings.par_units[:,np.newaxis,np.newaxis,np.newaxis]

In [None]:
fig, axs = plot.figax_double(nrows=2, ncols=3, sharex=True, height=5)

xx = fobs_cents*YR
nsamp = 5

# Plot `nsamp` random spectra 
seed = 67233 # more spread out
# seed = np.random.randint(99999)   # get a random number
print(seed)                       # print it out so we can reuse it if desired
np.random.seed(seed)              # set the random seed

# select random realizations to plot
idx = np.random.choice(NREALS, nsamp, replace=False)

# parameters to plot
yy_ss = [hc_ss, sspar[0], sspar[1], # sspar[2,],  # strain, mass, mass ratio,
         sspar[4], sspar[5], sspar[6]] # final comoving distance, final separation, final angular separation
yy_bg = [hc_bg, bgpar[0], bgpar[1],  # strain, mass, mass ratio, initial redshift
         bgpar[4], bgpar[5], bgpar[6]]
ylabels = np.append([plot.LABEL_CHARACTERISTIC_STRAIN,], sings.par_labels)
ylabels = ['Characteristic Strain', 'Total Mass [M$_\odot$]', 'Mass Ratio, $q$', 
           'Com. Distance [Mpc]', 'Separation [pc]', 'Separation [rad]']
for ii, ax in enumerate(axs.flatten()):
    # Plot the median   
    ax.plot(xx, np.median(yy_bg[ii], axis=-1), 'k-')
    for aa, nn in enumerate(idx):
        ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  

    for aa, nn in enumerate(idx):
        for ll in range(3):
            edgecolor = 'k' if ll==0 else None
            ax.scatter(xx, yy_ss[ii][:,nn,ll], color=colors[aa], alpha=0.3, edgecolor=edgecolor,
                       s=20)


    # Plot the confidence intervals
    for pp in [50, 95]:
        percs = pp / 2
        percs = [50 - percs, 50 + percs]
        ax.fill_between(xx, *np.percentile(yy_bg[ii], percs, axis=-1), alpha=0.25, color='k')

    # label axes
    if ii>=3:
        ax.set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
    ax.set_ylabel(ylabels[ii])
    
fig.tight_layout()
# fig.savefig('/Users/emigardiner/GWs/holodeck/output/figures/bigplots'
#             +f'/params_midvars_{seed}.png')

# Vary Parameter


### truncate colormaps

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    '''
    https://stackoverflow.com/a/18926541
    '''
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap


cmap_Blues = truncate_colormap('Blues', 0.4, 1)
cmap_PuBuGn = truncate_colormap('PuBuGn', 0.2, 1)
cmap_Greens = truncate_colormap('Greens', 0.4, 1)

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}.npz' 
    # load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}_ds' 
    # if red_gamma is not None and red2white is not None:
    #     load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    # load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

# Hard Time, 3vars

In [None]:
TARGET = 'hard_time'
NVARS = 21
NSKIES = 50
NREALS = 500

parvars = [0,5,10,15,20]
yy_ss = []
yy_bg = []
data, params = get_data(TARGET, nvars=NVARS, nskies=NSKIES, nreals=NREALS,
                        path='/Users/emigardiner/GWs/holodeck/output/anatomy_redz')
for vv, var in enumerate(parvars):
    hc_ss = data[var]['hc_ss']
    hc_bg = data[var]['hc_bg']

    sspar = data[var]['sspar']
    bgpar = data[var]['bgpar']

    sspar = sings.all_sspars(fobs_cents, sspar)
    bgpar = bgpar*sings.par_units[:,np.newaxis,np.newaxis]
    sspar = sspar*sings.par_units[:,np.newaxis,np.newaxis,np.newaxis]


   # parameters to plot
    _yy_ss = [hc_ss[...,0], sspar[0,...,0], #sspar[1,...,0], # sspar[2,],  # strain, mass, mass ratio,
            sspar[4,...,0]] # final comoving distance, single loudest only

    _yy_bg = [hc_bg, bgpar[0], #bgpar[1],  # strain, mass, mass ratio, initial redshift, final com distance
            bgpar[4],]
    yy_ss.append(_yy_ss)
    yy_bg.append(_yy_bg)

In [None]:
ylabels = ['Char. Strain', 'Mass [M$_\odot$]', #'Mass Ratio', 
        'Distance [Mpc]',]

fig, axs = plot.figax_single(nrows=3, ncols=1, sharex=True, height=10)
xx = fobs_cents*YR
nsamp = 5

colors = cmap_Blues(np.linspace(0, 1, NVARS))

for ii, ax in enumerate(axs.flatten()):
    # Plot the median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=colors[var])
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(hh)

    for vv, var in enumerate(parvars):
        # for aa, nn in enumerate(idx):
            # for ll in range(3):
            # edgecolor = 'k' if ll==0 else None
        ymed = np.median(yy_ss[vv][ii], axis=-1)
        ymax = np.max(yy_ss[vv][ii], axis=-1)-ymed
        ymin = ymed - np.min(yy_ss[vv][ii], axis=-1)
        ax.errorbar(xx, ymed, yerr=(ymin, ymax), color=colors[var], alpha=0.5, 
                capsize=3, marker=None)
        ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.8, s=20)


    for vv, var in enumerate(parvars):
        # Plot the confidence intervals
        for pp in [95]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_bg[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])


    # label axes
    # if ii>=3:
    ax.set_ylabel(ylabels[ii])

axs[-1].set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
plt.subplots_adjust(wspace=0, hspace=0)
plot._twin_hz(axs[0], nano=True)

# fig.tight_layout()
labels = []
for vv, var in enumerate(parvars):
    labels.append(f"{params[var][TARGET]:.2f}")

fig.legend(handles=handles, labels=labels, bbox_to_anchor=(0.5,0.02), loc='lower center', ncols=len(parvars), title='$\\tau_\mathrm{hard}$')

In [None]:
ylabels = ['Char. Strain', 'Mass [M$_\odot$]',# 'Mass Ratio', 
        'Distance [Mpc]',]

fig, axs = plot.figax_single(nrows=3, ncols=1, sharex=True, height=10)
xx = fobs_cents*YR
nsamp = 5

colors = cmap_Blues(np.linspace(0, 1, NVARS))
# colors = cmap_PuBuGn(np.linspace(0, 1, NVARS))

for ii, ax in enumerate(axs.flatten()):
    # Plot the median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=colors[var], lw=3)
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(hh)

    # for vv, var in enumerate(parvars):
    #     # for aa, nn in enumerate(idx):
    #         # for ll in range(3):
    #         # edgecolor = 'k' if ll==0 else None
    #     ymed = np.median(yy_ss[vv][ii], axis=-1)
    #     ymax = np.max(yy_ss[vv][ii], axis=-1)-ymed
    #     ymin = ymed - np.min(yy_ss[vv][ii], axis=-1)
    #     ax.errorbar(xx, ymed, yerr=(ymin, ymax), color=colors[var], alpha=0.5, 
    #             capsize=3, marker=None)
    #     ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.8, s=20)


    for vv, var in enumerate(parvars):
        # Plot the loudest single sources confidence intervals
        for pp in [95,]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])


    # label axes
    # if ii>=3:
    ax.set_ylabel(ylabels[ii])

axs[-1].set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
plt.subplots_adjust(wspace=0, hspace=0)
plot._twin_hz(axs[0], nano=True)

# fig.tight_layout()
labels = []
for vv, var in enumerate(parvars):
    labels.append(f"{params[var][TARGET]:.2f}")

fig.legend(handles=handles, labels=labels, bbox_to_anchor=(0.5,0.02), loc='lower center', ncols=len(parvars), title='$\\tau_\mathrm{hard}$')

# GSMF_Phi0, 3vars

In [None]:
TARGET = 'gsmf_phi0'
NVARS = 3
NREALS = 500

parvars = [0,1,2]
yy_ss = []
yy_bg = []
data, params = get_data(TARGET, nvars=NVARS, nreals=NREALS,
                        path='/Users/emigardiner/GWs/holodeck/output/anatomy_redz')
print(data.shape)
for vv, var in enumerate(parvars):
    hc_ss = data[var]['hc_ss']
    hc_bg = data[var]['hc_bg']

    sspar = data[var]['sspar']
    bgpar = data[var]['bgpar']

    sspar = sings.all_sspars(fobs_cents, sspar)
    bgpar = bgpar*sings.par_units[:,np.newaxis,np.newaxis]
    sspar = sspar*sings.par_units[:,np.newaxis,np.newaxis,np.newaxis]


   # parameters to plot
    _yy_ss = [hc_ss[...,0], sspar[0,...,0], #sspar[1,...,0], # sspar[2,],  # strain, mass, mass ratio,
            sspar[4,...,0]] # final comoving distance, single loudest only

    _yy_bg = [hc_bg, bgpar[0], #bgpar[1],  # strain, mass, mass ratio, initial redshift, final com distance
            bgpar[4],]
    yy_ss.append(_yy_ss)
    yy_bg.append(_yy_bg)

In [None]:
ylabels = ['Char. Strain', 'Mass [M$_\odot$]', #'Mass Ratio', 
        'Distance [Mpc]',]

fig, axs = plot.figax_single(nrows=3, ncols=1, sharex=True, height=10)
xx = fobs_cents*YR
nsamp = 5

colors = cmap_Greens(np.linspace(0, 1, NVARS))

for ii, ax in enumerate(axs.flatten()):
    # Plot the median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=colors[var])
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(hh)

    for vv, var in enumerate(parvars):
        # for aa, nn in enumerate(idx):
            # for ll in range(3):
            # edgecolor = 'k' if ll==0 else None
        ymed = np.median(yy_ss[vv][ii], axis=-1)
        ymax = np.max(yy_ss[vv][ii], axis=-1)-ymed
        ymin = ymed - np.min(yy_ss[vv][ii], axis=-1)
        ax.errorbar(xx, ymed, yerr=(ymin, ymax), color=colors[var], alpha=0.5, 
                capsize=3, marker=None)
        ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.8, s=20)


    for vv, var in enumerate(parvars):
        # Plot the confidence intervals
        for pp in [95]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_bg[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])


    # label axes
    # if ii>=3:
    ax.set_ylabel(ylabels[ii])

axs[-1].set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
plt.subplots_adjust(wspace=0, hspace=0)
plot._twin_hz(axs[0], nano=True)

# fig.tight_layout()
labels = []
for vv, var in enumerate(parvars):
    labels.append(f"{params[var][TARGET]:.2f}")

fig.legend(handles=handles, labels=labels, bbox_to_anchor=(0.5,0.02), loc='lower center', ncols=len(parvars), title=plot.PARAM_KEYS[TARGET])

In [None]:
ylabels = ['Char. Strain', 'Mass [M$_\odot$]',# 'Mass Ratio', 
        'Distance [Mpc]',]

fig, axs = plot.figax_single(nrows=3, ncols=1, sharex=True, height=10)
xx = fobs_cents*YR
nsamp = 5

colors = cmap_Greens(np.linspace(0, 1, NVARS))
# colors = cmap_PuBuGn(np.linspace(0, 1, NVARS))

for ii, ax in enumerate(axs.flatten()):
    # Plot the median 
    handles=[]
    for vv, var in enumerate(parvars): 
        hh, = ax.plot(xx, np.median(yy_bg[vv][ii], axis=-1), color=colors[var], lw=3)
        # for aa, nn in enumerate(idx):
        #     ax.plot(xx, yy_bg[ii][:,nn], linestyle='-', alpha=0.75, color=colors[aa])  
        handles.append(hh)

    # for vv, var in enumerate(parvars):
    #     # for aa, nn in enumerate(idx):
    #         # for ll in range(3):
    #         # edgecolor = 'k' if ll==0 else None
    #     ymed = np.median(yy_ss[vv][ii], axis=-1)
    #     ymax = np.max(yy_ss[vv][ii], axis=-1)-ymed
    #     ymin = ymed - np.min(yy_ss[vv][ii], axis=-1)
    #     ax.errorbar(xx, ymed, yerr=(ymin, ymax), color=colors[var], alpha=0.5, 
    #             capsize=3, marker=None)
    #     ax.scatter(xx, ymed, marker='o', color=colors[vv], alpha=0.8, s=20)


    for vv, var in enumerate(parvars):
        # Plot the loudest single sources confidence intervals
        for pp in [95,]:
            percs = pp / 2
            percs = [50 - percs, 50 + percs]
            ax.fill_between(xx, *np.percentile(yy_ss[vv][ii], percs, axis=-1), alpha=0.25, color=colors[var])


    # label axes
    # if ii>=3:
    ax.set_ylabel(ylabels[ii])

axs[-1].set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
plt.subplots_adjust(wspace=0, hspace=0)
plot._twin_hz(axs[0], nano=True)

# fig.tight_layout()
labels = []
for vv, var in enumerate(parvars):
    labels.append(f"{params[var][TARGET]:.2f}")

fig.legend(handles=handles, labels=labels, bbox_to_anchor=(0.5,0.02), loc='lower center', ncols=len(parvars), title=plot.PARAM_KEYS[TARGET])