In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import holodeck as holo
import holodeck.anisotropy as anis
from holodeck.constants import YR, MSOL, GYR
from holodeck import utils

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import kalepy as kale

# Calculate and store useful variables for sam+hardening model

In [None]:
def sam_model(sam, hard,
        dur=16.03*YR, cad=0.2*YR, use_redz=True):
    fobs_gw_cents = utils.nyquist_freqs(dur,cad)
    fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
        hard_name = 'Fixed Time'
    elif isinstance(hard, holo.hardening.Hard_GW):
        hard_name = 'GW Only'
    else:
        raise Exception("'hard' must be an instance of 'Fixed_Time_2PL_SAM' or 'Hard_GW'")

    redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)
    if use_redz:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges, redz_final)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges, redz_final)
    else:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges)

    vals = {
        'hard':hard, 'sam':sam, 'edges':edges, 'number': number, 'diff_num':diff_num, 'redz_final':redz_final,
        'hs_cents':hs_cents, 'hs_edges':hs_edges, 'fobs_gw_cents':fobs_gw_cents, 'fobs_gw_edges':fobs_gw_edges, 
        'fobs_orb_cents':fobs_orb_cents, 'fobs_orb_edges':fobs_orb_edges, 'hard_name':hard_name
    }
    return vals


def strain_amp_at_bin_centers_redz(edges, redz=None):
    """ Calculate strain amplitude at bin centers, with final or initial redz.
    
    """
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # redshifts are defined across 4D grid, shape (M, Q, Z, Fc)
    #    where M, Q, Z are edges and Fc is frequency centers
    # find midpoints of redshifts in M, Q, Z dimensions, to end up with (M-1, Q-1, Z-1, Fc)
    if redz is not None:
        for dd in range(3):
            redz = np.moveaxis(redz, dd, 0)
            redz = kale.utils.midpoints(redz, axis=0)
            redz = np.moveaxis(redz, 0, dd)
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else:
        redz = kale.utils.midpoints(edges[2])[np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value


    # ---- calculate GW strain ----
    mt = kale.utils.midpoints(edges[0])
    mr = kale.utils.midpoints(edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs = utils.gw_strain_source(mc, dc, fr)
    return hs

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=20)
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
vals = sam_model(sam, hard)


# Print weird rz_final info

In [None]:
mm, qq, zz, ff = 15, 18, 15, 28 # weird spot

fobs_orb_cents = vals['fobs_orb_cents']
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)

hs_final = holo.gravwaves.strain_amp_from_bin_edges_redz(vals['edges'], redz_final)
hs_init = anis.strain_amp_at_bin_centers_redz(vals['edges'], redz=None)
# hs_final_anis = anis.strain_amp_at_bin_centers_redz(vals['edges'], redz=vals['redz_final'])
hs_final_anis = strain_amp_at_bin_centers_redz(vals['edges'], redz=vals['redz_final'])
print(np.all(hs_final == hs_final_anis)) # check, I am calculating hs correctly

print('hs_init:', utils.stats(hs_init), f", {hs_init.shape=}")
print('hs_final:', utils.stats(hs_final), f", {hs_init.shape=}")

rz_init = kale.utils.midpoints(sam.redz, axis=0)
rz_final = redz_final
for dd in range(3):
    rz_final = np.moveaxis(rz_final, dd, 0)
    rz_final = kale.utils.midpoints(rz_final, axis=0)
    rz_final = np.moveaxis(rz_final, 0, dd)

print('M(mm=%d)=%.2e M_sol, q(qq=%d)=%.2e, f_obs,orb(ff=%d)=%.2f/yr' 
      % (mm, utils.midpoints(vals['edges'][0])[mm]/MSOL, 
         qq, utils.midpoints(vals['edges'][1])[qq],
         ff, fobs_orb_cents[ff]*YR))
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_final=%.2f, hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))

# Plots of Weird Spot

In [None]:
def integrate_mm(dnum, edges): # integrate dN/dlogM
    num = utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
    return num
def integrate_qq(dnum, edges): # integrate dN/dq
    num = utils.trapz(dnum, edges[1], axis=1, cumsum=False)
    return num
def integrate_zz(dnum, edges): # dN/dz
    num = utils.trapz(dnum, edges[2], axis=2, cumsum=False)
    return num
def integrate_ff(dnum, fobs_gw_edges): # dN/dlogn
    num = dnum*np.diff(np.log(fobs_gw_edges))
    return num 

In [None]:
def calc_integrated(vals):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    fobs_gw_edges = vals['fobs_gw_edges']

    hs_cents = vals['hs_cents']
    hs_edges = vals['hs_edges']

    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges),  fobs_gw_edges)
    numh2_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**2, edges)
    numh4_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**4, edges)
    # print('dnum_mm:', dnum_mm.shape, 'numh2_mm:', numh2_mm.shape)

    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    # print('dnum_qq:', dnum_qq.shape)
    # print((dnum_qq* utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2).shape)
    numh2_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2, edges)
    numh4_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**4, edges)
    
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    numh2_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**2, edges)
    numh4_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**4, edges)
    cynum = vals['number']
    numh2_cy = cynum*hs_cents**2
    numh4_cy = cynum*hs_cents**4


    flnum = np.floor(cynum)
    numh2_fl = flnum*hs_cents**2
    numh4_fl = flnum*hs_cents**4

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    # print(f"{xx.shape=}")

    return (xx, numh2_mm, numh2_qq, numh2_zz, numh4_mm, numh4_qq, numh4_zz, 
            numh2_cy, numh4_cy, numh2_fl, numh4_fl)



def plot_integrated_vary(vals, mm_arr, qq_arr, zz_arr, ff_arr, vary='ff'):


    if vary == 'mm':
        colors=cm.rainbow_r(np.linspace(0,1,len(mm_arr)))
    elif vary == 'qq':
        colors=cm.rainbow_r(np.linspace(0,1,len(qq_arr)))
    elif vary == 'zz':
        colors=cm.rainbow_r(np.linspace(0,1,len(zz_arr)))
    else:
        colors=cm.rainbow_r(np.linspace(0,1,len(ff_arr)))



    (xx, numh2_mm, numh2_qq, numh2_zz, numh4_mm, numh4_qq, numh4_zz, 
            numh2_cy, numh4_cy, numh2_fl, numh4_fl) = calc_integrated(vals)

    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, figsize=(16,8))

    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels_h2 = np.array(['$\int h_s^2 dN/d\log (M)$', '$\int h_s^2 dN/dq$', '$\int h_s^2 dN/dz$'])
    ylabels_h4 = np.array(['$\int h_s^4 dN/d\log (M)$', '$\int h_s^4 dN/dq$', '$\int h_s^4 dN/dz$'])
    fobs_gw_cents = vals['fobs_gw_cents']

    for mi,mm in enumerate(mm_arr):
        for qi,qq in enumerate(qq_arr):
            for zi,zz in enumerate(zz_arr):
                for fi,ff in enumerate(ff_arr):
                    if vary == 'mm':
                        color=colors[mi]
                    elif vary == 'qq':
                        color=colors[qi]
                    elif vary == 'zz':
                        color=colors[zi]
                    else:
                        color=colors[fi]
    
                    yy_h2 = np.array([numh2_mm[:,qq,zz,ff], numh2_qq[mm,:,zz,ff], numh2_zz[mm,qq,:,ff]])
                    yy_h4 = np.array([numh4_mm[:,qq,zz,ff], numh4_qq[mm,:,zz,ff], numh4_zz[mm,qq,:,ff]])
                    cy_h2 = np.array([numh2_cy[:,qq,zz,ff], numh2_cy[mm,:,zz,ff], numh2_cy[mm,qq,:,ff]])
                    cy_h4 = np.array([numh4_cy[:,qq,zz,ff], numh4_cy[mm,:,zz,ff], numh4_cy[mm,qq,:,ff]])

                    label0 = '$M=%.2e\ M_\odot$, $q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (xx[0,mm], xx[1,qq], xx[2,zz], fobs_gw_cents[ff]*YR)
                    for ii, ax in enumerate(axs[0,:]): # h2
                        label = label0 if ii==0 else None

                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times \int dN/dx$'
                        else: 
                            cylabel=None
                        ll, = ax.plot(xx[ii], cy_h2[ii], label=cylabel, linestyle='-', alpha=0.35, lw=4, color=color)
                        cc = ll.get_color()
                        ax.plot(xx[ii], yy_h2[ii], label=label, linestyle='--', alpha=0.75, color=cc, lw=1)
                    for ii, ax in enumerate(axs[1,:]): # h4
                        cylabel=None
                        ll = ax.scatter(xx[ii], cy_h4[ii], label=cylabel, marker='o', alpha=0.35, lw=3, s=10, color=color)
                        cc = ll.get_facecolors()
                        ax.scatter(xx[ii], yy_h4[ii], label=None, marker='x', alpha=0.75, color=cc )                

    for ii, ax in enumerate(axs[0,:]):
        ax.set_ylabel(ylabels_h2[ii])
        # if ii>0: ax.sharey(axs[0,0])
        # ax.legend(fontsize=8)
    for ii, ax in enumerate(axs[1,:]):
        ax.set_ylabel(ylabels_h4[ii])
        ax.set_xlabel(xlabels[ii])
        # if ii>0: ax.sharey(axs[1,0])
        ax.sharex(axs[0,ii])
        # ax.legend(fontsize=8)
    leg = fig.legend(ncols=3, bbox_to_anchor=(0.05,0), loc='upper left', fontsize=14)
    title = ('%s, %s, Varying %s' % (str(vals['hard_name']), str(vals['sam'].shape), vary))

    fig.suptitle(title)
    fig.tight_layout()
    return fig, title


In [None]:
def calc_strain_amp(vals):
    # dnum = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)


    hs_final = anis.strain_amp_at_bin_centers_redz(edges, vals['redz_final'])
    hs_initz = anis.strain_amp_at_bin_centers_redz(edges, redz=None)
    # print(holo.utils.stats(hs_final/hs_initz))

    return hs_initz, hs_final, cents_mm, cents_qq, cents_zz

def plot_strain_amp_vary(vals, mm_arr, qq_arr, zz_arr, ff_arr, vary='ff', initz=True):
    hs_initz, hs_final, cents_mm, cents_qq, cents_zz = calc_strain_amp(vals)
    fobs_gw_cents = vals['fobs_gw_cents']


    if vary == 'mm':
        colors=cm.rainbow_r(np.linspace(0,1,len(mm_arr)))
    elif vary == 'qq':
        colors=cm.rainbow_r(np.linspace(0,1,len(qq_arr)))
    elif vary == 'zz':
        colors=cm.rainbow_r(np.linspace(0,1,len(zz_arr)))
    else:
        colors=cm.rainbow_r(np.linspace(0,1,len(ff_arr)))

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$h_s (M)$', '$h_s(q)$', '$h_s(z)$'])


    for mi,mm in enumerate(mm_arr):
        for qi,qq in enumerate(qq_arr):
            for zi,zz in enumerate(zz_arr):
                for fi,ff in enumerate(ff_arr):
                    if vary == 'mm':
                        color=colors[mi]
                    elif vary == 'qq':
                        color=colors[qi]
                    elif vary == 'zz':
                        color=colors[zi]
                    else:
                        color=colors[fi]
                    
                    yy_initz = np.array([hs_initz[:,qq,zz,ff], hs_initz[mm,:,zz,ff], hs_initz[mm,qq,:,ff]])
                    yy_final = np.array([hs_final[:,qq,zz,ff], hs_final[mm,:,zz,ff], hs_final[mm,qq,:,ff]])
                    label_ii = ('$M=%.2e\ M_\odot$, $q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR))

                    for ii, ax in enumerate(axs):
                        label=label_ii if ii==0 else None
                        if initz:
                            if ii==0 and mi==0 and qi==0 and zi==0 and fi==0:
                                label_init = 'initial z'
                            else: label_init=None
                            ax.plot(xx[ii], yy_initz[ii], label=label_init, linestyle='--', alpha=0.65, 
                                color=color)
                        ax.plot(xx[ii], yy_final[ii], label=label, linestyle='-', alpha=0.75, 
                                color=color,)
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])

    leg = fig.legend(ncols=3, bbox_to_anchor=(0.05,0), loc='upper left', fontsize=14)
    title='%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape))
    fig.suptitle(title)
    fig.tight_layout
    return fig, title

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=20)
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
vals = sam_model(sam=sam, hard=hard)
fig, title = plot_strain_amp_vary(vals, mm_arr=[15,], qq_arr=[18,], zz_arr=[18,], ff_arr=[0,4,9,14,19,24,29,34,38], vary='ff')
fig, title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[18,], zz_arr=[18,], ff_arr=[0,4,9,14,19,24,29,34,38], vary='ff')