In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import kalepy as kale

import holodeck as holo
# import holodeck.anisotropy as anis
from holodeck import detstats, plot, utils
from holodeck.constants import YR, MSOL, GYR, SPLC, PC, MPC

# Functions

In [None]:
def strain_amp_at_bin_centers_redz(edges, redz=None):
    """ Calculate strain amplitude at bin centers, with final or initial redz.
    
    """
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # redshifts are defined across 4D grid, shape (M, Q, Z, Fc)
    #    where M, Q, Z are edges and Fc is frequency centers
    # find midpoints of redshifts in M, Q, Z dimensions, to end up with (M-1, Q-1, Z-1, Fc)
    if redz is not None:
        for dd in range(3):
            redz = np.moveaxis(redz, dd, 0)
            redz = kale.utils.midpoints(redz, axis=0)
            redz = np.moveaxis(redz, 0, dd)
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else:
        redz = kale.utils.midpoints(edges[2])[np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value


    # ---- calculate GW strain ----
    mt = kale.utils.midpoints(edges[0])
    mr = kale.utils.midpoints(edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs = utils.gw_strain_source(mc, dc, fr)
    return hs


def strain_amp_at_bin_edges_redz(edges, redz=None):
    """ Calculate strain amplitude at bin edges, with final or initial redz.
    
    """
    assert len(edges) == 4
    assert np.all([np.ndim(ee) == 1 for ee in edges])

    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    # df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)


    if redz is not None:
        dc = +np.inf * np.ones_like(redz)
        sel = (redz > 0.0)
        dc[sel] = holo.cosmo.comoving_distance(redz[sel]).cgs.value
    else: 
        redz = edges[2][np.newaxis,np.newaxis,:,np.newaxis]
        dc = holo.cosmo.comoving_distance(redz).cgs.value

    # ---- calculate GW strain ----
    mt = (edges[0])
    mr = (edges[1])
    mc = utils.chirp_mass_mtmr(mt[:, np.newaxis], mr[np.newaxis, :])
    mc = mc[:, :, np.newaxis, np.newaxis]
    
    # convert from observer-frame to rest-frame; still using frequency-bin centers
    fr = utils.frst_from_fobs(fc[np.newaxis, np.newaxis, np.newaxis, :], redz)

    hs_edges = utils.gw_strain_source(mc, dc, fr)
    return hs_edges

    
def sam_model(sam, hard,
        dur=16.03*YR, cad=0.2*YR, use_redz=True):
    fobs_gw_cents = utils.nyquist_freqs(dur,cad)
    fobs_gw_edges = utils.nyquist_freqs_edges(dur,cad)
    fobs_orb_cents = fobs_gw_cents/2.0
    fobs_orb_edges = fobs_gw_edges/2.0

    if isinstance(hard, holo.hardening.Fixed_Time_2PL_SAM):
        hard_name = 'Fixed Time'
    elif isinstance(hard, holo.hardening.Hard_GW):
        hard_name = 'GW Only'
    else:
        raise Exception("'hard' must be an instance of 'Fixed_Time_2PL_SAM' or 'Hard_GW'")

    redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = holo.sam_cython.integrate_differential_number_3dx1d(edges, diff_num)

    hsf_cents = strain_amp_at_bin_centers_redz(edges, redz_final)
    hsf_edges = strain_amp_at_bin_edges_redz(edges, redz_final)
    hsi_cents = strain_amp_at_bin_centers_redz(edges)
    hsi_edges = strain_amp_at_bin_edges_redz(edges)

    vals = {
        'hard':hard, 'sam':sam, 'edges':edges, 'number': number, 'diff_num':diff_num, 'redz_final':redz_final,
        'hsi_cents':hsi_cents, 'hsi_edges':hsi_edges, 'hsf_cents':hsf_cents, 'hsf_edges':hsf_edges, 
        'fobs_gw_cents':fobs_gw_cents, 'fobs_gw_edges':fobs_gw_edges, 
        'fobs_orb_cents':fobs_orb_cents, 'fobs_orb_edges':fobs_orb_edges, 'hard_name':hard_name
    }
    return vals

def integrate_mm(dnum, edges): # integrate dN/dlogM
    num = utils.trapz(dnum, np.log10(edges[0]), axis=0, cumsum=False)
    return num
def integrate_qq(dnum, edges): # integrate dN/dq
    num = utils.trapz(dnum, edges[1], axis=1, cumsum=False)
    return num
def integrate_zz(dnum, edges): # dN/dz
    num = utils.trapz(dnum, edges[2], axis=2, cumsum=False)
    return num
def integrate_ff(dnum, fobs_gw_edges): # dN/dlogn
    num = dnum*np.diff(np.log(fobs_gw_edges))
    return num 

# Shape 5, 6, 7

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=(5,6,7))
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
vals = sam_model(sam, hard)

fobs_orb_cents=vals['fobs_orb_cents']

In [None]:
redz_final_cy, diff_num_cy = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
temp, diff_num_py, redz_final_py = sam.dynamic_binary_number_at_fobs(hard, fobs_orb_cents)
print('cython:', utils.stats(redz_final_cy))
print('python:', utils.stats(redz_final_py))

## Compare redz

In [None]:

fig, ax = plot.figax(xlabel='index', ylabel='redz_final', xscale='linear', yscale='linear')
ax.scatter(np.arange(redz_final_py.size), redz_final_py.flatten(), label='python', marker='x', alpha=0.5)
ax.scatter(np.arange(redz_final_cy.size), redz_final_cy.flatten(), label='cython', marker='+', alpha=0.5)
ax.set_title('%s, %s' % (vals['hard_name'], str(sam.shape)))
ax.legend()

## Compare diff_num

In [None]:
print('cython:', utils.stats(diff_num_cy))
print('python:', utils.stats(diff_num_py))

fig, ax = plot.figax(xlabel='index', ylabel='diff_num', xscale='linear', yscale='log')
ax.scatter(np.arange(diff_num_py.size), diff_num_py.flatten(), label='python', marker='x', alpha=0.5)
ax.scatter(np.arange(diff_num_cy.size), diff_num_cy.flatten(), label='cython', marker='+', alpha=0.5)
ax.set_title('%s, %s' % (vals['hard_name'], str(sam.shape)))
ax.legend()

# Shape 20, 20, 20

In [None]:
sam = holo.sam.Semi_Analytic_Model(shape=20)
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)
vals = sam_model(sam, hard)

fobs_orb_cents=vals['fobs_orb_cents']

In [None]:
redz_final_cy, diff_num_cy = holo.sam_cython.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, holo.cosmo)
temp, diff_num_py, redz_final_py = sam.dynamic_binary_number_at_fobs(hard, fobs_orb_cents)
print('cython:', utils.stats(redz_final_cy))
print('python:', utils.stats(redz_final_py))

## Compare redz

In [None]:

fig, ax = plot.figax(xlabel='index', ylabel='redz_final', xscale='linear', yscale='linear')
ax.scatter(np.arange(redz_final_py.size), redz_final_py.flatten(), label='python', marker='x', alpha=0.5)
ax.scatter(np.arange(redz_final_cy.size), redz_final_cy.flatten(), label='cython', marker='+', alpha=0.5)
ax.set_title('%s, %s' % (vals['hard_name'], str(sam.shape)))
ax.legend()

## compare diff_num

In [None]:
print('cython:', utils.stats(diff_num_cy))
print('python:', utils.stats(diff_num_py))

fig, ax = plot.figax(xlabel='index', ylabel='diff_num', xscale='linear', yscale='log')
ax.scatter(np.arange(diff_num_py.size), diff_num_py.flatten(), label='python', marker='x', alpha=0.5)
ax.scatter(np.arange(diff_num_cy.size), diff_num_cy.flatten(), label='cython', marker='+', alpha=0.5)
ax.set_title('%s, %s' % (vals['hard_name'], str(sam.shape)))
ax.legend()

# Weird Spot

In [None]:
mm, qq, zz, ff = 15, 18, 15, 28 # weird spot

fobs_orb_cents = vals['fobs_orb_cents']
fobs_gw_cents = vals['fobs_gw_cents']
redz_final, diff_num = holo.sam_cython.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo)

hs_final_cy = holo.gravwaves.strain_amp_from_bin_edges_redz(vals['edges'], redz_final_cy)
hs_final_py = holo.gravwaves.strain_amp_from_bin_edges_redz(vals['edges'], redz_final_py)
hs_init = strain_amp_at_bin_centers_redz(vals['edges'], redz=None)

print('hs_init:', utils.stats(hs_init), f", {hs_init.shape=}")
print('hs_final_cy:', utils.stats(hs_final_cy), f", {hs_init.shape=}")
print('hs_final_py:', utils.stats(hs_final_py), f", {hs_init.shape=}")

rz_init = kale.utils.midpoints(sam.redz, axis=0)
rzf_cy = redz_final_cy
rzf_py = redz_final_py
for dd in range(3):
    rzf_cy = np.moveaxis(rzf_cy, dd, 0)
    rzf_cy = kale.utils.midpoints(rzf_cy, axis=0)
    rzf_cy = np.moveaxis(rzf_cy, 0, dd)

    rzf_py = np.moveaxis(rzf_py, dd, 0)
    rzf_py = kale.utils.midpoints(rzf_py, axis=0)
    rzf_py = np.moveaxis(rzf_py, 0, dd)

print('M(mm=%d)=%.2e M_sol, q(qq=%d)=%.2e, f_obs,orb(ff=%d)=%.2f/yr, f_obs,gw(ff=%d)=%.2f/yr' 
      % (mm, utils.midpoints(vals['edges'][0])[mm]/MSOL, 
         qq, utils.midpoints(vals['edges'][1])[qq],
         ff, fobs_orb_cents[ff]*YR,
         ff, fobs_gw_cents[ff]*YR))
for zz in (14,15,16):
    print('zz=%d, z_init=%.2f, z_fin,cy=%.2f, z_fin,py=%.2f, hs_init=%.2e, hs_fin,cy=%.2e, hs_fin,py=%.2e' 
          % (zz, rz_init[zz], rzf_cy[mm,qq,zz,ff], rzf_py[mm,qq,zz,ff], 
          hs_init[mm,qq,zz,ff], hs_final_cy[mm,qq,zz,ff], hs_final_py[mm,qq,zz,ff]))

## Print info
Note that hs_fin cy and py do not match. There are places where cython gives nonzero but python gives 0.

In [None]:
mm, qq, zz, ff = 15, 18, 15, 25 # weird spot
for ff in [0,4,9,14,19,24,29,34,38]:
      print('\nM(mm=%d)=%.2e M_sol, q(qq=%d)=%.2e, f_obs,orb(ff=%d)=%.2f/yr, f_obs,gw(ff=%d)=%.2f/yr' 
            % (mm, utils.midpoints(vals['edges'][0])[mm]/MSOL, 
            qq, utils.midpoints(vals['edges'][1])[qq],
            ff, fobs_orb_cents[ff]*YR,
            ff, fobs_gw_cents[ff]*YR))
      for zz in (14,15,16):
            print('zz=%d, z_init=%.2f, z_fin,cy=%.2f, z_fin,py=%.2f, hs_init=%.2e, hs_fin,cy=%.2e, hs_fin,py=%.2e' 
                  % (zz, rz_init[zz], rzf_cy[mm,qq,zz,ff], rzf_py[mm,qq,zz,ff], 
                  hs_init[mm,qq,zz,ff], hs_final_cy[mm,qq,zz,ff], hs_final_py[mm,qq,zz,ff]))

## Plot weird spot

In [None]:
def plot_strain_vs_z(
    vals, mm_arr=[mm,], qq_arr=[qq,], zz_arr=[zz,], ff_arr=[ff,],
    all_hs=np.array([hs_init, hs_final_cy, hs_final_py]), 
    labels_hs=np.array(['init z', 'cython final z', 'python final z']), 
    linestyles = np.array(['--', '-', '-.'])):

    fobs_gw_cents = vals['fobs_gw_cents']
    cents_mm = utils.midpoints(vals['edges'][0])
    cents_qq = utils.midpoints(vals['edges'][1])
    cents_zz = utils.midpoints(vals['edges'][2])
    colors=cm.rainbow_r(np.linspace(0,1,len(ff_arr)))

    fig, ax = holo.plot.figax(
        xlabel='$z$',
        ylabel='$h_s$',
    )

    xx = cents_zz # redshifts

    for mi,mm in enumerate(mm_arr):
        for qi,qq in enumerate(qq_arr):
            for zi,zz in enumerate(zz_arr):
                for fi,ff in enumerate(ff_arr):
                    color=colors[fi]
                    for yi, yy in enumerate(all_hs):
                        if mi==0 and qi==0 and zi==0 and fi==0:
                            label_hs = labels_hs[yi]
                        else:
                            label_hs = ''
                        if yi==0:
                            label_hs = label_hs + (', $M=%.2e\ M_\odot$, $q=%.2f$, $f=%.2f$/yr' 
                            % (cents_mm[mm]/MSOL, cents_qq[qq], fobs_gw_cents[ff]*YR))
                        ax.plot(xx, yy[mm,qq,:,ff], label=label_hs, linestyle=linestyles[yi], alpha=0.65, 
                            color=color)


    leg = fig.legend(ncols=1, bbox_to_anchor=(.9,.9), loc='upper left', fontsize=12)
    title='%s, %s' % (str(vals['hard_name']), str(vals['sam'].shape))
    fig.suptitle(title)
    fig.tight_layout
    return fig, title


In [None]:
mm, qq, zz, ff = 15, 18, 15, 28 # weird spot

fig,title = plot_strain_vs_z(
    vals, mm_arr=[mm,], qq_arr=[qq,], zz_arr=[zz,], ff_arr=[0,4,9,14,19,24,29,34,38],
    all_hs=np.array([ hs_final_cy, hs_final_py, hs_init,]), 
    labels_hs=np.array(['cython final z', 'python final z', 'init z', ]), 
    linestyles = np.array([ '--', '-', ':',]))

In [None]:
fig,title = plot_strain_vs_z(
    vals, mm_arr=[mm,], qq_arr=[qq,], zz_arr=[zz,], ff_arr=[0,4,9,14,19,24,29,34,38],
    all_hs=np.array([hs_init, hs_final_py]), 
    labels_hs=np.array(['init z', 'python final z']), 
    linestyles = np.array(['--', '-']))

In [None]:
fig,title = plot_strain_vs_z(
    vals, mm_arr=[mm,], qq_arr=[qq,], zz_arr=[zz,], ff_arr=[0,4,9,14,19,24,29,34,38],
    all_hs=np.array([hs_final_cy, hs_final_py]), 
    labels_hs=np.array(['cython final z', 'python final z']), 
    linestyles = np.array(['--', '-']))