In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os


In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_09B'     
):

    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz' 
    # load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}_ds' 
    # if red_gamma is not None and red2white is not None:
    #     load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    # load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data'][int(NVARS/2)]
    params = file['params'][int(NVARS/2)]
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

In [None]:
data, params = get_data('hard_time')

In [None]:
hc_ss = data['hc_ss']
hc_bg = data['hc_bg']
sspar = data['sspar']
bgpar = data['bgpar']
fobs_cents = data['fobs_cents']

In [None]:
edgecolors = np.repeat(None, NLOUDEST)
edgecolors[0] = 'k'
print(edgecolors)

In [None]:
fig, ax = plot.figax_single(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN)


nsamp = 5    # number of sample GWB spectra to plot

xx = fobs_cents * YR
xx_ss = np.repeat(xx, NLOUDEST).reshape(NFREQS, NLOUDEST)

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0)

# Plot the median GWB spectrum
ax.plot(xx, np.median(hc_bg, axis=-1), 'k-')



# Plot `nsamp` random spectra 
seed = 67233 # more spread out
# seed = 98068 # includes random super high
# seed = np.random.randint(99999)   # get a random number
print(seed)                       # print it out so we can reuse it if desired
np.random.seed(seed)              # set the random seed

nsamp = np.min([nsamp, NREALS])

colors = [
    '#6a3d9a', # purple
    '#ff7f0f', # orange
    '#15becf', # teal
    '#f0027f', # pink
    '#a6d853', # green
]

# select random realizations to plot
idx = np.random.choice(NREALS, nsamp, replace=False)
for aa, ii in enumerate(idx):

    # edgecolors = np.repeat(colors[aa], NLOUDEST).reshape(4, NLOUDEST) # idk why this isnt working
    # edgecolors = np.swapaxes(edgecolors, 0,1)
    # edgecolors[0,:] = np.array([1, 1, 1, 1])

    for ll in range(5):
        edgecolor = 'k' if ll==0 else None
        ax.scatter(xx, hc_ss[:,ii,ll], color=colors[aa], alpha=0.3, edgecolor=edgecolor)

for aa, ii in enumerate(idx):
    ax.plot(xx, hc_bg[:,ii], linestyle='-', alpha=0.75, color=colors[aa])  

fig.tight_layout()
# fig.savefig(f'/Users/emigardiner/GWs/holodeck/output/figures/bigplots/hc_midvars_{seed}.png', dpi=100)

# plot contours at 50% and 98% confidence intervals
for pp in [50, 98]:
    percs = pp / 2
    percs = [50 - percs, 50 + percs]
    ax.fill_between(xx, *np.percentile(hc_bg, percs, axis=-1), alpha=0.25, color='k')
    
# plt.show()

# all parameters

In [None]:
print(sings.par_labels)

bgpar = data['bgpar']
sspar = data['sspar']
sspar = sings.all_sspars(fobs_cents, sspar)

bgpar = bgpar*sings.par_units[:,np.newaxis,np.newaxis]
sspar = sspar*sings.par_units[:,np.newaxis,np.newaxis,np.newaxis]

In [None]:
fig, axs = plot.figax_double(nrows=2, ncols=3, sharex=True)

xx = fobs_cents*YR
nsamp = 5

# Plot `nsamp` random spectra 
seed = 67233 # more spread out
# seed = np.random.randint(99999)   # get a random number
print(seed)                       # print it out so we can reuse it if desired
np.random.seed(seed)              # set the random seed

# select random realizations to plot
idx = np.random.choice(NREALS, nsamp, replace=False)

# parameters to plot
yy_ss = [hc_ss, sspar[0], sspar[1], # sspar[2,],  # strain, mass, mass ratio,
         sspar[4], sspar[5], sspar[6]] # final comoving distance, final separation, final angular separation
yy_bg = [hc_bg, bgpar[0], bgpar[1],  # strain, mass, mass ratio, initial redshift
         bgpar[4], bgpar[5], bgpar[6]]
ylabels = np.append([plot.LABEL_CHARACTERISTIC_STRAIN,], sings.par_labels)
ylabels = ['Characteristic Strain, $h_c$', 'Total Mass, $M_\mathrm{tot}$ [M$_\odot$]', 'Mass Ratio, $q$', 
           'Comoving Distance, $d_\mathrm{com}$ [Mpc]', 'Separation [pc]', 'Separation [rad]']
for ii, ax in enumerate(axs.flatten()):
    # Plot the median   
    ax.plot(xx, np.median(yy_bg[ii], axis=-1), 'k-')

    # Plot the confidence intervals
    for pp in [50, 95]:
        percs = pp / 2
        percs = [50 - percs, 50 + percs]
        ax.fill_between(xx, *np.percentile(yy_bg[ii], percs, axis=-1), alpha=0.25, color='k')

    # label axes
    if ii>=3:
        ax.set_xlabel(plot.LABEL_GW_FREQUENCY_YR)
    ax.set_ylabel(ylabels[ii])
    
fig.tight_layout()

In [None]:
# print(holo.utils.stats(sspar[3]))

In [None]:
xx = np.linspace(0,20,100)
xx[5] = np.nan

print(np.sum(xx))