In [None]:
# %load ../notebooks/init.ipy
# %reload_ext autoreload
# %autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from datetime import datetime
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot


# --- Holodeck ----
import holodeck as holo
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.WARNING)

In [None]:
def draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None):
    if color is None:
        color = ax._get_lines.get_next_color()

    mm, *conf = np.percentile(gwb, [50, 25, 75], axis=1)
    ax.plot(xx, mm, alpha=0.5, color=color, label=label)
    ax.fill_between(xx, *conf, color=color, alpha=0.15)

    if (nsamp is not None) and (nsamp > 0):
        nsamp_max = gwb.shape[1]
        idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False)
        for ii in idx:
            ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-')
            
    return


# Make some nice plots

In [None]:
NSTEPS = 100

NHARMS = 200
# SAM_SHAPE = (2, 3, 4)
SAM_SHAPE = 20

INIT_ECCEN = 0.999
INIT_SEPA = 1.0 * PC

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=SAM_SHAPE)
dcom = cosmo.comoving_distance(sam.redz).to('Mpc').value
print("evolve")
sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, INIT_ECCEN, INIT_SEPA, NSTEPS)

print("interp and gwb")
gwfobs = np.logspace(-1, 3, 23) / YR

# gwfobs_harms, hc2, ecc_out, tau_out
# faster, cython calculation
gwb_hcn2_test = holo.gravwaves.sam_calc_gwb_single_eccen(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS
)

In [None]:
fig, ax = plot.figax(figsize=[5, 4], ylim=[1e-20, 1e-15])

gwb = np.sqrt(np.sum(gwb_hcn2_test, axis=1))
ax.plot(gwfobs*YR, gwb, 'k--', alpha=0.75, lw=2.0)

nh = 100
# named_colors= ['blue', 'red', 'purple', 'green', 'brown', 'orange']
named_colors= ['blue', 'red', 'orange', 'brown', 'purple']
nnamed = len(named_colors)
colors = mpl.cm.get_cmap('viridis')(np.linspace(0.1, 0.9, nh-nnamed))
for ii in range(nh):
    if ii < nnamed or (((ii+1)%20 == 0) and ii < 100):
        label = f"{ii+1:02d}"
    else:
        label = ""
    # ls = '-' if ii < 10 else '--'
    ls = '-'
    cc = colors[ii-nnamed] if ii >= nnamed else named_colors[ii]
    alpha = 0.35 if ii >= nnamed else 0.65
    ax.plot(gwfobs*YR, np.sqrt(gwb_hcn2_test[:, ii]), alpha=alpha, label=label, color=cc, ls=ls)

plot._twin_hz(ax)
ax.legend(title='harmonic', ncol=2, loc='lower left')
plt.show()

In [None]:
# gwfobs_harms, hc2, ecc_out, tau_out
# faster, cython calculation
gwfobs_harms, gwb_check, ecc_out, tau_out = holo.gravwaves._python_sam_calc_gwb_single_eccen(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS
)

In [None]:
fig, ax = plot.figax(ylim=[1e-20, 1e-12])

ax.plot(gwfobs*YR, np.sqrt(gwb_hcn2_test), alpha=0.5)
gwb = np.sqrt(np.sum(gwb_hcn2_test, axis=1))
ax.plot(gwfobs*YR, gwb, 'k-', alpha=0.5)

ax.plot(gwfobs*YR, np.sqrt(gwb_check), ls='--', alpha=0.5)
gwb = np.sqrt(np.sum(gwb_check, axis=1))
ax.plot(gwfobs*YR, gwb, 'k--', alpha=0.5)

plot._twin_hz(ax)
plt.show()

In [None]:
# `ecc_out` has shape (M, Q, Z, F, N)  =  (mtot, mrat, redz, freqs, harmonics)
use_ecc = ecc_out[..., 0]   # get the n=1 harmonic, corresponding to the eccentricity at these exact frequencies
use_ecc = use_ecc[:, -1, 0, :]

xx = gwfobs * YR

fig, ax = plot.figax()
nmtot = len(sam.mtot)
colors = mpl.cm.get_cmap("Spectral")(np.linspace(0.1, 0.9, nmtot))
for ii, mt in enumerate(sam.mtot):
    mt = np.log10(mt/MSOL)
    lab = ii % (nmtot // 5) == 0
    lab = f"$10^{{{mt:.1f}}}$" if lab else ""
    ax.plot(xx, use_ecc[ii, :], label=lab, color=colors[ii])

plot._twin_hz(ax)
ax.legend()
plt.show()

In [None]:
def interp_xaxis(xnew, xx, yy, axis=-1, invalid=np.nan):
    assert np.ndim(xnew) == 0
    assert np.ndim(xx) == 2
    assert np.ndim(yy) == 1
    assert np.shape(xx)[axis] == np.shape(yy)[0]

    xx = np.moveaxis(xx, axis, -1)
    hi = np.argmax(xx > xnew, axis=-1)
    
    lo = hi - 1
    ylo = yy[lo]
    yhi = yy[hi]
    xlo = xx[np.arange(xx.shape[0]), lo]
    xhi = xx[np.arange(xx.shape[0]), hi]

    bads = (hi == 0) | (xhi < xnew)

    dy = yhi - ylo
    dx = xhi - xlo
    ynew = ylo + (dy/dx) * (xnew - xlo)
    ynew[bads] = invalid
    return ynew

# Compare Different Eccentricities

In [None]:
NSTEPS = 60
NHARMS = 100
# SAM_SHAPE = (2, 3, 4)
SAM_SHAPE = 20
INIT_SEPA = 10.0 * PC

sam = holo.sam.Semi_Analytic_Model(shape=SAM_SHAPE)
dcom = cosmo.comoving_distance(sam.redz).to('Mpc').value
xx = np.logspace(-2, 1, 23)
gwfobs = xx / YR
# fig, ax = plot.figax()

initial_eccens = [0.0, 0.8, 0.9, 0.99, 0.995]
times_evo = []
times_gwb = []
gwb_harms = []
for init_ecc in initial_eccens:
    label = f"{init_ecc:.5f}"
    print(label)
    beg = datetime.now()
    sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, init_ecc, INIT_SEPA, NSTEPS)
    dur = datetime.now() - beg
    times_evo.append(dur.total_seconds())

    beg = datetime.now()
    gwb_hcn2 = holo.gravwaves.sam_calc_gwb_single_eccen(
        gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS
    )
    gwb_harms.append(gwb_hcn2)
    dur = datetime.now() - beg
    times_gwb.append(dur.total_seconds())
    
    # gwb = np.sqrt(np.sum(gwb_hcn2, axis=1))
    # ax.plot(xx, gwb, label=label)
    
# plot._twin_hz(ax)
# plt.show()

In [None]:
fig, ax = plot.figax()
for init_ecc, gwb_hcn2 in zip(initial_eccens, gwb_harms):
    label = f"{init_ecc:.5f}"
    gwb = np.sqrt(np.sum(gwb_hcn2, axis=1))
    ax.plot(xx, gwb, label=label)
    
plot._twin_hz(ax)
ax.legend()
plt.show()

# Discretize Eccentric Population

In [None]:
SHAPE = 90
NREALS = 100
NSTEPS = 123
# NHARMS = 100
NHARMS = 5
SEPA_INIT = 1.0 * PC
gwfobs_edges = utils.nyquist_freqs_edges(3*YR, 0.1*YR)
# gwfobs_edges = utils.nyquist_freqs_edges(1*YR, 0.01*YR)
# gwfobs_edges = utils.nyquist_freqs_edges()
gwfobs = 0.5 * (gwfobs_edges[1:] + gwfobs_edges[:-1])

sam = holo.sam.Semi_Analytic_Model(shape=SHAPE)

In [None]:
hard = holo.hardening.Hard_GW
# hard = holo.hardening.Fixed_Time.from_sam(sam, 2*GYR, exact=True, progress=False)
gwb = sam.gwb(gwfobs_edges, realize=NREALS, hard=hard)

### Start with Circular

In [None]:
ECCEN_INIT = 0.0
sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, ECCEN_INIT, SEPA_INIT, NSTEPS)
gwb_ecc = holo.gravwaves.sam_calc_gwb_single_eccen(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS
)

In [None]:
gwb_ecc_disc = holo.gravwaves.sam_calc_gwb_single_eccen_discrete(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS, nreals=NREALS,
)

In [None]:
fig, ax = plot.figax()
xx = gwfobs*YR

ref = np.sqrt(np.sum(gwb_ecc, axis=1))
gwb_new = np.sqrt(np.sum(gwb_ecc_disc, axis=1))
gwb_old = np.copy(gwb)

ax.plot(xx, ref, 'k-', alpha=0.5)
for gg, lab in zip([gwb_new, gwb_old], ['new', 'old']):
    mm, *conf = np.percentile(gg, [50, 25, 75, 10, 90], axis=1)
    cc, = ax.plot(xx, mm, alpha=0.5, label=lab)
    cc = cc.get_color()
    ax.fill_between(xx, *conf[:2], color=cc, alpha=0.15)
    # ax.fill_between(xx, *conf[2:], color=cc, alpha=0.15)
    nn = gg.shape[1]
    idx = np.random.choice(nn, np.min([10, nn]), replace=False)
    for ii in idx:
        ax.plot(xx, gg[:, ii], color=cc, alpha=0.25, lw=1.0, ls='-')


ax.legend()
plt.show()

In [None]:
gwb.shape, gwb_ecc.shape, gwb_ecc_disc.shape
gwb_old = gwb
gwb_new = np.sqrt(gwb_ecc_disc[:, 1, :])
gwb_ref = np.sqrt(gwb_ecc[:, 1])

fig, ax = plot.figax()
xx = gwfobs * YR

med_new = np.median(gwb_new, axis=1)
med_old = np.median(gwb_old, axis=1)
std_new = np.std(gwb_new/gwb_ref[:, np.newaxis], axis=1)
std_old = np.std(gwb_old/gwb_ref[:, np.newaxis], axis=1)

ax.plot(med_new/gwb_ref, label='med new')
ax.plot(med_old/gwb_ref, label='med old')

ax.plot(std_new, label='std new')
ax.plot(std_old, label='std old')

ax.legend()
plt.show()

### Eccentric

In [None]:
ECCEN_INIT = 0.9
sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, ECCEN_INIT, SEPA_INIT, NSTEPS)
gwb_ecc = holo.gravwaves.sam_calc_gwb_single_eccen(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS
)

In [None]:
gwb_ecc_disc = holo.gravwaves.sam_calc_gwb_single_eccen_discrete(
    gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS, nreals=NREALS,
)

In [None]:
fig, ax = plot.figax()
xx = gwfobs*YR

ref = np.sqrt(np.sum(gwb_ecc, axis=1))
gwb_new = np.sqrt(np.sum(gwb_ecc_disc, axis=1))
gwb_old = np.copy(gwb)

ax.plot(xx, ref, 'k-', alpha=0.5)
for gg, lab in zip([gwb_new, gwb_old], ['new', 'old']):
    draw_gwb(ax, xx, gg, label=lab)
    # mm, *conf = np.percentile(gg, [50, 25, 75, 10, 90], axis=1)
    # cc, = ax.plot(xx, mm, alpha=0.5, label=lab)
    # cc = cc.get_color()
    # ax.fill_between(xx, *conf[:2], color=cc, alpha=0.15)
    # # ax.fill_between(xx, *conf[2:], color=cc, alpha=0.15)
    # nn = gg.shape[1]
    # idx = np.random.choice(nn, np.min([10, nn]), replace=False)
    # for ii in idx:
    #     ax.plot(xx, gg[:, ii], color=cc, alpha=0.25, lw=1.0, ls='-')


ax.legend()
plt.show()

### Vary Eccentricities

In [None]:
SHAPE = 10
NREALS = 40
NSTEPS = 60
NHARMS = 50
SEPA_INIT = 1.0 * PC

# gwfobs_edges = utils.nyquist_freqs_edges(3*YR, 0.1*YR)
# gwfobs_edges = utils.nyquist_freqs_edges(1*YR, 0.01*YR)
gwfobs_edges = utils.nyquist_freqs_edges(20*YR, 0.01*YR)
# gwfobs_edges = utils.nyquist_freqs_edges()
print(gwfobs_edges.shape)
# gwfobs_edges = np.concatenate([gwfobs_edges[:20], gwfobs_edges[20:100:5], gwfobs_edges[100::10]])
gwfobs_edges = np.concatenate([gwfobs_edges[:10], gwfobs_edges[10:40:5], gwfobs_edges[40:100:10], gwfobs_edges[100::20]])
print(gwfobs_edges.shape)
gwfobs = 0.5 * (gwfobs_edges[1:] + gwfobs_edges[:-1])

sam = holo.sam.Semi_Analytic_Model(shape=SHAPE)

In [None]:
# eccen_list = [0.0, 0.9, 0.925, 0.95, 0.975, 0.99, 0.995]
eccen_list = [0.0, 0.9, 0.95, 0.99]
gwb_eccen = []
for ii, ecc in enumerate(eccen_list):
    print(f"{ii:2d}, {ecc:.4f}")
    sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, ecc, SEPA_INIT, NSTEPS)
    print(f"\tevolved ==> gwb")
    _gwb = holo.gravwaves.sam_calc_gwb_single_eccen_discrete(
        gwfobs, sam, sepa_evo, eccen_evo, nharms=NHARMS, nreals=NREALS,
    )
    gwb_eccen.append(_gwb)
    break

In [None]:
gwb_eccen[0]

In [None]:
fig, ax = plot.figax(
    figsize=[6, 4], ylim=[0.5e-16, 1.2e-14], left=0.13, bottom=0.13, right=0.98, top=0.88,
    xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain'
)

for ii, (ecc, gwb) in enumerate(zip(eccen_list, gwb_eccen)):
    gwb_tot = np.sqrt(np.sum(gwb, axis=1))
    lab = f"{ecc:.3f}"
    draw_gwb(ax, gwfobs*YR, gwb_tot, label=lab, nsamp=5)

plot._twin_hz(ax)
plot._draw_plaw(ax, gwfobs*YR, f0=1)
ax.legend(title='$e(a = 1 \mathrm{pc})$')
output_path = Path(holo._PATH_OUTPUT, 'eccen_discretized.png')
fig.savefig(output_path, dpi=500)
plt.show()

In [None]:
eccen = f"{eccen_list[5]:.4f}"
title = f"$e(a=1 \mathrm{{pc}}) = {eccen}$"
gwb = gwb_eccen[5]

fig, ax = plot.figax(
    figsize=[6, 4], ylim=[1e-17, 1e-14], left=0.13, bottom=0.13, right=0.98, top=0.88,
    xlabel='Frequency $[\mathrm{yr}^{-1}]$', ylabel='Characteristic Strain'
)



ave = np.mean(gwb, axis=-1)
med = np.sqrt(np.median(gwb, axis=-1))
med_tot = np.sqrt(np.median(np.sum(gwb, axis=1), axis=-1))
nreals = gwb.shape[-1]

xx = gwfobs * YR

ax.plot(xx, med_tot, 'k--', label='tot')

idx = np.random.choice(nreals, 5, replace=False)
print(idx)

sel = 3
few = med[:, :sel]
for ii in range(sel):
    ax.plot(xx, few[:, ii], alpha=0.95, color='0.75', lw=2.0, zorder=8)
    cc, = ax.plot(xx, few[:, ii], alpha=0.75, label=f'{ii+1:02d}', zorder=10)
    cc = cc.get_color()
    for jj in idx:
        yy = np.sqrt(gwb[:, ii, jj])
        ax.plot(xx, yy, color=cc, alpha=0.25, zorder=9, lw=0.5)

    yy = np.percentile(gwb[:, ii, :], [25, 75], axis=-1)
    ax.fill_between(xx, *np.sqrt(yy), color=cc, alpha=0.25, zorder=9, lw=0.0)
        


others = med[:, sel:]
num = np.shape(others)[1]
cmap = 'Reds_r'
colors = mpl.cm.get_cmap(cmap)(np.linspace(0.01, 0.85, num))
for ii in range(num):
    nn = ii + 1 + sel
    lab = f"{nn:02d}" if nn%10 == 0 or ii == 0 else None
    cc = colors[ii]
    ax.plot(xx, others[:, ii], alpha=0.5, color=cc, label=lab)

ax.legend(title='harmonic', ncol=2)
fig.text(0.99, 1.0, title, ha='right', va='top')
plot._twin_hz(ax)
plot._draw_plaw(ax, gwfobs*YR, f0=1)
output_path = Path(holo._PATH_OUTPUT, f'eccen_harmonics_{eccen}.png')
fig.savefig(output_path, dpi=500)
plt.show()



# Simple Checks

## Look at eccentricity distribution function $g(n,e)$

In [None]:
fig, ax = plot.figax(figsize=[5, 3], yscale='log', ylim=[1e-4, 1e3], xlabel='Harmonic', ylabel='Relative Strain')
tw = ax.twinx(); tw.set(yscale='log', ylim=[1e-3, 1.2])
nn = np.arange(1, 20000+1)
# nn = np.logspace(0, 4, 10)
eccen_list = [0.0, 0.4, 0.8, 0.9, 0.98, 0.995]
fracs = [0.5, 0.95, 0.99]
complete = np.zeros((len(eccen_list), len(fracs)))

for ii, ee in enumerate(eccen_list):
    # strain is proportional to    g(n,e) * (2/n) ^ 2
    gne = holo.utils.gw_freq_dist_func(nn, ee) * np.power(2.0 / nn, 2)
    if ee == 0.0:
        gne[:] = 0.0
        gne[nn == 2] = 1.0
    cc, = ax.plot(nn, gne, label=f'{ee:.4f}')
    cc = cc.get_color()
    zz = np.cumsum(gne)
    zz = zz / zz[-1]
    tw.plot(nn, 1 - zz, ls='--', color=cc, alpha=0.65)

    # Mark the harmonics within which a given fraction of net strain is included
    for jj, pp in enumerate(fracs):
        idx = np.argmax(zz > pp)
        ls = [1, jj*2 + 1]
        if idx == zz.size - 1:
            continue
        xx = nn[idx]
        tw.axvline(xx, color=cc, ls=(0, ls), alpha=0.5)
        complete[ii, jj] = xx

ax.legend(fontsize=8, loc='lower right')    
plt.show()

In [None]:
fig, ax = plot.figax(xlabel='$1 - e$', ylabel='Harmonic')
for jj, frac in enumerate(fracs):
    comp = complete[:, jj]
    ax.plot(1-np.array(eccen_list), comp, label=f"{frac:.3f}", marker='x')
    
ax.legend(fontsize=8)
plt.show()

## Look at eccentricity evolution

In [None]:
eccen_list = [0.0, 0.4, 0.8, 0.9, 0.98, 0.99, 0.999]
xx = np.logspace(1, -3, 100)
sepa = xx * PC
mtot = 3.14e9 * MSOL

fix, ax = plot.figax(figsize=[4, 3], yscale='lin', xlabel='Separation [pc]', ylabel='Eccentricity', xlim=utils.minmax(xx))

colors = mpl.cm.get_cmap('Spectral')(np.linspace(0.1, 0.9, len(eccen_list)))
for ii, e0 in enumerate(eccen_list):
    eccen = np.zeros_like(sepa)
    eccen[0] = e0
    e1 = e0
    for step, a0 in enumerate(sepa[:-1]):
        da = sepa[step+1] - a0
        e0 = e1

        _, e1 = utils.rk4_step(Hard_GW.deda, x0=a0, y0=e0, dx=da)
        e1 = np.clip(e1, 0.0, 1.0 - 1e-6)
        de = e1 - e0
        eccen[step+1] = e1

    ax.plot(xx, eccen, color='0.5', lw=3.5, alpha=0.5)
    ax.plot(xx, eccen, label=f"{eccen[0]:.4f}", color=colors[ii])

ax.xaxis.set_inverted(True)    
ax.legend(fontsize=8)

freqs = utils.kepler_freq_from_sepa(mtot, sepa)
tw = ax.twiny()
tw.xaxis.set_inverted(True)    
lab = f'Frequency $[\mathrm{{yr}}^{{-1}}]$ for $M=10^{{{np.log10(mtot/MSOL):.1f}}}$'
tw.set(xlim=utils.minmax(freqs*YR), xscale='log', xlabel=lab)

plt.show()


# Eccentricity Evolution

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=SAM_SHAPE)

## single, uniform eccentricity

Choose a single $a0$ and $e0$ for all binaries.
The $da/de$ rate is only a function of $a$ and $e$, and thus for a fixed initial $a0$ and $e0$, the eccentricity trajectory (versus separation) is identical for all binaries.  Use a fixed range of separations $a$ for all binaries and numerically integrated to find $e(a)$:

In [None]:
sepa_evo, eccen_evo = holo.sam.evolve_eccen_uniform_single(sam, INIT_ECCEN, INIT_SEPA, 123)

In [None]:
fig, ax = plot.figax(
    xlabel='Separation $[\mathrm{pc}]$',
    yscale='lin', ylabel='Eccentricity'                     
)
ax.plot(sepa_evo/PC, eccen_evo)
ax.xaxis.set_inverted(True)
plt.show()

In [None]:
# We want to find parameters (including eccentricity) as a function of binary frequency.  Convert
# separations to binary rest-frame orbital frequencies
# (M, A)
frst_orb_evo = utils.kepler_freq_from_sepa(sam.mtot[:, np.newaxis], sepa_evo[np.newaxis, :])

In [None]:
fig, ax = plot.figax(
    xlabel='Frequency (orbital, rest-frame) $[\mathrm{yr}^{-1}]$',
    yscale='lin', ylabel='Eccentricity'
)
ax.plot(frst_orb_evo.T*YR, eccen_evo)
plt.show()

In [None]:
# fobs = np.array([0.1, 1.0, 10.0])
fobs = np.logspace(-2, 1, 4) / YR
redz = 0.15
fig, ax = plot.figax(scale='lin')
ebins = np.linspace(0.0, 1.0, 11)

for ff in fobs[::-1]:
    ee = interp_xaxis(ff, frst_orb_evo / (1.0 + redz), eccen_evo, axis=-1)    
    ax.hist(ee, bins=ebins, histtype='step', alpha=0.75, label=f"{ff*YR:05.2f}", lw=2.0)

ax.legend()
plt.show()

From numerically integrated $e(a)$ for each SAM parameter-space bin, interpolate to find eccentricity and hardening rate at all harmonics of target frequencies.

In [None]:
def interp_eccens_uniform(fobs, sam, sepa, forb_rst, eccen, nharms):
    # NOTE: need to check for coalescences and set to zero
    # NOTE: need to check for frequencies below starting separation and set to zero

    assert np.ndim(fobs) == 1
    assert np.ndim(forb_rst) == 2
    assert np.ndim(eccen) == 1
    assert np.shape(forb_rst) == (sam.mtot.size, eccen.size)

    # (F, H)
    fobs_harms = fobs[:, np.newaxis] / np.arange(1, nharms+1)[np.newaxis, :]
    # NOTE: should sort `fobs_harms` into an ascending 1D array to speed up processes

    shape = sam.shape + np.shape(fobs_harms)
    tau_out = np.zeros(shape)
    ecc_out = np.zeros(shape)

    # (Z, F, H)
    # frst_harms = np.zeros(sam.redz.shape + fobs_harms.shape)
    # print(f"{frst_harms.shape=}")
    
    for (aa, bb), fo in np.ndenumerate(fobs_harms):
        # iterate over mtot
        for ii, mt in enumerate(sam.mtot):
            # (E,)
            frst = forb_rst[ii]
            # iterate over redshifts
            for kk, zz in enumerate(sam.redz):
                # interpolate to target frequency
                # this applies the same to all mass-ratios
                fbin_obs = frst / (1.0 + zz)
                ee = np.interp(fo, fbin_obs, eccen, left=np.nan, right=np.nan)
                sa = utils.kepler_sepa_from_freq(mt, fo * (1.0 + zz))

                m1, m2 = utils.m1m2_from_mtmr(mt, sam.mrat)
                tau_out[ii, :, kk, aa, bb] = sa / -utils.gw_hardening_rate_dadt(m1, m2, sa, ee)
                ecc_out[ii, :, kk, aa, bb] = ee

    return fobs_harms, ecc_out, tau_out
    
# choose target observer-frame frequencies
gwfobs = np.logspace(-2, 1, 10) / YR

# interpolate to frequencies and their harmonics
# `fobs_harms` gives the frequency harmonics
#     shape (F, H)  for `F` frequencies (from `fobs`) and `H` harmonics
#     note that these are actually the orbital-frequencies, whose harmonic is the input GW frequency
#     for example: if `fobs_harms[a, 0]` = 20 nHz  then  `fobs_harms[a, 10]` = 2 nHz
#                  so the frequency who's 10th harmonic is 20 nHz is 2 nHz, and the eccentricity at
#                  that (orbital) frequency is `ee[m, q, z, a, 10]`
#     target 
# `ecc_interp` is the eccentricity   for each SAM bin and freq-harmonic, shape (M, Q, Z, F, H)
# `tau_interp` is the hardening-time for each SAM bin and freq-harmonic, shape (M, Q, Z, F, H)
fobs_harms, ecc_interp, tau_interp = interp_eccens_uniform(gwfobs, sam, sepa_evo, frst_orb_evo, eccen_evo, nharms=NHARMS)

Check that interpolated eccentricities match up

In [None]:
fig, ax = plot.figax(
    xlabel='Frequency (orbital, observer-frame) $[\mathrm{yr}^{-1}]$',
    xlim=[1e-3, 1e1], ylim=[-0.1, 1.1],
    yscale='lin', ylabel='Eccentricity'
)

# `forb_rst` has shape (M, E)
# (M, Z, E)
fobs_orb_evo = frst_orb_evo[:, np.newaxis, :] / (1.0 + sam.redz[np.newaxis, :, np.newaxis])

for ii, _ in enumerate(sam.mtot):
    if ii % 3 != 0:
        continue
    for kk, _ in enumerate(sam.redz):
        if kk % 3 != 0:
            continue
        cc, = ax.plot(fobs_orb_evo[ii, kk, :]*YR, eccen_evo)
        cc = cc.get_color()
        
        xx = fobs_harms.flatten()
        yy = ecc_interp[ii, 0, kk].flatten()
        ax.scatter(xx*YR, yy, marker='x', color=cc)
        
plt.show()

check that interpolated hardening rates match up

In [None]:

fig, ax = plot.figax(
    xlabel='Frequency (orbital, observer-frame) $[\mathrm{yr}^{-1}]$',
    xlim=[1e-3, 1e1],
    ylim=[1e5, 1e+26],
    yscale='log', ylabel='Hardening rate da/dt'
)

# `forb_rst` has shape (M, E)
# (M, Z, E)
forb_obs = frst_orb_evo[:, np.newaxis, :] / (1.0 + sam.redz[np.newaxis, :, np.newaxis])

m1, m2 = utils.m1m2_from_mtmr(sam.mtot[:, np.newaxis], sam.mrat[np.newaxis, :])

m1, m2 = [mm[..., np.newaxis] for mm in [m1, m2]]
atemp, etemp = [aa[np.newaxis, np.newaxis, :] for aa in [sepa_evo, eccen_evo]]
# calculate hardening rate for integrated trajectories
dadt_all = utils.gw_hardening_rate_dadt(m1, m2, atemp, etemp)
tau_all = atemp / -dadt_all

for ii, _ in enumerate(sam.mtot):
    if ii%6 != 0:
        continue    
    for jj, _ in enumerate(sam.mrat):
        if jj%6 != 0:
            continue
        for kk, _ in enumerate(sam.redz):
            if kk%6 != 0:
                continue

            cc, = ax.plot(forb_obs[ii, kk, :]*YR, tau_all[ii, jj, :])
            cc = cc.get_color()
            
            xx = fobs_harms.flatten()
            yy = tau_interp[ii, jj, kk].flatten()
            yy = np.fabs(yy)
            ax.scatter(xx*YR, yy, marker='x', color=cc)
        
plt.show()

Calculate GWB assuming circular, ideal calculation

In [None]:
# units of [Mpc^-3]
ndens = sam.static_binary_density
print("ndens = ", utils.stats(ndens), ndens.shape)

# (Z, F)
gw_frst = gwfobs[np.newaxis, :] * (1.0 + sam.redz[:, np.newaxis]) / 2.0
# (M, Z, F)
sa = utils.kepler_sepa_from_freq(sam.mtot[:, np.newaxis, np.newaxis], gw_frst[np.newaxis, :, :])
# (M, Q)
m1, m2 = utils.m1m2_from_mtmr(sam.mtot[:, np.newaxis], sam.mrat[np.newaxis, :])
# (M, Q, 1, 1)
m1, m2 = [mm[:, :, np.newaxis, np.newaxis] for mm in [m1, m2]]
# (M, Q, Z, F)
_dadt = utils.gw_hardening_rate_dadt(m1, m2, sa[:, np.newaxis, :, :], eccen=None)
tau = sa[:, np.newaxis, :, :] / -_dadt
print("tau = ", utils.stats(tau), tau.shape)

# (M, Q, 1, 1)
mchirp = utils.chirp_mass(m1, m2)
# (1, 1, Z, F)
gw_frst = gw_frst[np.newaxis, np.newaxis, :, :]
# (1, 1, Z, 1)
zterm = (1.0 + sam.redz[np.newaxis, np.newaxis, :, np.newaxis])
dcom = cosmo.comoving_distance(zterm - 1).to('Mpc').value
print("dcom = ", utils.stats(dcom), dcom.shape)
dc_term = 4*np.pi*(SPLC/MPC) * (dcom**2)

hs = utils.gw_strain_source(mchirp, dcom*MPC, gw_frst)
print("hs = ", utils.stats(hs), hs.shape)

# (M, Q, Z, F)
integ = ndens[..., np.newaxis] * dc_term * zterm * tau * (hs**2)

# integrate
args = [np.log10(sam.mtot), sam.mrat, sam.redz]
for ii, xx in enumerate(args):
    integ = np.moveaxis(integ, ii, 0)
    dx = np.diff(xx)
    integ = dx * 0.5 * np.moveaxis(integ[1:] + integ[:-1], 0, -1)
    integ = np.moveaxis(integ, -1, ii)
    
gwb = np.sum(integ, axis=(0, 1, 2))
gwb = np.sqrt(gwb)

print(integ.shape, sam.shape, gwb.shape, gwb)

In [None]:
fig, ax = plot.figax()
xx = gwfobs * YR
ax.plot(xx, gwb)
plot._draw_plaw(ax, xx, amp=1e-15, f0=1)
plt.show()

Calculate GWB from interpolated values

In [None]:
# `fobs_harms` (F, H)    `ecc_interp`, `tau_interp` (M, Q, Z, F, H)

# units of [Mpc^-3]
ndens = sam.static_binary_density
# use_ecc = np.zeros_like(ecc_interp)
use_ecc = ecc_interp

# (Z, F, H)
gw_frst = fobs_harms[np.newaxis, :, :] * (1.0 + sam.redz[:, np.newaxis, np.newaxis])
# (M, Z, F, H)
sa = utils.kepler_sepa_from_freq(sam.mtot[:, np.newaxis, np.newaxis, np.newaxis], gw_frst[np.newaxis, :, :, :])
# (M, Q)
m1, m2 = utils.m1m2_from_mtmr(sam.mtot[:, np.newaxis], sam.mrat[np.newaxis, :])
# (M, Q, 1, 1, 1)
m1, m2 = [mm[:, :, np.newaxis, np.newaxis, np.newaxis] for mm in [m1, m2]]
# (M, Q, Z, F, H)
_dadt = utils.gw_hardening_rate_dadt(m1, m2, sa[:, np.newaxis, :, :, :], eccen=use_ecc)
tau = sa[:, np.newaxis, :, :, :] / -_dadt

# (M, Q, 1, 1, 1)
mchirp = utils.chirp_mass(m1, m2)
# (1, 1, Z, F, H)
gw_frst = gw_frst[np.newaxis, np.newaxis, :, :, :]
# (1, 1, Z, 1, 1)
zterm = (1.0 + sam.redz[np.newaxis, np.newaxis, :, np.newaxis, np.newaxis])
dcom = cosmo.comoving_distance(zterm - 1).to('Mpc').value
dc_term = 4*np.pi*(SPLC/MPC) * (dcom**2)

hs = utils.gw_strain_source(mchirp, dcom*MPC, gw_frst)
nharms = np.arange(1, fobs_harms.shape[-1]+1)
hsn2 = (hs**2) * np.square(2 / nharms) * utils.gw_freq_dist_func(nharms, ee=use_ecc, recursive=False)

# (M, Q, Z, F, H)
integ = ndens[..., np.newaxis, np.newaxis] * dc_term * zterm * tau * hsn2

# integrate
args = [np.log10(sam.mtot), sam.mrat, sam.redz]
for ii, xx in enumerate(args):
    integ = np.moveaxis(integ, ii, 0)
    dx = np.diff(xx)
    integ = dx * 0.5 * np.moveaxis(integ[1:] + integ[:-1], 0, -1)
    integ = np.moveaxis(integ, -1, ii)
    
gwb_harm = np.sum(integ, axis=(0, 1, 2))
gwb_harm = np.sqrt(gwb_harm)

print(integ.shape, sam.shape, gwb_harm.shape, gwb_harm)

In [None]:
fig, ax = plot.figax()
xx = gwfobs * YR
ax.plot(xx, gwb, 'b:', lw=4.0, alpha=0.75)
plot._draw_plaw(ax, xx, amp=1e-15, f0=1)

nharms = gwb_harm.shape[-1]
assert nharms == NHARMS
for ii in range(nharms):
    if ii % (nharms // 10) != 0:
        continue
    ax.plot(xx, gwb_harm[:, ii], label=f'{ii+1:03d}')

temp = np.sqrt(np.sum(gwb_harm**2, axis=1))
ax.plot(xx, temp, 'k--')
ax.legend(fontsize=8)
plt.show()

## interpolate evolution integration and calculate GWB at the same time

In [None]:
def calc_gwb_with_interp(gwfobs, sam, sepa, forb_rst_evo, eccen_evo, nharms):
    """
    
    Parameters
    ----------
    gwfobs : (F,) array_like
        Observer-frame frequencies at which to calculate GWB.
    sam : `Semi_Analytic_Model` instance
    forb_rst_evo : (M, E) array_like
        Rest-frame orbital frequencies of binaries, for each total-mass M and evolution step E.
    eccen_evo : (E,) array_like
        Eccentricities at each evolution step.  The same for all binaries, corresponding to fixed
        binary separations for all binaries.
    nharms : int
        Number of harmonics to use in calculating GWB.
    
    """
    
    # NOTE: need to check for coalescences and set to zero
    # NOTE: need to check for frequencies below starting separation and set to zero

    assert np.ndim(gwfobs) == 1
    assert np.ndim(forb_rst_evo) == 2
    assert np.ndim(eccen_evo) == 1
    assert np.shape(forb_rst_evo) == (sam.mtot.size, eccen_evo.size)

    harm_nums = np.arange(1, nharms+1)
    two_over_nh_sq = (2.0 / harm_nums) ** 2

    # (M, Q, Z) units of [Mpc^-3]
    ndens = sam.static_binary_density

    # (F, H)
    gwfobs_harms = gwfobs[:, np.newaxis] / harm_nums[np.newaxis, :]

    # (Z,)
    dcom = cosmo.comoving_distance(sam.redz).to('Mpc').value

    # (Z, F, H) 
    # gw_frst ==> frst_orb_harms
    # gw_frst = gwfobs_harms[np.newaxis, :, :] * (1.0 * sam.redz[:, np.newaxis, np.newaxis])

    # shape will be a tuple of (M, Q, Z, F, H)
    shape = sam.shape + np.shape(gwfobs_harms)
    # setup output arrays with shape (M, Q, Z, F, H)
    hc2 = np.zeros(shape)
    hs2 = np.zeros(shape)
    hsn2 = np.zeros(shape)
    tau_out = np.zeros(shape)
    ecc_out = np.zeros(shape)
    
    gwfr_check = np.zeros(shape[2:])
    
    
    # NOTE: should sort `gwfobs_harms` into an ascending 1D array to speed up processes

    for (aa, bb), gwfo in np.ndenumerate(gwfobs_harms):
        nh = harm_nums[bb]
        # iterate over mtot M
        for ii, mt in enumerate(sam.mtot):
            # (Q,) masses of each component for this total-mass, and all mass-ratios
            m1, m2 = utils.m1m2_from_mtmr(mt, sam.mrat)
            mchirp = utils.chirp_mass(m1, m2)
            
            # (E,) rest-frame orbital frequencies for this total-mass bin
            frst_evo = forb_rst_evo[ii]
            # iterate over redshifts Z
            for kk, zz in enumerate(sam.redz):
                # () scalar
                zterm = (1.0 + zz)
                dc = dcom[kk]   # this is still in units of [Mpc]
                dc_term = 4*np.pi*(SPLC/MPC) * (dc**2)
                # rest-frame frequency corresponding to target observer-frame frequency of GW observations
                gwfr = gwfo * zterm
                if ii > 0:
                    assert gwfr_check[kk, aa, bb] == gwfr
                else:
                    gwfr_check[kk, aa, bb] = gwfr
                sa = utils.kepler_sepa_from_freq(mt, gwfr)

                # interpolate to target (rest-frame) frequency
                # this is the same for all mass-ratios
                # () scalar
                ecc = np.interp(gwfr, frst_evo, eccen_evo, left=np.nan, right=np.nan)
                # ecc_2 = np.interp(sa, sepa[::-1], eccen_evo[::-1], left=np.nan, right=np.nan)

                # da/dt values are negative, get a positive rate
                tau = -utils.gw_hardening_rate_dadt(m1, m2, sa, ecc)
                # convert to timescale
                tau = sa / tau
                # print(f"{m1.shape")
                tau_out[ii, :, kk, aa, bb] = tau
                ecc_out[ii, :, kk, aa, bb] = ecc

                # Calculate the GW spectral strain at each harmonic
                #    see: [Amaro-seoane+2010 Eq.9]
                # () 
                temp = two_over_nh_sq[bb] * utils.gw_freq_dist_func(harm_nums[bb], ee=ecc, recursive=False)
                # (Q,)
                hs2[ii, :, kk, aa, bb] = utils.gw_strain_source(mchirp, dc*MPC, gwfr) ** 2
                hsn2[ii, :, kk, aa, bb] = temp * hs2[ii, :, kk, aa, bb]

                # (Q,)
                hc2[ii, :, kk, aa, bb] = ndens[ii, :, kk] * dc_term * zterm * tau * hsn2[ii, :, kk, aa, bb]

    # integrate
    args = [np.log10(sam.mtot), sam.mrat, sam.redz]
    for ii, xx in enumerate(args):
        hc2 = np.moveaxis(hc2, ii, 0)
        dx = np.diff(xx)
        hc2 = dx * 0.5 * np.moveaxis(hc2[1:] + hc2[:-1], 0, -1)
        hc2 = np.moveaxis(hc2, -1, ii)
        
    hc2 = np.sum(hc2, axis=(0, 1, 2))

    return gwfobs_harms, gwfr_check, hc2, hsn2, hs2, ecc_out, tau_out
    
# choose target observer-frame frequencies
# gwfobs = np.logspace(-2, 1, 10) / YR

# interpolate to frequencies and their harmonics
# `fobs_harms` gives the frequency harmonics
#     shape (F, H)  for `F` frequencies (from `fobs`) and `H` harmonics
#     note that these are actually the orbital-frequencies, whose harmonic is the input GW frequency
#     for example: if `fobs_harms[a, 0]` = 20 nHz  then  `fobs_harms[a, 10]` = 2 nHz
#                  so the frequency who's 10th harmonic is 20 nHz is 2 nHz, and the eccentricity at
#                  that (orbital) frequency is `ee[m, q, z, a, 10]`
#     target 
# `ecc_interp` is the eccentricity   for each SAM bin and freq-harmonic, shape (M, Q, Z, F, H)
# `tau_interp` is the hardening-time for each SAM bin and freq-harmonic, shape (M, Q, Z, F, H)
# fobs_harms_2, gwfr_check, gwb_2, hsn2_check, hs2_check, ecc_interp_2, tau_interp_2 = calc_gwb_with_interp(gwfobs, sam, sepa, forb_rst, eccen, nharms=NHARMS)
# gwb_2 = np.sqrt(gwb_2)

In [None]:
# sam_big = holo.sam.Semi_Analytic_Model()
# print("evolve")
# sepa_evo_big, eccen_evo_big = sam_evolve_eccen_uniform_single(sam_big, INIT_ECCEN, INIT_SEPA)
# print("interp and gwb")
# hc2_big = holo.gravwaves.sam_calc_gwb_0(gwfobs, sam_big, sepa_evo_big, eccen_evo_big, nharms=100)

In [None]:
# NHARMS = 23
# NHARMS = 200

# other

In [None]:
def dynamic_binary_number_eccentricity(sam, fobs_orb, eccen_init, eccen_init_sepa=10.0*PC):
    """

    d^4 N / [dlog10(M) dq dz dln(X)    <===    d^3 n / dlog10(M) dq dz

    d^2 N / dz dln(f_r) = (dn/dz) * (dt/d ln f_r) * (dz/dt) * (dVc/dz)
                        = (dn/dz) * (f_r / [df_r/dt]) * 4 pi c D_c^2 (1+z)
                        = `dens`  *      `tau`        *   `cosmo_fact`

    """

    eccen_init = np.asarray(eccen_init) * np.ones(sam.shape)

    fobs_orb = np.asarray(fobs_orb)
    xsize = fobs_orb.size
    edges = sam.edges + [fobs_orb, ]

    # shape: (M, Q, Z)
    dens = sam.static_binary_density   # d3n/[dz dlog10(M) dq]  units: [Mpc^-3]

    # (Z,) comoving-distance in [Mpc]
    dc = cosmo.comoving_distance(sam.redz).to('Mpc').value

    # (Z,) this is `(dVc/dz) * (dz/dt)` in units of [Mpc^3/s]
    cosmo_fact = 4 * np.pi * (SPLC/MPC) * np.square(dc) * (1.0 + sam.redz)

    # (M, Q) calculate chirp-mass
    mchirp = utils.chirp_mass_mtmr(sam.mtot[:, np.newaxis], sam.mrat[np.newaxis, :])
    # (M, Q, 1, 1) make shape broadcastable for later calculations
    mchirp = mchirp[..., np.newaxis, np.newaxis]

    # (M*Q*Z,) 1D arrays of each total-mass, mass-ratio, and redshift
    mt, mr, rz = [gg.ravel() for gg in sam.grid]
    e0 = eccen_init.ravel()

    # Convert from observer-frame orbital freq, to rest-frame orbital freq
    # (X, M*Q*Z)
    frst_orb = fobs_orb[:, np.newaxis] * (1.0 + rz[np.newaxis, :])
    sa = utils.kepler_sepa_from_freq(mt[np.newaxis, :], frst_orb)


    hard = holo.hardening.Hard_GW
    # (X, M*Q*Z), hardening rate, negative values, units of [cm/sec]
    dadt = hard.dadt(mt[np.newaxis, :], mr[np.newaxis, :], sa)

    # Calculate `tau = dt/dlnf_r = f_r / (df_r/dt)`
    # dfdt is positive (increasing frequency)
    dfdt, frst_orb = utils.dfdt_from_dadt(dadt, sa, frst_orb=frst_orb)
    tau = frst_orb / dfdt


    # convert `tau` to the correct shape, note that moveaxis MUST happen _before_ reshape!
    # (X, M*Q*Z) ==> (M*Q*Z, X)
    tau = np.moveaxis(tau, 0, -1)
    # (M*Q*Z, X) ==> (M, Q, Z, X)
    tau = tau.reshape(dens.shape + (xsize,))

    # (M, Q, Z) units: [1/s] i.e. number per second
    dnum = dens * cosmo_fact
    # (M, Q, Z, X) units: [] unitless, i.e. number
    dnum = dnum[..., np.newaxis] * tau

    bads = ~np.isfinite(tau)
    if np.any(bads):
        log.warning(f"Found {utils.frac_str(bads)} invalid hardening timescales.  Setting to zero densities.")
        dnum[bads] = 0.0

    return edges, dnum

In [None]:
NUM = 1000
mt = MSOL * (10.0 ** np.random.uniform(6, 10, NUM))
mr = (10.0 ** np.random.uniform(-2, 0, NUM))
rz = zmath.random_power([0, 1], +2, NUM) + 0.01

kale.dist1d(rz)
plt.show()