# Test Binary Evolution

In [None]:
print("Hello")

import numpy as np
import holodeck as holo
from datetime import datetime
import holodeck.sam_cython
from holodeck.constants import YR, GYR
from holodeck import utils, cosmo

DEF_NUM_FBINS = 7
DEF_PTA_DUR = 16.03     # [yrs]
SHAPE = 20

# Choose observed GW-Frequency bins based on nyquist sampling
fobs_edges = utils.nyquist_freqs_edges(DEF_PTA_DUR*YR, cad=0.1*YR)
fobs_edges = fobs_edges[:DEF_NUM_FBINS+1]
fobs_cents = utils.midpoints(fobs_edges)
fobs_orb_cents = fobs_cents / 2.0

gsmf = holo.sam.GSMF_Schechter()               # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()                 # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()                 # Galaxy Merger Time           (GMT)
mmbulge = holo.relations.MMBulge_Standard()    # M-MBulge Relation            (MMB)

sam = holo.sam.Semi_Analytic_Model(
    gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=SHAPE,
    ZERO_GMT_STALLED_SYSTEMS=True,
)

# hard = holo.hardening.Hard_GW
hard = holo.hardening.Fixed_Time.from_sam(sam, 2 * GYR)

print("START", datetime.now())
redz_final, diff_num = holodeck.sam_cython.dynamic_binary_number(fobs_orb_cents, sam, hard, cosmo, nsteps=200)
print("END", datetime.now())

In [None]:
edges, dnum = sam.dynamic_binary_number(hard, fobs_orb_cents, zero_stalled=True, zero_coalesced=True)

In [None]:
old_redz = sam._redz_final
sel1 = (redz_final > 0.0)
sel2 = (old_redz > 0.0)
print(utils.stats(redz_final[sel1]))
print(utils.stats(old_redz[sel2]))
sel = sel1 & sel2
print(utils.frac_str(sel1))
print(utils.frac_str(sel2))
print(utils.frac_str(sel))
print(utils.stats(redz_final[sel]/old_redz[sel]))

In [None]:
old_redz[-1, -1, :, 0], redz_final[-1, -1, :, 0]

In [None]:
print(utils.stats(dnum))
print(utils.stats(diff_num))
sel1 = (diff_num > 0.0)
sel2 = (dnum > 0.0)
sel = sel1 & sel2
print(utils.frac_str(sel1))
print(utils.frac_str(sel2))
print(utils.frac_str(sel))
print(utils.stats(diff_num[sel]/dnum[sel]))

In [None]:
fobs_orb_cents

In [None]:
bads = (sel1 & ~sel2) | (~sel1 & sel2)
# bads = (sel1 & ~sel2)
# bads = (~sel1 & sel2)
print(utils.frac_str(bads))
bads = np.where(bads)

fig, axes = plt.subplots(figsize=[10, 10], nrows=4)
for ii, (ax, bb) in enumerate(zip(axes, bads)):
    ax.set(xscale='log', yscale='log')
    ax.hist(edges[ii][bb], bins=edges[ii], density=False)
    
plt.show()

In [None]:
cosmo._grid_dcom/MPC