In [None]:
from pathlib import Path

import numpy as np
import scipy as sp
import scipy.stats
import matplotlib.pyplot as plt

import holodeck2 as holo
from holodeck2 import physics, utils, cosmo
import holodeck2.constants
from holodeck2.constants import YR

In [None]:
def save_cosmo_grid(cosmo, num_redz=2000):
    fname = "cosmology.txt"
    fname = Path(fname).absolute()
    print(f"{fname=}")

    redz = np.logspace(-5, 3, num_redz)
    scafa = cosmo.scale_factor(redz)
    dcom = cosmo.comoving_distance(redz).value
    vcom = cosmo.comoving_volume(redz).value
    tlook = cosmo.lookback_time(redz).to('Myr').value
    efunc = cosmo.efunc(redz)
    legend = (
        f"# H0={cosmo.H0.value:.4e} Om0={cosmo.Om0:.4e} Ob0={cosmo.Ob0:.4e} Ol0={cosmo.Ode0:.4e} | N={num_redz}"
    )
    print(legend)
    with open(fname, 'w') as fout:
        fout.write(legend + "\n")
        fout.write("# redshift scafa dcom[Mpc] vcom[Mpc3] tlook[Myr] efunc\n")
        for i in range(num_redz):
            fout.write(f"{redz[i]:.8e} {scafa[i]:.8e} {dcom[i]:.8e} {vcom[i]:.8e} {tlook[i]:.8e} {efunc[i]:.8e}\n")

    print(f"Saved to file '{fname}'")
    return tlook

tlook = save_cosmo_grid(cosmo)

In [None]:
redz = np.power(10.0, sorted(np.random.uniform(-5, 3, 10)))
test_data = {"redz": redz}
test_data["scafa"] = cosmo.scale_factor(redz)
test_data["dcom"] = cosmo.comoving_distance(redz).value
test_data["vcom"] = cosmo.comoving_volume(redz).value
test_data["tlook"] = cosmo.lookback_time(redz).to('Myr').value
test_data["efunc"] = cosmo.efunc(redz)

for kk, vv in test_data.items():
    msg = [f"{v:.8e}".format(v) for v in vv]
    msg = ", ".join(msg)
    kk = f"{kk}[]"
    msg = f"double {kk:8s} = {{{msg}}}"
    print(msg)

In [None]:
for kk in dir(holodeck2.constants):
    if kk.startswith('__'): continue
    val = getattr(holodeck2.constants, kk)
    print(f"constexpr float {kk:10s} = {val:.8e}")

In [None]:
sam = holo.sam.SAM()

In [None]:
edges_3d, ndens_3d = sam.number_density_3d()

In [None]:
fobs_gw_cents, fobs_gw_edges = physics.pta_freqs()
cents_4d, numb_4d = holo.sam.number_expect_4d_gwonly_instant(fobs_gw_edges, edges_3d, ndens_3d)

In [None]:
gwb = holo.sam.gws_from_number_expect_instant(fobs_gw_edges, cents_4d, numb_4d, realize=100)

In [None]:
print(f"{gwb.shape=}")
gwb = np.sqrt(np.sum(gwb, axis=(1, 2, 3)))
print(f"{gwb.shape=}")

plt.loglog(fobs_gw_cents*YR, gwb)
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[12, 5], ncols=3)
labels = ['m1', 'm2', 'z']
ymax = np.sum(ndens_3d)
ymin = ymax / 1e10

for ii, ax in enumerate(axes):
    xx = edges_3d[ii]
    ax.set(
        xscale='log', xlabel=labels[ii],
        yscale='log', ylabel='Density', ylim=[ymin, ymax],
    )

    margin = [0, 1, 2]
    margin.pop(ii)
    yy = np.sum(ndens_3d, axis=tuple(margin))

    ax.plot(xx, yy)

plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[12, 10], ncols=3, nrows=3)
labels = ['m1', 'm2', 'z']
ymax = np.sum(ndens_3d)
ymin = ymax / 1e10

for (ii, jj), ax in np.ndenumerate(axes):
    if ii < jj:
        ax.set_visible(False)
        continue

    ax.set(
        xscale='log',
        # xlabel=labels[ii],
        yscale='log',
        # ylim=[ymin, ymax],
        # ylabel='Density',
    )


    xx = edges_3d[jj]

    # ---- 1D
    if ii == jj:
        margin = [0, 1, 2]
        margin.pop(jj)
        yy = np.sum(ndens_3d, axis=tuple(margin))

        ax.plot(xx, yy)
        ax.set(ylim=[ymin, ymax])

    # ---- 2D
    else:
        yy = edges_3d[ii]
        mesh = np.meshgrid(xx, yy, indexing='ij')

        margin = [0, 1, 2]
        margin.pop(ii)  # ii > jj, so do ii first
        margin.pop(jj)
        zz = np.sum(ndens_3d, axis=tuple(margin))

        ax.pcolormesh(*mesh, np.log10(zz), shading='gouraud')



plt.show()