In [None]:
import numpy as np

import matplotlib.pyplot as plt

from glow import lenses
from glow import waveform
from glow import tools

import glow.mismatch as mm
import glow.physical_units as pu

cosmology={} # (dict) cosmological parameters, default is Planck18
mm.initialize_cosmology(**cosmology)



In [None]:
# # LISA 

detector='LISA'
z_src_waveform= 5
Mtot=1e6
Mtot_detector= Mtot*(1+z_src_waveform)
# Mtot_detector=1e6
q=1
spin=0
inc=0
Tobs= 0.13*pu.yr_to_s


# # ET

# detector='ET'
# z_src_waveform= 1
# Mtot=1000
# Mtot_detector= Mtot*(1+z_src_waveform)
# q=1
# spin=0
# inc=0
# Tobs=210


# Same keys of Pycbc get_fd_waveform
params_source= {'approximant': "IMRPhenomXHM",
            'mass1'          : Mtot_detector * q/(1. + q),
            'mass2'          : Mtot_detector * 1/(1. + q),
            'spin1z'         : spin,
            'spin2z'         : spin,
            'redshift'       : z_src_waveform,
            'inclination'    : inc,
            'long_asc_nodes' : 0,
            'f_lower'        : np.amax([waveform.f0_obs(Mtot_detector, Tobs, units='s'),waveform.f_bounds_detector(detector)[0]]),
            'delta_f'        : 1/Tobs,
            'f_final':  5*waveform.f_isco(Mtot_detector),
            }

# Waveform generated once a waveform object is initialized thorugh parameters 
h_fd=waveform.WaveformFD(params_source)


In [None]:
psd=waveform.get_psd_from_file(detector) 
h_fd.load_psd(psd)
snr=h_fd.snr
print('snr of the signal:', snr)

In [None]:
# !! Check that the signal meets the psd on the lhs

plt.loglog(h_fd.sample_frequencies,h_fd.sample_frequencies*np.abs(h_fd.p))
plt.loglog(h_fd.sample_frequencies,np.sqrt(h_fd.sample_frequencies*np.abs(h_fd.psd_grid)))
plt.xlabel('$f$ [Hz]')
plt.ylabel('Characteristic strain')
plt.grid()
print('snr of the signal:', h_fd.snr)

Critical curve

In [None]:
Psi_SIS = lenses.Psi_SIS()

Single point evaluation

In [None]:
y=60
MLz=1e4

h_fd_lensed=waveform.get_lensed_fd_from_Psi(h_fd, Psi_SIS, y, MLz, p_prec_t={'tmax':1e7, 'Nt' : 5000})
mismatch=mm.mismatch(h_fd_lensed, h_fd, only_plus=True, optimized=False)

print('The mismatch is: ', mismatch)

Method 1: critical curve from root finder with Psi

In [None]:

MLzs=np.geomspace(1e2,1e11, 20) # LISA
kwargs_mm={'optimized':False}
p_prec_t={'tmax':1e9}
kwargs_lensing={'SL':False, 'p_prec_t':p_prec_t}

y_crits, mmSNR2_s=mm.get_y_crit_curve_opt(h_fd, Psi_SIS, MLzs, 1, 100, s=1, 
                                robust= False, 
                                n_iter=3, return_mm=True, rtol=1e-3,
                                kwargs_mm=kwargs_mm, kwargs_lensing=kwargs_lensing
                                )



In [None]:
print('Values of mm x SNR2 at the critical curve: ', mmSNR2_s)

In [None]:
fig, ax = plt.subplots(1, figsize=(7, 4.5))
ax.plot(MLzs, y_crits)
ax.set_ylabel('$y_{cr}$', fontsize=14)
ax.set_xlabel('$M_{Lz}\,[M_\odot]$', fontsize=14)
ax.set_yscale('linear')
ax.set_xscale('log')
ax.grid()

Method 2: critical curve from root finder with Fw analytic

In [None]:
from glow import freq_domain_c

Fw_analytic_SIS=freq_domain_c.Fw_AnalyticPointLens_C
MLzs=np.geomspace(1e2,1e11, 20) # LISA
kwargs_mm={'optimized':False}

y_crits, mmSNR2_s=mm.get_y_crit_curve_Fw_opt(h_fd, Fw_analytic_SIS, MLzs, 1, 200, s=1, 
                                robust= False, 
                                n_iter=3, return_mm=True, rtol=1e-3,
                                kwargs_mm=kwargs_mm
                                )

In [None]:
print('Values of mm x SNR2 at the critical curve: ', mmSNR2_s)

In [None]:
fig, ax = plt.subplots(1, figsize=(7, 4.5))
ax.plot(MLzs, y_crits)
ax.set_ylabel('$y_{cr}$', fontsize=14)
ax.set_xlabel('$M_{Lz}\,[M_\odot]$', fontsize=14)
ax.set_yscale('linear')
ax.set_xscale('log')
ax.grid()

Methos 3: critical curve from grid evaluation

In [None]:
# Mismatch grid to be interpolated
MLzs_grid=np.geomspace(1e2,1e10, 20) # LISA
ys=np.geomspace(2,200, 10) # LISA

In [None]:
interp_scale=['log','log']
grid_basis= [MLzs_grid, ys]
XX, YY= np.meshgrid(*grid_basis)

In [None]:
h_lensed_grid=waveform.get_lensed_fd_from_Psi_vec(h_fd, Psi_SIS, YY, XX) 

In [None]:
mismatch_fun_SIS, mismatch_grid_SIS= mm.get_mismatch_fun(grid_basis, 
                                                            h_lensed_grid, 
                                                            h_fd, 
                                                            scale=interp_scale, 
                                                            scale_grid='log')

In [None]:
fun_2d=np.vectorize(lambda x,y: mismatch_fun_SIS((x,y)))


In [None]:
# Check the interpolation works well

xx =np.geomspace(MLzs_grid[0],MLzs_grid[-1], num=100)

yy= np.geomspace(ys[0],ys[-1], num=10)

X, Y = np.meshgrid(xx, yy, indexing='ij')

fig = plt.figure()

ax = fig.add_subplot(projection='3d')

ax.scatter(np.log10(XX), np.log10(YY), np.log10(mismatch_grid_SIS), s=10, c='k')

ax.plot_wireframe(np.log10(X), np.log10(Y), np.log10(fun_2d(X, Y)), rstride=3, cstride=3,

                  alpha=0.4, color='m', label='linear interp')


ax.set_xlabel(r"$\log{M_{\rm lz}}$", fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
ax.set_zlabel(r"$\mathcal{M}$", fontsize=14)

plt.legend()
ax.view_init(10, 60)
plt.show()

In [None]:
# Mismatch contour plot

xx =np.geomspace(MLzs_grid[0],MLzs_grid[-1], num=100)
yy= np.linspace(ys[0],ys[-1], 100)

X, Y = np.meshgrid(xx, yy, indexing='ij')

fig, ax = plt.subplots(figsize=(7,5))

CS = ax.contourf(X,Y, np.log10(fun_2d(X,Y)))
CS2= ax.contour(CS, levels=[np.log10(1/snr**2)], colors='red', linestyles=['dotted'])
ax.scatter(XX, YY,  s=1, c='k')
ax.plot(MLzs, y_crits)
ax.set_xscale('log')
ax.set_yscale('log')


ax.set_title('')
ax.set_xlabel(r"$M_{\rm lz}$", fontsize=14)
ax.set_ylabel(r'y', fontsize=14)
ax.grid(True, which="both", alpha=0.3)
plt.title(r"$M_{{\rm BBH}}^D= {:s}\,M_\odot\,(z_s={:s})$".format(tools.latex_float(Mtot_detector), tools.latex_float(z_src_waveform)), fontsize=14)
ax.set_xlim(MLzs[0], MLzs[-1])
ax.set_ylim(np.amin(ys), np.amax(ys))

ax.clabel(CS2, fmt='1/SNR$^2$', inline=True, fontsize=14)
cbar=fig.colorbar(CS)
cbar.add_lines(CS2)


In [None]:
# Store mismatch grid

import pandas as pd 

dict_save={'SNR': snr, 'MLz': [MLzs],'y':[ys], 'mm_grid':[mismatch_grid_SIS], 'scale':[interp_scale]}
to_save=pd.DataFrame.from_dict(data=dict_save)
tags=['mismatch_', '{:s}_'.format(detector), "Mbbh{:.0e}_".format(Mtot), "zsrc{:.0f}".format(int(z_src_waveform))]
filename_wrt=''.join(tags)
to_save.to_pickle('mm_bank/'+filename_wrt+'.dat')