In [None]:
import numpy as np
import scipy as sp
import scipy.stats
import matplotlib.pyplot as plt

import holodeck2 as holo
from holodeck2 import physics, utils
from holodeck2.constants import YR

In [None]:
sam = holo.sam.SAM()

In [None]:
edges_3d, ndens_3d = sam.number_density_3d()

In [None]:
fobs_gw_cents, fobs_gw_edges = physics.pta_freqs()
cents_4d, numb_4d = holo.sam.number_expect_4d_gwonly_instant(fobs_gw_edges, edges_3d, ndens_3d)

In [None]:
gwb = holo.sam.gws_from_number_expect_instant(fobs_gw_edges, cents_4d, numb_4d, realize=100)

In [None]:
print(f"{gwb.shape=}")
gwb = np.sqrt(np.sum(gwb, axis=(1, 2, 3)))
print(f"{gwb.shape=}")

plt.loglog(fobs_gw_cents*YR, gwb)
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[12, 5], ncols=3)
labels = ['m1', 'm2', 'z']
ymax = np.sum(ndens_3d)
ymin = ymax / 1e10

for ii, ax in enumerate(axes):
    xx = edges_3d[ii]
    ax.set(
        xscale='log', xlabel=labels[ii],
        yscale='log', ylabel='Density', ylim=[ymin, ymax],
    )

    margin = [0, 1, 2]
    margin.pop(ii)
    yy = np.sum(ndens_3d, axis=tuple(margin))

    ax.plot(xx, yy)

plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[12, 10], ncols=3, nrows=3)
labels = ['m1', 'm2', 'z']
ymax = np.sum(ndens_3d)
ymin = ymax / 1e10

for (ii, jj), ax in np.ndenumerate(axes):
    if ii < jj:
        ax.set_visible(False)
        continue

    ax.set(
        xscale='log',
        # xlabel=labels[ii],
        yscale='log',
        # ylim=[ymin, ymax],
        # ylabel='Density',
    )


    xx = edges_3d[jj]

    # ---- 1D
    if ii == jj:
        margin = [0, 1, 2]
        margin.pop(jj)
        yy = np.sum(ndens_3d, axis=tuple(margin))

        ax.plot(xx, yy)
        ax.set(ylim=[ymin, ymax])

    # ---- 2D
    else:
        yy = edges_3d[ii]
        mesh = np.meshgrid(xx, yy, indexing='ij')

        margin = [0, 1, 2]
        margin.pop(ii)  # ii > jj, so do ii first
        margin.pop(jj)
        zz = np.sum(ndens_3d, axis=tuple(margin))

        ax.pcolormesh(*mesh, np.log10(zz), shading='gouraud')



plt.show()