In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import os

from holodeck import plot, detstats, utils
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10

CONSTRUCT = False
JUST_DETSTATS = False
SAVEFIG = True
TOL=0.01
MAXBADS=5

RED_GAMMA = None
RED_AMP = None

NVARS = 21

NPSRS = 40
NSKIES = 100
TARGET = 'mmb_scatter_dex' # EDIT AS NEEDED
TITLE = plot.PARAM_KEYS[TARGET]  # EDIT AS NEEDED

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz', ssn='_ssn',     
):
    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 
    load_dets_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/detstats_s{nskies}{ssn}' 

    if red_gamma is not None and red2white is not None:
        load_dets_from_file = load_dets_from_file+f'_r2w{red2white:.1f}_rg{red_gamma:.1f}'
    else:
        load_dets_from_file = load_dets_from_file+f'_white'
    load_dets_from_file = load_dets_from_file+'.npz'

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    if os.path.exists(load_dets_from_file) is False:
        err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()
    print(target, "got data")
    file = np.load(load_dets_from_file, allow_pickle=True)
    print(target, "loaded dets")
    print(file.files)
    dsdat = file['dsdat']
    file.close()

    return data, params, dsdat

In [None]:
data, params, dsdat = get_data('mmb_scatter_dex')

In [None]:
def plot_dp(fobs_cents, dp_ss, dp_bg, gamma_ssi, ax_avg=(0,3)):
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Detection Probability')

    xx = fobs_cents*YR
    y1 = dp_bg[0] # 1
    y2 = gamma_ssi[:,:,:,:].reshape(NFREQS, NREALS*NSKIES*NLOUDEST) # F, R*S*L
    y3 = dp_ss[0] # S
    favg = np.average(
        np.repeat(xx, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST),
        weights=gamma_ssi.reshape(NFREQS, NREALS, NSKIES, NLOUDEST), axis=ax_avg)
    
    label1 = 'BG Detprob'
    label2 = 'Individual SS Detprob'
    label3 = 'Overall SS Detprob'
    label4 = 'dp-weighted $\langle f \\rangle$'


    h1 = ax.axhline(y1)
    h2 = plot.draw_med_conf_color(ax, xx, y2, color='orange')
    for ss in range(NSKIES):
        h3 = ax.axhline(y3[ss], color='tab:red', alpha=0.2)
    h4 = ax.axvline(np.median(favg), color='g')

    ax.legend(handles=[h1,h2,h3,h4], labels=[label1, label2, label3, label4], loc='upper right')
    ax.text(1.0,0., f"{RED_AMP=}, {RED_GAMMA=}", horizontalalignment='right', verticalalignment='bottom',
            transform=ax.transAxes)
    
    return fig

# average over all realizations

In [None]:
for pp in [0, 5, 10, 15, 20]:
    fobs_cents = data[pp]['fobs_cents']
    dp_ss = dsdat[pp]['dp_ss']
    dp_bg = dsdat[pp]['dp_bg']
    gamma_ssi = dsdat[pp]['gamma_ssi']

    fig = plot_dp(fobs_cents, dp_ss, dp_bg, gamma_ssi, ax_avg=(0,1,2,3))
    fig.text(0,0, f"{TARGET}={params[pp][TARGET]}", ha='left', va='bottom')

# average over all skies, median of strains

In [None]:
for pp in [0, 5, 10, 15, 20]:
    fobs_cents = data[pp]['fobs_cents']
    dp_ss = dsdat[pp]['dp_ss']
    dp_bg = dsdat[pp]['dp_bg']
    gamma_ssi = dsdat[pp]['gamma_ssi']

    fig = plot_dp(fobs_cents, dp_ss, dp_bg, gamma_ssi, ax_avg=(0,2,3))
    fig.text(0,0, f"{TARGET}={params[pp][TARGET]}", ha='left', va='bottom')

# average only over freqs and loudest, median of realizations

In [None]:
for pp in [0, 5, 10, 15, 20]:
    fobs_cents = data[pp]['fobs_cents']
    dp_ss = dsdat[pp]['dp_ss']
    dp_bg = dsdat[pp]['dp_bg']
    gamma_ssi = dsdat[pp]['gamma_ssi']

    fig = plot_dp(fobs_cents, dp_ss, dp_bg, gamma_ssi, ax_avg=(0,3))
    fig.text(0,0, f"{TARGET}={params[pp][TARGET]}", ha='left', va='bottom')

In [None]:
for pp in [0, 5, 10, 15, 20]:
    fobs_cents = data[pp]['fobs_cents']
    dp_ss = dsdat[pp]['dp_ss']
    dp_bg = dsdat[pp]['dp_bg']
    gamma_ssi = dsdat[pp]['gamma_ssi']

    fig = plot_dp(fobs_cents, dp_ss, dp_bg, gamma_ssi, ax_avg=(0,1,2,3))
    fig.text(0,0, f"{TARGET}={params[pp][TARGET]}", ha='left', va='bottom')

In [None]:
print(fobs_cents.shape)
print(gamma_ssi.shape)
xx = np.repeat(fobs_cents, NREALS*NSKIES*NLOUDEST)*YR
yy = gamma_ssi.flatten()
favg = np.average(xx.reshape(NFREQS, NREALS, NSKIES, NLOUDEST),
                   weights=yy.reshape(NFREQS, NREALS, NSKIES, NLOUDEST), axis=(0,-1))
# print(f"{fmean:.2e}")

In [None]:
print(favg.shape)

In [None]:
fig, ax = plot.figax(xlabel='freq', ylabel='dpssi')
ax.scatter(xx, yy, s=4, alpha=0.1)
plot.draw_med_conf(ax, fobs_cents*YR, gamma_ssi.reshape(NFREQS, NREALS*NSKIES*NLOUDEST))
ax.axvline(np.median(favg), color='green')

In [None]:
for ii in range(300):
    rr = int(np.random.uniform(0,500))
    ss = int(np.random.uniform(0,100))
    ax.axvline(favg[rr,ss], linestyle='dashed', color='green', alpha=0.5)
fig

In [None]:
_, fobs_edges = holo.utils.pta_freqs()
dp_edges = np.geomspace(10**-6, 10**0, num=20)
hist,  dpe, ffe,  = np.histogram2d(yy.flatten(),xx.flatten(),  bins=(dp_edges, fobs_edges, ))
plt.pcolormesh(fobs_edges, (dp_edges), np.log10(hist),)

In [None]:

weights = np.array([2,3,2,1,0.05,0.01,0,0,0,0])
arr = np.linspace(1,10,len(weights))

In [None]:
weights = (np.repeat(weights, 5) + np.random.uniform(-0.1,0.1,50)).reshape(10,5) # 5 realizations
arr = np.repeat(arr, 5).reshape(10,5)

In [None]:
print(np.average(arr, weights=weights))
print(np.average(arr, weights=weights, axis=0))