In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import kalepy as kale

import h5py


from holodeck import plot, detstats, utils, cosmo
import holodeck.sams.cyutils as sam_cyutils
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

import hasasia.sim as hsim

In [None]:
shape = 40

fobs_cents, fobs_edges = utils.pta_freqs()
sam = holo.sams.Semi_Analytic_Model(shape=shape)
hard = holo.hardening.Fixed_Time_2PL_SAM(sam=sam, time=3*GYR)


In [None]:
nloudest = 5
nreals = 10
hc_ss, hc_bg, sspar, bgpar = sam.gwb(fobs_edges, hard, realize=nreals, loudest=nloudest, params=True)

In [None]:
for ii, par in enumerate(bgpar):
    print(f"{sings.par_names[ii]}, {utils.stats(par)}")

In [None]:
for ii, par in enumerate(sspar):
    print(f"{sings.par_names[ii]}, {utils.stats(par)}")

In [None]:
fobs_orb_cents = fobs_cents/2
fobs_orb_edges = fobs_edges/2

In [None]:
redz_final, diff_num = sam_cyutils.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo
)

edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
number = sam_cyutils.integrate_differential_number_3dx1d(edges, diff_num)

In [None]:
print(number.shape)

### Internal to ss_gws_redz()

In [None]:
redz=redz_final

# All other bin midpoints
mt = kale.utils.midpoints(edges[0]) #: total mass
mr = kale.utils.midpoints(edges[1]) #: mass ratio
rz = kale.utils.midpoints(edges[2]) #: initial redshift


# hsfdf = hsamp^2 * f/df # this is same as hc^2
h2fdf = holo.gravwaves.char_strain_sq_from_bin_edges_redz(edges, redz)

# indices of bins sorted by h2fdf
indices = np.argsort(-h2fdf[...,0].flatten()) # just sort for first frequency
unraveled = np.array(np.unravel_index(indices, (len(mt),len(mr),len(rz))))
msort = unraveled[0,:]
qsort = unraveled[1,:]
zsort = unraveled[2,:]

In [None]:
print(utils.stats(redz))
if np.any(np.logical_and(redz<0, redz!=-1)):
            err = np.sum(np.logical_and(redz<0, redz!=-1))
            err = f"{err} redz < 0 and !=-1 found in redz, in ss_gws_redz()"
            raise ValueError(err)

In [None]:
for dd in range(3):
    redz = np.moveaxis(redz, dd, 0)
    redz = kale.utils.midpoints(redz, axis=0)
    redz = np.moveaxis(redz, 0, dd)


In [None]:
print(utils.stats(redz))
if np.any(np.logical_and(redz<0, redz!=-1)):
    err = np.sum(np.logical_and(redz<0, redz!=-1))
    err = f"{err} redz < 0 and !=-1 found in redz, in ss_gws_redz()"
    print(err)

In [None]:
print(redz.size)

In [None]:
dcom_final = +np.inf*np.ones_like(redz)
# print(holo.utils.stats(redz), "before sel")
sel = (redz > 0.0)
redz[~sel] = -1.0
# print(holo.utils.stats(redz), "after sel")
redz[redz<0] = -1.0
# print(holo.utils.stats(redz), "after redz[redz<0]=-1")
dcom_final[sel] = cosmo.comoving_distance(redz[sel]).cgs.value
if np.any(dcom_final<0): print('dcom_final<0 found')
if np.any(np.isnan(dcom_final)): print('nan dcom_final found')
# redz[redz<0] = -1

In [None]:
print(utils.stats(redz))
print(utils.stats(dcom_final))

In [None]:

fobs_orb_edges = edges[-1]
fobs_orb_cents = kale.utils.midpoints(fobs_orb_edges)
frst_orb_cents = utils.frst_from_fobs(fobs_orb_cents[np.newaxis,np.newaxis,np.newaxis,:], redz) # (M,Q,Z,F,), final


In [None]:

sepa = utils.kepler_sepa_from_freq(mt[:,np.newaxis,np.newaxis,np.newaxis], frst_orb_cents) # (M,Q,Z,F) in cm
angs = utils.angs_from_sepa(sepa, dcom_final, redz) # (M,Q,Z,F) use sepa and dcom in cm

print(utils.stats(sepa))
print(utils.stats(angs))

In [None]:
shape=number.shape
print(shape)
M, Q, Z, F = [*shape]
L = nloudest
R = nreals

hc2ss = np.zeros_like(hc_ss)
hc2bg = np.zeros_like(hc_bg)
bgpar = np.zeros((7,F,R,))
sspar = np.zeros((4,F,R,L))

In [None]:
thresh = 0.5 
for rr in range(R):
    for ff in range(F):
        ll = 0 # track which index in the loudest list you're currently storing
                    # start at 0 for the loudest of all.
        # reset strain sums
        sum_bg = 0 # sum of bg h2fdf, for parameter averaging and gwb
        # reset parameter averaging sums
        m_bg = 0
        q_bg = 0
        z_bg = 0
        zfinal_bg = 0
        dcom_bg = 0
        sepa_bg = 0
        angs_bg = 0
        for bb in range(M*Q*Z): #iterate through bins, loudest to quietest
            mm = msort[bb]
            qq = qsort[bb]
            zz = zsort[bb]
            num = number[mm,qq,zz,ff]
            num = np.random.poisson(num)
            # if(num < 1):
            #     continue
            cur = h2fdf[mm,qq,zz,ff] # h^2 * f/df of current bin
            # if (num<1):
            #     continue # to next loudest bin
            while (ll < L) and (num > 0) and (cur>0):
                # store ll loudest source strain
                hc2ss[ff,rr,ll] = cur

                # store indices of ll loudest source
                sspar[0,ff,rr,ll] = mt[mm]
                sspar[1,ff,rr,ll] = mr[qq]
                sspar[2,ff,rr,ll] = rz[zz]
                sspar[3,ff,rr,ll] = redz_final[mm,qq,zz,ff]

                # check for negative redz_final
                if redz_final[mm,qq,zz,ff]<0 and redz_final[mm,qq,zz,ff]!=-1:
                    # badz = badz+1
                    err = f"redz_final[{mm},{qq},{zz},{ff}] = {redz_final[mm,qq,zz,ff]} < 0"
                    print("ERROR IN CYUTILS:", err)

                # update number and ll index
                num -= 1
                ll += 1
            if cur > 0 and num > 0:
                sum_bg += num * cur # tot bg h2fdf
                # add to average parameters of background sources
                m_bg += num * cur * mt[mm] # tot weight bg mass
                q_bg += num * cur * mr[qq] # tot weighted bg ratio
                z_bg += num * cur * rz[zz] # tot weighted bg redshift
                zfinal_bg += num * cur * redz_final[mm,qq,zz,ff] # tot weighted bg redshift after hardening
                dcom_bg += num * cur * dcom_final[mm,qq,zz,ff] # tot weighted bg com. dist. after hardening
                sepa_bg += num * cur * sepa[mm,qq,zz,ff] # tot weighted bg separation after hardening
                angs_bg += num * cur * angs[mm,qq,zz,ff] # tot weighted bg angular separation after hardening

        if np.any(sspar[3,ff,rr,:]<0):
            print(f"{ff=}, {rr=}, {sspar[3,ff,rr,:]=}")
        if ll < nloudest-1: 
            print(f'not enough loudest at {ff=}, {rr=}, {ll=} ')
        hc2bg[ff,rr] = sum_bg # background strain
        # background average parameters
        bgpar[0,ff,rr] = m_bg/sum_bg # bg avg mass
        bgpar[1,ff,rr] = q_bg/sum_bg # bg avg ratio
        bgpar[2,ff,rr] = z_bg/sum_bg # bg avg redshift
        bgpar[3,ff,rr] = zfinal_bg/sum_bg # bg avg redshift after hardening
        bgpar[4,ff,rr] = dcom_bg/sum_bg # bg avg comoving distance after hardening
        bgpar[5,ff,rr] = sepa_bg/sum_bg # bg avg binary separation after hardening
        bgpar[6,ff,rr] = angs_bg/sum_bg # bg avg binary angular separation after hardening
        # print(f"{ff=}, {rr=}, {dcom_bg=}, {sum_bg=}")


In [None]:
print(hc2ss[27,0,:])

In [None]:
print(redz_final[:,:,:,27])

In [None]:
for bb in range(M*Q*Z):
    mm = msort[bb]
    qq = qsort[bb]
    zz = zsort[bb]
    if mm==19 and qq==19 and zz==19:
        print(f"{bb=}, {mm}, {qq}, {zz}") # this is the quietest bin!

In [None]:
print(h2fdf[19,19,19])

In [None]:
print(M*Q*Z)

In [None]:
print(M*Q*Z)

In [None]:
print(len(msort))

In [None]:
print(msort[-1]) # this is the quietest bin
print(qsort[-1])
print(zsort[-1])

In [None]:
for ff in range(F):
    for rr in range(nreals):
        for ll in range(nloudest):
            if sspar[3,ff,rr,ll]<0:
                print(f"{ff=}, {rr=}, {ll=}")

In [None]:
print(sspar[3,:,:,:][sspar[3]<0])

In [None]:
print(utils.stats(sspar[3]))

In [None]:
print(np.sum(sspar[3][sspar[3]<0]))

In [None]:
print(sspar[3])