In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
from tqdm import tqdm
import os


from holodeck import plot, detstats, utils
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo
from holodeck.sams import sam

import hasasia.sim as hsim

In [None]:
RECONSTRUCT_FLAG=True

In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

NSTEPS = 20

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

PARAM_NAMES = ['gsmf_phi0', 'gsmf_mchar0_log10',
               'mmb_mamp_log10', 'mmb_scatter_dex', 
               'hard_time', 'hard_gamma_inner']

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    if path == '/Users/emigardiner/GWs/holodeck/output/anatomy_09B':
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz' 
    else:
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

# Let's just copy Luke's notebook

### function to construct evolution data

In [None]:
def construct_evolution(target_param, params, nsteps, 
    pspace = holo.param_spaces.PS_Uniform_09B(holo.log, nsamples=1, sam_shape=None, seed=None),
):
    """ Construct evolution data.

    saves
    -----
    taus : [nparams][nmasses] array of [nsteps] 1Darrays
        hardening timescales
    target_param_list : [nparams] 
        each of the target's values


    """
    m1_range = [1e7*MSOL, 1e8*MSOL]
    m2_range = [1e8*MSOL, 1e9*MSOL]
    m3_range = [1e9*MSOL, 1e10*MSOL]
    m4_range = [1e10*MSOL, 1e11*MSOL]
    mtot_ranges = [m1_range, m2_range, m3_range, m4_range]
    mrat_range = [0.2, 1.0]
    redz_range = [0, np.inf]


    # Whatever param we're varying
    target_param_list =  []
    # range of binary separations to plot
    sepa = np.logspace(-3, 3, nsteps)[::-1] * PC

    # hcss = []
    # hcbg = []
    taus = [] # shape [nvars][nmasses]

    # Iterate over target lifetimes
    for tt in tqdm(range(len(params))):

        # using my parameters from above
        _params = params[tt] # midpoints
        target_param_list.append(_params[target_param])

        sam, hard = pspace.model_for_params(_params)
        print(sam._gmt_time)

        # calculate hc_bg and hc_ss at bin centers, between the given bin edges
        # _hcss_step, _hcbg_step, = sam.gwb(fobs_gw_edges, hard, 
        #                             loudest = NLOUDEST, realize=NREALS)
        # print("got strains!")
        # print(sam._gmt)
        # hcss.append(_hcss_step)
        # hcbg.append(_hcbg_step)

        # _hcss.append(data[tt]['hc_ss'])
        # _hcbg.append(data[tt]['hc_bg'])

        # call static_binary_density property to setup variables
        _ = sam.static_binary_density 
        print(sam._gmt_time)


        # calculate binary properties at target separations
        _edges, _dnum, _redz_final, _details = sam._dynamic_binary_number_at_sepa_consistent(
            hard, sepa, details=True) # it would be better if I saved these details when I first calculated them!
        
        # select the bins with target binary parameters
        _tau = [] # Shape [nmassses,]
        for mtot_range in mtot_ranges:
            sel_mtot  = (mtot_range[0] < sam.mtot) & (sam.mtot <= mto_range[1])
            sel_mrat = (mrat_range[0] < sam.mrat) & (sam.mrat <= mrat_range[1])
            sel_redz = (redz_range[0] < sam.redz) & (sam.redz <= redz_range[1])
            sel = (
                sel_mtot[:, np.newaxis, np.newaxis] *
                sel_mrat[np.newaxis, :, np.newaxis] * 
                sel_redz[np.newaxis, np.newaxis, :]
            )
            _tau.append(_details['tau'][sel].T)
        taus.append(_tau)
        
    # save results
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09B/'
    filename = fileloc+f'evol_{target_param}.npz' 
    print(f"{filename=}")
    np.savez(filename, taus=taus, target_param_list=target_param_list,
            nsteps=nsteps, sepa=sepa,
            mtot_ranges = mtot_ranges)


# hard_gamma_inner

In [None]:
target='hard_gamma_inner'
 
data, params = get_data(target)
data = [data[0], data[10], data[20]]
params = [params[0], params[10], params[20]]
fobs_gw_cents = data[0]['fobs_cents']
_, fobs_gw_edges = holo.utils.pta_freqs()

In [None]:
if RECONSTRUCT_FLAG:
    construct_evolution(target_param=target, params=params, nsteps=NSTEPS)

## load hard time results

    np.savez(filename, taus=taus, target_param_list=target_param_list,
            nsteps=nsteps, sepa=sepa,
            mtot_ranges = mtot_ranges)

In [None]:
def load_evolution(target)
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09B/'
    filename = fileloc+f'evol_{target_param}.npz'

In [None]:
target_param='hard_time'
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape'
filename = fileloc+'/evol_%s_%dsteps.npz' % (target_param, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

# Plot current function

In [None]:
def plot_current():
    fig, axs = plot.figax_double(height=7, nrows=2,  ncols=2, hspace=0.35, bottom=0.1)

    xx = sepa/PC
    YR_LABEL_PAD = -4
    colors = ['tab:green', 'tab:blue', 'tab:orange']

    # ------------------------   Ax Row 0   ----------------------------
    ax = axs[0,0]
    ax1 = axs[0,1]

    ax.set_title(f'Mass Range: {mtot_range/MSOL}')
    ax1.set_title(f"Mass Range: {mtot_hirng/MSOL}")
    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_SEPARATION_PC, ylabel=plot.LABEL_HARDENING_TIME, xscale='log', yscale='log')
    axis.invert_xaxis()

        # axis.axhline(times_list[1], color='k', alpha=0.65)
        # axis.axhline(times_list[0], color='k', ls='--', alpha=0.25)


    labels = []
    handles = []
    for ii, tau in enumerate(taus):
        print(ii)
        yy = tau / GYR
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=True, color=colors[ii])
        # colors.append(hh[0].get_color())
        handles.append(hh[0])
        labels.append(f"${target_param_list[ii]:.1f}$")

        y1 = taus_high[ii]/GYR   
        plot.draw_med_conf_color(ax1, xx, y1, fracs=[0.5], filter=True, color=colors[ii])



    leg = ax.legend(handles, labels, loc='lower left', 
                    ncol=len(handles), title=target_param, title_fontsize=14)

    # ----------------------------- Ax Row 1 --------------------------------

    ax = axs[1,0]
    ax1 = axs[1,1]

    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_GW_FREQUENCY_NHZ, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN, xscale='log', yscale='log')

    xx = fobs_gw_cents*1e9 # nHz

    labels=[]
    handles=[]
    # colors = []

    for ii, yy in enumerate(hcbg):
        # yy = np.median(yy, axis=-1)
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=False, color=colors[ii])
        # colors.append(hh[0].get_color())
        ss = hcss[ii]
        for rr in range(len(ss[0])):
            ax.scatter(xx, ss[:,rr,0], color = colors[ii], alpha=0.5, s=5) # only single include loudest of each realization

    return fig

## Plot Results

In [None]:
fig = plot_current()

# gamma_inner

In [None]:
target='hard_gamma_inner'
nsteps=20
 
data, params = get_data(target)
fobs_gw_cents = data[0]['fobs_cents']
_, fobs_gw_edges = holo.utils.pta_freqs()

if RECONSTRUCT_FLAG:
    construct_evolution(target_param=target, params=params, NSTEPS=nsteps)

## load gamma_inner results

In [None]:
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09B/'
filename = fileloc+'evol_%s_%dsteps.npz' % (target, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
def plot_current():
    fig, ax = plot.figax_single(ncols=2, sharey=True)

    xx = sepa/PC
    x1_m1 = utils.kepler_freq_from_sepa(mass, sepa)
    YR_LABEL_PAD = -4
    colors = ['tab:green', 'tab:blue', 'tab:orange']

    # ------------------------   Ax Row 0   ----------------------------

    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_SEPARATION_PC, ylabel=plot.LABEL_HARDENING_TIME, xscale='log', yscale='log')
    axis.invert_xaxis()

        # axis.axhline(times_list[1], color='k', alpha=0.65)
        # axis.axhline(times_list[0], color='k', ls='--', alpha=0.25)


    labels = []
    handles = []
    for ii, tau in enumerate(taus):
        print(ii)
        yy = tau / GYR
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=True, color=colors[ii])
        # colors.append(hh[0].get_color())
        handles.append(hh[0])
        labels.append(f"${target_param_list[ii]:.1f}$")

        y1 = taus_high[ii]/GYR   
        plot.draw_med_conf_color(ax1, xx, y1, fracs=[0.5], filter=True, color=colors[ii])



    leg = ax.legend(handles, labels, loc='lower left', 
                    ncol=len(handles), title=target_param, title_fontsize=14)

    # ----------------------------- Ax Row 1 --------------------------------

    ax = axs[1,0]
    ax1 = axs[1,1]

    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_GW_FREQUENCY_NHZ, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN, xscale='log', yscale='log')

    xx = fobs_gw_cents*1e9 # nHz

    labels=[]
    handles=[]
    # colors = []

    for ii, yy in enumerate(hcbg):
        # yy = np.median(yy, axis=-1)
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=False, color=colors[ii])
        # colors.append(hh[0].get_color())
        ss = hcss[ii]
        for rr in range(len(ss[0])):
            ax.scatter(xx, ss[:,rr,0], color = colors[ii], alpha=0.5, s=5) # only single include loudest of each realization

    return fig

In [None]:
fig = plot_current()

# mmb_mamp_log10

In [None]:
target_param = 'mmb_mamp_log10'
nsteps = 20
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param,
                        allow_pickle=True)       
params = npz['params']
npz.close()
if RECONSTRUCT_FLAG:
    construct_evolution(target_param=target_param, params=params, NSTEPS=nsteps)

## load mmb_mamp_log10 results

In [None]:
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape'
filename = fileloc+'/evol_%s_%dsteps.npz' % (target_param, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
fig = plot_current()

# mmb_scatter_dex

In [None]:
target_param = 'mmb_scatter_dex'
nsteps=20
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param, allow_pickle=True)       
params = npz['params']
npz.close()
if RECONSTRUCT_FLAG:
    construct_evolution(target_param, params, nsteps)

## load mmb_scatter_dex results

In [None]:
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape'
filename = fileloc+'/evol_%s_%dsteps.npz' % (target_param, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
fig = plot_current()

# gsmf_phi0

In [None]:
target_param = 'gsmf_phi0'
nsteps = 20
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param,
                        allow_pickle=True)       
params = npz['params']
npz.close()
if RECONSTRUCT_FLAG:
    construct_evolution(target_param=target_param, params=params, NSTEPS=nsteps)

## load gsmf_phi0 results

In [None]:
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape'
filename = fileloc+'/evol_%s_%dsteps.npz' % (target_param, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
fig = plot_current()

# gsmf_mchar0_log10

In [None]:
target_param = 'gsmf_mchar0_log10'
nsteps = 20
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param, allow_pickle=True)       
params = npz['params']
npz.close()
if RECONSTRUCT_FLAG:
    construct_evolution(target_param=target_param, params=params, NSTEPS=nsteps)

## load gsmf_phi0 results

In [None]:
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape'
filename = fileloc+'/evol_%s_%dsteps.npz' % (target_param, nsteps)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
fig = plot_current()