In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import healpy as hp
import kalepy as kale

import holodeck as holo
from holodeck import detstats, anisotropy, plot, utils
from holodeck.constants import YR, MSOL, GYR

# Sato-Polito Process fow SWGB Modeling

1) Predict number of binaries in each bin in redshift, mass, and frequency

2) Obtain number of binaries for each realization in each bin by Poisson sampling

3) Uniformly sample angular positions (in $\cos \theta$ and $\phi$)

4) Get strain amplitude of each using eq. (7)
$$ h^2(z, \mathcal{M}, f) = \frac{32 \pi^{4/3}}{5 c^8} \frac{(1+z)^{10/3}}{d^2_L (z)} (\mathcal{G} \mathcal{M})^{10/3} f^{4/3} $$

5) Plug strain amplitude into Eq. (17)

$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{f}{4\pi \Delta f}   \int d \vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^2 (f,\vec{\theta})   \bigg)^2 
+ \big( \frac{f}{4 \pi \Delta f}\big)^2 \int d\vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^4 (f, \vec{\theta})
$$


## Applying to our methods


$\frac{d N_{\Delta f}}{d \vec{\theta}}$ is the number in that frequency bin, at that angle

$ h_c^2 = \frac{f}{df} h_s^2 $ so this is the same as
$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{1}{4\pi}   \int d \vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h_c^2 (f,\vec{\theta})   \bigg)^2 
+ \bigg( \frac{1}{4 \pi}\bigg)^2 \int d\vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h_c^4 (f, \vec{\theta} )
$$

$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{1}{4\pi}   \int d \vec{\theta} h_c^2 (f)   \bigg)^2 
+ \bigg( \frac{1}{4 \pi} \bigg)^2 \int d\vec{\theta} h_c^4 (f )
$$

We already have a characteristic strain at each pixel, having previously done $\int \frac{dN_{\Delta f}}{dM dq dz} h_c^2(M,q,z) dM dq dz$ and distributing this total strain among pixels (using an evenly spread background and individual sources at individual pixels). Now, we just need to integrate that $h_c$ over all the pixels.

$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{1}{4\pi}   \int d \vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h_c^2 (f,\vec{\theta})   \bigg)^2 
+ \bigg( \frac{1}{4 \pi}\bigg)^2 \int d\vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h_c^4 (f, \vec{\theta} )
$$

$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{1}{4\pi}   \int d \vec{\theta} h_c^2 (f)   \bigg)^2 
+ \bigg( \frac{1}{4 \pi} \bigg)^2 \int d\vec{\theta} h_c^4 (f )
$$
where $\vec{\theta}$ is now just position bin, not also M,q,z

Need to convert $ d\vec{\theta}$ to $d (\mathrm{pixel})$


$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{1}{4\pi}   \sum  d\theta d\phi  h_c^2 (f)   \bigg)^2 
+ \bigg( \frac{1}{4 \pi} \bigg)^2 \sum d\theta d\phi h_c^4 (f )
$$

The 1/4pi, but this doesn't matter if I normalize C_l by C_0.

Pixel area is given in square degrees, dA = dtheta dphi. It is the same for all pixels, so can be pulled out of the integral/sum

$$ \frac{C_\ell (f)}{C_0 (f)} = \frac{\delta_{\ell 0}\delta_{m0} \bigg( \sum_\mathrm{pixels}   A_\mathrm{pix} h_c^2 (f)   \bigg)^2 
+ \sum_\mathrm{pixels}  A_\mathrm{pix} h_c^4 (f )}{\delta_{0 0}\delta_{00} \bigg( \sum_\mathrm{pixels}   A_\mathrm{pix} h_c^2 (f)   \bigg)^2 
+ \sum_\mathrm{pixels}  A_\mathrm{pix} h_c^4 (f )}
$$
simplifying for $l = l> 0$
$$ \frac{C_{\ell>0} (f)}{C_0 (f)} = \frac{\sum_\mathrm{pixels}  h_c^4 (f )}{ A_\mathrm{pix} \bigg( \sum_\mathrm{pixels}  h_c^2 (f)   \bigg)^2 
+ \sum_\mathrm{pixels}   h_c^4 (f )}
$$ 

$$ \frac{C_{\ell>0} (f)}{C_0 (f)} = \sum_\mathrm{pixels} \frac{ h_c^4 (f )}{ A_\mathrm{pix} \big(  h_c^2 (f)   \big)^2 
+  h_c^4 (f )}
$$ 


$$ \frac{C_{\ell>0} (f)}{C_0 (f)} = \sum_\mathrm{pixels} \frac{ h_c^4 (f )}{ A_\mathrm{pix} h_c^4 (f)  
+  h_c^4 (f )}
$$ 

There should be an A_pixel in the denom of every integral.


# Set Up
### Read in Strain Data

In [None]:
sspath = '/Users/emigardiner/GWs/holodeck/output/2023-05-16-mbp-ss19_uniform05A_n1000_r50_d20_f30_l2000_p0/'
hdfname = sspath+'ss_lib.hdf5'
ssfile = h5py.File(hdfname, 'r')
print(list(ssfile.keys()))
hc_ss = ssfile['hc_ss'][...]
hc_bg = ssfile['hc_bg'][...]
fobs = ssfile['fobs'][:]
dfobs = ssfile['dfobs'][:]
ssfile.close()

shape = hc_ss.shape
nsamps, nfreqs, nreals, nloudest = shape[0], shape[1], shape[2], shape[3]
print('N,F,R,L =', nsamps, nfreqs, nreals, nloudest)


### Get best sample

In [None]:
hc_ref15_10yr = 11.2*10**-15 
nsort, fidx, hc_ref15 = detstats.rank_samples(hc_ss, hc_bg, fobs, hc_ref=hc_ref15_10yr, ret_all=True)
print(hc_ref15)

### Get healpix map

In [None]:
nside=32
moll_hc = anisotropy.healpix_map(hc_ss[nsort[0]], hc_bg[nsort[0]], nside=nside)

In [None]:
rr=0
hp.mollview(moll_hc[rr,fidx], title='Sample %d, Realization %d, $f$=%.2f yr$^{-1}$' % (nsort[0], rr, fobs[fidx]*YR))

# Calculate Anisotropy

$$ \frac{C_{\ell>0} (f)}{C_0 (f)} = \sum_\mathrm{pixels} \frac{ h_c^4 (f )}{ A_\mathrm{pix} h_c^4 (f)  
+  h_c^4 (f )}
$$ 

In [None]:
print(moll_hc.shape)

In [None]:
def ClC0_analytic(moll_hc):
    """ Calculate Cl/C0 for l>0 using Sato-Polito Eq. 17, modified
     to use characteristic strains from  Poisson sampled parameter bins 
     already calculated and placed at random pixels.

     Parameters
     ----------
     moll_hc : (F,R,npix) 
     """
    
    nside = hp.npix2nside(len(moll_hc[0,0]))
    area = hp.nside2pixarea(nside)

    sum_term = moll_hc**4 / (area*moll_hc**4 + moll_hc**4)
    ClC0 = np.sum(sum_term, axis=-1) # sum over pixels

    return ClC0

print(hp.nside2pixarea(nside))

In [None]:
ClC0 = ClC0_analytic(moll_hc)
print(ClC0.shape) # F, R

In [None]:
nshow=20

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='$C_{\ell>0}/C_0$')
xx = fobs*YR
for rr in range(nshow):
    yy = ClC0[:,rr]
    ax.plot(xx, yy, color='tab:orange')
# ax.set_ylim(10**-6, 10**0)
plot._twin_hz(ax, nano=False)

# Back to dN/bin
## Fresh Set Up

$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{f}{4\pi \Delta f}   \int d \vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^2 (f,\vec{\theta})   \bigg)^2 
+ \big( \frac{f}{4 \pi \Delta f}\big)^2 \int d\vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^4 (f, \vec{\theta})
$$


In [None]:
dur, cad = 16.03*YR, 0.2*YR
fobs_gw_cents = utils.nyquist_freqs(dur,cad)
fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
# sam = holo.sam.Semi_Analytic_Model()
sam = holo.sam.Semi_Analytic_Model(shape=20)  # faster version


In [None]:
fobs_orb_cents = fobs_gw_cents/2.0
fobs_orb_edges = fobs_gw_edges/2.0
# hard = holo.hardening.Hard_GW()
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)
edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)

if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
    hard_name = 'Fixed Time'
elif isinstance(hard, holo.hardening.Hard_GW):
    hard_name = 'GW Only'

In [None]:
hs = holo.gravwaves.strain_amp_from_bin_edges_redz(edges, redz_final)
print(hs.shape) # (M,Q,Z,F)
print(number.shape) # (M,Q,Z,F)

In [None]:
def Cl_analytic_from_num(fobs_orb_edges, number, hs, realize = False):
    """ Calculate Cl using Eq. (17) of Sato-Polito & Kamionkowski
    Parameters
    ----------
    fobs_orb_edges : (F,) 1Darray
        Observed orbital frequency bin edges
    hs : (M,Q,Z,F) NDarray
        Strain amplitude of each M,q,z bin
    number : (M,Q,Z,F) NDarray
        Number of sources in each M,q,z, bin
    
    """

    df = np.diff(fobs_orb_edges)                 #: frequency bin widths
    fc = kale.utils.midpoints(fobs_orb_edges)    #: use frequency-bin centers for strain (more accurate!)

    # df = fobs_orb_widths[np.newaxis, np.newaxis, np.newaxis, :] # (M,Q,Z,F) NDarray
    # fc = fobs_orb_cents[np.newaxis, np.newaxis, np.newaxis, :]  # (M,Q,Z,F) NDarray

    delta_term = (
        fc / (4*np.pi * df) * np.sum(number*hs**2, axis=(0,1,2))
    )**2

    Cl = (
        (fc / (4*np.pi*df))**2 * np.sum(number*hs**4, axis=(0,1,2))
    )

    C0 = Cl + delta_term

    return C0, Cl

C0, Cl = Cl_analytic_from_num(fobs_orb_edges, number, hs)

Need dN/dbin not just bin, so do this strain multiplication INSIDE the number function

In [None]:
nshow=20

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
xx = fobs_gw_cents
yy = Cl/C0 # (F,)
ax.plot(xx, yy, color='tab:orange', label='Eq (17) with cython number, full sam')
# ax.set_ylim(10**-6, 10**0)
plot._twin_yr(ax, nano=False)
ax.legend()

In [None]:
def Cl_analytic_from_num(fobs_orb_edges, number, hs, realize = False):
    """ Calculate Cl using Eq. (17) of Sato-Polito & Kamionkowski
    Parameters
    ----------
    fobs_orb_edges : (F,) 1Darray
        Observed orbital frequency bin edges
    hs : (M,Q,Z,F) NDarray
        Strain amplitude of each M,q,z bin
    number : (M,Q,Z,F) NDarray
        Number of sources in each M,q,z, bin
    realize : boolean or integer
        How many realizations to Poisson sample.
    
    Returns
    -------
    C0 : (F,R) or (F,) NDarray
        C_0 
    Cl : (F,R) or (F,) NDarray
        C_l>0 for arbitrary l using shot noise approximation
    """

    df = np.diff(fobs_orb_edges)                 #: frequency bin widths
    fc = kale.utils.midpoints(fobs_orb_edges)    #: frequency-bin centers 

    # df = fobs_orb_widths[np.newaxis, np.newaxis, np.newaxis, :] # (M,Q,Z,F) NDarray
    # fc = fobs_orb_cents[np.newaxis, np.newaxis, np.newaxis, :]  # (M,Q,Z,F) NDarray


    # Poisson sample number in each bin
    if utils.isinteger(realize):
        number = np.random.poisson(number[...,np.newaxis], 
                                size = (number.shape + (realize,)))
        df = df[...,np.newaxis]
        fc = fc[...,np.newaxis]
        hs = hs[...,np.newaxis]
    elif realize is True:
        number = holo.gravwaves.poisson_as_needed(number)



    delta_term = (fc/(4*np.pi*df) * np.sum(number*hs**2, axis=(0,1,2)))**2

    Cl = (fc/(4*np.pi*df))**2 * np.sum(number*hs**4, axis=(0,1,2))

    C0 = Cl + delta_term

    return C0, Cl

C0, Cl = Cl_analytic_from_num(fobs_orb_edges, number, hs, realize=False)
C0_many, Cl_many = Cl_analytic_from_num(fobs_orb_edges, number, hs, realize=20)

In [None]:
nshow=20

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
xx = fobs_gw_cents
yy = Cl/C0 # (F,)
ax.plot(xx, yy, color='tab:grey', label='Eq (17) with cython number, full sam')
rr = 0
ax.plot(xx, Cl_many[:,rr]/C0_many[:,rr], color='tab:orange', alpha=0.25, label = 'Poisson number/bin realization')
for rr in range(1, nshow):
    ax.plot(xx, Cl_many[:,rr]/C0_many[:,rr], color='tab:orange', alpha=0.25)
# ax.set_ylim(10**-6, 10**0)
plot._twin_yr(ax, nano=False)
ax.legend()

In [None]:
print(Cl_many.shape)

# Plot everything

In [None]:
nshow=20



def draw_analytic(ax, Cl, C0, fobs_gw_cents, color='tab:orange', label='Eq. 17 analytic', lw=2):
    xx = fobs_gw_cents
    yy = Cl/C0 # (F,)
    ax.plot(xx, yy, color=color, lw=lw, label=label, linestyle='dashdot')

def draw_reals(ax, Cl_many, C0_many, fobs_gw_cents,  color='tab:orange', label= 'Poisson number/bin realization',
                show_ci=False, show_reals=True, show_median=False):
    xx = fobs_gw_cents
    yy = Cl_many/C0_many # (F,R)
    if show_median:
        ax.plot(xx, np.median(yy[:,:], axis=-1), color=color) #, label='median of samples, $l=%d$' % ll)     
    if show_ci:
        for pp in [50, 98]:
            percs = pp/2
            percs = [50-percs, 50+percs]
            ax.fill_between(xx, *np.percentile(yy[:,:], percs, axis=-1), color=color, alpha=0.1)
    if show_reals:
        rr = 0
        ax.plot(xx, yy[:,rr], color=color, alpha=0.15, linestyle='-', 
                label = label)
        for rr in range(1, np.min([nshow, len(Cl_many[0])])):
            ax.plot(xx, yy[:,rr], color=color, alpha=0.25, linestyle='-')

def draw_spk(ax, label='SP & K Rough Estimate'):
    spk_xx= np.array([3.5*10**-9, 1.25*10**-8, 1*10**-7]) /YR
    spk_yy= np.array([1*10**-5, 1*10**-3, 1*10**-1])
    ax.plot(spk_xx * YR, spk_yy, label=label, color='limegreen', ls='--')

def draw_bayes(ax, lmax, colors = ['k', 'b', 'r', 'g', 'c', 'm']):
    xx_Nihan = np.array([2.0, 4.0, 5.9, 7.9, 9.9]) *10**-9 # Hz
    
    Cl_nihan = np.array([
    [0.20216773, 0.14690035, 0.09676646, 0.07453352, 0.05500382, 0.03177427],
    [0.21201336, 0.14884939, 0.10545698, 0.07734305, 0.05257189, 0.03090662],
    [0.20840993, 0.14836757, 0.09854803, 0.07205384, 0.05409881, 0.03305785],
    [0.19788951, 0.15765126, 0.09615489, 0.07475364, 0.0527356 , 0.03113331],
    [0.20182648, 0.14745265, 0.09681202, 0.0746824 , 0.05503161, 0.0317012 ]])
    for ll in range(lmax):
        ax.plot(xx_Nihan, Cl_nihan[:,ll]/Cl_nihan[:,0], 
                    label = '$l=%d$' % (ll+1), 
                color=colors[ll], marker='o', ms=8)
        
def draw_sim(ax, xx, Cl_best, lmax, nshow, show_ci=True, show_reals=True):

    yy = Cl_best[:,:,:,1:]/Cl_best[:,:,:,0,np.newaxis] # (B,F,R,l)
    yy = np.median(yy, axis=-1) # (B,F,l) median over realizations

    colors = ['k', 'b', 'r', 'g', 'c', 'm']
    for ll in range(lmax):
        ax.plot(xx, np.median(yy[:,:,ll], axis=0), color=colors[ll]) #, label='median of samples, $l=%d$' % ll)
        if show_ci:
            for pp in [50, 98]:
                percs = pp/2
                percs = [50-percs, 50+percs]
                ax.fill_between(xx, *np.percentile(yy[:,:,ll], percs, axis=0), alpha=0.1, color=colors[ll])
        if show_reals:
            for bb in range(0,nshow):
                # if ll==0 and bb==0:
                #     label = "individual best samples, median of realizations"
                # else: 
                label=None
                ax.plot(xx, yy[bb,:,ll], color=colors[ll], linestyle=':', alpha=0.1,
                                 linewidth=1, label=label)

def plot_ClC0():
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
    draw_analytic(ax, Cl, C0, fobs_gw_cents)
    draw_reals(ax, Cl_many, C0_many, fobs_gw_cents)
    draw_spk(ax)
    draw_bayes(ax, lmax=6)
    # ax.set_ylim(10**-6, 10**0)
    plot._twin_yr(ax, nano=False)
    fig.legend(bbox_to_anchor=(0,-0.15), loc='upper left', bbox_transform = ax.transAxes, ncols=3)
    return fig

fig = plot_ClC0()

In [None]:
# sph_harm_file = np.load('/Users/emigardiner/GWs/holodeck/output/brc_output/ss51-2023-05-22_uniform_07a_n1000_r100_f40_l2000/anisotropy/sph_harm_lmax6_nside32_nbest100.npz')

# # load ss info
# shape = sph_harm_file['ss_shape']
# nsamps, nfreqs, nreals, nloudest = shape[0], shape[1], shape[2], shape[3]
# fobs = sph_harm_file['fobs']

# # load ranking info
# nsort = sph_harm_file['nsort']
# fidx = sph_harm_file['fidx']
# hc_tt = sph_harm_file['hc_tt']
# hc_ref15 = sph_harm_file['hc_ref15']

# # load harmonics info
# nside = sph_harm_file['nside']
# lmax  = sph_harm_file['lmax']
# moll_hc_best = sph_harm_file['moll_hc_best']
# Cl_best = sph_harm_file['Cl_best']
# nbest = len(moll_hc_best)

# sph_harm_file.close()

In [None]:
def plot_ClC0():
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
    draw_analytic(ax, Cl, C0, fobs_gw_cents)
    draw_reals(ax, Cl_many, C0_many, fobs_gw_cents)
    draw_spk(ax)
    draw_bayes(ax, lmax=6)
    # draw_sim(ax, fobs_gw_cents, Cl_best, lmax, nshow=10)
    # ax.set_ylim(10**-6, 10**0)
    plot._twin_yr(ax, nano=False)
    ax.set_xlim(fobs[0]- 10**(-10), 1/YR)

    fig.legend(bbox_to_anchor=(0,-0.15), loc='upper left', bbox_transform = ax.transAxes, ncols=3)
    return fig

fig = plot_ClC0()         

# From Dnum/dens

* dens = d^3 n / [dlog10M dq dz] in units of [Mpc^-3] 
= number density of binaries, per unit redshift, mass-ratio, and log10 of mass

* dnum = d^4N / dlog10M dq dz dlnf

* number = dN /dlnf


$$ C_\ell (f) = \delta_{\ell 0}\delta_{m0} \bigg( \frac{f}{4\pi \Delta f}   \int d \vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^2 (f,\vec{\theta})   \bigg)^2 
+ \big( \frac{f}{4 \pi \Delta f}\big)^2 \int d\vec{\theta} \frac{d N_{\Delta f}}{d \vec{\theta}} h^4 (f, \vec{\theta})
$$


## Cl_analytic_from_dnum

In [None]:
def strain_amp_at_bin_edges(edges, redz=None):
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)


    if redz is not None:
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else: 
        redz = edges[2][np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value

    # ---- calculate GW strain ----
    mt = (edges[0])
    mr = (edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs_edges = utils.gw_strain_source(mc, dc, fr)
    return hs_edges

def strain_amp_at_bin_centers_redz(edges, redz=None):
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # redshifts are defined across 4D grid, shape (M, Q, Z, Fc)
    #    where M, Q, Z are edges and Fc is frequency centers
    # find midpoints of redshifts in M, Q, Z dimensions, to end up with (M-1, Q-1, Z-1, Fc)
    if redz is not None:
        for dd in range(3):
            redz = np.moveaxis(redz, dd, 0)
            redz = kale.utils.midpoints(redz, axis=0)
            redz = np.moveaxis(redz, 0, dd)
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else:
        redz = kale.utils.midpoints(edges[2])[np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value


    # ---- calculate GW strain ----
    mt = kale.utils.midpoints(edges[0])
    mr = kale.utils.midpoints(edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs = utils.gw_strain_source(mc, dc, fr)
    return hs


def Cl_analytic_from_dnum(edges, dnum, redz=None, realize=False):
    """ Calculate Cl using Eq. (17) of Sato-Polito & Kamionkowski
    Parameters
    ----------
    edges : (F,) 1Darray
        Observed orbital frequency bin edges
    dnum : (M,Q,Z,F) NDarray
        dN / [ dlog10M dq dz dlnf ]
    hs : (M,Q,Z,F) NDarray
        Strain amplitude of each M,q,z bin
    
    """
    fobs_orb_edges = edges[-1]
    fobs_gw_edges = fobs_orb_edges * 2.0

    df = np.diff(fobs_orb_edges)                 #: frequency bin widths
    fc = kale.utils.midpoints(fobs_orb_edges)    #: use frequency-bin centers for strain (more accurate!)


    if realize is False:
        hs_edges = strain_amp_at_bin_edges(edges, redz)

        # ---- integrate from differential-number to number per bin
        # integrate over dlog10(M)
        numh2 = utils.trapz(dnum*hs_edges**2, np.log10(edges[0]), axis=0)
        # integrate over mass-ratio
        numh2 = utils.trapz(numh2, edges[1], axis=1)
        # integrate over redshift
        numh2 = utils.trapz(numh2, edges[2], axis=2)
        # times dln(f)
        numh2 = numh2 * np.diff(np.log(fobs_gw_edges)) 

        # integrate over dlog10(M)
        numh4 = utils.trapz(dnum*hs_edges**4, np.log10(edges[0]), axis=0)
        # integrate over mass-ratio
        numh4 = utils.trapz(numh4, edges[1], axis=1)
        # integrate over redshift
        numh4 = utils.trapz(numh4, edges[2], axis=2)
        # times dln(f)
        print('numh4:', numh4.shape, 'np.diff(np.log(fobs_gw_edges))', np.diff(np.log(fobs_gw_edges)).shape)
        numh4 = numh4 * np.diff(np.log(fobs_gw_edges))  # how is this not a shape issue??

    elif utils.isinteger(realize):
        # add reals axis
        hs_cents = strain_amp_at_bin_centers_redz(edges, redz)[...,np.newaxis]
        print('hs_cents:', hs_cents.shape)
        df = df[:,np.newaxis] 
        fc = fc[:,np.newaxis] 

    
        number = holo.sam_cython.integrate_differential_number_3dx1d(edges, dnum)
        shape = number.shape + (realize,)
        print('number:', number.shape)
        number = holo.gravwaves.poisson_as_needed(number[...,np.newaxis] * np.ones(shape))
        print('number:', number.shape)

        numh2 = number * hs_cents**2 * np.diff(np.log(fobs_gw_edges))[:,np.newaxis] 
        numh4 = number * hs_cents**4 * np.diff(np.log(fobs_gw_edges))[:,np.newaxis] 



    else:
        err = "`realize` ({}) must be one of {{False, integer}}!".format(realize)
        raise ValueError(err)

    print('numh2:', numh2.shape, 'numh4:', numh4.shape)


    # df = fobs_orb_widths[np.newaxis, np.newaxis, np.newaxis, :] # (M,Q,Z,F) NDarray
    # fc = fobs_orb_cents[np.newaxis, np.newaxis, np.newaxis, :]  # (M,Q,Z,F) NDarray

    delta_term = (
        fc / (4*np.pi * df) * np.sum(numh2, axis=(0,1,2))
    )**2

    Cl = (
        (fc / (4*np.pi*df))**2 * np.sum(numh4, axis=(0,1,2))
    )

    C0 = Cl + delta_term

    return C0, Cl

C0_dnum, Cl_dnum = Cl_analytic_from_dnum(edges, diff_num)
C0_dnum_reals, Cl_dnum_reals = Cl_analytic_from_dnum(edges, diff_num, realize=10)
C0_redz, Cl_redz = Cl_analytic_from_dnum(edges, diff_num, redz_final)
C0_redz_reals, Cl_redz_reals = Cl_analytic_from_dnum(edges, diff_num, redz_final, realize=10)

In [None]:
arr = np.array([1,2,3,4,])
print(arr[:,np.newaxis].shape)

In [None]:
print(C0_dnum.shape)
print(C0_dnum_reals.shape)

note that Cl_best does not use the same model as the mockups for Sato-Polito method here!

In [None]:
def plot_ClC0():
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
    draw_analytic(ax, Cl, C0, fobs_gw_cents, label='analytic from integrated num')
    draw_reals(ax, Cl_many, C0_many, fobs_gw_cents, label=None, color='tab:orange')
    
    draw_analytic(ax, Cl_dnum, C0_dnum, fobs_gw_cents, label='analytic from dnum, hs from z_init', color='deeppink')
    draw_reals(ax, Cl_dnum_reals, C0_dnum_reals, fobs_gw_cents, label=None, color='deeppink')
    
    draw_analytic(ax, Cl_redz, C0_redz, fobs_gw_cents, label='analytic from dnum, hs from z_final', color='indigo')
    draw_reals(ax, Cl_redz_reals, C0_redz_reals, fobs_gw_cents, label=None, color='indigo')
    
    draw_spk(ax, label='S-P & K')
    draw_bayes(ax, lmax=6)
    # draw_sim(ax, fobs_gw_cents, Cl_best, lmax, show_ci=True, show_reals=True, nshow=10)
    # ax.set_ylim(10**-6, 10**0)
    plot._twin_yr(ax, nano=False)
    ax.set_xlim(fobs[0]- 10**(-10), 1/YR)

    fig.legend(bbox_to_anchor=(0,-0.15), loc='upper left', bbox_transform = ax.transAxes, ncols=3)
    return fig

fig = plot_ClC0()  
fig.suptitle(hard_name, sam.shape)

In [None]:
def plot_ClC0():
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
    draw_analytic(ax, Cl, C0, fobs_gw_cents, label='analytic from integrated num',
                  lw=4)
    draw_reals(ax, Cl_many, C0_many, fobs_gw_cents, label=None, color='tab:orange',
               show_ci=True, show_median=True)
    
    draw_analytic(ax, Cl_dnum, C0_dnum, fobs_gw_cents, label='analytic from dnum, hs from z_init', color='deeppink',
                  lw=4)
    draw_reals(ax, Cl_dnum_reals, C0_dnum_reals, fobs_gw_cents, label=None, color='deeppink',
               show_ci=True, show_median=True)
    
    draw_analytic(ax, Cl_redz, C0_redz, fobs_gw_cents, label='analytic from dnum, hs from z_final', color='indigo',
                  lw=4)
    draw_reals(ax, Cl_redz_reals, C0_redz_reals, fobs_gw_cents, label=None, color='indigo',
               show_ci=True, show_median=True)
    
    draw_spk(ax, label='S-P & K')
    draw_bayes(ax, lmax=6)
    # draw_sim(ax, fobs_gw_cents, Cl_best, lmax, show_ci=False, show_reals=True, nshow=20)
    # ax.set_ylim(10**-6, 10**0)
    plot._twin_yr(ax, nano=False)
    ax.set_xlim(fobs[0]- 10**(-10), 1/YR)

    fig.legend(bbox_to_anchor=(0,-0.15), loc='upper left', bbox_transform = ax.transAxes, ncols=3)
    return fig

fig = plot_ClC0()  

In [None]:
def plot_ClC0():
    fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_HZ, ylabel='$C_{\ell>0}/C_0$')
    draw_analytic(ax, Cl, C0, fobs_gw_cents, label='analytic from integrated num',
                  lw=4)
    draw_reals(ax, Cl_many, C0_many, fobs_gw_cents, label=None, color='tab:orange',
               show_ci=True, show_median=True)
    
    draw_analytic(ax, Cl_dnum, C0_dnum, fobs_gw_cents, label='analytic from dnum, hs from z_init', color='deeppink',
                  lw=4)
    draw_reals(ax, Cl_dnum_reals, C0_dnum_reals, fobs_gw_cents, label=None, color='deeppink',
               show_ci=True, show_median=True)
    
    draw_analytic(ax, Cl_redz, C0_redz, fobs_gw_cents, label='analytic from dnum, hs from z_final', color='indigo',
                  lw=4)
    draw_reals(ax, Cl_redz_reals, C0_redz_reals, fobs_gw_cents, label=None, color='indigo',
               show_ci=True, show_median=True)
    plot._twin_yr(ax, nano=False)
    ax.set_xlim(fobs[0]- 10**(-10), 1/YR)

    fig.legend(bbox_to_anchor=(0,-0.15), loc='upper left', bbox_transform = ax.transAxes, ncols=1)
    return fig

fig = plot_ClC0()  