In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys

import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sams
import holodeck.gravwaves
from holodeck import cosmo, utils, plot, discrete, sams, host_relations
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

In [None]:
import compare_discrete

In [None]:
# ---- Define the GWB frequencies
freqs, freqs_edges = utils.pta_freqs()
print(f"{freqs.shape[0]=}, {freqs_edges.shape[0]=}")
NFREQS = freqs.shape[0]
NREALS = 100
NLOUD = 1000

# ---- Create discrete population(s)
tmp = compare_discrete.create_dpops(allow_mbh0=True, mod_mmbulge=False, skip_evo=False, fsa_only=True, nreals=NREALS, nloudest=NLOUD)
all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops = tmp

In [None]:
print(f"For now, just using first element of `all_fsa_dpops` ({len(all_fsa_dpops)=})")

print(f"(nfreq, nloud, nreals): {all_fsa_dpops[0].gwb.sspar[0].shape}") #nfreq, nloud, nreals

print(f"Inspiral timescale for all binaries = {all_fsa_dpops[0].tau / GYR} Gyr")

In [None]:
print("creating SAM using Galaxy Pair Fraction (GPF) + Galaxy Merger Timescale (GMT)...")

sam = sams.Semi_Analytic_Model(gpf = sams.GPF_Power_Law())

print("    ...calculating hardening")
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)

print("    ...creating gwb")
gwb_sam = sam.gwb(freqs_edges, hard, realize=NREALS, loudest=NLOUD, params=True)

#gwb_sam_L100 = sam.gwb(freqs_edges, hard, realize=50, loudest=100, params=True)

# this is what we've been using as the default for the sam GWB,
# until we had to switch to get the param output
#gwb_sam = sam.gwb_new(freqs_edges, hard, realize=50)
gwb_new_sam = sam.gwb_new(freqs_edges, hard, realize=NREALS)  ### ***NOTE*** this returns hc, not hc2

In [None]:
print(len(gwb_sam))
print(len(gwb_sam[0]), len(gwb_sam[1]), len(gwb_sam[2]), len(gwb_sam[3])) #, len(gwb_sam_new))
print(gwb_sam[0].shape) # hcss: [nfreqs, nreals, nloudest]
print(gwb_sam[1].shape) # hcbg: [nfreqs, nreals]
print(gwb_sam[2].shape) # sspars: [nsspars, nfreqs, nreals, nloudest]
print(gwb_sam[3].shape) # bgpars: [nbgpars, nfreqs, nreals]
gwb_sam_hcss = gwb_sam[0]
gwb_sam_hcbg = gwb_sam[1]
gwb_sam_hctot = np.sqrt( np.sum(gwb_sam[0]**2,axis=2) + gwb_sam[1]**2 )
gwb_sam_sspars = gwb_sam[2]
gwb_sam_bgpars = gwb_sam[3]

print(gwb_sam_hcss.shape) # hcss: [nfreqs, nreals, nloudest]
print(gwb_sam_hcbg.shape) # hcbg: [nfreqs, nreals]
print(gwb_sam_hctot.shape) # hctot: [nfreqs, nreals]
print(gwb_sam_sspars.shape) # sspars: [nsspars, nfreqs, nreals, nloudest]
print(gwb_sam_bgpars.shape) # bgpars: [nbgpars, nfreqs, nreals]
# sspars:
# sspar[0,ff,rr,ll] = mt[mm]
# sspar[1,ff,rr,ll] = mr[qq]
# sspar[2,ff,rr,ll] = rz[zz]
# sspar[3,ff,rr,ll] = redz_final[mm,qq,zz,ff]
# bgpars: 
# bgpar[0,ff,rr] = m_bg/sum_bg # bg avg mass
# bgpar[1,ff,rr] = q_bg/sum_bg # bg avg ratio
# bgpar[2,ff,rr] = z_bg/sum_bg # bg avg redshift
# bgpar[3,ff,rr] = zfinal_bg/sum_bg # bg avg redshift after hardening
# bgpar[4,ff,rr] = dcom_bg/sum_bg # bg avg comoving distance after hardening
# bgpar[5,ff,rr] = sepa_bg/sum_bg # bg avg binary separation after hardening
# bgpar[6,ff,rr] = angs_bg/sum_bg # bg avg binary angular separation after hardening


In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=[10,10])

#plt.xscale('log')
#plt.yscale('log')
print(all_fsa_dpops[0].gwb.sspar[3].shape) ##nfreq, nloud, nreals
axs[0,0].set(xlabel='frequency', ylabel='Mtot',xscale='log',yscale='log')
axs[0,0].plot(freqs, gwb_sam_bgpars[0,:,:]/MSOL,'k',lw=0.2,alpha=0.2); # avg mt of bg sources for each freq for each real

axs[0,1].set(xlabel='frequency', ylabel='q',xscale='log') #,yscale='log')
axs[0,1].plot(freqs, gwb_sam_bgpars[1,:,:],'k',lw=0.2,alpha=0.2); # avg mrat of bg sources for each freq for each real

axs[1,0].set(xlabel='frequency', ylabel='z_init',xscale='log', ylim=(0,2.5)) #,yscale='log')
axs[1,0].plot(freqs, gwb_sam_bgpars[2,:,:],'k',lw=0.2,alpha=0.2); # avg mrat of bg sources for each freq for each real

axs[1,1].set(xlabel='frequency', ylabel='z_final',xscale='log', ylim=(0,2.5)) #,yscale='log')
axs[1,1].plot(freqs, gwb_sam_bgpars[3,:,:],'k',lw=0.2,alpha=0.2); # avg rzi of bg sources for each freq for each real

for ll in range(NLOUD):
    # mt of 10 loudest sources in each freq bin, avg over nreals
    axs[0,0].plot(freqs, np.mean(gwb_sam_sspars[0,:,:,ll],axis=1)/MSOL,'k.',alpha=0.5)
    axs[0,0].plot(freqs, np.mean(all_fsa_dpops[0].gwb.sspar[0][:,ll,:],axis=1)/MSOL,'c.',alpha=0.5) 

    # mrat of 10 loudest sources in each freq bin, avg over nreals
    axs[0,1].plot(freqs, np.mean(gwb_sam_sspars[1,:,:,ll],axis=1),'k.',alpha=0.5) 
    axs[0,1].plot(freqs, np.mean(all_fsa_dpops[0].gwb.sspar[1][:,ll,:],axis=1),'c.',alpha=0.5)

    # rzi of 10 loudest sources in each freq bin, avg over nreals
    axs[1,0].plot(freqs, np.mean(gwb_sam_sspars[2,:,:,ll],axis=1),'k.',alpha=0.5) 
    axs[1,0].plot(freqs, np.mean(all_fsa_dpops[0].gwb.sspar[2][:,ll,:],axis=1),'c.',alpha=0.5) 

    # rzf of 10 loudest sources in each freq bin, avg over nreals
    axs[1,1].plot(freqs, np.mean(gwb_sam_sspars[3,:,:,ll],axis=1),'k.',alpha=0.5)
    axs[1,1].plot(freqs, np.mean(all_fsa_dpops[0].gwb.sspar[3][:,ll,:],axis=1),'c.',alpha=0.5) 


#for ll in range(nloud):
#    plt.plot(freqs, np.mean(gwb_sam[2][3,:,:,ll],axis=1),'k.',alpha=0.5) # rzf of 10 loudest sources in each freq bin, avg over nreals
#    plt.plot(freqs, np.mean(dp.gwb.sspar[3][:,:,ll],axis=1),'c.',alpha=0.5) # rzf of 10 loudest sources in each freq bin, avg over nreals


In [None]:
for d in all_dpops:
    print(f"{d.evo.scafa.shape=}, {d.evo.tlook.shape=}, {d.evo.sepa.shape=}, {d.evo.mass.shape=}")

In [None]:
print("creating 'no-GPF' SAM using Galaxy Merger Rate (GMR) "
      "    (uses galaxy merger rates directly from RG15 instead of GPF+GMT)...")

sam_no_gpf = sams.Semi_Analytic_Model()

print("    ...calculating hardening for no-GPF SAM")
hard_no_gpf = holo.hardening.Fixed_Time_2PL_SAM(sam_no_gpf, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
print("    ...creating gwb for no-GPF SAM")
gwb_new_sam_no_gpf = sam_no_gpf.gwb_new(freqs_edges, hard_no_gpf, realize=NREALS)

In [None]:

#def get_dpop_amplitudes_at_freq(d,f):
    
#    freqs = np.array([1/YR, 1/(3*YR), 1/(10*YR)])
#    amps = np.zeros_like(freqs)
    
#    for i,f in enumerate(freqs):
#        # ---- find frequency bins closest to chosen freqs
#        idx = np.where(np.abs(d.freqs-f)==np.abs(d.freqs-f).min())[0]
#        amps[i] = d.gwb.back[idx,:].flatten()

#    return amps
        
##for d in tng_dpops + tng_fsa_dpops: #all_sim_dpops + all_fsa_dpops:
##for d in all_hires_fid_and_fsa_dpops: 
##for d in all_fsa_dpops: 
#for d in all_dpops: 
#    d.get_amplitudes_at_freqs()
#for d in all_fsa_dpops: 
#    d.get_amplitudes_at_freqs()

### this is sketch. only works if freqs is same for all dpops and for sam
#ayr_sam = gwb_sam[d.idx_ayr,:].flatten()
#a3yr_sam = gwb_sam[d.idx_a3yr,:].flatten()
#a10yr_sam = gwb_sam[d.idx_a10yr,:].flatten()

#ayr_sam_no_gpf = gwb_sam_no_gpf[d.idx_ayr,:].flatten()
#a3yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a3yr,:].flatten()
#a10yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a10yr,:].flatten()


In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=[10,10])

freqs_to_plot = np.array([1/YR, 1/(3*YR), 1/(10*YR)])
amps = np.zeros_like(freqs)
    
for i,f in enumerate(freqs_to_plot):
    idx = np.where(np.abs(freqs-f)==np.abs(freqs-f).min())[0]
    amp_sam = gwb_sam_hctot[idx,:].flatten()
    amp_new_sam = gwb_new_sam[idx,:].flatten()
    amp_new_sam_no_gpf = gwb_new_sam_no_gpf[idx,:].flatten()

    if i==freqs_to_plot.size-1:
        axs[i].set(xlabel=r'$\log_{10}(A_\mathrm{yr})$')
    axs[i].set(ylabel='Probability Density')
    axs[i].grid(alpha=0.2)
    kale.dist1d(np.log10(amp_sam), density=True, hist=False, confidence=True, carpet=False, 
                lw=4, color='k', label='SAM',ax=axs[i])
    kale.dist1d(np.log10(amp_new_sam), density=True, hist=False, confidence=True, carpet=False, 
                lw=3, color='k', label='SAM new',ax=axs[i])
    kale.dist1d(np.log10(amp_new_sam_no_gpf), density=True, hist=False, confidence=True, carpet=False, 
                lw=4, color='orchid', label='SAM (no GPF)', ax=axs[i])
    #for i in np.arange(len(tng_dpops)):
    for d in all_fsa_dpops:
        if np.any(d.freqs!=freqs):
            print(f"WARNING: SAM and sim have different frequency arrays.")
            idx_sim = np.where(np.abs(d.freqs-f)==np.abs(d.freqs-f).min())[0]
        else:
            idx_sim = idx    
        amp_sim = d.gwb.strain[idx_sim,:].flatten()  #`strain` is equivalent to `both`
    
        if np.any(~np.isfinite(amp_sim)):
            print(f"skipping {d.lbl} ({amp_sim.min()}, {amp_sim.max()})")
        else:
            print(f"{d.lbl} ({amp_sim.min()}, {amp_sim.max()})")
            kale.dist1d(np.log10(amp_sim), density=True, hist=False, confidence=False, carpet=False, 
                        label=d.lbl, lw=d.lw, color=d.color,alpha=0.5, ax=axs[i])
    #for d in [all_fsa_dpops[0],all_fsa_dpops[2]]:
    #    if np.any(~np.isfinite(d.ayr)):
    #        print(f"skipping {d.lbl} ({d.ayr.min()}, {d.ayr.max()})")
    #    else:
    #        print(f"{d.lbl} ({d.ayr.min()}, {d.ayr.max()})")
    #        kale.dist1d(np.log10(d.ayr), density=True, hist=False, confidence=False, carpet=False, 
    #                    label=d.lbl, lw=d.lw, color=d.color,alpha=0.5, ax=axs[i])

    axs[i].set(title=f"Comparison of GWB amplitudes at f={f*YR:.1g} YR")
    axs[i].legend()
#fig.savefig(f'compare_dpops_tau{tau/(1e9*YR):.1f}_A1yr.png')
#plt.show()

In [None]:
def __plot_gwb(fobs, gwb, hc_ss=None, bglabel=None, sslabel=None, **kwargs):
    xx = fobs * YR
    fig, ax = figax(
        xlabel=LABEL_GW_FREQUENCY_YR,
        ylabel=LABEL_CHARACTERISTIC_STRAIN
    )
    if(hc_ss is not None):
        draw_ss_and_gwb(ax, xx, hc_ss, gwb, sslabel=sslabel,
                        bglabel=bglabel, **kwargs)
    else:
        draw_gwb(ax, xx, gwb, **kwargs)
    _twin_hz(ax)
    return fig

def __draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None, **kwargs):
    if color is None:
        color = ax._get_lines.get_next_color()

    kw_plot = kwargs.pop('plot', {})
    kw_plot.setdefault('color', color)
    hh = __draw_med_conf(ax, xx, gwb, plot=kw_plot, label=label, **kwargs)
    if (nsamp is not None) and (nsamp > 0):
        nsamp_max = gwb.shape[1]
        idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False)
        for ii in idx:
            ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-')

    return hh

def __draw_med_conf(ax, xx, vals, fracs=[0.50, 0.90], weights=None, plot={}, 
                    fill={}, filter=False, label=None, lw=1.0, ls='-'):
    #plot.setdefault('alpha', 0.75)
    #fill.setdefault('alpha', 0.2)
    plot.setdefault('alpha', 0.75)
    fill.setdefault('alpha', 0.2)
    percs = np.atleast_1d(fracs)
    assert np.all((0.0 <= percs) & (percs <= 1.0))

    # center the target percentages into pairs around 50%, e.g.  68 ==> [16,84]
    inter_percs = [[0.5-pp/2, 0.5+pp/2] for pp in percs]
    # Add the median value (50%)
    inter_percs = [0.5, ] + np.concatenate(inter_percs).tolist()
    # Get percentiles; they go along the last axis
    if filter:
        rv = [
            kale.utils.quantiles(vv[vv > 0.0], percs=inter_percs, weights=weights)
            for vv in vals
        ]
        rv = np.asarray(rv)
    else:
        rv = kale.utils.quantiles(vals, percs=inter_percs, weights=weights, axis=-1)

    med, *conf = rv.T
    # plot median
    hh, = ax.plot(xx, med, **plot, lw=lw, ls=ls, label=label)

    # Reshape confidence intervals to nice plotting shape
    # 2*P, X ==> (P, 2, X)
    conf = np.array(conf).reshape(len(percs), 2, xx.size)

    kw = dict(color=hh.get_color())
    kw.update(fill)
    fill = kw

    # plot each confidence interval
    for lo, hi in conf:
        gg = ax.fill_between(xx, lo, hi, **fill)

    return (hh, gg)

In [None]:
LABEL_GW_FREQUENCY_YR = r"GW Frequency $[\mathrm{yr}^{-1}]$"
LABEL_GW_FREQUENCY_HZ = r"GW Frequency $[\mathrm{Hz}]$"
LABEL_GW_FREQUENCY_NHZ = r"GW Frequency $[\mathrm{nHz}]$"
LABEL_SEPARATION_PC = r"Binary Separation $[\mathrm{pc}]$"
LABEL_CHARACTERISTIC_STRAIN = r"GW Characteristic Strain"
LABEL_HARDENING_TIME = r"Hardening Time $[\mathrm{Gyr}]$"
LABEL_CLC0 = r"$C_\ell / C_0$"

fig, ax = plot.figax(
    xlabel=LABEL_GW_FREQUENCY_YR,
    ylabel=LABEL_CHARACTERISTIC_STRAIN
)

frac = 0.50

xx = freqs * YR
__draw_gwb(ax, xx, gwb_sam_hctot, nsamp=0, color='k', label='SAM gwb',lw=3,fracs=[frac])
__draw_gwb(ax, xx, gwb_new_sam, nsamp=0, color='b', label='SAM gwb new',lw=3,fracs=[frac])
#__draw_gwb(ax, xx, gwb_sam, nsamp=0, color='k', label='SAM (GPF+GMT)',lw=2)
#__draw_gwb(ax, xx, gwb_sam_no_gpf, nsamp=0, color='m', label='SAM (GMR)')
#color = ['darkgreen','purple']
#lw = [
#for d in all_fsa_dpops:
#__draw_gwb(ax, xx, all_fsa_dpops[0].gwb.both, nsamp=0, color=color[0], label=all_fsa_dpops[0].lbl,lw=1.5)
if len(all_dpops)>0:
    xx = all_dpops[0].freqs * YR
    print(f"{xx.size=}, {gwb_sam_hctot.shape=}")
    __draw_gwb(ax, xx, all_dpops[0].gwb.both, nsamp=0, color='g', 
               label=all_dpops[0].lbl,lw=1.5,fracs=[frac])
if len(all_dpops)>1:
    xx = all_dpops[1].freqs * YR
    __draw_gwb(ax, xx, all_dpops[1].gwb.both, nsamp=0, color='m', 
               label=all_dpops[1].lbl,lw=1.5,fracs=[frac])


if len(all_fsa_dpops)>0:
    xx = all_fsa_dpops[0].freqs * YR
    print(f"{xx.size=}, {gwb_sam_hctot.shape=}")
    print(f"{all_fsa_dpops[0].gwb.both.shape=}")
    __draw_gwb(ax, xx, all_fsa_dpops[0].gwb.both, nsamp=0, color='darkgreen', 
               label=all_fsa_dpops[0].lbl,lw=2,ls='--',fracs=[frac])
if len(all_fsa_dpops)>1:
    xx = all_fsa_dpops[1].freqs * YR
    print(f"{all_fsa_dpops[1].gwb.both.shape=}")
    __draw_gwb(ax, xx, all_fsa_dpops[1].gwb.both, nsamp=0, color='purple', 
               label=all_fsa_dpops[1].lbl,lw=2,ls='--',fracs=[frac])

plt.legend(loc='lower left',fontsize=12)
plt.savefig('gwb_compare_ill_tng100_main.png',dpi=300)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_new_sam)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_sam_hctot)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_new_sam_no_gpf)