# SAM - LISA Detection Rates 

In [None]:
from pathlib import Path

import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt

import legwork as lw

import holodeck as holo
import holodeck.librarian
from holodeck import cosmo, utils, plot
from holodeck.constants import GYR, YR, PC, MSOL

In [None]:
LISA_DUR_YR = 5.0

Get LISA sensitivity curve from the [`legwork` package](https://github.com/TeamLEGWORK/LEGWORK)

In [None]:
lisa_mission_dur = LISA_DUR_YR * u.yr
fobs = np.logspace(-7, 0, 1000) * u.Hz

# --- plot LISA sensitivity curve
lisa_psd = lw.psd.power_spectral_density(f=fobs, t_obs=lisa_mission_dur)
lisa_hc = np.sqrt(fobs * lisa_psd)

plt.loglog(fobs, lisa_hc)
plt.gca().set(xlabel='GW frequency [Hz]', ylabel='Characteristic Strain')
plt.show()

In [None]:
def is_lisa_detectable(ff, hc, fl, hl, snr=3.0):
    """Determine which binaries (ISCO frequencies and strains) are detectable (above LISA curve).

    Note that this function will automatically select binaries reaching the correct frequencies.

    Arguments
    ---------
    ff : array_like of float
        Frequencies of binaries (at ISCO).  Units must match `fl`; typically [Hz].
    hc : array_like of float
        Characterstic-strains of binaries (at ISCO).
    fl : array_like of float
        Frequencies of LISA sensitivty curve.  Units must match `ff`; typically [Hz].
    hl : array_like of float
        Characterstic-strains of LISA sensitivity curve.

    Returns
    -------
    sel : array_like of bool
        Whether or not the corresponding binary is detectable.
        Matches the shape of `ff`.

    """

    # use logarithmic interpolation to find the LISA sensitivity curve at the binary frequencies
    # if the binary frequencies are outside of the LISA band, `NaN` values are returned
    sens_at_ff = utils.interp(ff, fl, hl)
    # select binaries above sensitivity curve, `NaN` values (i.e. outside of band) will be False.
    sel = (hc > sens_at_ff*snr)
    return sel

## SAM LISA Detection Rates

### Build SAM

In [None]:
mmbulge = holo.host_relations.MMBulge_KH2013()
gsmf = holo.sams.components.GSMF_Double_Schechter()
gmr = holo.sams.components.GMR_Illustris()

sam = holo.sams.sam.Semi_Analytic_Model(gsmf=gsmf, gmr=gmr, mmbulge=mmbulge)

In [None]:
# ---- Number density of binary mergers
# ``d^3 n / [dlog10M dq dz]`` in units of [Mpc^-3]
ndens = sam.static_binary_density

In [None]:
mtot, mrat, redz = sam.edges
mt, mr, rz = np.meshgrid(mtot, mrat, redz, indexing='ij')
dc = cosmo.z_to_dcom(rz)

m1, m2 = utils.m1m2_from_mtmr(mt, mr)
mc = utils.chirp_mass_mtmr(mt, mr)

# Place all binaries at the ISCO, find the corresponding frequency, strain, and characteristic strain
risco = utils.rad_isco(mt)
fisco_rst = utils.kepler_freq_from_sepa(mt, risco)
fisco_obs = fisco_rst / (1.0 + rz)
hs = utils.gw_strain_source(mc, dc, fisco_rst)
dadt = utils.gw_hardening_rate_dadt(m1, m2, risco)
dfdt, _ = utils.dfdt_from_dadt(dadt, risco, mtot=mt, frst_orb=fisco_rst)
print("hs = ", utils.stats(hs))

ncycles = fisco_rst**2 / dfdt
print("ncycles = ", utils.stats(ncycles))

hc = np.sqrt(ncycles) * hs
print("hc = ", utils.stats(hc))

### Compare Binaries to LISA Sensitivity Curve

In [None]:
fig, ax = plt.subplots()
ax.set(xlabel='Frequency [Hz]', ylabel='Characteristic Strain')

lab = f"LISA ({LISA_DUR_YR:.1f} yr)"
ax.loglog(fobs, lisa_hc, label=lab)


# --- plot ISCO characteristic-strains
# color based on chirp-mass
smap = plot.smap(mc/MSOL, log=True)
colors = smap.to_rgba(mc.flatten()/MSOL)
# find which points are detectable (above LISA curve)
ff = fisco_obs.flatten()
hh = hc.flatten()
sel = is_lisa_detectable(ff, hh, fobs, lisa_hc)
print(f"Fraction of detectable grid points: {utils.frac_str(sel)}")
# plot
ax.scatter(ff[~sel], hh[~sel], alpha=0.01, s=1, facecolor=colors[~sel], label='ISCO binaries')
ax.scatter(ff[sel], hh[sel], alpha=0.9, s=4, facecolor=colors[sel])


plt.colorbar(smap, ax=ax, label='Chirp Mass $[M_\odot]$')
plt.legend(markerscale=4.0)
plt.show()

### Calculate Rates

Above, we have `ndens` which is the differential number-density of binaries in bins of total mass, mass ratio, and redshift:
$$\frac{d^3 n}{d\log_{10}M \, dq \, dz}.$$
The number-density is
$$n = \frac{dN}{dV_c}$$
for comoving volume $V_c$.

$$\frac{dN}{dt}
    = \int \frac{d^2 N}{dV_c dz} \frac{dz}{dt} \frac{d V_c}{dz} \frac{1}{1+z}dz 
    = \int \frac{d n}{dz} \frac{dz}{dt} \frac{d V_c}{dz} \frac{1}{1+z}dz,$$
where the $(1+z)$ converts from rest-frame time (RHS), to observer-frame time (LHS).

We must also integrate `ndens` over total mass and mass ratio, in addition to redshift, but the integrands in the above equations have no explicit $M$ or $q$ dependence, so that can be done independently.

In [None]:
# Get cosmological factors
# (Z,) units of [1/sec]
dzdt = 1.0 / cosmo.dtdz(redz)
# `ndens` is in units of [Mpc^-3], so make sure dVc/dz matches: [Mpc^3]
dVcdz = cosmo.dVcdz(redz, cgs=False).to('Mpc3').value

# --- Use trapezoid rule to integrate over redshift (last dimension of `ndens`)
# (Z,)
integ = dzdt * dVcdz / (1.0 + redz)
# (M, Q, Z)
integ = ndens * integ
# multiple by boolean array of detectable elements (i.e. zero out non-detectable binaries)
integ *= sel.reshape(ndens.shape)
# (Z-1,)
dz = np.diff(redz)
# perform 'integration', but don't sum over redshift bins
# (M, Q, Z-1)
rate = 0.5 * (integ[:, :, :-1] + integ[:, :, 1:]) * dz

# ---- Integrate over mass and mass-ratio
# (M-1,)
dlogm = np.diff(np.log10(mtot))
# (Q-1,)
dq = np.diff(mrat)
# (M-1, Q, Z-1)
rate = 0.5 * (rate[:-1, :, :] + rate[1:, :, :]) * dlogm[:, np.newaxis, np.newaxis]
# (M-1, Q-1, Z-1)
rate = 0.5 * (rate[:, :-1, :] + rate[:, 1:, :]) * dq[np.newaxis, :, np.newaxis]

In [None]:
print(f"Rate of detections is {rate.sum()*YR:.2e} [1/yr]")

In [None]:
fig, axes = plt.subplots(figsize=[8, 3], ncols=3, sharey=True)
plt.subplots_adjust(wspace=0.02)

units = [MSOL, 1.0, 1.0]
direct = [-1, -1, +1]
labels = ['total mass $[M_\odot]$', 'mass ratio', 'redshift']
ylab = 'detection rate $[1/\mathrm{yr}]$'
for ii, ax in enumerate(axes):
    ax.grid(True, alpha=0.15)
    ax.set(xscale='log', yscale='log', xlabel=labels[ii])
    rr = np.moveaxis(rate, ii, 0)
    rr = np.sum(rr, axis=(1, 2)) * YR

    xx = sam.edges[ii] / units[ii]

    if direct[ii] < 0:
        yy = np.cumsum(rr[::-1])[::-1]
    else:
        yy = np.cumsum(rr)
    ax.plot(xx[1:], yy, lw=2.0)

axes[0].set(ylabel=ylab, ylim=[1e-1, 1e3])
plt.show()

# New Calculation

In [None]:
SHAPE = None
# SHAPE = (31, 32, 33)
mmbulge = holo.host_relations.MMBulge_KH2013()
gsmf = holo.sams.components.GSMF_Double_Schechter()

# gmr = holo.sams.components.GMR_Illustris()
# sam = holo.sams.sam.Semi_Analytic_Model(gsmf=gsmf, gmr=gmr, mmbulge=mmbulge, shape=SHAPE)

gmt = holo.sams.components.GMT_Power_Law()
gpf = holo.sams.components.GPF_Power_Law()
sam = holo.sams.sam.Semi_Analytic_Model(gsmf=gsmf, gmt=gmt, gpf=gpf, mmbulge=mmbulge, shape=SHAPE)

hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 2*GYR)

In [None]:
SHAPE = None
# SHAPE = (31, 32, 33)

mlpars = {
    'hard_time': 7.851539308157039,
    'gsmf_phi0': -1.8839740086540187,
    'gsmf_mchar0_log10': 11.235424795005963,
    'mmb_mamp_log10': 8.818001401400009,
    'mmb_scatter_dex': 0.2587876888563057,
    'hard_gamma_inner': -0.8609000063794954
}

mmbulge = holo.host_relations.MMBulge_KH2013(
    mamp_log10=mlpars['mmb_mamp_log10'],
    scatter_dex=mlpars['mmb_scatter_dex']
)
gsmf = holo.sams.GSMF_Schechter(
    phi0=mlpars['gsmf_phi0'],
    mchar0_log10=mlpars['gsmf_mchar0_log10'],
)
# gsmf = holo.sams.components.GSMF_Double_Schechter()
# gmr = holo.sams.components.GMR_Illustris()
# sam = holo.sams.sam.Semi_Analytic_Model(gsmf=gsmf, gmr=gmr, mmbulge=mmbulge, shape=SHAPE)

gmt = holo.sams.components.GMT_Power_Law()
gpf = holo.sams.components.GPF_Power_Law()
sam = holo.sams.sam.Semi_Analytic_Model(gsmf=gsmf, gmt=gmt, gpf=gpf, mmbulge=mmbulge, shape=SHAPE)

hard = holo.hardening.Fixed_Time_2PL_SAM(sam, mlpars['hard_time']*GYR, gamma_inner=mlpars['hard_gamma_inner'])

In [None]:
redz_final, rate, fisco, hc = sam.rate_chirps(hard, integrate=False)

In [None]:
integ = sam._integrate_event_rate(rate)
print(integ.sum()*YR)

In [None]:
fig, axes = plt.subplots(figsize=[16, 7], ncols=3)
labels = ['total mass', 'mass ratio', 'redshift (initial)']
for ii, ax in enumerate(axes):
    jj = (ii+1) % 3
    kk = (ii+2) % 3
    ax.set(xscale='log', yscale='log', xlabel=labels[jj], ylabel=labels[kk])

    xx = sam.edges[jj]
    yy = sam.edges[kk]
    zz = np.sum(rate, axis=ii) * YR
    zz = np.log10(zz)

    if jj > kk:
        zz = zz.T

    pcm = ax.pcolormesh(xx, yy, zz.T, shading='auto')
    plt.colorbar(pcm, ax=ax, orientation='horizontal')


plt.show()

In [None]:
lisa = holo.gravwaves.LISA()
new_rate = lisa(fisco, hc) * rate
new_rate = sam._integrate_event_rate(new_rate)
print(new_rate.sum()*YR, integ.sum()*YR)

# Sample 15yr Posteriors

In [None]:
path_data = "/Users/lzkelley/Programs/nanograv/15yr_astro_data/"
path_model = Path(path_data).joinpath("phenom/ceffyl_chains/astroprior_hdall")
# path_data = Path("./data/astroprior_hdall").resolve()
print(path_model)
assert path_model.is_dir()
fname_pars = path_model.joinpath("pars.txt")
fname_chains = path_model.joinpath("chain_1.0.txt")
print(fname_pars)
print(fname_chains)
assert fname_chains.is_file() and fname_pars.is_file()

In [None]:
chain_pars = np.loadtxt(fname_pars, dtype=str)
chains = np.loadtxt(fname_chains)
npars = len(chain_pars)
nsamps = len(chains)
print(f"{nsamps=}, {npars=} | {chain_pars}")

idx = np.random.choice(nsamps)
print(idx)
pars = {cp: chains[idx, ii] for ii, cp in enumerate(chain_pars)}
print(pars)


In [None]:
# Choose the appropriate Parameter Space (from 15yr astro analysis)
pspace = holo.librarian.PS_Classic_Phenom_Uniform
# Load SAM and hardening model for desired parameters
# sam, hard = pspace.model_for_params(pars)
sam, hard = pspace.model_for_params({})

In [None]:
redz_final, rate, fisco, hc = sam.rate_chirps(hard, integrate=False)

In [None]:
integ = sam._integrate_event_rate(rate)
print(integ.sum()*YR)

In [None]:
sel = is_lisa_detectable(fisco, hc, fobs, lisa_hc)
integ = sam._integrate_event_rate(sel * rate)
print(integ.sum()*YR)

In [None]:
import glob
files = glob.glob("/Users/lzkelley/Programs/nanograv/holodeck/lisa-calc/output/*.txt")

data = []
for fil in files:
    _dat = np.loadtxt(fil)
    data.append(_dat)

data = np.concatenate(data, axis=0)
print(data.shape)

In [None]:
import kalepy as kale

kale.dist1d(np.log10(10*data[:, 0]), density=True, probability=True)

ax = plt.gca()
ax.set(ylabel="$d\#/d\log_{10}(N/\mathrm{yr})$", xlabel="$\log_{10}(N/\mathrm{yr})$")
plt.show()