In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import h5py
import healpy as hp
import kalepy as kale

import holodeck as holo
import holodeck.anisotropy as anis
from holodeck import detstats, plot, utils
from holodeck.constants import YR, MSOL, GYR, SPLC, PC, MPC

# Functions

In [None]:
def sam_model(sam, hard,
        dur=16.03*YR, cad=0.2*YR, use_redz=True):
    fobs_gw_cents = utils.nyquist_freqs(dur,cad)
    fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
        hard_name = 'Fixed Time'
    elif isinstance(hard, holo.hardening.Hard_GW):
        hard_name = 'GW Only'
    else:
        raise Exception("'hard' must be an instance of 'Fixed_Time_2PL_SAM' or 'Hard_GW'")

    redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)
    if use_redz:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges, redz_final)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges, redz_final)
    else:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges)

    vals = {
        'hard':hard, 'sam':sam, 'edges':edges, 'number': number, 'diff_num':diff_num, 'redz_final':redz_final,
        'hs_cents':hs_cents, 'hs_edges':hs_edges, 'fobs_gw_cents':fobs_gw_cents, 'fobs_gw_edges':fobs_gw_edges, 
        'fobs_orb_cents':fobs_orb_cents, 'fobs_orb_edges':fobs_orb_edges, 'hard_name':hard_name
    }
    return vals

def integrate_mm(dnum, edges): # integrate dN/dlogM
    num = utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
    return num
def integrate_qq(dnum, edges): # integrate dN/dq
    num = utils.trapz(dnum, edges[1], axis=1, cumsum=False)
    return num
def integrate_zz(dnum, edges): # dN/dz
    num = utils.trapz(dnum, edges[2], axis=2, cumsum=False)
    return num
def integrate_ff(dnum, fobs_gw_edges): # dN/dlogn
    num = dnum*np.diff(np.log(fobs_gw_edges))
    return num 

In [None]:
def strain_amp_at_bin_centers_redz(edges, redz=None):
    """ Calculate strain amplitude at bin centers, with final or initial redz.
    
    """
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # redshifts are defined across 4D grid, shape (M, Q, Z, Fc)
    #    where M, Q, Z are edges and Fc is frequency centers
    # find midpoints of redshifts in M, Q, Z dimensions, to end up with (M-1, Q-1, Z-1, Fc)
    if redz is not None:
        for dd in range(3):
            redz = np.moveaxis(redz, dd, 0)
            redz = kale.utils.midpoints(redz, axis=0)
            redz = np.moveaxis(redz, 0, dd)
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else:
        redz = kale.utils.midpoints(edges[2])[np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value


    # ---- calculate GW strain ----
    mt = kale.utils.midpoints(edges[0])
    mr = kale.utils.midpoints(edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs = utils.gw_strain_source(mc, dc, fr)
    return hs

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=(5,6,7))
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
vals = sam_model(sam, hard)

hs_init = strain_amp_at_bin_centers_redz(vals['edges'], redz=None)
hs_final = strain_amp_at_bin_centers_redz(vals['edges'], redz=vals['redz_final'])


In [None]:
def redz_allN(sam, hard, fobs_orb, steps=200, details=False):
        """Get correct redshifts for full binary-number calculation.

        Slower but more correct than old `dynamic_binary_number`.
        Same as new cython implementation `holo.sam_cython.dynamic_binary_number_at_fobs`, which is
        more than 10x faster.
        LZK 2023-05-11

        # BUG doesn't work for Fixed_Time_2PL

        """
        fobs_orb = np.asarray(fobs_orb)
        edges = sam.edges + [fobs_orb, ]

        # shape: (M, Q, Z)
        dens = sam.static_binary_density   # d3n/[dlog10(M) dq dz]  units: [Mpc^-3]


        # start from the hardening model's initial separation
        rmax = hard._sepa_init
        # (M,) end at the ISCO
        rmin = utils.rad_isco(sam.mtot)
        # Choose steps for each binary, log-spaced between rmin and rmax
        extr = np.log10([rmax * np.ones_like(rmin), rmin])     # (2,M,)
        rads = np.linspace(0.0, 1.0, steps)[np.newaxis, :]     # (1,X)
        # (M, S)  =  (M,1) * (1,S)
        rads = extr[0][:, np.newaxis] + (extr[1] - extr[0])[:, np.newaxis] * rads
        rads = 10.0 ** rads

        # (M, Q, S)
        mt, mr, rads, norm = np.broadcast_arrays(
            sam.mtot[:, np.newaxis, np.newaxis],
            sam.mrat[np.newaxis, :, np.newaxis],
            rads[:, np.newaxis, :],
            hard._norm[:, :, np.newaxis],
        )

        # these must all be 1darrays of matching size (X,) but they aren't
        print(f"{mt.shape=}, {mr.shape=}, {rads.shape=}, {norm.shape=}") 
        # mt=mt.flatten()
        # mr=mr.flatten()
        # rads=rads.flatten()
        # norm=norm.flatten()
        # print(f"{mt.shape=}, {mr.shape=}, {rads.shape=}, {norm.shape=}") 
        dadt_evo = hard.dadt(mt, mr, rads, norm=norm)
        # dadt_evo = hard.dadt(mt.flatten(), mr.flatten(), rads.flatten(), norm=norm.flatten())
        print(f"{utils.stats(dadt_evo*YR/PC)=}")

        # (M, Q, S-1)
        # Integrate (inverse) hardening rates to calculate total lifetime to each separation
        times_evo = -utils.trapz_loglog(-1.0 / dadt_evo, rads, axis=-1, cumsum=True)
        print(f"{utils.stats(times_evo/GYR)=}")
        # Combine the binary-evolution time, with the galaxy-merger time
        # (M, Q, Z, S-1)
        rz = sam.redz[np.newaxis, np.newaxis, :, np.newaxis]
        times_tot = times_evo[:, :, np.newaxis, :] + sam._gmt_time[:, :, :, np.newaxis]
        redz_evo = utils.redz_after(times_tot, redz=rz)

        # convert from separations to rest-frame orbital frequencies
        # (M, Q, S)
        frst_orb_evo = utils.kepler_freq_from_sepa(mt, rads)
        # (M, Q, Z, S)
        fobs_orb_evo = frst_orb_evo[:, :, np.newaxis, :] / (1.0 + rz)

        # ---- interpolate to target frequencies
        # `ndinterp` interpolates over 1th dimension

        # (M, Q, Z, S-1)  ==>  (M*Q*Z, S-1)
        fobs_orb_evo, redz_evo = [mm.reshape(-1, steps-1) for mm in [fobs_orb_evo[:, :, :, 1:], redz_evo]]
        # (M*Q*Z, X)
        redz_final = utils.ndinterp(fobs_orb, fobs_orb_evo, redz_evo, xlog=True, ylog=False)
        print(f"{utils.stats(redz_final)=}")

        # (M*Q*Z, X) ===> (M, Q, Z, X)
        redz_final = redz_final.reshape(sam.shape + (fobs_orb.size,))


        return redz_final

        # coal = (redz_final > 0.0)
        # frst_orb = fobs_orb * (1.0 + redz_final)
        # frst_orb[frst_orb < 0.0] = 0.0
        # redz_final[~coal] = -1.0

        # # (M, Q, Z, X) comoving-distance in [Mpc]
        # dc = np.zeros_like(redz_final)
        # dc[coal] = holo.cosmo.comoving_distance(redz_final[coal]).to('Mpc').value

        # # (M, Q, Z, X) this is `(dVc/dz) * (dz/dt)` in units of [Mpc^3/s]
        # cosmo_fact = np.zeros_like(redz_final)
        # cosmo_fact[coal] = 4 * np.pi * (SPLC/MPC) * np.square(dc[coal]) * (1.0 + redz_final[coal])

        # # (M, Q) calculate chirp-mass
        # mt = sam.mtot[:, np.newaxis, np.newaxis, np.newaxis]
        # mr = sam.mrat[np.newaxis, :, np.newaxis, np.newaxis]

        # # Convert from observer-frame orbital freq, to rest-frame orbital freq
        # sa = utils.kepler_sepa_from_freq(mt, frst_orb)
        # print(f"{utils.stats(sa/PC)=}")
        # mt, mr, sa, norm = np.broadcast_arrays(mt, mr, sa, hard._norm[:, :, np.newaxis, np.newaxis])
        # # hardening rate, negative values, units of [cm/sec]
        # dadt = hard.dadt(mt, mr, sa, norm=norm)
        # print(f"{utils.stats(dadt*YR/PC)=}")
        # # Calculate `tau = dt/dlnf_r = f_r / (df_r/dt)`
        # # dfdt is positive (increasing frequency)
        # dfdt, frst_orb = utils.dfdt_from_dadt(dadt, sa, frst_orb=frst_orb)
        # tau = frst_orb / dfdt

        # # (M, Q, Z, X) units: [1/s] i.e. number per second
        # dnum = dens[..., np.newaxis] * cosmo_fact * tau
        # dnum[~coal] = 0.0

        # if details:
        #     tau[~coal] = 0.0
        #     dadt[~coal] = 0.0
        #     sa[~coal] = 0.0
        #     cosmo_fact[~coal] = 0.0
        #     # (M, Q, X)  ==>  (M, Q, Z, X)
        #     dets = dict(tau=tau, cosmo_fact=cosmo_fact, dadt=dadt, fobs=fobs_orb, sepa=sa)
        #     return edges, dnum, redz_final, dets

        # sam._redz_final = redz_final

        # return edges, dnum, redz_final

In [None]:
fobs_orb=vals['fobs_orb_cents']
redz_final=vals['redz_final']
redz_allN = redz_allN(sam,hard,fobs_orb)

# plt.scatter(np.arange(redz_final.size), redz_final.flatten())

In [None]:
fig,ax = plot.figax(xlabel='$h_s$', ylabel='N')
ax.hist(hs_init.flatten(), histtype='step', bins=500, ls='solid', label='hs_init')
ax.hist(hs_final.flatten(), histtype='step', bins=500, ls='solid', label='hs_final')
ax.legend()

In [None]:
fobs_orb_cents=vals['fobs_orb_cents']
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)