In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import holodeck as holo
import holodeck.anisotropy as anis
from holodeck.constants import YR, MSOL, GYR
from holodeck import utils

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

* dens = d^3 n / [dlog10M dq dz] in units of [Mpc^-3] 
* dnum = d^4N / dlog10M dq dz dlnf
* number = dN /dlnf

# Functions

sam_model()

In [None]:
def sam_model(sam, hard,
        dur=16.03*YR, cad=0.2*YR, use_redz=True):
    fobs_gw_cents = utils.nyquist_freqs(dur,cad)
    fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
        hard_name = 'Fixed Time'
    elif isinstance(hard, holo.hardening.Hard_GW):
        hard_name = 'GW Only'
    else:
        raise Exception("'hard' must be an instance of 'Fixed_Time_2PL_SAM' or 'Hard_GW'")

    redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)
    if use_redz:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges, redz_final)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges, redz_final)
    else:
        hs_cents = anis.strain_amp_at_bin_centers_redz(edges)
        hs_edges = anis.strain_amp_at_bin_edges_redz(edges)

    vals = {
        'hard':hard, 'sam':sam, 'edges':edges, 'number': number, 'diff_num':diff_num, 'redz_final':redz_final,
        'hs_cents':hs_cents, 'hs_edges':hs_edges, 'fobs_gw_cents':fobs_gw_cents, 'fobs_gw_edges':fobs_gw_edges, 
        'fobs_orb_cents':fobs_orb_cents, 'fobs_orb_edges':fobs_orb_edges, 'hard_name':hard_name
    }
    return vals

def integrate_mm(dnum, edges): # integrate dN/dlogM
    num = utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
    return num
def integrate_qq(dnum, edges): # integrate dN/dq
    num = utils.trapz(dnum, edges[1], axis=1, cumsum=False)
    return num
def integrate_zz(dnum, edges): # dN/dz
    num = utils.trapz(dnum, edges[2], axis=2, cumsum=False)
    return num
def integrate_ff(dnum, fobs_gw_edges): # dN/dlogn
    num = dnum*np.diff(np.log(fobs_gw_edges))
    return num 

plot_dnum_dpar() vs each edge parameter

In [None]:
def plot_dnum_dpar(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(18,5),
        sharey=True
    )

    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    fobs_gw_edges = vals['fobs_gw_edges']
    fobs_gw_cents = vals['fobs_gw_cents']

    diff_num = vals['diff_num']
    edges = vals['edges']
    fobs_gw_edges = vals['fobs_gw_edges']
    fobs_gw_cents = vals['fobs_gw_cents']
    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges), fobs_gw_edges)
    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$dN/d\log (M)$', '$dN/dq$', '$dN/dz$'])

    xx = np.array([edges_mm/MSOL, edges_qq, edges_zz])

    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy = np.array([dnum_mm[:,qq,zz,ff], dnum_qq[mm,:,zz,ff], dnum_zz[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])
                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy[ii], label=labels[ii], alpha=0.75, 
                                )
    for ii, ax in enumerate(axs):             
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend(loc='upper right')

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

plot_number() by cython, utils, trapz, and rounding down

In [None]:
def plot_number(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    fobs_gw_edges = vals['fobs_gw_edges']
    fobs_gw_cents = vals['fobs_gw_cents']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)

    cynum = vals['number']
    tznum = integrate_zz(dnum_zz, edges)
    utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    utnum = utnum * np.diff(np.log(fobs_gw_edges))

    flnum = np.floor(cynum)

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$N (M)$', '$N(q)$', '$N(z)$'])

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
   
                    yy_cyth = np.array([cynum[:,qq,zz,ff], cynum[mm,:,zz,ff], cynum[mm,qq,:,ff]])
                    yy_trap = np.array([tznum[:,qq,zz,ff], tznum[mm,:,zz,ff], tznum[mm,qq,:,ff]])
                    yy_util = np.array([utnum[:,qq,zz,ff], utnum[mm,:,zz,ff], utnum[mm,qq,:,ff]])
                    yy_flor = np.array([flnum[:,qq,zz,ff], flnum[mm,:,zz,ff], flnum[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_cyth[ii], label=labels[ii]+' cython', linestyle='-', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_trap[ii], label=labels[ii]+' trapz', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_util[ii], label=labels[ii]+' utils', linestyle=':', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_flor[ii], label=labels[ii]+' floor', linestyle='-.', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

plot_strain_amp() for initial and final z

In [None]:
def plot_strain_amp(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    # dnum = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)


    hs_final = anis.strain_amp_at_bin_centers_redz(edges, vals['redz_final'])
    hs_initz = anis.strain_amp_at_bin_centers_redz(edges, redz=None)
    print(holo.utils.stats(hs_final/hs_initz))

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$h_s (M)$', '$h_s(q)$', '$h_s(z)$'])

    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy_initz = np.array([hs_initz[:,qq,zz,ff], hs_initz[mm,:,zz,ff], hs_initz[mm,qq,:,ff]])
                    yy_final = np.array([hs_final[:,qq,zz,ff], hs_final[mm,:,zz,ff], hs_final[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_initz[ii], label=labels[ii]+' initial z', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_final[ii], label=labels[ii]+' final z', linestyle='-', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

plot_number_times_h2()

In [None]:
def plot_number_times_h2(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges), fobs_gw_edges)
    hs_cents = vals['hs_cents']
    fobs_gw_edges = vals['fobs_gw_edges']
    fobs_gw_cents = vals['fobs_gw_cents']

    cynum = vals['number']*hs_cents**2
    tznum = integrate_zz(dnum_zz, edges)*hs_cents**2
    utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    utnum = utnum * np.diff(np.log(fobs_gw_edges))*hs_cents**2

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array([r'$N (M)\times h_s^2$', r'$N(q)\times h_s^2$', r'$N(z)\times h_s^2$'])

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
   
                    yy_cyth = np.array([cynum[:,qq,zz,ff], cynum[mm,:,zz,ff], cynum[mm,qq,:,ff]])
                    yy_trap = np.array([tznum[:,qq,zz,ff], tznum[mm,:,zz,ff], tznum[mm,qq,:,ff]])
                    yy_util = np.array([utnum[:,qq,zz,ff], utnum[mm,:,zz,ff], utnum[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs):
                        ax.plot(xx[ii], yy_cyth[ii], label=labels[ii]+' cython', linestyle='-', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_trap[ii], label=labels[ii]+' trapz', linestyle='--', alpha=0.75 
                                )
                        ax.plot(xx[ii], yy_util[ii], label=labels[ii]+' utils', linestyle=':', alpha=0.75 
                                )
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])
        ax.legend()

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout
    return fig

plot_integrated()

In [None]:
def plot_integrated(vals, mm_arr, qq_arr, zz_arr, ff_arr):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)

    hs_cents = vals['hs_cents']
    hs_edges = vals['hs_edges']
    fobs_gw_edges = vals['fobs_gw_edges']

    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges),  fobs_gw_edges)
    numh2_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**2, edges)
    numh4_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**4, edges)
    # print('dnum_mm:', dnum_mm.shape, 'numh2_mm:', numh2_mm.shape)

    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    # print('dnum_qq:', dnum_qq.shape)
    print((dnum_qq* utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2).shape)
    numh2_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2, edges)
    numh4_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**4, edges)
    
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    numh2_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**2, edges)
    numh4_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**4, edges)

    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels_h2 = np.array(['$\int h_s^2 dN/d\log (M)$', '$\int h_s^2 dN/dq$', '$\int h_s^2 dN/dz$'])
    ylabels_h4 = np.array(['$\int h_s^4 dN/d\log (M)$', '$\int h_s^4 dN/dq$', '$\int h_s^4 dN/dz$'])

    cynum = vals['number']
    numh2_cy = cynum*hs_cents**2
    numh4_cy = cynum*hs_cents**4


    flnum = np.floor(cynum)
    numh2_fl = flnum*hs_cents**2
    numh4_fl = flnum*hs_cents**4
    # tznum = integrate_zz(dnum_zz, edges)
    # utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    # utnum = utnum * np.diff(np.log(fobs_gw_edges))

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])

    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, figsize=(16,8))


    for mm in mm_arr:
        for qq in qq_arr:
            for zz in zz_arr:
                for ff in ff_arr:
                    yy_h2 = np.array([numh2_mm[:,qq,zz,ff], numh2_qq[mm,:,zz,ff], numh2_zz[mm,qq,:,ff]])
                    yy_h4 = np.array([numh4_mm[:,qq,zz,ff], numh4_qq[mm,:,zz,ff], numh4_zz[mm,qq,:,ff]])
                    cy_h2 = np.array([numh2_cy[:,qq,zz,ff], numh2_cy[mm,:,zz,ff], numh2_cy[mm,qq,:,ff]])
                    cy_h4 = np.array([numh4_cy[:,qq,zz,ff], numh4_cy[mm,:,zz,ff], numh4_cy[mm,qq,:,ff]])
                    fl_h2 = np.array([numh2_fl[:,qq,zz,ff], numh2_fl[mm,:,zz,ff], numh2_fl[mm,qq,:,ff]])
                    fl_h4 = np.array([numh4_fl[:,qq,zz,ff], numh4_fl[mm,:,zz,ff], numh4_fl[mm,qq,:,ff]])
                    labels = np.array(['$q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_zz[zz], fobs_gw_cents[ff]*YR),
                                    '$M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq],  fobs_gw_cents[ff]*YR),])

                    for ii, ax in enumerate(axs[0,:]): # h2
                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times$ cynum'
                            fllabel=r'$h_s^2 \times$ flnum'
                        else: 
                            cylabel=None
                            fllabel=None
                        ll, = ax.plot(xx[ii], cy_h2[ii], label=cylabel, linestyle='-', alpha=0.35, lw=4)
                        cc = ll.get_color()
                        ax.plot(xx[ii], fl_h2[ii], label=fllabel, linestyle='-', alpha=0.65, color=cc, lw=2)
                        ax.plot(xx[ii], yy_h2[ii], label=labels[ii], linestyle='--', alpha=0.75, color=cc, lw=1)
                    for ii, ax in enumerate(axs[1,:]): # h4
                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times$ cynum'
                            fllabel=r'$h_s^2 \times$ flnum'
                        else: 
                            cylabel=None
                            fllabel=None
                        # ll, = ax.plot(xx[ii], cy_h4[ii], label=cylabel, linestyle='-', alpha=0.5, lw=3)
                        ll = ax.scatter(xx[ii], cy_h4[ii], label=cylabel, marker='o', alpha=0.35, lw=3, s=10)
                        # cc = ll.get_color()
                        # ax.plot(xx[ii], yy_h4[ii], label=labels[ii], linestyle='x', alpha=0.75, color=cc )
                        cc = ll.get_facecolors()
                        ax.scatter(xx[ii], fl_h4[ii], label=fllabel, marker='+', alpha=0.65, color=cc)
                        ax.scatter(xx[ii], yy_h4[ii], label=labels[ii], marker='x', alpha=0.75, color=cc )
                        
                        

    for ii, ax in enumerate(axs[0,:]):
        ax.set_ylabel(ylabels_h2[ii])
        # if ii>0: ax.sharey(axs[0,0])
        ax.legend(fontsize=8)
    for ii, ax in enumerate(axs[1,:]):
        ax.set_ylabel(ylabels_h4[ii])
        ax.set_xlabel(xlabels[ii])
        # if ii>0: ax.sharey(axs[1,0])
        ax.sharex(axs[0,ii])
        ax.legend(fontsize=8)

    fig.suptitle('%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape)))
    fig.tight_layout()
    return fig

plot_integrated_vary()

In [None]:
def calc_integrated(vals):
    diff_num = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)
    fobs_gw_edges = vals['fobs_gw_edges']

    hs_cents = vals['hs_cents']
    hs_edges = vals['hs_edges']

    dnum_mm = integrate_ff(integrate_zz(integrate_qq(diff_num, edges), edges),  fobs_gw_edges)
    numh2_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**2, edges)
    numh4_mm = integrate_mm(dnum_mm* 
                            utils.midpoints_multiax(hs_edges, axis=(1,2), log=True)**4, edges)
    # print('dnum_mm:', dnum_mm.shape, 'numh2_mm:', numh2_mm.shape)

    dnum_qq = integrate_ff(integrate_zz(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    # print('dnum_qq:', dnum_qq.shape)
    # print((dnum_qq* utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2).shape)
    numh2_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**2, edges)
    numh4_qq = integrate_qq(dnum_qq*
                            utils.midpoints_multiax(hs_edges, axis=(0,2), log=True)**4, edges)
    
    dnum_zz = integrate_ff(integrate_qq(integrate_mm(diff_num, edges), edges),  fobs_gw_edges)
    numh2_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**2, edges)
    numh4_zz = integrate_zz(dnum_zz*
                            utils.midpoints_multiax(hs_edges, axis=(0,1), log=True)**4, edges)
    cynum = vals['number']
    numh2_cy = cynum*hs_cents**2
    numh4_cy = cynum*hs_cents**4


    flnum = np.floor(cynum)
    numh2_fl = flnum*hs_cents**2
    numh4_fl = flnum*hs_cents**4
    # tznum = integrate_zz(dnum_zz, edges)
    # utnum = utils._integrate_grid_differential_number(edges, diff_num, freq=False)
    # utnum = utnum * np.diff(np.log(fobs_gw_edges))

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    # print(f"{xx.shape=}")

    return (xx, numh2_mm, numh2_qq, numh2_zz, numh4_mm, numh4_qq, numh4_zz, 
            numh2_cy, numh4_cy, numh2_fl, numh4_fl)



def plot_integrated_vary(vals, mm_arr, qq_arr, zz_arr, ff_arr, vary='mm'):

    # if vary == 'mm':
    #     colors=cm.rainbow_r(np.linspace(0,1,np.max(mm_arr)+1))
    # elif vary == 'qq':
    #     colors=cm.rainbow_r(np.linspace(0,1,np.max(qq_arr)+1))
    # elif vary == 'zz':
    #     colors=cm.rainbow_r(np.linspace(0,1,np.max(zz_arr)+1))
    # else:
    #     colors=cm.rainbow_r(np.linspace(0,1,np.max(ff_arr)+1))

    if vary == 'mm':
        colors=cm.rainbow_r(np.linspace(0,1,len(mm_arr)))
    elif vary == 'qq':
        colors=cm.rainbow_r(np.linspace(0,1,len(qq_arr)))
    elif vary == 'zz':
        colors=cm.rainbow_r(np.linspace(0,1,len(zz_arr)))
    else:
        colors=cm.rainbow_r(np.linspace(0,1,len(ff_arr)))



    (xx, numh2_mm, numh2_qq, numh2_zz, numh4_mm, numh4_qq, numh4_zz, 
            numh2_cy, numh4_cy, numh2_fl, numh4_fl) = calc_integrated(vals)

    fig, axs = holo.plot.figax(
        nrows=2, ncols=3, figsize=(16,8))

    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels_h2 = np.array(['$\int h_s^2 dN/d\log (M)$', '$\int h_s^2 dN/dq$', '$\int h_s^2 dN/dz$'])
    ylabels_h4 = np.array(['$\int h_s^4 dN/d\log (M)$', '$\int h_s^4 dN/dq$', '$\int h_s^4 dN/dz$'])
    fobs_gw_cents = vals['fobs_gw_cents']



    for mi,mm in enumerate(mm_arr):
        for qi,qq in enumerate(qq_arr):
            for zi,zz in enumerate(zz_arr):
                for fi,ff in enumerate(ff_arr):
                    if vary == 'mm':
                        color=colors[mi]
                    elif vary == 'qq':
                        color=colors[qi]
                    elif vary == 'zz':
                        color=colors[zi]
                    else:
                        color=colors[fi]
    
                    yy_h2 = np.array([numh2_mm[:,qq,zz,ff], numh2_qq[mm,:,zz,ff], numh2_zz[mm,qq,:,ff]])
                    yy_h4 = np.array([numh4_mm[:,qq,zz,ff], numh4_qq[mm,:,zz,ff], numh4_zz[mm,qq,:,ff]])
                    cy_h2 = np.array([numh2_cy[:,qq,zz,ff], numh2_cy[mm,:,zz,ff], numh2_cy[mm,qq,:,ff]])
                    cy_h4 = np.array([numh4_cy[:,qq,zz,ff], numh4_cy[mm,:,zz,ff], numh4_cy[mm,qq,:,ff]])
                    # fl_h2 = np.array([numh2_fl[:,qq,zz,ff], numh2_fl[mm,:,zz,ff], numh2_fl[mm,qq,:,ff]])
                    # fl_h4 = np.array([numh4_fl[:,qq,zz,ff], numh4_fl[mm,:,zz,ff], numh4_fl[mm,qq,:,ff]])
                    label0 = '$M=%.2e\ M_\odot$, $q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (xx[0,mm], xx[1,qq], xx[2,zz], fobs_gw_cents[ff]*YR)
                    for ii, ax in enumerate(axs[0,:]): # h2
                        label = label0 if ii==0 else None

                        if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                            cylabel=r'$h_s^2 \times \int dN/dx$'
                            fllabel=r'rounded $h_s^2 \times \int dN/dx$'
                        else: 
                            cylabel=None
                            fllabel=None
                        ll, = ax.plot(xx[ii], cy_h2[ii], label=cylabel, linestyle='-', alpha=0.35, lw=4, color=color)
                        cc = ll.get_color()
                        # ax.plot(xx[ii], fl_h2[ii], label=fllabel, linestyle='-', alpha=0.65, color=cc, lw=2)
                        ax.plot(xx[ii], yy_h2[ii], label=label, linestyle='--', alpha=0.75, color=cc, lw=1)
                    for ii, ax in enumerate(axs[1,:]): # h4
                        # if ii==0 and mm==mm_arr[0] and qq==qq_arr[0] and zz==zz_arr[0] and ff==ff_arr[0]:
                        #     cylabel=r'$h_s^2 \times \int dN/dx$'
                        #     fllabel=r'rounded $h_s^2 \times \int dN/dx$'
                        # else: 
                        cylabel=None
                        fllabel=None
                        # ll, = ax.plot(xx[ii], cy_h4[ii], label=cylabel, linestyle='-', alpha=0.5, lw=3)
                        ll = ax.scatter(xx[ii], cy_h4[ii], label=cylabel, marker='o', alpha=0.35, lw=3, s=10, color=color)
                        # cc = ll.get_color()
                        # ax.plot(xx[ii], yy_h4[ii], label=labels[ii], linestyle='x', alpha=0.75, color=cc )
                        cc = ll.get_facecolors()
                        # ax.scatter(xx[ii], fl_h4[ii], label=fllabel, marker='+', alpha=0.65, color=cc)
                        ax.scatter(xx[ii], yy_h4[ii], label=None, marker='x', alpha=0.75, color=cc )                

    for ii, ax in enumerate(axs[0,:]):
        ax.set_ylabel(ylabels_h2[ii])
        # if ii>0: ax.sharey(axs[0,0])
        # ax.legend(fontsize=8)
    for ii, ax in enumerate(axs[1,:]):
        ax.set_ylabel(ylabels_h4[ii])
        ax.set_xlabel(xlabels[ii])
        # if ii>0: ax.sharey(axs[1,0])
        ax.sharex(axs[0,ii])
        # ax.legend(fontsize=8)
    leg = fig.legend(ncols=3, bbox_to_anchor=(0.05,0), loc='upper left', fontsize=14)
    title = ('%s, %s, Varying %s' % (str(vals['hard_name']), str(vals['sam'].shape), vary))

    fig.suptitle(title)
    fig.tight_layout()
    return fig, title


plot strain_amp_vary()

In [None]:


def calc_strain_amp(vals):
    # dnum = vals['diff_num']
    edges = vals['edges']
    edges_mm = vals['edges'][0]
    cents_mm = utils.midpoints(edges_mm, log=True)
    edges_qq = vals['edges'][1]
    cents_qq = utils.midpoints(edges_qq, log=True)
    edges_zz = vals['edges'][2]
    cents_zz = utils.midpoints(edges_zz, log=True)


    hs_final = anis.strain_amp_at_bin_centers_redz(edges, vals['redz_final'])
    hs_initz = anis.strain_amp_at_bin_centers_redz(edges, redz=None)
    # print(holo.utils.stats(hs_final/hs_initz))

    return hs_initz, hs_final, cents_mm, cents_qq, cents_zz

def plot_strain_amp_vary(vals, mm_arr, qq_arr, zz_arr, ff_arr, vary='ff', initz=True):
    hs_initz, hs_final, cents_mm, cents_qq, cents_zz = calc_strain_amp(vals)
    fobs_gw_cents = vals['fobs_gw_cents']


    if vary == 'mm':
        colors=cm.rainbow_r(np.linspace(0,1,len(mm_arr)))
    elif vary == 'qq':
        colors=cm.rainbow_r(np.linspace(0,1,len(qq_arr)))
    elif vary == 'zz':
        colors=cm.rainbow_r(np.linspace(0,1,len(zz_arr)))
    else:
        colors=cm.rainbow_r(np.linspace(0,1,len(ff_arr)))

    fig, axs = holo.plot.figax(
        xlabel='edge parameters',
        ylabel='$dN/d(\mathrm{edge parameter})$',
        ncols=3, figsize=(16,5),
        sharey=True
    )

    xx = np.array([cents_mm/MSOL, cents_qq, cents_zz])
    xlabels = np.array(['M ($M_\odot$)', 'q', 'z'])
    ylabels = np.array(['$h_s (M)$', '$h_s(q)$', '$h_s(z)$'])


    for mi,mm in enumerate(mm_arr):
        for qi,qq in enumerate(qq_arr):
            for zi,zz in enumerate(zz_arr):
                for fi,ff in enumerate(ff_arr):
                    if vary == 'mm':
                        color=colors[mi]
                    elif vary == 'qq':
                        color=colors[qi]
                    elif vary == 'zz':
                        color=colors[zi]
                    else:
                        color=colors[fi]
                    
                    yy_initz = np.array([hs_initz[:,qq,zz,ff], hs_initz[mm,:,zz,ff], hs_initz[mm,qq,:,ff]])
                    yy_final = np.array([hs_final[:,qq,zz,ff], hs_final[mm,:,zz,ff], hs_final[mm,qq,:,ff]])
                    label_ii = ('$M=%.2e\ M_\odot$, $q=%.2f$, $z=%.2f$, $f=%.2f$/yr' % (cents_mm[mm]/MSOL, cents_qq[qq], cents_zz[zz], fobs_gw_cents[ff]*YR))

                    for ii, ax in enumerate(axs):
                        label=label_ii if ii==0 else None
                        if initz:
                            if ii==0 and mi==0 and qi==0 and zi==0 and fi==0:
                                label_init = 'initial z'
                            else: label_init=None
                            ax.plot(xx[ii], yy_initz[ii], label=label_init, linestyle='--', alpha=0.65, 
                                color=color)
                        ax.plot(xx[ii], yy_final[ii], label=label, linestyle='-', alpha=0.75, 
                                color=color,)
    for ii, ax in enumerate(axs):
        ax.set_xlabel(xlabels[ii])
        ax.set_ylabel(ylabels[ii])

    leg = fig.legend(ncols=3, bbox_to_anchor=(0.05,0), loc='upper left', fontsize=14)
    title='%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape))
    fig.suptitle(title)
    fig.tight_layout
    return fig, title

# Orange Spike

## FT, shape 20

In [None]:
sam=holo.sam.Semi_Analytic_Model(shape=20)
valsFT20 = sam_model(sam=sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR))

In [None]:
vals=valsFT20

In [None]:
# fig = plot_dnum_dpar(vals, mm_arr=[2,7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
# fig = plot_strain_amp(vals, mm_arr=[6,], qq_arr=[17,], zz_arr=[15], ff_arr=[1,8,])
# fig = plot_number(vals,  mm_arr=[7], qq_arr=[17,], zz_arr=[17], ff_arr=[32])
# fig = plot_integrated(vals, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])

In [None]:
print(vals['fobs_gw_cents'][3]/2, utils.midpoints(vals['edges'][3], log=True)[3])

In [None]:
# vary mass, finding weird things happen
fig,title = plot_integrated_vary(vals, mm_arr=[0,4,9,14,18], qq_arr=[17,], zz_arr=[17,], ff_arr=[32], vary='mm')

### first ff dropoff

In [None]:
# vary ff, finding weird things happen at high frequencies
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[17,], ff_arr=[0,10,20,30,38], vary='ff')
# find where we have dropoff, between 10 and 20
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[15,], ff_arr=[10,12,14,16,18,20], vary='ff')
fig.suptitle(title+', first dropoff between ff=10 and ff=20')
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[15,], ff_arr=[16,17,18,], vary='ff')
fig.suptitle(title+', first dropoff at ff=15 = fobs_orb=%.2e/yr' % (utils.midpoints(vals['edges'][3], log=True)[15]*YR))

### first ff discontinuity

In [None]:

# find where we have full discontinuity, between 20 and 30
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[15,], ff_arr=[20,24,26,28,30], vary='ff')
fig.suptitle(title+', first discontinuity between ff=20 and ff=30')
# first discontinuity between 24 and 26
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[15,], ff_arr=[24,25,26], vary='ff')
fig.suptitle(title+', first discontinuity at ff=25 = fobs_orb = %.2e /yr' % (utils.midpoints(vals['edges'][3], log=True)[25]*YR))
# fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[17,], zz_arr=[15,], ff_arr=[25,], vary='ff')

### first qq discontinuity

qq only even shows up for low mm or high zz. should be overlap at mm=15
overlap exists at mm=15, ff=28

discontinuity is independent of qq

In [None]:
# vary zz
# everything else looks normal at these weird
fig,title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[0,4,9,14,18,], zz_arr=[17,], ff_arr=[28,], vary='qq')
fig.suptitle(title+', discontinuity independent of qq')

### first mm discontinuity

In [None]:
fig,title = plot_integrated_vary(vals, mm_arr=[7,12,15,18], qq_arr=[17,], zz_arr=[17,], ff_arr=[28,], vary='mm')
fig.suptitle(title+', first discontinuity between mm=12 and mm=15')
fig,title = plot_integrated_vary(vals, mm_arr=[14,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[28,], vary='mm')
fig.suptitle(title+', first discontinuity at mm=15=%.2e' % (utils.midpoints(vals['edges'][0], log=True)[15]/MSOL))

### varying z, find exact weird spots

In [None]:

fig,title = plot_integrated_vary(vals, mm_arr=[15], qq_arr=[18,], zz_arr=[10,12,14,16,18], ff_arr=[28,], vary='zz')
fig.suptitle(title+', discontinuity at zz=16, $z$=%.2e' % (utils.midpoints(vals['edges'][2], log=True)[16]))

In [None]:
zz_cents = utils.midpoints(vals['edges'][2], log=True)
for zz in range(len(zz_cents)):
    print('zz=',zz, zz_cents[zz])

# weird stuff happens between z=1 and z=3

### look at hs for the weird case

In [None]:
fig, title = plot_strain_amp_vary(vals, mm_arr=[15,], qq_arr=[18,], zz_arr=[18,], ff_arr=[0,4,9,14,19,24,29,34,38], vary='ff')
fig, title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[18,], zz_arr=[18,], ff_arr=[0,4,9,14,19,24,29,34,38], vary='ff')

# Calculating hs

In [None]:
print(vals['fobs_gw_cents'].shape)

In [None]:
print(holo.utils.stats(utils.midpoints(sam.redz, log=True)))

In [None]:
# fig, title = plot_integrated_vary(vals, mm_arr=[15,], qq_arr=[18,], zz_arr=[18,], ff_arr=[0,4,9,14,19,24,29,34,38], vary='ff')

mm, qq, zz, ff = 15, 18, 15, 28 # weird spot
temp=calc_strain_amp(vals)
hs_init, hs_final =temp[0], temp[1]
print('hs_init:', utils.stats(hs_init), f", {hs_init.shape=}")
print('hs_final:', utils.stats(hs_final), f", {hs_init.shape=}")
rz_init = utils.midpoints(sam.redz)
rz_final = utils.midpoints(vals['redz_final'], axis=(0,), log=True)
rz_final = utils.midpoints(rz_final, axis=(1,), log=True)
rz_final = utils.midpoints(rz_final, axis=(2,), log=True)
print('rz_final:', utils.stats(rz_final), rz_final.shape)
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_final=%.2f, hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))
# print(holo.utils.stats(hs[mm,qq,zz,ff]))

In [None]:
# final redshift
fobs_orb_cents=vals['fobs_orb_cents']
hard = vals['hard']
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)

mtot = utils.midpoints(sam.mtot)
mrat = utils.midpoints(sam.mrat)
mchirp = utils.chirp_mass_mtmr(mtot[:,np.newaxis], mrat[np.newaxis,:])

redz_init = utils.midpoints(sam.redz)
dcom = holo.cosmo.comoving_distance(redz_init).cgs.value

frst_orb = utils.frst_from_fobs(fobs_orb_cents[np.newaxis,:], redz_init[:,np.newaxis])



mm=15
qq=18
zz=15
ff=28
hs15 = holo.utils.gw_strain_source(mchirp[mm,qq], dcom[zz], frst_orb[zz,ff])
print(hs15)

mm=15
qq=18
zz=16
ff=28
hs16 = holo.utils.gw_strain_source(mchirp[mm,qq], dcom[zz], frst_orb[zz,ff])
print(hs16)

In [None]:
rz = redz_final
print(redz_final[mm,qq,15,28])
print(redz_final[mm+1,qq,15,28])
print(np.mean(rz[mm,qq,15,ff], rz[mm+1,qq,15,ff], rz[mm,qq+1,zz,ff]))
print(redz_final[mm,qq,16,28])
print(redz_final[mm+1,qq,16,28])
print(redz_final[mm,qq,17,28])
print(redz_final[mm+1,qq,17,28])

In [None]:
print('%.2e %.2e' % (edges[0][mm]/MSOL, edges[0][mm+1]/MSOL))
print('%.2e %.2e' % (edges[2][15], edges[2][16]))

In [None]:
print(np.max(edges[2]))

check final redz hs calculation from bin edges in gravwaves

In [None]:

redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)


hs_final = holo.gravwaves.strain_amp_from_bin_edges_redz(vals['edges'], vals['redz_final'])
hs_init = anis.strain_amp_at_bin_centers_redz(vals['edges'], redz=None)
hs_final_anis = anis.strain_amp_at_bin_centers_redz(vals['edges'], redz=vals['redz_final'])
print(np.all(hs_final == hs_final_anis)) # check, I am calculating hs correctly
rzfinal_cents = utils.midpoints(redz_final, log=True, axis=0)
rzfinal_cents = utils.midpoints(rzfinal_cents, log=True, axis=1)
rzfinal_cents = utils.midpoints(rzfinal_cents, log=True, axis=2)

mm, qq, zz, ff = 15, 18, 15, 28 # weird input
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_final_edge=%.2f, z_final_cents=%.2f, hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], rzfinal_cents[mm,qq,zz,ff], hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))
    
print('low ff')
mm, qq, zz, ff = 15, 18, 15, 18 # weird input
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_final=%.2f, z_final_cents=%.2f,  hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], rzfinal_cents[mm,qq,zz,ff], hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))
    
print('lower zz')
mm, qq, zz, ff = 15, 18, 15, 28 # weird input
for zz in (4,5,6):
    print('zz=%d, z_init=%.2f, z_final=%.2f,  z_final_cents=%.2f, hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], rzfinal_cents[mm,qq,zz,ff], hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))

In [None]:
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)
print(sam._redz_final)

In [None]:
import kalepy as kale

In [None]:
fobs_orb_edges=vals['fobs_orb_edges']
edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)

hs_final = holo.gravwaves.strain_amp_from_bin_edges_redz(edges, redz_final)

redz=redz_final

for dd in range(3):
    redz = np.moveaxis(redz, dd, 0)
    redz = kale.utils.midpoints(redz, axis=0)
    redz = np.moveaxis(redz, 0, dd)


rzfinal_cents = utils.midpoints(redz_final, log=False, axis=0)
rzfinal_cents = utils.midpoints(rzfinal_cents, log=False, axis=1)
rzfinal_cents = utils.midpoints(rzfinal_cents, log=False, axis=2)

mm, qq, zz, ff = 15, 18, 15, 28 # weird input
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_final_edge=%.2f, z_f_cents_kale=%.2f, z_f_cents_utils=%.2f, hs_init=%.2e, hs_final=%.2e' 
          % (zz, rz_init[zz], rz_final[mm,qq,zz,ff], redz[mm,qq,zz,ff], rzfinal_cents[mm,qq,zz,ff],
             hs_init[mm,qq,zz,ff], hs_final[mm,qq,zz,ff]))

# intermediates of char_strain_sq_from_bin_edges_redz()

In [None]:
import kalepy as kale
kale.midpoints()


In [None]:
def char_strain_sq_from_bin_edges_redz(edges, redz):
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # redshifts are defined across 4D grid, shape (M, Q, Z, Fc)
    #    where M, Q, Z are edges and Fc is frequency centers
    # find midpoints of redshifts in M, Q, Z dimensions, to end up with (M-1, Q-1, Z-1, Fc)
    for dd in range(3):
        redz = np.moveaxis(redz, dd, 0)
        redz = kale.utils.midpoints(redz, axis=0)
        redz = np.moveaxis(redz, 0, dd)

    # ---- calculate GW strain ----
    mt = kale.utils.midpoints(edges[0])
    mr = kale.utils.midpoints(edges[1])
    # rz = kale.utils.midpoints(edges[2])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    dc = +np.inf * np.ones_like(redz)
    sel = (redz > 0.0)
    dc[sel] = cosmo.comoving_distance(redz[sel]).cgs.value

    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs = utils.gw_strain_source(mc, dc, fr)
    hc2 = (hs ** 2) * (fc / df)
    return hc2

# Orange spike, cont.

## GW, shape 20

In [None]:
sam_20=holo.sam.Semi_Analytic_Model(shape=20)
vals_20GW = sam_model(sam=sam_20, hard=holo.hardening.Hard_GW())
print(vals['hard_name'])

In [None]:
fig = plot_dnum_dpar(vals_20GW, mm_arr=[2,7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
fig = plot_strain_amp(vals_20GW, mm_arr=[6,], qq_arr=[17,], zz_arr=[15], ff_arr=[1,8,])
fig = plot_number(vals_20GW,  mm_arr=[7], qq_arr=[17,], zz_arr=[17], ff_arr=[2,32])
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[2, 32])

In [None]:
fig = plot_integrated(vals_20GW, mm_arr=[7,15], qq_arr=[17,], zz_arr=[5,17,], ff_arr=[32,])

## GW, shape full

In [None]:
sam_full = holo.sam.Semi_Analytic_Model(shape=None)
vals_fullGW = sam_model(sam=sam_full, hard=holo.hardening.Hard_GW())

In [None]:
fig = plot_dnum_dpar(vals_fullGW, mm_arr=[2,7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_strain_amp(vals_fullGW, mm_arr=[6,], qq_arr=[17,], zz_arr=[65], ff_arr=[1,8,])
fig = plot_number(vals_fullGW,  mm_arr=[7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_integrated(vals_fullGW, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[65,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullGW, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

## FT, shape full

In [None]:
sam_full = holo.sam.Semi_Analytic_Model(shape=None)
vals_fullFT = sam_model(sam=sam_full, hard=holo.hardening.Fixed_Time_2PL_SAM(sam_full, 3*GYR))

In [None]:
fig = plot_dnum_dpar(vals_fullFT, mm_arr=[2,7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_strain_amp(vals_fullFT, mm_arr=[6,], qq_arr=[17,], zz_arr=[65], ff_arr=[1,8,])
fig = plot_number(vals_fullFT,  mm_arr=[7], qq_arr=[17,], zz_arr=[65], ff_arr=[2,32])
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[65,], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

In [None]:
fig = plot_integrated(vals_fullFT, mm_arr=[7, -15], qq_arr=[17,], zz_arr=[55,85], ff_arr=[32])

# Looking for orange spike for dif models/shapes

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=20, mmbulge=holo.relations.MMBulge_MM2013(scatter_dex=0))
vals = sam_model(sam=sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR))


In [None]:
fig = plot_integrated(vals, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])
fig.text(0, 0.98, 'scatter_dex=0', fontsize=18)

In [None]:
print(vals.keys())

## shape 20

In [None]:
shape=20
sam = holo.sam.Semi_Analytic_Model(shape=shape)
vals = sam_model(sam=sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR))


fig = plot_integrated(vals, mm_arr=[7,15], qq_arr=[17,], zz_arr=[17,], ff_arr=[32])


for zz in range(0,shape, int(shape/6)):
    print(f"{zz=}")
    fig, leg = plot_integrated_vary(vals, mm_arr=[15.], qq_arr=[17,], zz_arr=[zz,], ff_arr=np.arange(0,41,10), vary='ff')

# ax = fig.axes[0]
# ax.legend(ncols=3, bbox_to_anchor=(0,0), loc='upper left', bbox_transform=ax.transAxes)

In [None]:
shape=30
sam = holo.sam.Semi_Analytic_Model(shape=shape)
vals = sam_model(sam=sam, hard=holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR))

for mm in range(0,shape,int(shape/6)):
    print(f"{mm=}")
    for qq in range(0,shape, int(shape/6)):
        print(f"{qq=}")
        fig, leg = plot_integrated_vary(vals, mm_arr=np.arange(19), qq_arr=[10,], zz_arr=[17,], ff_arr=[32], vary='mm')

# ax = fig.axes[0]
# ax.legend(ncols=3, bbox_to_anchor=(0,0), loc='upper left', bbox_transform=ax.transAxes)


In [None]:
x = np.nan
print( x>0,0)