## Supported waveforms:

* IMRPhenomXAS (aligned spin): theta = [Mchirp, eta, chi1, chi2, D, tc, phic]
* IMRPhenomD (aligned spin): theta = [Mchirp, eta, chi1, chi2, D, tc, phic]
* IMRPhenomPv2 (Still finalizing sampling checks): not known
* TaylorF2 with tidal effects : [Mchirp, eta, chi1, chi2, lambda1, lambda2, D, tc, phic]
  * lambda1: Dimensionless tidal deformability of the primary object [between 0 and 5000]
  * lambda2: Dimensionless tidal deformability of the secondary object [between 0 and 5000]
* IMRPhenomD_NRTidalv2, verified for the low spin regime (chi1, chi2 < 0.05), further testing is required for higher spins: [Mchirp, eta, chi1, chi2, lambda1, lambda2, D, tc, phic, inclination]

## Parameters:

* Mchirp: Chirp mass of the system [solar masses]
* eta: Symmetric mass ratio [between 0.0 and 0.25]
* chi1: Dimensionless aligned spin of the primary object [between -1 and 1]
* chi2: Dimensionless aligned spin of the secondary object [between -1 and 1]
* lambda1: Dimensionless tidal deformability of the primary object [between 0 and 5000]
* lambda2: Dimensionless tidal deformability of the secondary object [between 0 and 5000]
* D: Luminosity distance to source [Mpc]
* tc: Time of coalesence. This only appears as an overall linear in f contribution to the phase
* phic: Phase of coalesence
* inclination: Inclination angle of the binary [between 0 and PI]

In [1]:
import jax.numpy as jnp
import numpy as np
import bilby
#from ripple.waveforms import IMRPhenomD
from ripple import ms_to_Mc_eta
from jax import vmap
from jax import jit
from gwsnr import noise_weighted_inner_product

In [2]:
from gwsnr import RippleInnerProduct
#from gwsnr import findchirp_chirptime_jax

In [3]:
import numpy as np
size = 10

gw_param_dict = {
    "mass_1": np.random.uniform(10,50,size),
    "mass_2": np.random.uniform(10,50,size),
    "luminosity_distance": 440*np.ones(size),
    "theta_jn": 0.0*np.ones(size),
    "psi": 0.659*np.ones(size),
    "phase": 0.0*np.ones(size),
    "geocent_time": 0.*np.ones(size),
    "ra": 1.375*np.ones(size),
    "dec": -1.2108*np.ones(size),
    "a_1": 0.5*np.ones(size),
    "a_2": -0.5*np.ones(size),
    "tilt_1": 0.0*np.ones(size),
    "tilt_2": 0.0*np.ones(size),
    "phi_12": 0.0*np.ones(size),
    "phi_jl": 0.0*np.ones(size),
}

idx = (gw_param_dict['mass_1'] < gw_param_dict['mass_2'])
gw_param_dict['mass_1'][idx], gw_param_dict['mass_2'][idx] = gw_param_dict['mass_2'][idx], gw_param_dict['mass_1'][idx]

In [4]:
test = RippleInnerProduct('IMRPhenomD')

In [6]:
from gwsnr import GWSNR

gwsnr = GWSNR(npool=4, snr_type='inner_product_jax', gwsnr_verbose=False, ifos=['L1', 'H1', 'V1'], multiprocessing_verbose=True)

In [10]:
# size=100000, 4m 18.2s
ans = test.noise_weighted_inner_product_jax(gw_param_dict, gwsnr.psds_list, gwsnr.detector_list, multiprocessing_verbose=True)

100%|██████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 257.85it/s]


In [8]:
# size=100000, time=50.7s
ans2 = gwsnr.snr(gw_param_dict=gw_param_dict)
ans2

solving SNR with inner product JAX


100%|███████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.06it/s]


{'L1': array([66.04323701, 59.20393426, 42.74319364, 41.98536111, 28.78254717,
        43.76257075, 55.9290411 , 49.92542936, 61.99139003, 38.79745521]),
 'H1': array([87.19084117, 78.16153572, 56.42992646, 55.42942953, 37.99896269,
        57.77571676, 73.83799402, 65.91197493, 81.84155845, 51.22072915]),
 'V1': array([53.5713521 , 48.05895552, 34.95451127, 34.23742383, 22.95812389,
        35.76835162, 45.43122714, 40.76172455, 50.30093562, 31.6269057 ]),
 'optimal_snr_net': array([121.79425973, 109.1970453 ,  78.95020621,  77.5073764 ,
         52.90965545,  80.82432198, 103.17026412,  92.18706593,
        114.32916185,  71.61750337])}

In [9]:
gwsnr = GWSNR(npool=4, snr_type='inner_product', gwsnr_verbose=False, ifos=['L1', 'H1', 'V1'], multiprocessing_verbose=True)
ans3 = gwsnr.snr(gw_param_dict=gw_param_dict)
ans3

solving SNR with inner product


100%|███████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.04it/s]


{'L1': array([66.04311361, 59.20373333, 42.74315353, 41.98536447, 28.78253705,
        43.76249461, 55.92884469, 49.92547396, 61.991195  , 38.79744863]),
 'H1': array([87.19067826, 78.16127045, 56.42987351, 55.42943397, 37.99894932,
        57.77561623, 73.83773471, 65.9120338 , 81.84130097, 51.22072047]),
 'V1': array([53.57124489, 48.05878205, 34.95447256, 34.23742383, 22.95811577,
        35.76827529, 45.4310542 , 40.76176521, 50.30077473, 31.62689902]),
 'optimal_snr_net': array([121.79402904, 109.19667015,  78.95012951,  77.50738139,
         52.90963682,  80.82417511, 103.16989591,  92.18715015,
        114.328801  ,  71.61749065])}

In [35]:
hp = ans[0][0]
hc = ans[0][1]
fs = ans[0][2]
fsize_arr = ans[0][3]
fmin = ans[0][4]
duration = ans[0][5]
psd_object = ans[0][7][0]

hp = np.array(hp[:fsize_arr], dtype=np.complex128)
hc = np.array(hc[:fsize_arr], dtype=np.complex128)
# find the index of 20Hz or nearby
# set all elements to zero below this index
fs =  fs[:fsize_arr]
idx = np.abs( fs - fmin).argmin()
hp[:idx] = 0.0 + 0.0j
hc[:idx] = 0.0 + 0.0j

psd_array = psd_object.get_power_spectral_density_array(fs)

In [37]:
signal1 = jnp.array(hp, dtype=np.complex128)
signal2 = jnp.array(hc, dtype=np.complex128)
psd = jnp.array(psd_array)
# duration
np.conj(signal1) * signal2 /psd

  signal1 = jnp.array(hp, dtype=np.complex128)
  signal2 = jnp.array(hc, dtype=np.complex128)


Array([ 0. +0.j,  0. +0.j,  0. +0.j, ..., nan+nanj, nan+nanj, nan+nanj],      dtype=complex64)

In [34]:
signal1 = hp
signal2 = hc
psd = psd_array
# duration
np.conj(signal1) * signal2 /psd


array([0.+0.00000000e+00j, 0.+0.00000000e+00j, 0.+0.00000000e+00j, ...,
       0.-6.13156915e-10j, 0.-6.09602254e-10j, 0.-6.06050285e-10j])

In [27]:
from numba import njit
from jax import jit

@jit
def noise_weighted_inner_product(
    signal1, signal2, psd, duration,
):
    nwip_arr = np.conj(signal1) * signal2 / psd
    return 4 / duration * np.sum(nwip_arr)

In [72]:
siz2 = 2
hp = np.array([[1,2,3],[4,5,6]])
hc = np.array([[7,8,9],[10,11,12]])
fs = np.array([[13,14,15],[16,17,18]])
iter = np.arange(size)
list_ = []
for i in range(size):
    list_.append([hp[0], hc[0], fs[0], iter[0]])

In [76]:
np.array(list_, dtype=object)[0]

array([array([1, 2, 3]), array([7, 8, 9]), array([13, 14, 15]), 0],
      dtype=object)

In [42]:
input_arguments[0].shape, hp.shape

((3, 2), (2, 3))

In [18]:
test2 = np.concatenate((ans[0], ans[1]), axis=1)

In [21]:
test2[0].shape

(18434,)

In [None]:
# set up psd
psd_dict = psd_dict.copy()
for key, value in psd_dict.items(): # key is the detector name
    if not isinstance(value, bilby.gw.detector.PowerSpectralDensity):
        raise ValueError(f"Expected bilby.gw.detector.PowerSpectralDensity object for {key} in psd_dict")
    # compute the power spectral density array according to the frequency array
    psd_list = []
    fs_ = []
    for i in range(size):
        fs_.append(fs[i][:fsize_arr[i]])  # avoid np.nan paddings
        psd_list.append(psd_dict[key].get_power_spectral_density_array(fs_[i]))
    psd_dict[key] = np.array(psd_list, dtype=object)
    fs = np.array(fs_, dtype=object)

# set up h+,hx for inner_product
hp_, hc_ = [], []
for i in range(size):
    # remove the np.nan padding
    hp_.append(np.array(hp[i][:fsize_arr[i]], dtype=np.complex128))
    hc_.append(np.array(hc[i][:fsize_arr[i]], dtype=np.complex128))
    # find the index of 20Hz or nearby
    # set all elements to zero below this index
    idx = np.abs(fs[i] - self.f_l).argmin()
    hp_[i][0:idx] = 0.0 + 0.0j
    hc_[i][0:idx] = 0.0 + 0.0j
# each row don't have the same length, so keep the array type to object
hp = np.array(hp_, dtype=object)
hc = np.array(hc_, dtype=object)
#del hp_, hc_  # free up memory


# compute the noise weighted inner product
hp_inner_hp_list = []
hc_inner_hc_list = []
for det, psd_list in psd_dict.items():
    hp_inner_hp = []
    hc_inner_hc = []
    for i in range(size):  # loop over parameters
        psd_ = np.array(psd_list[i])
        duration_ = duration[i]
        idx2 = (psd_ != 0.0) & (psd_ != np.inf)  # this is necessary to avoid problem in np.sum 
        # it's cumbersome to use jax jitting noise_weighted_inner_product
        # so, I will use numba jitted one instead
        hp_inner_hp.append(self.vmap_noise_weighted_inner_product(
            hp[i][idx2],
            hp[i][idx2],
            psd_[idx2],
            duration_,
        ))

        hc_inner_hc.append(noise_weighted_inner_product(
            hc[i][idx2],
            hc[i][idx2],
            psd_[idx2],
            duration_,
        ))

    hp_inner_hp_list.append(hp_inner_hp)
    hc_inner_hc_list.append(hc_inner_hc)

hp_inner_hp = np.array(hp_inner_hp_list)
hc_inner_hc = np.array(hc_inner_hc_list)


In [9]:
ans = test.noise_weighted_inner_prod(gw_param_dict, psd_dict, duration=None, verbose=False, duration_min=2, duration_max=64)

setup time:  0.10905790328979492
duration calculation time:  0.05954384803771973
setup time:  1.8899898529052734
waveform generation time:  1.2709009647369385
psd generation time:  10.304682970046997
setup time:  24.447003841400146
inner product time:  29.904842853546143
setup time:  0.0164792537689209


* convert inner product to vmap
* psd generation time and setup time to multiprocessing

In [None]:
def _helper(hp,hc,fsize_arr,fs):
    # set up h+,hx for inner_product
    hp_, hc_ = [], []
    for i in range(size):
        # remove the np.nan padding
        hp_.append(np.array(hp[i][:fsize_arr[i]], dtype=np.complex128))
        hc_.append(np.array(hc[i][:fsize_arr[i]], dtype=np.complex128))
        # find the index of 20Hz or nearby
        # set all elements to zero below this index
        idx = np.abs(fs[i] - self.f_l).argmin()
        hp_[i][0:idx] = 0.0 + 0.0j
        hc_[i][0:idx] = 0.0 + 0.0j
    # each row don't have the same length, so keep the array type to object
    hp = np.array(hp_, dtype=object)
    hc = np.array(hc_, dtype=object)

    return hp,hc

In [7]:
hp_inner_hp, hc_inner_hc, fs, psd_dict, fsize_arr, del_f  = test.noise_weighted_inner_prod(gw_param_dict, psd_dict, duration=None, verbose=False, duration_min=2, duration_max=64)

setup time:  0.10808801651000977
duration calculation time:  0.06410765647888184
setup time:  0.7802941799163818
waveform generation time:  1.5824799537658691
psd generation time:  2.86651611328125
setup time:  9.871968984603882
inner product time:  6.127278089523315
setup time:  0.007072925567626953


In [12]:
df = min(del_f)*0.1
flist = np.arange(0.0, 1024 + df, df)
len(flist)

92161

* rewrite psd generator from bilby
  * generate once for the highest array length and pick points
* write a njit fn for 'set up h+,hx for inner_product': for loop is the problem here
* write a jax.jit+jax.vmap for inner product

In [11]:
%prun hp_inner_hp, hc_inner_hc  = test.noise_weighted_inner_prod(gw_param_dict, psd_dict, duration=None, verbose=False, duration_min=2, duration_max=64)

 

         24091631 function calls (23984225 primitive calls) in 16.979 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   380702    3.535    0.000    3.535    0.000 {built-in method numpy.array}
        1    1.935    1.935   16.419   16.419 ripple_class.py:127(noise_weighted_inner_prod)
   120000    1.426    0.000    1.426    0.000 njit_functions.py:282(noise_weighted_inner_product)
   119966    1.346    0.000    2.171    0.000 dispatch.py:84(apply_primitive)
    38325    0.985    0.000    0.985    0.000 {built-in method numpy.core._multiarray_umath.interp}
       20    0.952    0.048    0.956    0.048 compiler.py:240(backend_compile)
    80000    0.581    0.000    5.025    0.000 lax_numpy.py:11306(_attempt_rewriting_take_via_slice)
        1    0.571    0.571   16.991   16.991 <string>:1(<module>)
  2755882    0.418    0.000    0.503    0.000 config.py:293(value)
369647/331322    0.279    0.000    1.530    0.000 {built-in m

In [17]:
hp_inner_hp

array([[3528.21411469+0.j, 6749.38881416+0.j],
       [3528.21411469+0.j, 6749.38881416+0.j],
       [2045.37526927+0.j, 4030.01160398+0.j]])

In [19]:
hp_inner_hp[2]

array([2045.37526927+0.j, 4030.01160398+0.j])

In [91]:
hp_inner_hp

array([[3528.21411469+0.j, 6749.38881416+0.j]])

In [32]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(fs, np.real(hc), label='hc')
plt.plot(fs, np.real(hp), label='hp')

plt.xlim(20, 100)
plt.ylim(-1e-22, 1e-22)
plt.legend()
plt.show()

NameError: name 'fs' is not defined

<Figure size 1000x400 with 0 Axes>

In [161]:
from gwsnr import antenna_response_array
# get polarization tensor
# np.shape(Fp) = (size1, len(num_det))
detector_tensor = np.array(gwsnr.detector_tensor_list.copy())
Fp, Fc = antenna_response_array(
    gw_param_dict['ra'], gw_param_dict['dec'], gw_param_dict['geocent_time'], gw_param_dict['psi'], detector_tensor
)
snrs_sq = abs((Fp**2) * hp_inner_hp + (Fc**2) * hc_inner_hc)
snr = np.sqrt(snrs_sq)
snr_effective = np.sqrt(np.sum(snrs_sq, axis=0))

In [162]:
snr_effective

array([33.20856331,  0.        ])

In [1]:
import jax
# Test to see if JAX is using the CPU
print("JAX devices:", jax.devices())

JAX devices: [CpuDevice(id=0)]


In [18]:
from gwsnr import GWSNR

snr = GWSNR(snr_type='inner_product_jax')

psds not given. Choosing bilby's default psds

Chosen GWSNR initialization parameters:

npool:  4
snr type:  inner_product_jax
waveform approximant:  IMRPhenomD
sampling frequency:  2048.0
minimum frequency (fmin):  20.0
mtot=mass1+mass2
min(mtot):  2.0
max(mtot) (with the given fmin=20.0): 184.98599853446768
detectors:  ['L1', 'H1', 'V1']
psds:  [PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/AdV_asd.txt')]


In [24]:
import numpy as np
size = 1000

gw_param_dict = {
    "mass_1": np.random.uniform(10,50,size),
    "mass_2": np.random.uniform(10,50,size),
    "luminosity_distance": 440*np.ones(size),
    "theta_jn": 0.0*np.ones(size),
    "psi": 0.659*np.ones(size),
    "phase": 0.0*np.ones(size),
    "geocent_time": 0.*np.ones(size),
    "ra": 1.375*np.ones(size),
    "dec": -1.2108*np.ones(size),
    "a_1": 0.5*np.ones(size),
    "a_2": -0.5*np.ones(size),
    "tilt_1": 0.0*np.ones(size),
    "tilt_2": 0.0*np.ones(size),
    "phi_12": 0.0*np.ones(size),
    "phi_jl": 0.0*np.ones(size),
}

idx = (gw_param_dict['mass_1'] < gw_param_dict['mass_2'])
gw_param_dict['mass_1'][idx], gw_param_dict['mass_2'][idx] = gw_param_dict['mass_2'][idx], gw_param_dict['mass_1'][idx]

In [25]:
# size=20000, time=17.8s 
# size=100000, with for loop 10000X10, 1m 12s
# size=100000, without for loop, inf
for i in range(100):
    result_jax = snr.snr(gw_param_dict=gw_param_dict)

solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inner product JAX
solving SNR with inn

In [6]:
#result_jax['optimal_snr_net']

In [13]:
from gwsnr import GWSNR

snr = GWSNR(snr_type='inner_product')

psds not given. Choosing bilby's default psds

Chosen GWSNR initialization parameters:

npool:  4
snr type:  inner_product
waveform approximant:  IMRPhenomD
sampling frequency:  2048.0
minimum frequency (fmin):  20.0
mtot=mass1+mass2
min(mtot):  2.0
max(mtot) (with the given fmin=20.0): 184.98599853446768
detectors:  ['L1', 'H1', 'V1']
psds:  [PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple1/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/AdV_asd.txt')]


In [None]:
# size=100000, with for loop, 1m 1.0s
result_bilby = snr.compute_bilby_snr(gw_param_dict=gw_param_dict)

100%|█████████████████████████████████████████████████████| 100000/100000 [01:00<00:00, 1653.18it/s]


In [15]:
result_bilby['optimal_snr_net']

array([ 77.31933101,  90.77171028,  91.74700759,  66.3794602 ,
        67.0895219 , 114.40239431,  78.43606522,  70.857959  ,
        69.07294291,  50.76252381, 100.73894353,  80.06975415,
        59.16920444,  80.29692442,  93.14666081,  67.86330662,
        76.38751393, 117.37656027, 106.72251131,  63.42591677,
        54.69385599,  83.79585122, 101.90275619,  72.37090661,
        51.5468909 , 120.97855645,  90.46829213, 103.87783013,
       101.15667436,  59.26054712,  87.30334376,  59.18665181,
        54.91392879, 106.0821094 ,  66.80762741,  83.31542595,
        99.33707259,  80.03583806, 118.83693245,  82.40572857,
        66.93347258,  98.93170276, 107.62912022, 100.57317852,
        63.04273233,  62.60780267,  95.47219284,  56.40215422,
        97.25588862,  76.33749737,  60.7396575 ,  97.42932947,
        65.5618559 ,  64.25346638, 120.80583188,  55.12563456,
       107.09186324,  60.82863658,  92.84145045,  74.86881584,
        90.73816493,  93.91424515,  83.77514304, 103.86

In [196]:
psd_dict = snr[-1]
psd_list = psd_dict['H1']
idx2 = (psd_list != 0.0) & (psd_list != np.inf)

In [197]:
idx2

Array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]], dtype=bool)

In [171]:
duration = 16.0*np.ones(size)
del_f = 1.0/duration
fs = []
for df  in del_f:
    fs.append(np.arange(test.f_l, test.f_u, df))
# fs = jnp.array(fs)

In [53]:
import numpy as np

# Example: Including upper limit in the array
start = 0
stop = 5
step = 1.2

# Adjust stop to include the upper limit
arr = np.arange(start, stop + step, step)
print(arr)

[0.  1.2 2.4 3.6 4.8 6. ]


In [57]:
int((stop-start)/step)+1

5

In [58]:
len(arr)

6