Skip to content

Commit

Permalink
Use lookback time instead of age of the universe when interpolating f…
Browse files Browse the repository at this point in the history
…or much higher accuracy.

Add tests for consistent dynamic_binary_number calculation.
  • Loading branch information
lzkelley committed Apr 13, 2024
1 parent 20b3c64 commit ce0d88b
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 95 deletions.
5 changes: 4 additions & 1 deletion holodeck/librarian/param_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class _PS_Astro_Strong(_Param_Space):
"""

__version__ = "0.1"
__version__ = "0.2"

DEFAULTS = dict(
# Hardening model (phenom 2PL)
Expand Down Expand Up @@ -199,6 +199,9 @@ class _PS_Astro_Strong(_Param_Space):
mmb_mamp=0.49e9, # 0.49e9 + 0.06 - 0.05 [Msol]
mmb_plaw=1.17, # 1.17 ± 0.08
mmb_scatter_dex=0.28, # no uncertainties given
# bulge fraction
bf_sigmoid_lo=0.4,
bf_sigmoid_hi=0.8,
)

@classmethod
Expand Down
97 changes: 58 additions & 39 deletions holodeck/sams/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,18 @@ def dynamic_binary_number_at_fobs(self, hard, fobs_orb, use_cython=True, **kwarg
return grid, dnum, redz_final

def _dynamic_binary_number_at_fobs_consistent(self, hard, fobs_orb, steps=200, details=False):
"""Get correct redshifts for full binary-number calculation.
r"""Calculate the differential number of binaries in at each grid point, at each frequency.
Slower but more correct than old `dynamic_binary_number`.
Same as new cython implementation `sam_cyutils.dynamic_binary_number_at_fobs`, which is
more than 10x faster.
LZK 2023-05-11
See :meth:`dynamic_binary_number_at_fobs` for general information.
This is the python implementation for binary evolution (hardening) that is self-consistent,
i.e. evolution models that are able to evolve binaries from galaxy merger until the target
frequencies.
This function should produce the same results as the new cython implementation in:
:func:`holodeck.sams.sam_cyutils.dynamic_binary_number_at_fobs`, which is more than 10x
faster. This python implementation is maintained for diagnostic purposes, and for
functionality when cython is not available.
# BUG doesn't work for Fixed_Time_2PL
Expand All @@ -480,79 +486,92 @@ def _dynamic_binary_number_at_fobs_consistent(self, hard, fobs_orb, steps=200, d
edges = self.edges + [fobs_orb, ]

# shape: (M, Q, Z)
dens = self.static_binary_density # d3n/[dlog10(M) dq dz] units: [Mpc^-3]
dens = self.static_binary_density # d3n/[dlog10(M) dq dz] units: [cMpc^-3]

# ---- Choose the binary separations over which to integrate the binary evolution.

# Start at large separations (galaxy merger) and evolve to small separations (coalescense).

# start from the hardening model's initial separation
rmax = hard._sepa_init
# (M,) end at the ISCO
# end at the ISCO
# (M,)
rmin = utils.rad_isco(self.mtot)
# Choose steps for each binary, log-spaced between rmin and rmax
extr = np.log10([rmax * np.ones_like(rmin), rmin]) # (2,M,)
rads = np.linspace(0.0, 1.0, steps+1)[np.newaxis, :] # (1,X)
# (M, S) = (M,1) * (1,S)
rads = np.linspace(0.0, 1.0, steps+1)[np.newaxis, :] # (1,S)
# (M, S) <== (M,1) * (1,S)
rads = extr[0][:, np.newaxis] + (extr[1] - extr[0])[:, np.newaxis] * rads
rads = 10.0 ** rads

# ---- Calculate binary hardening rate (da/dt) at each separation, for each grid point

# broadcast arrays to a consistent shape
# (M, Q, S)
mt, mr, rads, norm = np.broadcast_arrays(
self.mtot[:, np.newaxis, np.newaxis],
self.mrat[np.newaxis, :, np.newaxis],
rads[:, np.newaxis, :],
hard._norm[:, :, np.newaxis],
)
# calculate hardening rate (negative values, in units of [cm/s])
dadt_evo = hard.dadt(mt, mr, rads, norm=norm)

# ---- Integrate evolution
# to find times and redshifts at which binaries reach each separation

# (M, Q, S-1)
# Integrate (inverse) hardening rates to calculate total lifetime to each separation
times_evo = -utils.trapz_loglog(-1.0 / dadt_evo, rads, axis=-1, cumsum=True)
# ~~~~ RIEMANN integration ~~~~
# times_evo = 2.0 * np.diff(rads, axis=-1) / (dadt_evo[..., 1:] + dadt_evo[..., :-1])
# times_evo = np.cumsum(times_evo, axis=-1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# times_evo = -utils.trapz_loglog(-1.0 / dadt_evo, rads, axis=-1, cumsum=True)

times_evo = 2.0 * np.diff(rads, axis=-1) / (dadt_evo[..., 1:] + dadt_evo[..., :-1])
# for ss in range(steps):
# print(f"py {ss:03d} : {rads[8, 0, ss]:.6e} ==> {rads[8, 0, ss+1]:.6e} == {times_evo[8, 0, ss]:.6e}")
times_evo = np.cumsum(times_evo, axis=-1)
# add array of zero time-delays at starting point (i.e. before the first step)
# with same shape as a slice at a single step
zpad = np.zeros_like(times_evo[..., 0])
times_evo = np.concatenate([zpad[..., np.newaxis], times_evo], axis=-1)
# print(f"{times_evo[8, 0, :]=}")

# Combine the binary-evolution time, with the galaxy-merger time
# (M, Q, Z, S-1)
rz = self.redz[np.newaxis, np.newaxis, :, np.newaxis]
times_tot = times_evo[:, :, np.newaxis, :]
# ---- Convert from time to redshift

# initial redshift (of galaxy merger)
rz = self.redz[np.newaxis, np.newaxis, :, np.newaxis] # (1, 1, Z, 1)

tlbk_init = cosmo.z_to_tlbk(rz)
tlbk = tlbk_init - times_evo[:, :, np.newaxis, :]
# Combine the binary-evolution time, with the galaxy-merger time (if it is defined)
if self._gmt_time is not None:
times_tot += self._gmt_time[:, :, :, np.newaxis]
tlbk -= self._gmt_time[:, :, :, np.newaxis]

redz_evo = utils.redz_after(times_tot, redz=rz)
# for ss in range(steps):
# print(f"py {ss:03d} : t={times_evo[8, 0, ss]:.6e} z={redz_evo[8, 0, 11, ss]:.6e}")
# (M, Q, Z, S)
redz_evo = cosmo.tlbk_to_z(tlbk)

#! age of the universe version of calculation is MUCH less accurate !#
# Use age-of-the-universe
# times_tot = times_evo[:, :, np.newaxis, :]
# # Combine the binary-evolution time, with the galaxy-merger time (if it is defined)
# if self._gmt_time is not None:
# times_tot += self._gmt_time[:, :, :, np.newaxis]
# redz_evo = utils.redz_after(times_tot, redz=rz)
#! ---------------------------------------------------------------- !#

# ---- interpolate to target frequencies

# convert from separations to rest-frame orbital frequencies
# (M, Q, S)
frst_orb_evo = utils.kepler_freq_from_sepa(mt, rads)
# (M, Q, Z, S)
fobs_orb_evo = frst_orb_evo[:, :, np.newaxis, :] / (1.0 + redz_evo)

# ---- interpolate to target frequencies
# `ndinterp` interpolates over 1th dimension

# print(f"{frst_orb_evo[8, 0, :]=}")
# print(f"{fobs_orb=}")
# print(f"{fobs_orb_evo[8, 0, 11, :]=}")
# print(f"{redz_evo[8, 0, 11, :]=}")

# (M, Q, Z, S-1) ==> (M*Q*Z, S-1)
# (M, Q, Z, S) ==> (M*Q*Z, S)
fobs_orb_evo, redz_evo = [tt.reshape(-1, steps+1) for tt in [fobs_orb_evo, redz_evo]]
# `ndinterp` interpolates over 1th dimension
# (M*Q*Z, X)
redz_final = utils.ndinterp(fobs_orb, fobs_orb_evo, redz_evo, xlog=False, ylog=False)
redz_final = utils.ndinterp(fobs_orb, fobs_orb_evo, redz_evo, xlog=True, ylog=False)

# (M*Q*Z, X) ===> (M, Q, Z, X)
# (M, Q, Z, X) <=== (M*Q*Z, X)
redz_final = redz_final.reshape(self.shape + (fobs_orb.size,))
# _test_times = _test_times.reshape(self.shape + (fobs_orb.size,))

# print(f"{redz_final[8, 0, 11, :]=}")
# print(f"{_test_times[8, 0, 11, :]=}")

coal = (redz_final > 0.0)
frst_orb = fobs_orb * (1.0 + redz_final)
Expand Down
29 changes: 26 additions & 3 deletions holodeck/sams/sam_cyutils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def integrate_differential_number_3dx1d(edges, dnum):

# each edge should have the same length as the corresponding dimension of `dnum`
shape = [len(ee) for ee in edges]
err = f"Shape of edges={shape} does not match dnum={np.shape(dnum)}"
# except the last edge (freq), where `dnum` should be 1-shorter
shape[-1] -= 1
assert np.shape(dnum) == tuple(shape)
assert np.shape(dnum) == tuple(shape), err
# the number will be shaped as one-less the size of each dimension of `dnum`
new_shape = [sh-1 for sh in dnum.shape]
# except for the last dimension (freq) which is the same shape
Expand Down Expand Up @@ -643,7 +644,11 @@ cdef int _dynamic_binary_number_at_fobs_2pwl(
)

# Find time to move from left- to right- edges: dt = da / (da/dt)
# average da/dt on the left- and right- edges of the bin (i.e. trapezoid rule)
dt = 2.0 * (sepa_right - sepa_left) / (dadt_left + dadt_right)
# if ii == 8 and jj == 0:
# printf("cy %03d : %.2e ==> %.2e == %.2e\n", step, sepa_left, sepa_right, dt)

time_evo += dt

# ---- Iterate over starting redshift bins
Expand All @@ -657,8 +662,8 @@ cdef int _dynamic_binary_number_at_fobs_2pwl(

# if we pass the age of the universe, this binary has stalled, no further redshifts will work
# NOTE: if `gmt_time` decreases faster than redshift bins increase the universe age,
# then systems in later `redz` bins may no longer stall, so we still need to calculate them
# i.e. we can NOT use a `break` statement here
# then systems in later `redz` bins may no longer stall, so we still need to calculate them.
# i.e. we can NOT use a `break` statement here, must use `continue` statement.
if time_left > age_universe:
continue

Expand All @@ -684,6 +689,9 @@ cdef int _dynamic_binary_number_at_fobs_2pwl(
if redz_right < 0.0:
redz_right = 0.0

# if ii == 8 and jj == 0 and kk == 11:
# printf("cy %03d : t=%.2e z=%.2e\n", step, time_right, redz_right)

# convert to frequencies
fobs_orb_left = frst_orb_left / (1.0 + redz_left)
fobs_orb_right = frst_orb_right / (1.0 + redz_right)
Expand Down Expand Up @@ -724,6 +732,21 @@ cdef int _dynamic_binary_number_at_fobs_2pwl(
# get comoving distance
dcom = interp_at_index(new_interp_idx, new_time, tage_interp_grid, dcom_interp_grid)

# if (ii == 0) and (jj == 0) and (kk == 0):
# printf("cy f=%03d (step=%03d)\n", ff, step)
# printf(
# "fl=%.6e, f=%.6e, fr=%.6e ==> tl=%.6e, t=%.6e, tr=%.6e\n",
# fobs_orb_left, ftarget, fobs_orb_right,
# time_left, new_time, time_right
# )
# printf(
# "interp (%d) time: %.6e, %.6e, %.6e ==> z: %.6e, %.6e, %.6e\n",
# new_interp_idx,
# tage_interp_grid[new_interp_idx], new_time, tage_interp_grid[new_interp_idx+1],
# redz_interp_grid[new_interp_idx], new_redz, redz_interp_grid[new_interp_idx+1],
# )
# printf("======> z=%.6e\n", new_redz)

# Store redshift
redz_final[ii, jj, kk, ff] = new_redz

Expand Down
93 changes: 91 additions & 2 deletions holodeck/sams/tests/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ def test_dbn_gw_only():
NUM_FREQS = 9
fobs_gw_cents, fobs_gw_edges = holo.utils.pta_freqs(PTA_DUR, NUM_FREQS)
fobs_orb_cents = fobs_gw_cents / 2.0
# fobs_orb_edges = fobs_gw_edges / 2.0

# (1)
# (1) make sure it runs

grid_py, dnum_py, redz_final_py = sam.dynamic_binary_number_at_fobs(hard_gw, fobs_orb_cents, use_cython=False)
grid_cy, dnum_cy, redz_final_cy = sam.dynamic_binary_number_at_fobs(hard_gw, fobs_orb_cents, use_cython=True)
Expand All @@ -157,6 +156,96 @@ def test_dbn_gw_only():
print(f"{utils.stats(redz_final_cy[bads])=}")
assert not np.any(bads), f"Found {utils.frac_str(bads)} inconsistent `redz_final` b/t python and cython calcs!"

return


def test_dbn_phenom():
"""Test the dynamic_binary_number method using Phenomenological evolution.
(1) runs without error
(2) dnum values are consistent between cython and python
(3) redz_final values are consistent between cython and python
"""

shape = (10, 11, 12)
sam = holo.sams.Semi_Analytic_Model(shape=shape)
TIME = 1.0e9 * YR
hard_phenom = holo.hardening.Fixed_Time_2PL_SAM(sam, TIME, num_steps=300)

PTA_DUR = 20.0 * YR
NUM_FREQS = 9
fobs_gw_cents, fobs_gw_edges = holo.utils.pta_freqs(PTA_DUR, NUM_FREQS)
fobs_orb_cents = fobs_gw_cents / 2.0
fobs_orb_edges = fobs_gw_edges / 2.0

# we'll allow differences at very low redshifts, where numerical differences become significant
ALLOW_BADS_BELOW_REDZ = 1.0e-2

# (1) make sure it runs

grid_py, dnum_py, redz_final_py = sam.dynamic_binary_number_at_fobs(hard_phenom, fobs_orb_cents, use_cython=False)
grid_cy, dnum_cy, redz_final_cy = sam.dynamic_binary_number_at_fobs(hard_phenom, fobs_orb_cents, use_cython=True)
edges_py = grid_py[:-1] + [fobs_orb_edges,]
edges_cy = grid_cy[:-1] + [fobs_orb_edges,]

redz_not_ignore = (redz_final_py > ALLOW_BADS_BELOW_REDZ) | (redz_final_cy > ALLOW_BADS_BELOW_REDZ)

# (2) the same dnum values are zero

zeros_py = (dnum_py == 0.0)
zeros_cy = (dnum_cy == 0.0)

# ignore mismastch at low-redshifts
bads = (zeros_py != zeros_cy) & redz_not_ignore
if np.any(bads):
print(f"{utils.frac_str(bads)=}")
print(f"{utils.stats(dnum_py[bads])=}")
print(f"{utils.stats(dnum_cy[bads])=}")
assert not np.any(bads), "Zero points in `dnum` do not match between python and cython!"

# (3) dnum consistent between cython- and python- versions of calculation

# ignore mismastch at low-redshifts
bads = ~np.isclose(dnum_py, dnum_cy, rtol=1e-1) & redz_not_ignore
if np.any(bads):
errs = (dnum_py - dnum_cy) / dnum_cy
print(f"{utils.frac_str(bads)=}")
print(f"{utils.stats(errs)=}")
print(f"{utils.stats(errs[bads])=}")
print(f"{dnum_py[bads][:10]=}")
print(f"{dnum_cy[bads][:10]=}")
print(f"{errs[bads][:10]=}")
print(f"{utils.stats(dnum_py[bads])=}")
print(f"{utils.stats(dnum_cy[bads])=}")
assert not np.any(bads), f"Found {utils.frac_str(bads)} inconsistent `dnum` b/t python and cython calcs!"

# (3,) redz_final consistent between cython- and python- versions of calculation

# ignore mismastch at low-redshifts
bads = (~np.isclose(redz_final_py, redz_final_cy, rtol=1e-2)) & redz_not_ignore
if np.any(bads):
print(f"{utils.frac_str(bads)=}")
print(f"{redz_final_py[bads][:10]=}")
print(f"{redz_final_cy[bads][:10]=}")
print(f"{utils.stats(redz_final_py[bads])=}")
print(f"{utils.stats(redz_final_cy[bads])=}")
assert not np.any(bads), f"Found {utils.frac_str(bads)} inconsistent `redz_final` b/t python and cython calcs!"

# (4,) make sure that ALL numbers of binaries are consistent

num_py = holo.sams.sam_cyutils.integrate_differential_number_3dx1d(edges_py, dnum_py)
num_cy = holo.sams.sam_cyutils.integrate_differential_number_3dx1d(edges_cy, dnum_cy)

# Make sure that `atol` is also set to a reasonable value
bads = ~np.isclose(num_py, num_cy, rtol=1e-2, atol=1.0e-1)
if np.any(bads):
print(f"{utils.frac_str(bads)=}")
print(f"{num_py[bads][:10]=}")
print(f"{num_cy[bads][:10]=}")
print(f"{utils.stats(num_py[bads])=}")
print(f"{utils.stats(num_cy[bads])=}")
assert not np.any(bads), f"Found {utils.frac_str(bads)} inconsistent `num` b/t python and cython calcs!"

return

Loading

0 comments on commit ce0d88b

Please sign in to comment.