In [None]:
from pathlib import Path

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import h5py

import kalepy as kale

import holodeck as holo
import holodeck.ems
from holodeck.constants import MSOL, PC, YR, GYR, SPLC, EDDT

# Get population parameters from 15yr data

In [None]:
path_data = "/Users/lzkelley/programs/nanograv/15yr_astro_data/phenom/ceffyl_chains/astroprior_hdall"
path_data = Path(path_data).resolve()
print(path_data)
assert path_data.is_dir()
fname_pars = path_data.joinpath("pars.txt")
fname_chains = path_data.joinpath("chain_1.0.txt")
print(fname_pars)
print(fname_chains)
assert fname_chains.is_file() and fname_pars.is_file()

chain_pars = np.loadtxt(fname_pars, dtype=str)
chains = np.loadtxt(fname_chains)
npars = len(chain_pars)

# Get maximum likelihood parameters (estimate using KDE)
mlpars = {}
fig, axes = plt.subplots(figsize=[10, 1.5*npars], nrows=npars)
plt.subplots_adjust(hspace=0.75)
for ii, ax in enumerate(axes):
    ax.set(xlabel=chain_pars[ii])
    vals = chains[:, ii]
    extr = holo.utils.minmax(vals)
    xx, yy = kale.density(vals, reflect=extr)
    kale.dist1d(chains[:, ii], ax=ax, density=True, carpet=1000)
    idx = np.argmax(yy)
    xmax = xx[idx]
    ax.axvline(xmax, color='firebrick')
    mlpars[chain_pars[ii]] = xmax
    
plt.show()

In [None]:
print(f"Maximum Likelihood binary population parameters:")
for kk, vv in mlpars.items():
    print(f"\t{kk:>20s}: {vv:+.4e}")

# Construct population

In [None]:
# Choose the appropriate Parameter Space (from 15yr astro analysis)
pspace = holo.param_spaces.PS_Uniform_09B
# Load SAM and hardening model for desired parameters
sam, hard = pspace.model_for_params(mlpars)

In [None]:
# Choose range of orbital periods of interest
tvals = [10.0, 0.1]   # yr
NFBINS = 10
print(f"Considering orbital periods between {tvals} yrs, {NFBINS} bins")
# convert to frequencies
fobs_orb_edges = 1 / np.array(tvals)   # 1/yr
# construct bins 
fobs_orb_edges = np.logspace(*np.log10(fobs_orb_edges/YR), NFBINS+1)  # 1/sec
fobs_orb_cents = holo.utils.midpoints(fobs_orb_edges)
fobs = fobs_orb_cents

# calculate (differential) number of binaries
redz_final, diff_num = holo.sams.cyutils.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo
)
# integrate to find total number of binaries in each bin
edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
number = holo.sams.cyutils.integrate_differential_number_3dx1d(edges, diff_num)
print(f"Loaded {number.sum():.1e} binaries across frequency range")

# Calculate DRW parameters

In [None]:
# breaker()
fedd_num = 10
# we dont care about orbital frequency for this, so ignore
cents = [holo.utils.midpoints(ee, log=True) for ee in edges[:-1]]
mesh = [mm.flatten() for mm in np.meshgrid(*cents, indexing='ij')]
size = mesh[0].size
shape = (size, fedd_num, 2)
fedd = holo.utils.log_normal_base_10(0.1, 0.5, shape)
fedd[fedd > 1.0] = 1.0/fedd[fedd > 1.0]
fedd = fedd.reshape(-1, 2)
mesh = [mm[:, np.newaxis] * np.ones(shape[:-1]) for mm in mesh]
mt, mr, rz = [mm.flatten() for mm in mesh]
m1, m2 = holo.utils.m1m2_from_mtmr(mt, mr)

num = number.sum(axis=-1).flatten()
num = num[:, np.newaxis] * np.ones(shape[:-1])
num = num.flatten() / fedd_num

scatter = False
imag_1, taus_1, sfis_1 = holo.ems.drw.drw_params(m1, fedd[:, 0], scatter=scatter)
imag_2, taus_2, sfis_2 = holo.ems.drw.drw_params(m2, fedd[:, 1], scatter=scatter)

scatter = True
imag_1_scatter, taus_1_scatter, sfis_1_scatter = holo.ems.drw.drw_params(m1, fedd[:, 0], scatter=scatter)
imag_2_scatter, taus_2_scatter, sfis_2_scatter = holo.ems.drw.drw_params(m2, fedd[:, 1], scatter=scatter)

In [None]:
fname = "./drw_params.hdf5"
fname = Path(fname).resolve()
with h5py.File(fname, 'w') as out:
    out.create_dataset("m1", data=m1)
    out.create_dataset("m2", data=m2)
    out.create_dataset("num", data=num)
    out.create_dataset("redz", data=rz)
    out.create_dataset("fedd1", data=fedd[:, 0])
    out.create_dataset("fedd2", data=fedd[:, 1])

    group = out.create_group("mean")
    group.create_dataset("imag1", data=imag_1)
    group.create_dataset("imag2", data=imag_2)
    group.create_dataset("taus1", data=taus_1)
    group.create_dataset("taus2", data=taus_2)
    group.create_dataset("sfis1", data=sfis_1)
    group.create_dataset("sfis2", data=sfis_2)

    group = out.create_group("scatter")
    group.create_dataset("imag1", data=imag_1_scatter)
    group.create_dataset("imag2", data=imag_2_scatter)
    group.create_dataset("taus1", data=taus_1_scatter)
    group.create_dataset("taus2", data=taus_2_scatter)
    group.create_dataset("sfis1", data=sfis_1_scatter)
    group.create_dataset("sfis2", data=sfis_2_scatter)
    
print(f"Saved to {fname}, size {holo.utils.get_file_size(fname)}")