In [1]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys

import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sams
import holodeck.gravwaves
from holodeck import cosmo, utils, plot, discrete, sams, host_relations
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

In [2]:
import compare_discrete

In [3]:
tmp = compare_discrete.create_dpops(allow_mbh0=True, mod_mmbulge=False, skip_evo=False, fsa_only=True, nreals=10)
all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops = tmp

Setting inspiral timescale tau = 1.0 Gyr.
Setting fixed init binary sep = 10000.0 pc.

Creating Discrete_Pop class instance 'fsa-mm-TNG100-1' with tau=3.15576e+16, fixed_sepa=3.0856775814913676e+22
 fname=galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5
fname = /Users/lblecha/holodeck/holodeck/data/galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5
DEBUG: in population: self.scafa.min()=0.116345263, self.scafa.max()=1.0, np.median(self.scafa)=0.374999594
num with scafa=1: 11
gal_rads.min()=1.326763564100145e+21, gal_rads.max()=7.282709693743754e+23
self.sepa.min()=3.0856775814913676e+22, self.sepa.max()=3.0856775814913676e+22
No zero-mass BHs found in this merger tree file!
self._use_mstar_tot_as_mbulge=False
self.mbulge.min()=1.3306179568569139e+41, self.mbulge.max()=2.5837386273508947e+45
self.mstar_tot.min()=2.2618726543725402e+41, self.mstar_tot.max()=3.765812268004308e+45
No zero-mass BHs found in this merger tree file!
sample volume = 3.987e+79 [cgs] = 1.357e+

GW frequencies:   0%|          | 0/40 [00:00<?, ?it/s]


DEBUG: fobs_gw=1.976798990270053e-09

*** DEBUG *** mrat.min()=0.00022569315672513172, 0.999960413329724
temp.shape=(3246,), gwb_harms.shape=(1,), num_pois.shape=(3246, 10), both.shape=(10,), temp_to_sort.shape=(3246, 10)
idx_loud.shape=(3246, 10), loud.shape=(3246, 10)
mchirp.shape=(3246,), redz.shape=(3246,), (3246, 10)=
mchirp_loud.shape=(5, 10), redz_loud.shape=(5, 10)
loud.shape=(5, 10), mchirp_loud.shape=(5, 10), redz_loud.shape=(5, 10)
mchirp: min=1.035e+06, max=2.493e+09, med=4.355e+06
mchirp_loud: min=6.334e+08, max=2.493e+09, med=9.604e+08
mpri: min=1.191e+06, max=7.378e+09, med=1.07e+07
mpri_loud: min=7.305e+08, max=6.972e+09, med=2.417e+09
mrat: min=0.0002257, max=1, med=0.5436
mrat_loud: min=0.03719, max=0.9921, med=0.5552
redz: min=5.071e-05, max=4.895, med=1.398
redz_loud: min=0.04593, max=0.6924, med=0.1663
sepa: min=0.002031, max=0.0869, med=0.007048
sepa_loud: min=0.05259, max=0.08539, med=0.06003

DEBUG: fobs_gw=3.953597980540106e-09

*** DEBUG *** mrat.min()=0.0002

In [4]:
print(len(all_fsa_dpops))

1


In [5]:
# ---- Define the GWB frequencies
freqs, freqs_edges = utils.pta_freqs()

In [6]:
print("creating SAM using Galaxy Pair Fraction (GPF) + Galaxy Merger Timescale (GMT)...")

sam = sams.Semi_Analytic_Model(gpf = sams.GPF_Power_Law())

print("    ...calculating hardening")
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
print("    ...creating gwb")
#gwb_sam = sam.gwb_new(freqs_edges, hard, realize=50)
gwb_sam = sam.gwb(freqs_edges, hard, realize=50, loudest=10, params=True)


creating SAM using Galaxy Pair Fraction (GPF) + Galaxy Merger Timescale (GMT)...
16:07:46 INFO : Galaxy pair-fraction provided, using galaxy pair-fraction and merger-time. [sam.py:__init__]
    ...calculating hardening
    ...creating gwb

 *** in _static_binary_density() *** 

** using GPF in _ndens_gal() **
GPF_USES_MTOT=False, GMT_USES_MTOT=False, GSMF_USES_MTOT=False
self._gpf(mass_gpf, mstar_rat, redz).shape=(91, 81, 101), self._gpf(mass_gpf, mstar_rat, redz).min()=0.03335999733439941, self._gpf(mass_gpf, mstar_rat, redz).max()=0.22698277091741006
gal_merger_rate.shape=(91, 81, 101), gal_merger_rate.max()=4.3373311313467944e-17, gal_merger_rate.min()=1.922987182177158e-18
16:07:47 INFO : Adding MMbulge scatter (2.8000e-01) [sam.py:static_binary_density]
16:07:47 INFO : 	dens bef: (5.35e-282, 1.12e-124, 4.81e-26, 1.36e-03, 1.59e-02, 2.90e-02, 6.20e-02) [sam.py:static_binary_density]
16:08:06 INFO : Scatter added after 19.040027 sec [sam.py:static_binary_density]
16:08:06 INFO : 	de

In [7]:
print(len(gwb_sam))

4


In [8]:
print(len(gwb_sam[0]), len(gwb_sam[1]), len(gwb_sam[2]), len(gwb_sam[3])) #, len(gwb_sam_new))
print(gwb_sam[0].shape) # hc2ss: [nfreqs, nreals, nloudest]
print(gwb_sam[1].shape) # hc2bg: [nfreqs, nreals]
print(gwb_sam[2].shape) # sspars: [nsspars, nfreqs, nreals, nloudest]
print(gwb_sam[3].shape) # bgpars: [nbgpars, nfreqs, nreals]
# sspars:
# sspar[0,ff,rr,ll] = mt[mm]
# sspar[1,ff,rr,ll] = mr[qq]
# sspar[2,ff,rr,ll] = rz[zz]
# sspar[3,ff,rr,ll] = redz_final[mm,qq,zz,ff]
# bgpars: 
# bgpar[0,ff,rr] = m_bg/sum_bg # bg avg mass
# bgpar[1,ff,rr] = q_bg/sum_bg # bg avg ratio
# bgpar[2,ff,rr] = z_bg/sum_bg # bg avg redshift
# bgpar[3,ff,rr] = zfinal_bg/sum_bg # bg avg redshift after hardening
# bgpar[4,ff,rr] = dcom_bg/sum_bg # bg avg comoving distance after hardening
# bgpar[5,ff,rr] = sepa_bg/sum_bg # bg avg binary separation after hardening
# bgpar[6,ff,rr] = angs_bg/sum_bg # bg avg binary angular separation after hardening


40 40 4 7
(40, 50, 10)
(40, 50)
(4, 40, 50, 10)
(7, 40, 50)


In [None]:
foo = [np.zeros((3,2,4))]*4
print(len(foo), foo[0].shape)

In [None]:
gwb_sam_L100 = sam.gwb(freqs_edges, hard, realize=50, loudest=100, params=True)
gwb_sam_new = sam.gwb_new(freqs_edges, hard, realize=50)

In [None]:
plt.xscale('log')
plt.yscale('log')
for ll in range(10):
    plt.plot(freqs, np.mean(gwb_sam[2][0,:,:,ll],axis=1)/MSOL,'.') # mtot of 10 loudest sources in each freq bin, avg over nreals
plt.plot(freqs, gwb_sam[3][0,:,:]/MSOL,'k',lw=0.5,alpha=0.1); # avg mtot of bg sources for each freq for each real

In [None]:
for d in all_dpops:
    print(f"{d.evo.scafa.shape=}, {d.evo.tlook.shape=}, {d.evo.sepa.shape=}, {d.evo.mass.shape=}")

In [None]:
print("creating 'no-GPF' SAM using Galaxy Merger Rate (GMR) "
      "    (uses galaxy merger rates directly from RG15 instead of GPF+GMT)...")

sam_no_gpf = sams.Semi_Analytic_Model()

print("    ...calculating hardening for no-GPF SAM")
hard_no_gpf = holo.hardening.Fixed_Time_2PL_SAM(sam_no_gpf, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
print("    ...creating gwb for no-GPF SAM")
gwb_sam_no_gpf = sam_no_gpf.gwb_new(freqs_edges, hard_no_gpf, realize=50)

In [None]:

#def get_dpop_amplitudes_at_freq(d,f):
    
#    freqs = np.array([1/YR, 1/(3*YR), 1/(10*YR)])
#    amps = np.zeros_like(freqs)
    
#    for i,f in enumerate(freqs):
#        # ---- find frequency bins closest to chosen freqs
#        idx = np.where(np.abs(d.freqs-f)==np.abs(d.freqs-f).min())[0]
#        amps[i] = d.gwb.back[idx,:].flatten()

#    return amps
        
##for d in tng_dpops + tng_fsa_dpops: #all_sim_dpops + all_fsa_dpops:
##for d in all_hires_fid_and_fsa_dpops: 
##for d in all_fsa_dpops: 
#for d in all_dpops: 
#    d.get_amplitudes_at_freqs()
#for d in all_fsa_dpops: 
#    d.get_amplitudes_at_freqs()

### this is sketch. only works if freqs is same for all dpops and for sam
#ayr_sam = gwb_sam[d.idx_ayr,:].flatten()
#a3yr_sam = gwb_sam[d.idx_a3yr,:].flatten()
#a10yr_sam = gwb_sam[d.idx_a10yr,:].flatten()

#ayr_sam_no_gpf = gwb_sam_no_gpf[d.idx_ayr,:].flatten()
#a3yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a3yr,:].flatten()
#a10yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a10yr,:].flatten()


In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=[10,10])

freqs_to_plot = np.array([1/YR, 1/(3*YR), 1/(10*YR)])
amps = np.zeros_like(freqs)
    
for i,f in enumerate(freqs_to_plot):
    idx = np.where(np.abs(freqs-f)==np.abs(freqs-f).min())[0]
    amp_sam = gwb_sam[idx,:].flatten()
    amp_sam_no_gpf = gwb_sam_no_gpf[idx,:].flatten()

    if i==freqs_to_plot.size-1:
        axs[i].set(xlabel=r'$\log_{10}(A_\mathrm{yr})$')
    axs[i].set(ylabel='Probability Density')
    axs[i].grid(alpha=0.2)
    kale.dist1d(np.log10(amp_sam), density=True, hist=False, confidence=True, carpet=False, 
                lw=4, color='k', label='SAM',ax=axs[i])
    kale.dist1d(np.log10(amp_sam_no_gpf), density=True, hist=False, confidence=True, carpet=False, 
                lw=4, color='orchid', label='SAM (no GPF)', ax=axs[i])
    #for i in np.arange(len(tng_dpops)):
    for d in all_fsa_dpops:
        if np.any(d.freqs!=freqs):
            print(f"WARNING: SAM and sim have different frequency arrays.")
            idx_sim = np.where(np.abs(d.freqs-f)==np.abs(d.freqs-f).min())[0]
        else:
            idx_sim = idx    
        amp_sim = d.gwb.both[idx_sim,:].flatten()
    
        if np.any(~np.isfinite(amp_sim)):
            print(f"skipping {d.lbl} ({amp_sim.min()}, {amp_sim.max()})")
        else:
            print(f"{d.lbl} ({amp_sim.min()}, {amp_sim.max()})")
            kale.dist1d(np.log10(amp_sim), density=True, hist=False, confidence=False, carpet=False, 
                        label=d.lbl, lw=d.lw, color=d.color,alpha=0.5, ax=axs[i])
    #for d in [all_fsa_dpops[0],all_fsa_dpops[2]]:
    #    if np.any(~np.isfinite(d.ayr)):
    #        print(f"skipping {d.lbl} ({d.ayr.min()}, {d.ayr.max()})")
    #    else:
    #        print(f"{d.lbl} ({d.ayr.min()}, {d.ayr.max()})")
    #        kale.dist1d(np.log10(d.ayr), density=True, hist=False, confidence=False, carpet=False, 
    #                    label=d.lbl, lw=d.lw, color=d.color,alpha=0.5, ax=axs[i])

    axs[i].set(title=f"Comparison of GWB amplitudes at f={f*YR:.1g} YR")
    axs[i].legend()
#fig.savefig(f'compare_dpops_tau{tau/(1e9*YR):.1f}_A1yr.png')
#plt.show()

In [None]:
def __plot_gwb(fobs, gwb, hc_ss=None, bglabel=None, sslabel=None, **kwargs):
    xx = fobs * YR
    fig, ax = figax(
        xlabel=LABEL_GW_FREQUENCY_YR,
        ylabel=LABEL_CHARACTERISTIC_STRAIN
    )
    if(hc_ss is not None):
        draw_ss_and_gwb(ax, xx, hc_ss, gwb, sslabel=sslabel,
                        bglabel=bglabel, **kwargs)
    else:
        draw_gwb(ax, xx, gwb, **kwargs)
    _twin_hz(ax)
    return fig

def __draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None, **kwargs):
    if color is None:
        color = ax._get_lines.get_next_color()

    kw_plot = kwargs.pop('plot', {})
    kw_plot.setdefault('color', color)
    hh = __draw_med_conf(ax, xx, gwb, plot=kw_plot, label=label, **kwargs)
    if (nsamp is not None) and (nsamp > 0):
        nsamp_max = gwb.shape[1]
        idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False)
        for ii in idx:
            ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-')

    return hh

def __draw_med_conf(ax, xx, vals, fracs=[0.50, 0.90], weights=None, plot={}, 
                    fill={}, filter=False, label=None, lw=1.0, ls='-'):
    #plot.setdefault('alpha', 0.75)
    #fill.setdefault('alpha', 0.2)
    plot.setdefault('alpha', 0.75)
    fill.setdefault('alpha', 0.2)
    percs = np.atleast_1d(fracs)
    assert np.all((0.0 <= percs) & (percs <= 1.0))

    # center the target percentages into pairs around 50%, e.g.  68 ==> [16,84]
    inter_percs = [[0.5-pp/2, 0.5+pp/2] for pp in percs]
    # Add the median value (50%)
    inter_percs = [0.5, ] + np.concatenate(inter_percs).tolist()
    # Get percentiles; they go along the last axis
    if filter:
        rv = [
            kale.utils.quantiles(vv[vv > 0.0], percs=inter_percs, weights=weights)
            for vv in vals
        ]
        rv = np.asarray(rv)
    else:
        rv = kale.utils.quantiles(vals, percs=inter_percs, weights=weights, axis=-1)

    med, *conf = rv.T
    # plot median
    hh, = ax.plot(xx, med, **plot, lw=lw, ls=ls, label=label)

    # Reshape confidence intervals to nice plotting shape
    # 2*P, X ==> (P, 2, X)
    conf = np.array(conf).reshape(len(percs), 2, xx.size)

    kw = dict(color=hh.get_color())
    kw.update(fill)
    fill = kw

    # plot each confidence interval
    for lo, hi in conf:
        gg = ax.fill_between(xx, lo, hi, **fill)

    return (hh, gg)

In [None]:
LABEL_GW_FREQUENCY_YR = r"GW Frequency $[\mathrm{yr}^{-1}]$"
LABEL_GW_FREQUENCY_HZ = r"GW Frequency $[\mathrm{Hz}]$"
LABEL_GW_FREQUENCY_NHZ = r"GW Frequency $[\mathrm{nHz}]$"
LABEL_SEPARATION_PC = r"Binary Separation $[\mathrm{pc}]$"
LABEL_CHARACTERISTIC_STRAIN = r"GW Characteristic Strain"
LABEL_HARDENING_TIME = r"Hardening Time $[\mathrm{Gyr}]$"
LABEL_CLC0 = r"$C_\ell / C_0$"

fig, ax = plot.figax(
    xlabel=LABEL_GW_FREQUENCY_YR,
    ylabel=LABEL_CHARACTERISTIC_STRAIN
)

frac = 0.50

xx = freqs * YR
__draw_gwb(ax, xx, gwb_sam, nsamp=0, color='k', label='SAM',lw=3,fracs=[frac])
#__draw_gwb(ax, xx, gwb_sam, nsamp=0, color='k', label='SAM (GPF+GMT)',lw=2)
#__draw_gwb(ax, xx, gwb_sam_no_gpf, nsamp=0, color='m', label='SAM (GMR)')
#color = ['darkgreen','purple']
#lw = [
#for d in all_fsa_dpops:
#__draw_gwb(ax, xx, all_fsa_dpops[0].gwb.both, nsamp=0, color=color[0], label=all_fsa_dpops[0].lbl,lw=1.5)
if len(all_dpops)>0:
    xx = all_dpops[0].freqs * YR
    print(f"{xx.size=}, {gwb_sam.shape=}")
    __draw_gwb(ax, xx, all_dpops[0].gwb.both, nsamp=0, color='g', 
               label=all_dpops[0].lbl,lw=1.5,fracs=[frac])
if len(all_dpops)>1:
    xx = all_dpops[1].freqs * YR
    __draw_gwb(ax, xx, all_dpops[1].gwb.both, nsamp=0, color='m', 
               label=all_dpops[1].lbl,lw=1.5,fracs=[frac])


if len(all_fsa_dpops)>0:
    xx = all_fsa_dpops[0].freqs * YR
    print(f"{xx.size=}, {gwb_sam.shape=}")
    print(f"{all_fsa_dpops[0].gwb.both.shape=}")
    __draw_gwb(ax, xx, all_fsa_dpops[0].gwb.both, nsamp=0, color='darkgreen', 
               label=all_fsa_dpops[0].lbl,lw=2,ls='--',fracs=[frac])
if len(all_fsa_dpops)>1:
    xx = all_fsa_dpops[1].freqs * YR
    print(f"{all_fsa_dpops[1].gwb.both.shape=}")
    __draw_gwb(ax, xx, all_fsa_dpops[1].gwb.both, nsamp=0, color='purple', 
               label=all_fsa_dpops[1].lbl,lw=2,ls='--',fracs=[frac])

plt.legend(loc='lower left',fontsize=12)
plt.savefig('gwb_compare_ill_tng100_main.png',dpi=300)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_sam_new)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_sam)

In [None]:
fig = holo.plot.plot_gwb(freqs, gwb_sam_no_gpf)