In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import holodeck as holo
import holodeck.anisotropy as anis
from holodeck.constants import YR, MSOL, GYR
from holodeck import utils

import matplotlib.pyplot as plt
import numpy as np

* dens = d^3 n / [dlog10M dq dz] in units of [Mpc^-3] 
* dnum = d^4N / dlog10M dq dz dlnf
* number = dN /dlnf

In [None]:

nbins = 10

# redshift edges 0.01 to 1
edges = np.geomspace(10**-3, 10**1, nbins+1)  # z bins
print('z edges:', edges)


def setup(
        dnum = np.geomspace(10**0, 10**6, nbins+1), # d^2N / dz dlnf 
        hs_exp_edges = np.linspace(-20, -15, nbins+1), 
):
    print('dnum:', dnum)
    if(dnum[0]>dnum[-1]):
        dnum_str = 'dnum decreasing'
    else: 
        dnum_str = 'dnum increasing'

    # num 
    num = holo.utils.trapz(dnum, edges)
    print('num:', num)

    # hs up to 10^-15
    hs_exp_cents = holo.utils.midpoints(hs_exp_edges)

    hs_edges = 10**hs_exp_edges
    hs_cents = 10**hs_exp_cents
    if(hs_cents[0]>hs_cents[-1]):
        hs_str = 'hs decreasing'
    else:
        hs_str = 'hs increasing'
    print('hs_cents:', hs_cents)
    print('hs_edges:', hs_edges)

    numh2 = num*hs_cents**2
    numh4 = num*hs_cents**4

    dnumh2 = dnum*hs_edges**2
    inth2 = holo.utils.trapz(dnumh2, edges, cumsum=False)

    dnumh4 = dnum*hs_edges**4
    inth4 = holo.utils.trapz(dnumh4, edges, cumsum=False)

    return num, dnum, hs_cents, dnum_str, hs_str, numh2, numh4, inth2, inth4


def plot_integral_vs_z_edges(): # plots using whatever the nb variables r set to
    fig, (ax1, ax2) = holo.plot.figax(ncols=2, figsize=(12,5),
                                      xlabel='z edges', ylabel='integral')
    # ax.plot(edges[:-1], inth2, label=r'$\int(\mathrm{dnum}*h^2)$')
    # ax.plot(edges[:-1], numh2, label=r'$h^2* \int(\mathrm{dnum})$')

    ax1.plot(edges[:-1], inth2, label=r'$\int(\mathrm{dnum}*h^2)$')
    ax1.plot(edges[:-1], numh2, label=r'$h^2* \int(\mathrm{dnum})$')
    ax1.legend()

    ax2.plot(edges[:-1], inth4, label=r'$\int(\mathrm{dnum}*h^4)$')
    ax2.plot(edges[:-1], numh4, label=r'$h^4* \int(\mathrm{dnum})$')
    ax2.legend()
    
    fig.suptitle('nbins=%d, z increasing -> %s (%.2e to %.2e) and %s (%.2e to %.2e)' 
                 % (nbins, dnum_str, dnum[0], dnum[-1], hs_str, hs_cents[0], hs_cents[-1]),
                 fontsize=12)
    fig.tight_layout()
    # ax2.set_title('z increasing -> %s and %s' % (dnum_str, hs_str))

    return fig


In [None]:

num, dnum, hs_cents, dnum_str, hs_str, numh2, numh4, inth2, inth4 = setup(
        dnum = np.geomspace(10**1, 10**2, nbins+1), # d^2N / dz dlnf 
        hs_exp_edges = np.linspace(-17, -16.5, nbins+1), )
fig=plot_integral_vs_z_edges()

In [None]:
num, dnum, hs_cents, dnum_str, hs_str, numh2, numh4, inth2, inth4 = setup(
        dnum = np.geomspace(10**8, 10**-15, nbins+1), # d^2N / dz dlnf 
        hs_exp_edges = np.linspace( -10, -35, nbins+1), )
fig=plot_integral_vs_z_edges()

In [None]:
num, dnum, hs_cents, dnum_str, hs_str, numh2, numh4, inth2, inth4 = setup(
        dnum = np.geomspace(10**-5, 10**6, nbins+1), # d^2N / dz dlnf 
        hs_exp_edges = np.linspace( -15, -25, nbins+1), )
fig=plot_integral_vs_z_edges()


In [None]:
num, dnum, hs_cents, dnum_str, hs_str, numh2, numh4, inth2, inth4 = setup(
        dnum = np.geomspace(10**-3, 10**6, nbins+1), # d^2N / dz dlnf 
        hs_exp_edges = np.linspace( -20, -15, nbins+1), )
fig=plot_integral_vs_z_edges()

integrating before just makes the values slightly smaller

In [None]:
print(num)

In [None]:
print(np.floor(num))

In [None]:
def sam_model(sam, hard,
        dur=16.03*YR, cad=0.2*YR, use_redz=True):
    fobs_gw_cents = utils.nyquist_freqs(dur,cad)
    fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
        hard_name = 'Fixed Time'
    elif isinstance(hard, holo.hardening.Hard_GW):
        hard_name = 'GW Only'
    else:
        raise Exception("'hard' must be an instance of 'Fixed_Time_2PL_SAM' or 'Hard_GW'")

    redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)
    if use_redz:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges, redz_final)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges, redz_final)
    else:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges)

    vals = {
        'hard':hard, 'sam':sam, 'edges':edges, 'number': number, 'diff_num':diff_num, 'redz_final':redz_final,
        'hs_cents':hs_cents, 'hs_edges':hs_edges, 'fobs_gw_cents':fobs_gw_cents, 'fobs_gw_edges':fobs_gw_edges, 
        'fobs_orb_cents':fobs_orb_cents, 'fobs_orb_edges':fobs_orb_edges, 'hard_name':hard_name
    }
    return vals

def integrate_mm(dnum, edges): # integrate dN/dlogM
    num = utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
    return num
def integrate_qq(dnum, edges): # integrate dN/dq
    num = utils.trapz(dnum, edges[1], axis=1, cumsum=False)
    return num
def integrate_zz(dnum, edges): # dN/dz
    num = utils.trapz(dnum, edges[2], axis=2, cumsum=False)
    return num
def integrate_ff(dnum, fobs_gw_edges): # dN/dlogn
    num = dnum*np.diff(np.log(fobs_gw_edges))
    return num 

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=10)
vals = sam_model(sam, hard=holo.hardening.Hard_GW())
edges_mm = vals['edges'][0]
cents_mm = utils.midpoints(edges_mm, log=True)
edges_qq = vals['edges'][1]
cents_qq = utils.midpoints(edges_qq, log=True)
edges_zz = vals['edges'][2]
cents_zz = utils.midpoints(edges_zz, log=True)
diff_num = vals['diff_num']

np.set_printoptions(precision=3)
print('M edges (g)', edges_mm)
print('q edges', edges_qq)
print('z edges', edges_zz)


edges = vals['edges']
fobs_gw_edges = vals['fobs_gw_edges']
fobs_gw_cents = vals['fobs_gw_cents']
dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges), fobs_gw_edges)
dnum_qq = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
print(dnum_mm.shape, dnum_qq.shape, dnum_zz.shape)

In [None]:
print(edges_zz.flatten().shape)
plt.scatter(np.arange(len(edges_zz)), edges_zz)
plt.yscale('log')

In [None]:
def plot_dnum_dpar(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(18,5),
        sharey=True
    )

    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)

    diff_num = vals['diff_num']
    edges = vals['edges']
    fobs_gw_edges = vals['fobs_gw_edges']
    fobs_gw_cents = vals['fobs_gw_cents']
    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges), fobs_gw_edges)
    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$dN/d\log (M)$', '$dN/dq$', '$dN/dz$'])

    xx = np.array([edges_mm/MSOL, edges_qq, edges_zz])

    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy = np.array([dnum_mm[:,qq,zz,ff], dnum_qq[mm,:,zz,ff], dnum_zz[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])
                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy[ii], label=labels[ii], alpha=0.75, 
                                )
    for ii, ax in enumerate(axs):             
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend(loc='upper right')

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

fig = plot_dnum_dpar(vals, mm_arr=[2,7], qq_arr=[7,], zz_arr=[7], ff_arr=[2,32])

In [None]:
arr = np.linspace(0.5, 3, 10)
print(arr)
print(np.floor(arr))

In [None]:
def plot_number(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)

    cynum = vals['number']
    tznum = integrate_zz(dnum_zz, edges)
    utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    utnum = utnum * np.diff(np.log(fobs_gw_edges))

    flnum = np.floor(cynum)

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$N (M)$', '$N(q)$', '$N(z)$'])

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
   
                    yy_cyth = np.array([cynum[:,qq,zz,ff], cynum[mm,:,zz,ff], cynum[mm,qq,:,ff]])
                    yy_trap = np.array([tznum[:,qq,zz,ff], tznum[mm,:,zz,ff], tznum[mm,qq,:,ff]])
                    yy_util = np.array([utnum[:,qq,zz,ff], utnum[mm,:,zz,ff], utnum[mm,qq,:,ff]])
                    yy_flor = np.array([flnum[:,qq,zz,ff], flnum[mm,:,zz,ff], flnum[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_cyth[ii], label=labels[ii]+' cython', linestyle='-', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_trap[ii], label=labels[ii]+' trapz', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_util[ii], label=labels[ii]+' utils', linestyle=':', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_flor[ii], label=labels[ii]+' floor', linestyle='-.', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

fig = plot_number(vals,  mm_arr=[7], qq_arr=[7,], zz_arr=[7], ff_arr=[32])

ok cool, my integration is working right

In [None]:
def plot_strain_amp(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    # dnum = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)


    hs_final = anis.strain_amp_at_bin_centers_redz(edges, vals['redz_final'])
    hs_initz = anis.strain_amp_at_bin_centers_redz(edges, redz=None)
    print(holo.utils.stats(hs_final/hs_initz))

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$h_s (M)$', '$h_s(q)$', '$h_s(z)$'])

    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy_initz = np.array([hs_initz[:,qq,zz,ff], hs_initz[mm,:,zz,ff], hs_initz[mm,qq,:,ff]])
                    yy_final = np.array([hs_final[:,qq,zz,ff], hs_final[mm,:,zz,ff], hs_final[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_initz[ii], label=labels[ii]+' initial z', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_final[ii], label=labels[ii]+' final z', linestyle='-', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

fig = plot_strain_amp(vals, mm_arr=[6,], qq_arr=[7,], zz_arr=[5], ff_arr=[1,8,])

In [None]:
def plot_number_times_h2(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    hs_cents = vals['hs_cents']

    cynum = vals['number']*hs_cents**2
    tznum = integrate_zz(dnum_zz, edges)*hs_cents**2
    utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    utnum = utnum * np.diff(np.log(fobs_gw_edges))*hs_cents**2

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array([r'$N (M)\times h_s^2$', r'$N(q)\times h_s^2$', r'$N(z)\times h_s^2$'])

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
   
                    yy_cyth = np.array([cynum[:,qq,zz,ff], cynum[mm,:,zz,ff], cynum[mm,qq,:,ff]])
                    yy_trap = np.array([tznum[:,qq,zz,ff], tznum[mm,:,zz,ff], tznum[mm,qq,:,ff]])
                    yy_util = np.array([utnum[:,qq,zz,ff], utnum[mm,:,zz,ff], utnum[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_cyth[ii], label=labels[ii]+' cython', linestyle='-', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_trap[ii], label=labels[ii]+' trapz', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_util[ii], label=labels[ii]+' utils', linestyle=':', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

fig = plot_number_times_h2(vals, mm_arr=[6,], qq_arr=[7,], zz_arr=[5], ff_arr=[1,8,])

In [None]:


def plot_integrated(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)

    hs_cents = vals['hs_cents']
    hs_edges = vals['hs_edges']

    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges),  fobs_gw_edges)
    numh2_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**2, edges)
    numh4_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**4, edges)
    print('dnum_mm:', dnum_mm.shape, 'numh2_mm:', numh2_mm.shape)

    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    print('dnum_qq:', dnum_qq.shape)
    print((dnum_qq* utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2).shape)
    numh2_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2, edges)
    numh4_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**4, edges)
    
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    numh2_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**2, edges)
    numh4_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**4, edges)

    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels_h2 = np.array(['$\int h_s^2 dN/d\log (M)$', '$\int h_s^2 dN/dq$', '$\int h_s^2 dN/dz$'])
    ylabels_h4 = np.array(['$\int h_s^4 dN/d\log (M)$', '$\int h_s^4 dN/dq$', '$\int h_s^4 dN/dz$'])

    cynum = vals['number']
    numh2_cy = cynum*hs_cents**2
    numh4_cy = cynum*hs_cents**4


    flnum = np.floor(cynum)
    numh2_fl = flnum*hs_cents**2
    numh4_fl = flnum*hs_cents**4
    # tznum = integrate_zz(dnum_zz, edges)
    # utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    # utnum = utnum * np.diff(np.log(fobs_gw_edges))

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])

    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, figsize=(16,8))


    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy_h2 = np.array([numh2_mm[:,qq,zz,ff], numh2_qq[mm,:,zz,ff], numh2_zz[mm,qq,:,ff]])
                    yy_h4 = np.array([numh4_mm[:,qq,zz,ff], numh4_qq[mm,:,zz,ff], numh4_zz[mm,qq,:,ff]])
                    cy_h2 = np.array([numh2_cy[:,qq,zz,ff], numh2_cy[mm,:,zz,ff], numh2_cy[mm,qq,:,ff]])
                    cy_h4 = np.array([numh4_cy[:,qq,zz,ff], numh4_cy[mm,:,zz,ff], numh4_cy[mm,qq,:,ff]])
                    fl_h2 = np.array([numh2_fl[:,qq,zz,ff], numh2_fl[mm,:,zz,ff], numh2_fl[mm,qq,:,ff]])
                    fl_h4 = np.array([numh4_fl[:,qq,zz,ff], numh4_fl[mm,:,zz,ff], numh4_fl[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs[0,:]): # h2
                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times \int dN/dx$'
                            fllabel=r'rounded $h_s^2 \times \int dN/dx$'
                        else: 
                            cylabel=None
                            fllabel=None
                        ll, = ax.plot(xx[ii], cy_h2[ii], label=cylabel, linestyle='-', alpha=0.35, lw=4)
                        cc = ll.get_color()
                        ax.plot(xx[ii], fl_h2[ii], label=fllabel, linestyle='-', alpha=0.65, color=cc, lw=2)
                        ax.plot(xx[ii], yy_h2[ii], label=labels[ii], linestyle='--', alpha=0.75, color=cc, lw=1)
                    for ii, ax in enumerate(axs[1,:]): # h4
                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times \int dN/dx$'
                            fllabel=r'rounded $h_s^2 \times \int dN/dx$'
                        else: 
                            cylabel=None
                            fllabel=None
                        # ll, = ax.plot(xx[ii], cy_h4[ii], label=cylabel, linestyle='-', alpha=0.5, lw=3)
                        ll = ax.scatter(xx[ii], cy_h4[ii], label=cylabel, marker='o', alpha=0.35, lw=3, s=10)
                        # cc = ll.get_color()
                        # ax.plot(xx[ii], yy_h4[ii], label=labels[ii], linestyle='x', alpha=0.75, color=cc )
                        cc = ll.get_facecolors()
                        ax.scatter(xx[ii], fl_h4[ii], label=fllabel, marker='+', alpha=0.65, color=cc)
                        ax.scatter(xx[ii], yy_h4[ii], label=labels[ii], marker='x', alpha=0.75, color=cc )
                        
                        

    for ii, ax in enumerate(axs[0,:]):
        ax.set_ylabel(ylabels_h2[ii])
        # if ii>0: ax.sharey(axs[0,0])
        ax.legend(fontsize=8)
    for ii, ax in enumerate(axs[1,:]):
        ax.set_ylabel(ylabels_h4[ii])
        ax.set_xlabel(xlabels[ii])
        # if ii>0: ax.sharey(axs[1,0])
        ax.sharex(axs[0,ii])
        ax.legend(fontsize=8)

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout()
    return fig

fig = plot_integrated(vals, mm_arr=[3,6,7], qq_arr=[-1,], zz_arr=[7,], ff_arr=[20])

# All together, fixed time model

In [None]:
sam=holo.sam.Semi_Analytic_Model(shape=20)
vals = sam_model(sam=sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR))
print(vals['hard_name'])

In [None]:
fig = plot_dnum_dpar(vals, mm_arr=[2,7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
fig = plot_strain_amp(vals, mm_arr=[6,], qq_arr=[17,], zz_arr=[15], ff_arr=[1,8,])
fig = plot_number(vals,  mm_arr=[7], qq_arr=[17,], zz_arr=[17], ff_arr=[32])
fig = plot_integrated(vals, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals, mm_arr=[7,15], qq_arr=[17,], zz_arr=[5,17,], ff_arr=[32,])

# All together, GW Only

In [None]:
sam_20=holo.sam.Semi_Analytic_Model(shape=20)
vals_20GW = sam_model(sam=sam_20, hard=holo.hardening.Hard_GW())
print(vals['hard_name'])

In [None]:
fig = plot_dnum_dpar(vals_20GW, mm_arr=[2,7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
fig = plot_strain_amp(vals_20GW, mm_arr=[6,], qq_arr=[17,], zz_arr=[15], ff_arr=[1,8,])
fig = plot_number(vals_20GW,  mm_arr=[7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[2, 32])

In [None]:
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[5,17,], ff_arr=[32,])

# Comparing dnum methods

## Full GW Only

In [None]:
sam_full = holo.sam.Semi_Analytic_Model(shape=None)
vals_fullGW = sam_model(sam=sam_full, hard=holo.hardening.Hard_GW())

In [None]:
print(vals_fullGW['edges'][2])

In [None]:
plt.plot(np.arange(len(vals_fullGW['edges'][2])), vals_fullGW['edges'][2])
plt.yscale('log')

In [None]:
fig = plot_dnum_dpar(vals_fullGW, mm_arr=[2,7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_strain_amp(vals_fullGW, mm_arr=[6,], qq_arr=[17,], zz_arr=[65], ff_arr=[1,8,])
fig = plot_number(vals_fullGW,  mm_arr=[7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_integrated(vals_fullGW, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[65,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullGW, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

## Full Fixed Time

In [None]:
sam_full = holo.sam.Semi_Analytic_Model(shape=None)
vals_fullFT = sam_model(sam=sam_full, hard=holo.hardening.Fixed_Time_2PL_SAM(sam_full, 3*GYR))

In [None]:
fig = plot_dnum_dpar(vals_fullFT, mm_arr=[2,7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_strain_amp(vals_fullFT, mm_arr=[6,], qq_arr=[17,], zz_arr=[65], ff_arr=[1,8,])
fig = plot_number(vals_fullFT,  mm_arr=[7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[65,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

# Comparing strain amplitude for initial/final redz

conclusion: 
hs differences are almost negligible UNLESS hs_final is just 0 because redz_final has become -1 (never reaches GW emission). This cuts off the loudest sources, which have the lowest redshifts and largest masses.


In [None]:
sam=holo.sam.Semi_Analytic_Model(shape=20)
vals = sam_model(sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR), use_redz=True)
fig = plot_strain_amp(vals, mm_arr=[14,], qq_arr=[18,], zz_arr=[15,], ff_arr=[32,])
fig.axes[0].set_ylim(10**-16, 10**-11)

In [None]:

edges_zz = vals['edges'][2]
cents_zz = utils.midpoints(edges_zz, log=True)
redz_final = vals['redz_final']
print(np.unravel_index(np.argmax(np.abs(redz_final-edges_zz[np.newaxis,np.newaxis,:,np.newaxis])), redz_final.shape))

In [None]:
hs_final = anis.strain_amp_at_bin_edges_redz(vals['edges'], vals['redz_final'])
hs_initz = anis.strain_amp_at_bin_edges_redz(vals['edges'], redz=None)
print(hs_final[14,0,17,32])
print(hs_initz[14,0,17,32])
print(np.unravel_index(np.argmax(np.abs(hs_final-hs_initz)), hs_initz.shape))
print(hs_final[19,19,0,39])
print(hs_initz[19,19,0,39])

In [None]:

plt.scatter(redz_final.flatten(), hs_final.flatten())

In [None]:
plt.scatter(edges_zz, redz_final[19,19,:,39])