In [None]:
import numpy as np
import matplotlib.pyplot as plt
import legwork as lw
import astropy.units as u
import tqdm
from astropy.cosmology import Planck18, z_at_value
from scipy.integrate import trapz, cumtrapz
from schwimmbad import MultiPool
from utils import get_LISA_norm, get_LISA_norm_circular, dTmerger_df_circ, get_t_evol_LISA
from scipy.interpolate import interp1d

In [None]:
def ligo_rate(m1):
    dat = np.array([[3.705799151343708, 0.001087789470121345],
                   [4.384724186704389, 0.00984816875074369],
                   [5.063649222065067, 0.06979974252228799],
                   [5.827439886845831, 0.41173514594201527],
                   [6.506364922206512, 1.3579705933006465],
                   [6.845827439886847, 2.148948034692836],
                   [7.77934936350778, 2.7449738151212433],
                   [8.543140028288544, 2.6218307403757986],
                   [9.561527581329564, 2.0525434471508692],
                   [11.173974540311175, 1.2388629239937763],
                   [12.701555869872706, 0.7828664968878465],
                   [14.398868458274404, 0.4947116747780942],
                   [16.859971711456865, 0.2895969742197884],
                   [19.66053748231967, 0.17748817964452962],
                   [22.206506364922213, 0.12773570001722281],
                   [24.837340876944843, 0.10389898279212807],
                   [27.722772277227726, 0.1087789470121345],
                   [30.183875530410184, 0.13070104796093673],
                   [32.729844413012735, 0.16441704701060267],
                   [34.85148514851486, 0.16695189854274867],
                   [37.397454031117405, 0.12107555776371784],
                   [39.26449787835927, 0.08010405199404155],
                   [41.30127298444131, 0.049851062445855264],
                   [43.592644978783596, 0.029631988560550687],
                   [45.629420084865636, 0.018440841322693136],
                   [48.0905233380481, 0.011832859313068754],
                   [50.891089108910904, 0.007949361111716631],
                   [53.77652050919379, 0.005764973856945108],
                   [57.25601131541727, 0.0043438393396653925],
                   [61.923620933521946, 0.0032730313574784275],
                   [66.67609618104669, 0.0024851284269805634],
                   [70.66478076379069, 0.002068305171949823],
                   [74.82319660537483, 0.0016952583040389245],
                   [78.72701555869875, 0.0013476220436441713],
                   [81.27298444130128, 0.0010389898279212807]])
    
    mass = dat[:,0]
    rate = dat[:,1]
    interp_rate = interp1d(mass, rate)
    
    return interp_rate(m1)


def get_LIGO_rate_single_e(m1, ecc):
    rate = ligo_rate(m1)
    rate_per_ecc = rate
    rate = np.array(rate_per_ecc) * u.Gpc**(-3) * u.yr**(-1) * u.Msun**(-1)
    
    return rate

def get_LIGO_rate_uniform_e(m1, ecc, ecc_grid):
    rate = ligo_rate(m1)
    rate_per_ecc = rate / len(ecc_grid)
    rate = np.array(rate_per_ecc) * u.Gpc**(-3) * u.yr**(-1) * u.Msun**(-1)
    
    return rate
        
    
def get_LIGO_rate_iso_dyn(m1, e, ecc_grid, frac_iso):
    rate = ligo_rate(m1)
    rate = np.where(e < 1e-6, rate * frac_iso / len(ecc_grid[ecc_grid < 1e-6]), rate * (1-frac_iso) / len(ecc_grid < 1e-6))
    
    return rate




## First look at the circular case

In [None]:
n_grid = 25

f = np.logspace(-1, -5, 50) * u.Hz

masses = np.arange(5, 80, 1)
delta_m = masses[1] - masses[0]

mass_bins = masses - 0.5 * delta_m
mass_bins = np.append(mass_bins, masses[-1] + 0.5 * delta_m)
masses = masses * u.Msun
mass_bins = mass_bins * u.Msun
m_c = lw.utils.chirp_mass(masses, masses)
F, MASS = np.meshgrid(f, masses)

MC = lw.utils.chirp_mass(MASS, MASS)

RATE = ligo_rate(MASS.flatten().value)
RATE = RATE.reshape(MC.shape) * u.Gpc**(-3) * u.yr**(-1) * u.Msun**(-1)

In [None]:
plt.scatter(F, MC, c=np.log10(RATE.value))
plt.colorbar()
plt.xscale('log')

In [None]:
t_merge = lw.evol.get_t_merge_circ(f_orb_i=F, m_1=MASS, m_2=MASS)

In [None]:
source = lw.source.Source(m_1=MASS.flatten(),
                          m_2=MASS.flatten(),
                          ecc=np.zeros(len(F.flatten())),
                          f_orb=F.flatten(),
                          dist=8 * np.ones(len(F.flatten())) * u.Mpc,
                          interpolate_g=False,
                          n_proc=36)
snr = source.get_snr(approximate_R=True, verbose=True)
D_h = snr/7 * 8 * u.Mpc
redshift = np.ones(len(D_h)) * 1e-8
redshift[D_h > 0.0001 * u.Mpc] = z_at_value(Planck18.luminosity_distance, D_h[D_h > 0.0001 * u.Mpc])
horizon_comoving_volume = Planck18.comoving_volume(z=redshift)
horizon_comoving_volume = horizon_comoving_volume.reshape(RATE.shape)
D_h = D_h.reshape(RATE.shape)

In [None]:
plt.scatter(F, MC, c=np.log10(t_merge.to(u.yr).value))
plt.colorbar(label='log$_{10}$(merger time/yr)')
plt.xlabel('orbital frequency [Hz]')
plt.ylabel('chirp mass [Msun]')
plt.xscale('log')

In [None]:
plt.scatter(F, MASS, c=np.log10(horizon_comoving_volume.to(u.Gpc**(3)).value))
plt.colorbar(label=r'horizon volume [Gpc$^{-3}$]')
plt.xscale('log')

In [None]:
f_dot = lw.utils.fn_dot(m_c = MC.flatten(), e = np.zeros(len(MC.flatten())), n=2, f_orb=F.flatten())

In [None]:
f_dot = f_dot.reshape(F.shape)

In [None]:
RATE

In [None]:
N_per_mass = np.zeros(len(masses)) * u.Msun**(-1)
for ii, m in enumerate(masses):
    N_per_mass[ii] = trapz(RATE[ii,:] / f_dot[ii,:] * horizon_comoving_volume[ii,:], -f)

z_lim = z_at_value(Planck18.comoving_volume, 0.5 * u.Gpc**3)
d_lim = Planck18.luminosity_distance(z=z_lim)

In [None]:
trapz(N_per_mass, masses)

In [None]:
plt.plot(masses, N_per_mass)


In [None]:
dat = []
for kk in tqdm.tqdm(range(30)):
    m_samp = []
    f_samp = []
    d_samp = []
    for ii, m in enumerate(masses):
        for jj, freq in enumerate(f[1:]):
            rate_per_Hz = RATE[ii, jj] / f_dot[ii, jj] * delta_m*u.Msun * horizon_comoving_volume[ii,jj].to(u.Gpc**3)
            n_samp = rate_per_Hz * (f[jj] - f[jj+1])
            # Decide on how many sources at this freq, mass to sample 
            n_int = int(n_samp)
            
            n_float = n_samp - n_int
            add_check = np.random.uniform(0, 1)
            if add_check < n_float:
                n_int += 1
            
            if n_int >= 1:
                d_samp.extend(np.random.power(3, n_int) * D_h[ii,jj].value)
                m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, n_int))
                f_samp.extend(np.ones(n_int) * freq.value)
    d_samp = np.array(d_samp)
    m_samp = np.array(m_samp)
    f_samp = np.array(f_samp)                
        
    dat.append([m_samp, f_samp, d_samp])
                
                


In [None]:
n_obs = []
for d in dat:
    m_samp, f_samp, D_samp = d
    source = lw.source.Source(m_1=m_samp*u.Msun,
                              m_2=m_samp*u.Msun,
                              ecc=np.zeros(len(m_samp)),
                              f_orb=f_samp*u.Hz,
                              dist=D_samp*u.Mpc,
                              interpolate_g=False,
                              n_proc=1)
    snr = source.get_snr(approximate_R=True, verbose=False)
    detectable_mask = snr > 7
    n_obs.append(len(detectable_mask))
    plt.scatter(f_samp[detectable_mask], 
                lw.utils.chirp_mass(np.array(m_samp[detectable_mask]), np.array(m_samp[detectable_mask])), 
                c=D_samp[detectable_mask], vmin=10, vmax=200)
plt.colorbar()
plt.xscale('log')
plt.xlim(1e-4, 1e-1)
plt.xlabel('frequency [Hz]')
plt.ylabel('chirp mass [Msun]')
plt.title(f'Number per LISA observation is {np.round(np.mean(n_obs), 2)} pm {np.round(np.std(n_obs), 1)}', size=20)

## Next up is to MC sample over eccentricity with q=1

In [None]:
n_grid_f = 100
n_grid_e = 10
n_grid_mass = 15

f = np.logspace(-1, -4, n_grid) * u.Hz
f_bins = f[:-1] - f[1:]

masses = np.linspace(5, 80, n_grid_mass)
m_bin_widths = masses[1] - masses[0]
mass_bins = masses - 0.5 * m_bin_widths
mass_bins = np.append(mass_bins, masses[-1] + 0.5 * m_bin_widths)
masses = masses * u.Msun
mass_bins = mass_bins * u.Msun

ecc = np.logspace(-8, -3, n_grid_e)


ecc_bins = ecc[1:] - ecc[:-1]

m_c = lw.utils.chirp_mass(masses, masses)
F, MASS, ECC = np.meshgrid(f, masses, ecc)

MC = lw.utils.chirp_mass(MASS, MASS)

RATE = get_LIGO_rate_uniform_e(MASS, ECC, ecc)
RATE_iso_10 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.1)
RATE_iso_50 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.5)
RATE_iso_90 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.9)

In [None]:
print(F.flatten().min())

In [None]:
with MultiPool(processes=2) as pool:
    T_LISA = np.array(list(pool.map(get_t_evol_LISA, zip(MASS.flatten(), MASS.flatten(), ECC.flatten(), F.flatten()))))


In [None]:
T_LISA = T_LISA * u.s
print(np.shape(F))

In [None]:
a_evol, e_evol, f_evol = lw.evol.evol_ecc(
    m_1=MASS.flatten(), m_2=MASS.flatten(), f_orb_i=10*u.Hz, ecc_i=ECC.flatten(), t_evol = T_LISA,
            t_before=0.01*u.yr, output_vars=["a", "ecc", "f_orb"], avoid_merger=False, n_step=2)

In [None]:
f_evol = f_evol.reshape((len(masses), len(f), len(ecc), 2))[:,:,:,1]
e_evol = e_evol.reshape((len(masses), len(f), len(ecc), 2))[:,:,:,1]

In [None]:
source = lw.source.Source(m_1=MASS.flatten(),
                          m_2=MASS.flatten(),
                          ecc=e_evol.flatten(),
                          f_orb=F.flatten(),
                          dist=8 * np.ones(len(F.flatten())) * u.Mpc,
                          interpolate_g=False,
                          n_proc=2)
snr = source.get_snr(approximate_R=True, verbose=True)
D_h = snr/7 * 8 * u.Mpc
redshift = np.ones(len(D_h)) * 1e-8
redshift[D_h > 0.0001 * u.Mpc] = z_at_value(Planck18.luminosity_distance, D_h[D_h > 0.0001 * u.Mpc])
horizon_comoving_volume = Planck18.comoving_volume(z=redshift)
horizon_comoving_volume = horizon_comoving_volume.reshape(MASS.shape)
D_h = D_h.reshape(MASS.shape)

In [None]:
plt.scatter(F[:,:,0], abs((T_LISA.reshape(F.shape)[:,:,0]).to(u.Myr)), c=MC[:,:,0].value)
plt.colorbar(label=r'chirp mass [M$_{\odot}$]')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('time to f=10 Hz')
plt.title(r'$\log_{10}(e_{10}) =$'+str(np.log10(ecc[0])), size=20)

In [None]:
np.shape(F), np.shape(D_h)

In [None]:
plt.scatter(F[:,:,0], D_h[:,:,0], c=MC[:,:,0].value)
plt.colorbar(label=r'chirp mass [M$_{\odot}$]')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('horizon distance')
plt.title(r'$\log_{10}(e_{10}) =$'+str(np.log10(ecc[0])), size=20)

In [None]:
plt.scatter(F[:,:,4], D_h[:,:,4], c=MC[:,:,4].value)
plt.colorbar(label=r'chirp mass [M$_{\odot}$]')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('horizon distance')
plt.title(r'$\log_{10}(e_{10}) =$'+str(np.log10(ecc[4])), size=20)

In [None]:
plt.scatter(F[0,:,:], abs((T_LISA.reshape(F.shape)[0,:,:]).to(u.Myr)), c=np.log10(ECC[0,:,:]))
plt.colorbar(label=r'$\log_{10}(e_{10})$')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('time to f=10 Hz')
plt.title(f'mass = {masses[0]} Msun', size=20)

In [None]:
plt.scatter(F[10,:,:], D_h[10,:,:], c=np.log10(ECC[10,:,:]))
plt.colorbar(label=r'$\log_{10}(e_{10})$')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('time to f=10 Hz')
plt.title(f'mass = {masses[10]} Msun', size=20)

In [None]:
plt.scatter(F[0,:,:], abs((T_LISA.reshape(F.shape)[0,:,:]).to(u.Myr)), c=np.log10(e_evol[0,:,:]))
plt.colorbar(label=r'$\log_{10}(e_{10})$')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('time to f=10 Hz')
plt.title(f'mass = {masses[0]} Msun', size=20)

In [None]:
plt.scatter(F[10,:,:], D_h[10,:,:], c=e_evol[10,:,:], vmin=0, vmax=1)
plt.colorbar(label=r'$\log_{10}(e_{10})$')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('frequency')
plt.ylabel('horizon distance [Mpc]')
plt.title(f'mass = {np.round(masses[10].value, 2)} Msun', size=20)

In [None]:
plt.scatter(F, MC, c=np.log10(horizon_comoving_volume.to(u.Gpc**3).value), s=100*e_evol)
plt.xscale('log')
plt.colorbar()

In [None]:
np.shape(F), len(masses), len(ecc), len(f), np.shape(e_evol)

In [None]:
def build_pop(dat):
    masses, ecc, f, RATE, T_LISA, horizon_comoving_volume, D_h = dat
    m_samp = []
    e_samp = []
    f_samp = []
    D_samp = []
    for ii, m in enumerate(masses):
        for jj, e in enumerate(ecc[1:]):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE.reshape(F.shape)[ii, kk+1, jj+1] * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * horizon_comoving_volume[ii, kk+1, jj+1].to(u.Gpc**3)
                d_lim = D_h[ii, kk+1, jj+1]
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
    return [m_samp, f_samp, e_samp, D_samp]

In [None]:
dat_uniform_in = []
for ii in range(100):
    dat_uniform_in.append([masses, ecc, f, RATE, T_LISA, horizon_comoving_volume, D_h])
    
dat_iso_90 = []
for ii in range(100):
    dat_iso_90.append([masses, ecc, f, RA, T_LISA, horizon_comoving_volume, D_h])

In [None]:
dat = []
for ii in tqdm.tqdm(range(10)):
    m_samp = []
    e_samp = []
    f_samp = []
    D_samp = []
    for ii, m in enumerate(masses):
        for jj, e in enumerate(ecc[1:]):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE.reshape(F.shape)[ii, kk+1, jj+1] * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * 1 * u.Gpc**3 * mass_bins[ii]
                z_lim = z_at_value(Planck18.comoving_volume, 1 * u.Gpc**3)
                d_lim = Planck18.luminosity_distance(z=z_lim)
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
    dat.append([m_samp, f_samp, e_samp, D_samp])
            

In [None]:
dat_iso_10 = []
for ii in tqdm.tqdm(range(10)):
    m_samp = []
    e_samp = []
    f_samp = []
    D_samp = []
    for ii, m in enumerate(masses):
        for jj, e in enumerate(ecc[1:]):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE_iso_10.reshape(F.shape)[ii, kk+1, jj+1]*u.Gpc**(-3) * u.Msun**(-1) * u.yr**(-1) * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * 1 * u.Gpc**3 * mass_bins[ii]
                z_lim = z_at_value(Planck18.comoving_volume, 1 * u.Gpc**3)
                d_lim = Planck18.luminosity_distance(z=z_lim)
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
    dat_iso_10.append([m_samp, f_samp, e_samp, D_samp])
            

In [None]:
dat_iso_50 = []
for ii in tqdm.tqdm(range(10)):
    m_samp = []
    e_samp = []
    f_samp = []
    D_samp = []
    for ii, m in enumerate(masses):
        for jj, e in enumerate(ecc[1:]):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE_iso_50.reshape(F.shape)[ii, kk+1, jj+1]*u.Gpc**(-3) * u.Msun**(-1) * u.yr**(-1) * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * 1 * u.Gpc**3 * mass_bins[ii]
                z_lim = z_at_value(Planck18.comoving_volume, 1 * u.Gpc**3)
                d_lim = Planck18.luminosity_distance(z=z_lim)
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
    dat_iso_50.append([m_samp, f_samp, e_samp, D_samp])
            

In [None]:
dat_iso_90 = []
for ii in tqdm.tqdm(range(10)):
    m_samp = []
    e_samp = []
    f_samp = []
    D_samp = []
    for ii, m in enumerate(masses):
        for jj, e in enumerate(ecc[1:]):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE_iso_90.reshape(F.shape)[ii, kk+1, jj+1]*u.Gpc**(-3) * u.Msun**(-1) * u.yr**(-1) * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * 1 * u.Gpc**3 * mass_bins[ii]
                z_lim = z_at_value(Planck18.comoving_volume, 1 * u.Gpc**3)
                d_lim = Planck18.luminosity_distance(z=z_lim)
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
    dat_iso_90.append([m_samp, f_samp, e_samp, D_samp])
            

In [None]:
dat_list = [dat, dat_iso_10, dat_iso_50, dat_iso_90]


In [None]:
fig, axes = plt.subplots(1,len(dat_list), figsize=(20,4))
for dlist, ax, ii in zip(dat_list, axes, range(len(dat_list))):
    n_obs = []
    for d in dlist:
        m_samp, f_samp, e_samp, D_samp = d
        n_obs.append(len(m_samp))
        c = ax.scatter(f_samp, lw.utils.chirp_mass(np.array(m_samp), np.array(m_samp)), c=np.log10(e_samp), vmin=-4, vmax=0)
    ax.set_xscale('log')
    ax.set_xlim(1e-4, 1e-1)
    ax.set_xlabel('frequency [Hz]')
    if ii == 0:
        ax.set_ylabel('chirp mass [Msun]')
    ax.set_title(r'N$_{\rm{LISA}}$='+str(np.mean(n_obs))+' $\pm$ '+str(np.round(np.std(n_obs), 2)), size=20)
plt.colorbar(c)

In [None]:
RATE_single_ecc = get_LIGO_rate_single_e(MASS, ECC)
RATE_single_ecc

In [None]:
dat_ecc = []
for ii in tqdm.tqdm(range(10)):
    dat_samp = []
    for jj, e in enumerate(ecc[1:]):
        m_samp = []
        e_samp = []
        f_samp = []
        D_samp = []
        for ii, m in enumerate(masses):
            for kk, freq in enumerate(f[1:]):
                n_per_Gpc3 = (RATE_single_ecc.reshape(F.shape)[ii, kk+1, jj+1] * -1 * T_LISA.reshape(F.shape)[ii, kk+1, jj+1]).to(u.Gpc**(-3)/u.Msun)  
                n_samp = n_per_Gpc3 * 1 * u.Gpc**3 * mass_bins[ii]
                z_lim = z_at_value(Planck18.comoving_volume, 1 * u.Gpc**3)
                d_lim = Planck18.luminosity_distance(z=z_lim)
                # Decide on how many sources at this freq, mass to sample 
                n_int = int(n_samp)
                n_float = n_samp - n_int
                add_check = np.random.uniform(0, 1)
                if add_check < n_float:
                    n_int += 1
                if n_int >= 1:
                    d_samp = np.random.power(3, n_int) * d_lim
                    
                    ind_keep, = np.where(d_samp < D_h[ii, kk+1, jj+1])
                    if len(ind_keep) > 0:
                        m_samp.extend(np.random.uniform(mass_bins[ii].value, mass_bins[ii+1].value, len(ind_keep)))
                        e_samp.extend(np.ones(len(ind_keep)) * e_evol[ii,kk+1,  jj+1])
                        f_samp.extend(np.random.uniform(f_bins[jj+1].value, f_bins[jj].value, len(ind_keep)))
                        D_samp.extend(d_samp[ind_keep].to(u.Mpc).value)
        dat_samp.append([m_samp, f_samp, e_samp, D_samp])
    dat_ecc.append(dat_samp)
            

In [None]:
np.shape(e_evol)

In [None]:
n_ecc = []
for de in dat_ecc:
    n_e = []
    for d in de:
        m_samp, f_samp, e_samp, D_samp = d
        n_e.append(len(e_samp))
    n_ecc.append(n_e)

In [None]:
n_ecc = np.array(n_ecc)
print(np.shape(n_ecc))

In [None]:
mean = []
std = []
for ii in range(len(ecc)-1):
    mean.append(np.mean(np.array(n_ecc[:,ii])))
    std.append(np.std(np.array(n_ecc[:,ii])))

In [None]:
plt.errorbar(ecc[1:], mean, std)
plt.yscale('log')
plt.xscale('log')

In [None]:
n_grid_f = 100
n_grid_e = 5
n_grid_mass = 15

f = np.logspace(-1, -4, n_grid) * u.Hz
f_bins = f[:-1] - f[1:]

masses = np.linspace(5, 80, n_grid_mass)
m_bin_widths = masses[1] - masses[0]
mass_bins = masses - 0.5 * m_bin_widths
mass_bins = np.append(mass_bins, masses[-1] + 0.5 * m_bin_widths)
masses = masses * u.Msun
mass_bins = mass_bins * u.Msun

ecc = np.linspace(0, 0.0001, n_grid_e)
ecc_bins = ecc[1:] - ecc[:-1]

m_c = lw.utils.chirp_mass(masses, masses)
F, MASS, ECC = np.meshgrid(f, masses, ecc)

MC = lw.utils.chirp_mass(MASS, MASS)

RATE = get_LIGO_rate_uniform_e(MASS, ECC, ecc)
RATE_iso_10 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.1)
RATE_iso_50 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.5)
RATE_iso_90 = get_LIGO_rate_iso_dyn(MASS, ECC, ecc, 0.9)

In [None]:
V_c = []
LIGO_rate_uniform = []
LIGO_rate_iso_dyn_50 = []
LIGO_rate_iso_dyn_80 = []
times = []
ecc_evols = []
f_orb_evols = []
LISA_norms = []
m1_evols = []
m2_evols = []


for d, m1, m2, e in tqdm.tqdm(zip(dat_out, M1, M2, E), total=len(M1)):
    f_orb_evol, ecc_evol, timesteps, LISA_norm = d
    f_mask = f_orb_evol < 0.3 * u.Hz
    source = lw.source.Source(m_1=m1 * np.ones(len(f_orb_evol[f_mask])) * u.Msun,
                              m_2=m2 * np.ones(len(f_orb_evol[f_mask])) * u.Msun,
                              ecc=ecc_evol[f_mask],
                              f_orb=f_orb_evol[f_mask],
                              dist=8 * np.ones(len(f_orb_evol[f_mask])) * u.Mpc,
                              interpolate_g=False,
                              n_proc=nproc)
    snr = source.get_snr(approximate_R=True, verbose=False)
    D_h = snr/7 * 8 * u.Mpc
    redshift = np.ones(len(D_h)) * 1e-8
    redshift[D_h > 0.0001 * u.Mpc] = z_at_value(Planck18.luminosity_distance, D_h[D_h > 0.0001 * u.Mpc])
    V_c.append(Planck18.comoving_volume(z=redshift))

    LISA_norms.append(LISA_norm[f_mask].to(u.yr/u.Hz))
    times.append(-timesteps[f_mask].to(u.yr))
    ecc_evols.append(ecc_evol[f_mask])
    f_orb_evols.append(f_orb_evol[f_mask])
    m1_evols.append(m1 * np.ones(len(f_orb_evol[f_mask])))
    m2_evols.append(m2 * np.ones(len(f_orb_evol[f_mask])))
    LIGO_rate_uniform.append(get_LIGO_rate_uniform_e(m1, delta_m[0]))
    LIGO_rate_iso_dyn_50.append(get_LIGO_rate_iso_dyn(m1, e, frac_iso=0.5))
    LIGO_rate_iso_dyn_80.append(get_LIGO_rate_iso_dyn(m1, e, frac_iso=0.8))


In [None]:
def get_m_lo(m1, mass_grid):
    return min(mass_grid[mass_grid < m1])
    

In [None]:
print(n_grid)
m1_keep = []
m2_keep = []
f_keep = []
for ii, m1, m2 in zip(range(len(M1)), M1, M2):
    print(m1, m2)
    ligo_rate = get_LIGO_rate_uniform_e(m1, delta_m[0]) / n_grid
    f = f_orb_evols[ii]
    v_c = V_c[ii]
    
    l_norm = LISA_norms[ii] * (u.yr/u.Hz)
    t_evol = times[ii]
    
    # first select out everything that's below a millihertz
    ind_mhz, = np.where(f < 0.001 * u.Hz)
    ind_mhz_gtr, = np.where(f >= 0.001 * u.Hz)
        
    t_max_lo = max(t_evol[ind_mhz])
    v_max_lo = max(v_c[ind_mhz].to(u.Mpc**3))
    n_sample_lo = int(ligo_rate.to(u.Mpc**(-3)/u.yr) * v_max_lo * t_max_lo)
    t_f_interp = interp1d(t_evol[ind_mhz], f[ind_mhz])
    v_c_interp = interp1d(f[ind_mhz], v_c[ind_mhz])
    
    m_lo = get_m_lo(m1, mass1_grid)
    m1_sample = np.random.uniform(m_lo, m1, n_sample_lo)
    m2_sample = np.random.uniform(5 * np.ones(n_sample_lo), m1_sample)
    t_sample = np.random.uniform(min(t_evol[ind_mhz]).value, t_max_lo.value, n_sample_lo)
    f_sample = t_f_interp(t_sample)
    v_max = v_c_interp(f_sample)
    d_max = 3/(4 * np.pi) * v_max_lo**(1/3)
    
    d = np.random.power(3, n_sample_lo) * (3/(4 * np.pi) * max(v_c[ind_mhz].value))**(1/3) * u.Mpc
    source = lw.source.Source(m_1=m1_sample * u.Msun,
                              m_2=m2_sample * u.Msun,
                              ecc=np.zeros(len(m1_sample)),
                              f_orb=f_sample * u.Hz,
                              dist=8 * np.ones(len(f_sample)) * u.Mpc,
                              interpolate_g=False,
                              n_proc=nproc)
    snr = source.get_snr(approximate_R=True, verbose=False)
    keep_mask = snr  > 7
    m1_keep.extend(m1_sample[keep_mask])
    m2_keep.extend(m2_sample[keep_mask])
    f_keep.extend(f_sample[keep_mask])
    print(f"sample {n_sample_lo} mergers below 1 mHz and keep {len(keep_mask[keep_mask])} of them")
    
    t_max = max(t_evol[ind_mhz_gtr])
    v_max = max(v_c[ind_mhz_gtr].to(u.Mpc**3))
    n_sample = int(ligo_rate.to(u.Mpc**(-3)/u.yr) * v_max * t_max)
    
    
    t_f_interp = interp1d(t_evol[ind_mhz_gtr], f[ind_mhz_gtr])
    v_c_interp = interp1d(f[ind_mhz_gtr], v_c[ind_mhz_gtr])
    m_lo = get_m_lo(m1, mass1_grid)
    m1_sample = np.random.uniform(m_lo, m1, n_sample)
    m2_sample = np.random.uniform(5 * np.ones(n_sample), m1_sample)
    t_sample = np.random.uniform(0, t_max.value, n_sample)
    f_sample = t_f_interp(t_sample)
    v_max = v_c_interp(f_sample)
    d_max = 3/(4 * np.pi) * v_max**(1/3)
    
    d = np.random.power(3, n_sample) * (3/(4 * np.pi) * max(v_c.value))**(1/3)
    d = np.random.power(3, n_sample_lo) * (3/(4 * np.pi) * max(v_c[ind_mhz].value))**(1/3) * u.Mpc
    source = lw.source.Source(m_1=m1_sample * u.Msun,
                              m_2=m2_sample * u.Msun,
                              ecc=np.zeros(len(m1_sample)),
                              f_orb=f_sample * u.Hz,
                              dist=8 * np.ones(len(f_sample)) * u.Mpc,
                              interpolate_g=False,
                              n_proc=nproc)
    snr = source.get_snr(approximate_R=True, verbose=False)
    keep_mask = snr  > 7
    m1_keep.extend(m1_sample[keep_mask])
    m2_keep.extend(m2_sample[keep_mask])
    f_keep.extend(f_sample[keep_mask])
    print(f"sample {n_sample} mergers above 1 mHz and keep {len(keep_mask[keep_mask])} of them")
    print()


In [None]:
m_c = lw.utils.chirp_mass(m1_keep * u.Msun, m2_keep * u.Msun)
print(len(m_c))

In [None]:
m_c = lw.utils.chirp_mass(m1_keep * u.Msun, m2_keep * u.Msun)
print(len(m_c))

In [None]:
plt.scatter(f_keep, m_c)
plt.xscale('log')

In [None]:
m_c = lw.utils.chirp_mass(m1_keep * u.Msun, m2_keep * u.Msun)
print(len(m_c))

In [None]:
plt.scatter(f_keep, m_c)
plt.xscale('log')