In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR, PC
import holodeck as holo
from holodeck.sams import sam

import hasasia.sim as hsim

import sys
sys.path.append('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation')
import anatomy as anat

In [None]:
RECONSTRUCT_FLAG = False

# Try again

In [None]:
sam = holo.sams.Semi_Analytic_Model(shape=10)

# Nothing below this makes sense

# Get PSpace Info

In [None]:
# use one file to get the shape
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/hard_time_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
                        allow_pickle=True)             
print(npz.files)
data = npz['data']
params = npz['params']
hard_name = npz['hard_name']
shape = npz['shape']
target_param = npz['target_param']

npz.close()

# get param names
pspace = holo.param_spaces.PS_Uniform_09A(holo.log, nsamples=1, sam_shape=shape, seed=None)
param_names = pspace.param_names
print(param_names)
print(f"{shape=}")
print(f"{data[0]['hc_ss'].shape}")
print(f"{data[0].keys()=}")

# # set directory path
# sam_loc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
# save_dir=sam_loc+'/figures'       

In [None]:
print(data[0]['gwb_params'][0].shape)

### Make Model

In [None]:
sam1, hard1 = pspace.model_for_params(params[1])
fobs_gw_cents = data[1]['fobs_cents']
fobs_gw_edges = data[1]['fobs_edges']
NFREQS, NREALS, NLOUDEST = [*data[0]['hc_ss'].shape]
print(f"{NFREQS=}, {NREALS=}, {NLOUDEST=}")

In [None]:
fobs_orb_cents = fobs_gw_cents/2
edges, dnum, redz_final, dets = sam1._dynamic_binary_number_at_fobs_consistent(hard1, fobs_orb_cents, details=True)

In [None]:
dadt = dets['dadt']
sepa = dets['sepa']
tau = dets['tau']

### Timescales

In [None]:
# def tau_from_dadt(dadt, sepa):
#     """ tau = dt/dlna = dt/(da/a) = a*(dt/da) = a/(da/dt)"""
#     tau = (sepa)/dadt
#     return tau

# Plot Hardening Time vs. Separation

In [None]:
print(sings.par_names)

In [None]:
print(f"{sepa.shape=}, {tau.shape=}")
print(holo.utils.stats(sepa))
print(holo.utils.stats(tau))

In [None]:
mm=-1
qq=-1
zz=-1
for mm in [5,50,80]:
    for qq in [5,50, 80]:
        for zz in [5,50,80]:
            plt.plot(fobs_gw_cents, sepa[mm,qq,zz])

In [None]:
print(sepa.shape)

In [None]:
xlabels = ['Binary Separation [pc]', 'GW Frequency [nHz]']
ylabels = ['Hardening Time [Gyr]', 'GW Characteristic Strain']

fig, axs = plot.figax(nrows=2, figsize=(5,6)
                     )
for ii,ax in enumerate(axs):
    ax.set_xlabel(xlabels[ii])
    ax.set_ylabel(ylabels[ii])

fig.tight_layout()

x1 = sepa/PC
y1 = dets['tau']/GYR

for mm in np.arange(0, 90, 10):
    for qq in np.arange(0,80,10):
        for zz in np.arange(0,100,10):
            axs[0].plot(x1[mm,qq,zz], y1[mm,qq,zz], alpha=0.5)
# axs[0].set_xlim(10**3, 10**-3)


# x2_bg = fobs_gw_cents





# Let's just copy Luke's notebook

### function to construct evolution data

In [None]:
def construct_evolution(params, nsteps):
    mtot_range = [3e8*MSOL, 3e9*MSOL]
    mtot_hirng = [3e9*MSOL, 3e10*MSOL]
    mrat_range = [0.2, 1.0]
    redz_range = [0, np.inf]

    space = pspace # from above

    # Whatever param we're varying
    target_param_list =  []
    # range of binary separations to plot
    sepa = np.logspace(-3, 3, NSTEPS)[::-1] * PC

    hcss = []
    hcbg = []
    taus = []
    taus_high = []

    # Iterate over target lifetimes
    for tt in tqdm.tqdm(range(len(params))):

        # using my parameters from above
        _params = params[tt] # midpoints
        target_param_list.append(_params[target_param])

        sam, hard = pspace.model_for_params(_params)

        # calculate hc_bg and hc_ss at bin centers, between the given bin edges
        _hcss_step, _hcbg_step, = sam.gwb(fobs_gw_edges, hard, 
                                    loudest = NLOUDEST, realize=NREALS)
        hcss.append(_hcss_step)
        hcbg.append(_hcbg_step)

        # _hcss.append(data[tt]['hc_ss'])
        # _hcbg.append(data[tt]['hc_bg'])

        # calculate binary properties at target separations
        _edges, _dnum, _redz_final, _details = sam._dynamic_binary_number_at_sepa_consistent(
            hard, sepa, details=True) # it would be better if I saved these details when I first calculated them!
        
        # select the bins with target binary parameters
        # I could update this to select out my single source bins
        sel_mtot = (mtot_range[0] < sam.mtot) & (sam.mtot <= mtot_range[1])
        sel_himt = (mtot_hirng[0] < sam.mtot) & (sam.mtot <= mtot_hirng[1])
        sel_mrat = (mrat_range[0] < sam.mrat) & (sam.mrat <= mrat_range[1])
        sel_redz = (redz_range[0] < sam.redz) & (sam.redz <= redz_range[1])
        sel = (
            sel_mtot[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        sel_high = (
            sel_himt[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        _tau = _details['tau'][sel].T
        _tau_high = _details['tau'][sel_high].T
        taus.append(_tau)
        taus_high.append(_tau_high)

    # save results
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
    filename = fileloc+'evol_%s_%dsteps.npz' % (target_param, nsteps)
    print(f"{filename=}")
    np.savez(filename, taus=taus, taus_high=taus_high, target_param_list=target_param_list,
            hcss=hcss, hcbg=hcbg, nsteps=nsteps, sepa=sepa,
            mtot_range=mtot_range, mtot_hirng=mtot_hirng, mrat_range=mrat_range, redz_range=redz_range)


# Varying Hard Time

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/hard_time_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
                        allow_pickle=True)             
print(npz.files)
# data = npz['data']
params = npz['params']
fobs_gw_cents = npz['data'][0]['fobs_cents']
npz.close()

In [None]:
if RECONSTRUCT_FLAG:
    NSTEPS = 20

    mtot_range = [3e8*MSOL, 3e9*MSOL]
    mtot_hirng = [3e9*MSOL, 3e10*MSOL]
    mrat_range = [0.2, 1.0]
    redz_range = [0, np.inf]

    space = pspace # from above

    # hard_time binary lifetimes 
    times_list =  []
    # hard_gamma_inner power law indices
    # inner_list = [params[1]['hard_gamma_inner'], ]
    # range of binary separations to plot
    sepa = np.logspace(-3, 3, NSTEPS)[::-1] * PC

    time_hcss = []
    time_hcbg = []
    time_taus = []
    time_taus_high = []

    # Iterate over target lifetimes
    for tt in tqdm.tqdm(range(len(params))):

        # set custom parameters:
        # using my parameters from above
        _params = params[tt] # midpoints
        times_list.append(_params['hard_time'])
        # params_step['hard_time'] = params[tt]['hard_time']
        # params_step['hard_gamma_inner'] = params[tt]['hard_gamma_inner']
        sam, hard = pspace.model_for_params(_params)

        # calculate hc_bg and hc_ss at bin centers, between the given bin edges
        _hcss_step, _hcbg_step, = sam.gwb(fobs_gw_edges, hard, 
                                    loudest = NLOUDEST, realize=NREALS)
        time_hcss.append(_hcss_step)
        time_hcbg.append(_hcbg_step)

        # _hcss.append(data[tt]['hc_ss'])
        # _hcbg.append(data[tt]['hc_bg'])

        # calculate binary properties at target separations
        _edges, _dnum, _redz_final, _details = sam._dynamic_binary_number_at_sepa_consistent(
            hard, sepa, details=True) # it would be better if I saved these details when I first calculated them!
        
        # select the bins with target binary parameters
        # I could update this to select out my single source bins
        sel_mtot = (mtot_range[0] < sam.mtot) & (sam.mtot <= mtot_range[1])
        sel_himt = (mtot_hirng[0] < sam.mtot) & (sam.mtot <= mtot_hirng[1])
        sel_mrat = (mrat_range[0] < sam.mrat) & (sam.mrat <= mrat_range[1])
        sel_redz = (redz_range[0] < sam.redz) & (sam.redz <= redz_range[1])
        sel = (
            sel_mtot[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        sel_high = (
            sel_himt[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        tau = _details['tau'][sel].T
        tau_high = _details['tau'][sel_high].T
        time_taus.append(tau)
        time_taus_high.append(tau_high)

## save results

In [None]:
if RECONSTRUCT_FLAG:
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
    filename = 'evol_hard_time_%dsteps.npz' % NSTEPS
    np.savez(filename, taus=time_taus, taus_high=time_taus_high, target_param_list=times_list,
            hcss=time_hcss, hcbg=time_hcbg, nsteps=NSTEPS, sepa=sepa,
            mtot_range=mtot_range, mtot_hirng=mtot_hirng, mrat_range=mrat_range, redz_range=redz_range)


## load hard time results

In [None]:
target_param='hard_time'
NSTEPS = 20
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
filename = 'evol_%s_%dsteps.npz' % (target_param, NSTEPS)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## Plot Results

In [None]:
def plot_current():
    fig, axs = plot.figax_double(height=7, nrows=2,  ncols=2, hspace=0.35, bottom=0.1)

    xx = sepa/PC
    YR_LABEL_PAD = -4
    colors = ['tab:green', 'tab:blue', 'tab:orange']

    # ------------------------   Ax Row 0   ----------------------------
    ax = axs[0,0]
    ax1 = axs[0,1]

    ax.set_title(f'Mass Range: {mtot_range/MSOL}')
    ax1.set_title(f"Mass Range: {mtot_hirng/MSOL}")
    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_SEPARATION_PC, ylabel=plot.LABEL_HARDENING_TIME, xscale='log', yscale='log')
        axis.invert_xaxis()

        # axis.axhline(times_list[1], color='k', alpha=0.65)
        # axis.axhline(times_list[0], color='k', ls='--', alpha=0.25)


    labels = []
    handles = []
    for ii, tau in enumerate(taus):
        print(ii)
        yy = tau / GYR
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=True, color=colors[ii])
        # colors.append(hh[0].get_color())
        handles.append(hh[0])
        labels.append(f"${target_param_list[ii]:.1f}$")

        y1 = taus_high[ii]/GYR   
        plot.draw_med_conf_color(ax1, xx, y1, fracs=[0.5], filter=True, color=colors[ii])



    leg = ax.legend(handles, labels, loc='lower left', 
                    ncol=len(handles), title=target_param, title_fontsize=14)

    # ----------------------------- Ax Row 1 --------------------------------

    ax = axs[1,0]
    ax1 = axs[1,1]

    ax1.sharex(ax)
    ax1.sharey(ax)

    for axis in [ax, ax1]:
        axis.set(xlabel=plot.LABEL_GW_FREQUENCY_NHZ, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN, xscale='log', yscale='log')

    xx = fobs_gw_cents*1e9 # nHz

    labels=[]
    handles=[]
    # colors = []

    for ii, yy in enumerate(hcbg):
        # yy = np.median(yy, axis=-1)
        hh = plot.draw_med_conf_color(ax, xx, yy, fracs=[0.5], filter=False, color=colors[ii])
        # colors.append(hh[0].get_color())
        ss = hcss[ii]
        for rr in range(len(ss[0])):
            ax.scatter(xx, ss[:,rr,0], color = colors[ii], alpha=0.5, s=5) # only single include loudest of each realization

    return fig

fig = plot_current()



# Varying gamma_inner

In [None]:
target_param = 'hard_gamma_inner'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param,
                        allow_pickle=True)       
params = npz['params']
npz.close()

if RECONSTRUCT_FLAG:
    NSTEPS = 20
    construct_evolution(params, nsteps)

## save results

In [None]:
if RECONSTRUCT_FLAG:
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
    filename = 'evol_%s_%dsteps.npz' % (target_param, NSTEPS)
    print(f"{filename=}")
    np.savez(filename, taus=taus, taus_high=taus_high, target_param_list=target_param_list,
            hcss=hcss, hcbg=hcbg, nsteps=NSTEPS, sepa=sepa,
            mtot_range=mtot_range, mtot_hirng=mtot_hirng, mrat_range=mrat_range, redz_range=redz_range)


## load gamma_inner results

In [None]:
target_param='hard_gamma_inner'
NSTEPS = 20
fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
filename = 'evol_%s_%dsteps.npz' % (target_param, NSTEPS)
file = np.load(filename)
taus=file['taus']
taus_high=file['taus_high'] 
target_param_list=file['target_param_list']
hcss=file['hcss']
hcbg=file['hcbg']
nsteps=file['nsteps']
sepa=file['sepa']
mtot_range=file['mtot_range'] 
mtot_hirng=file['mtot_hirng']
mrat_range=file['mrat_range']
redz_range=file['redz_range']
file.close()

## plot results

In [None]:
fig = plot_current()

# Varying mmb_mamp_log10

In [None]:
target_param = 'mmb_mamp_log10'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param,
                        allow_pickle=True)       
params = npz['params']
npz.close()

In [None]:
if RECONSTRUCT_FLAG:
    NSTEPS = 20

    mtot_range = [3e8*MSOL, 3e9*MSOL]
    mtot_hirng = [3e9*MSOL, 3e10*MSOL]
    mrat_range = [0.2, 1.0]
    redz_range = [0, np.inf]

    space = pspace # from above

    # Whatever param we're varying
    target_param_list =  []
    # range of binary separations to plot
    sepa = np.logspace(-3, 3, NSTEPS)[::-1] * PC

    hcss = []
    hcbg = []
    taus = []
    taus_high = []

    # Iterate over target lifetimes
    for tt in tqdm.tqdm(range(len(params))):

        # using my parameters from above
        _params = params[tt] # midpoints
        target_param_list.append(_params[target_param])

        sam, hard = pspace.model_for_params(_params)

        # calculate hc_bg and hc_ss at bin centers, between the given bin edges
        _hcss_step, _hcbg_step, = sam.gwb(fobs_gw_edges, hard, 
                                    loudest = NLOUDEST, realize=NREALS)
        hcss.append(_hcss_step)
        hcbg.append(_hcbg_step)

        # _hcss.append(data[tt]['hc_ss'])
        # _hcbg.append(data[tt]['hc_bg'])

        # calculate binary properties at target separations
        _edges, _dnum, _redz_final, _details = sam._dynamic_binary_number_at_sepa_consistent(
            hard, sepa, details=True) # it would be better if I saved these details when I first calculated them!
        
        # select the bins with target binary parameters
        # I could update this to select out my single source bins
        sel_mtot = (mtot_range[0] < sam.mtot) & (sam.mtot <= mtot_range[1])
        sel_himt = (mtot_hirng[0] < sam.mtot) & (sam.mtot <= mtot_hirng[1])
        sel_mrat = (mrat_range[0] < sam.mrat) & (sam.mrat <= mrat_range[1])
        sel_redz = (redz_range[0] < sam.redz) & (sam.redz <= redz_range[1])
        sel = (
            sel_mtot[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        sel_high = (
            sel_himt[:, np.newaxis, np.newaxis] *
            sel_mrat[np.newaxis, :, np.newaxis] * 
            sel_redz[np.newaxis, np.newaxis, :]
        )

        _tau = _details['tau'][sel].T
        _tau_high = _details['tau'][sel_high].T
        taus.append(_tau)
        taus_high.append(_tau_high)

## save results

In [None]:
if RECONSTRUCT_FLAG:
    fileloc = '/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/'
    filename = 'evol_%s_%dsteps.npz' % (target_param, NSTEPS)
    print(f"{filename=}")
    np.savez(filename, taus=taus, taus_high=taus_high, target_param_list=target_param_list,
            hcss=hcss, hcbg=hcbg, nsteps=NSTEPS,
            mtot_range=mtot_range, mtot_hirng=mtot_hirng, mrat_range=mrat_range, redz_range=redz_range)


## plot results

In [None]:
fig = plot_current()

# Vary mmb_scatter_dex

In [None]:
target_param = 'mmb_scatter_dex'
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/%s_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz'
              % target_param,
                        allow_pickle=True)       
params = npz['params']
npz.close()