In [1]:
import jax.numpy as jnp
import numpy as np

from ripple.waveforms import IMRPhenomD
from ripple import ms_to_Mc_eta

In [2]:
np.__version__

'1.24.4'

In [3]:
# Get a frequency domain waveform
# source parameters

m1_msun = 20.0 # In solar masses
m2_msun = 19.0
chi1 = 0.5 # Dimensionless spin
chi2 = -0.5
tc = 0.0 # Time of coalescence in seconds
phic = 0.0 # Time of coalescence
dist_mpc = 440 # Distance to source in Mpc
inclination = 0.0 # Inclination Angle

# The PhenomD waveform model is parameterized with the chirp mass and symmetric mass ratio
Mc, eta = ms_to_Mc_eta(jnp.array([m1_msun, m2_msun]))

In [4]:
# These are the parametrs that go into the waveform generator
# Note that JAX does not give index errors, so if you pass in the
# the wrong array it will behave strangely
theta_ripple = jnp.array([Mc, eta, chi1, chi2, dist_mpc, tc, phic, inclination])

# Now we need to generate the frequency grid
f_l = 24
f_u = 512
del_f = 0.01
fs = jnp.arange(f_l, f_u, del_f)
f_ref = f_l

# And finally lets generate the waveform!
hp_ripple, hc_ripple = IMRPhenomD.gen_IMRPhenomD_hphc(fs, theta_ripple, f_ref)

In [6]:
# set the GW parameters
mass_1 = jnp.array([5, 10.,50.,200.])
ratio = jnp.array([1, 0.8,0.5,0.2])
mass_2 = mass_1 * ratio
Mc, eta = ms_to_Mc_eta(jnp.array([mass_1, mass_2]))
chi1 = jnp.array([0.1, 0.2, 0.3, 0.4])
chi2 = jnp.array([0.1, 0.2, 0.3, 0.4])
tc = jnp.array([0.0, 0.0, 0.0, 0.0])
phic = jnp.array([0.0, 0.0, 0.0, 0.0])
dist_mpc = np.array([1000, 2000, 3000, 4000])
inclination = jnp.array([0.0, 0.0, 0.0, 0.0])

In [7]:
Mc.shape, eta.shape, chi1.shape, chi2.shape, dist_mpc.shape, tc.shape, phic.shape, inclination.shape

((4,), (4,), (4,), (4,), (4,), (4,), (4,), (4,))

In [8]:
theta_ripple = jnp.array([Mc, eta, chi1, chi2, dist_mpc, tc, phic, inclination]).T

# Now we need to generate the frequency grid
f_l = 20
f_u = 1024
duration = 4
del_f = 1.0/duration
fs = jnp.arange(f_l, f_u, del_f)
f_ref = f_l

# And finally lets generate the waveform!
# hp_ripple, hc_ripple = IMRPhenomD.gen_IMRPhenomD_hphc(fs, theta_ripple, f_ref)

In [9]:
theta_ripple.shape

(4, 8)

In [10]:
jnp.zeros((3, 2))

Array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)

In [29]:
# Note that we have not internally jitted the functions since this would
# introduce an annoying overhead each time the user evaluated the function with a different length frequency array
# We therefore recommend that the user jit the function themselves to accelerate evaluations. For example:

import jax

@jax.jit
def waveform(theta):
    hp,hc = [], []
    for i in range(theta.shape[1]):

        hf = IMRPhenomD.gen_IMRPhenomD_hphc(fs, theta[i, :], f_ref)
        hp.append(hf[0])
        hc.append(hf[1])

    return jnp.array(hp), jnp.array(hc)

In [14]:
IMRPhenomD.gen_IMRPhenomD_hphc?

[0;31mSignature:[0m
[0mIMRPhenomD[0m[0;34m.[0m[0mgen_IMRPhenomD_hphc[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mf[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparams[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mf_ref[0m[0;34m:[0m [0mfloat[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Generate PhenomD frequency domain waveform following 1508.07253.
vars array contains both intrinsic and extrinsic variables
theta = [Mchirp, eta, chi1, chi2, D, tc, phic]
Mchirp: Chirp mass of the system [solar masses]
eta: Symmetric mass ratio [between 0.0 and 0.25]
chi1: Dimensionless aligned spin of the primary object [between -1 and 1]
chi2: Dimensionless aligned spin of the secondary object [between -1 and 1]
D: Luminosity distance to source [Mpc]
tc: Time of coalesence. This only appears as an overall linear in f contribution to the pha

In [12]:
hf_arr = waveform(theta_ripple)

In [4]:
from gwsnr import GWSNR
from gwsnr.njit_functions import (
    # get_interpolated_snr,
    findchirp_chirptime,
    # antenna_response,
    # antenna_response_array,
)

snr_mine = GWSNR(snr_type='inner_product')

psds not given. Choosing bilby's default psds

Chosen GWSNR initialization parameters:

npool:  4
snr type:  inner_product
waveform approximant:  IMRPhenomD
sampling frequency:  2048.0
minimum frequency (fmin):  20.0
mtot=mass1+mass2
min(mtot):  2.0
max(mtot) (with the given fmin=20.0): 184.98599853446768
detectors:  ['L1', 'H1', 'V1']
psds:  [PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple/lib/python3.11/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple/lib/python3.11/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt'), PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ripple/lib/python3.11/site-packages/bilby/gw/detector/noise_curves/AdV_asd.txt')]


In [16]:
size_ = 100000
mass_1=10*np.ones(size_)
mass_2=10*np.ones(size_)
luminosity_distance=100.0*np.ones(size_)
theta_jn=0.0*np.ones(size_)
psi=0.0*np.ones(size_)
phase=0.0*np.ones(size_)
geocent_time=1246527224.169434*np.ones(size_)
ra=0.0*np.ones(size_)
dec=0.0*np.ones(size_)
a_1=0.0*np.ones(size_)
a_2=0.0*np.ones(size_)
tilt_1=0.0*np.ones(size_)
tilt_2=0.0*np.ones(size_)
phi_12=0.0*np.ones(size_)
phi_jl=0.0*np.ones(size_)
phic=0.0*np.ones(size_)
gw_param_dict=False
output_jsonfile=False

In [17]:
# if gw_param_dict is given, then use that
if gw_param_dict is not False:
    mass_1 = gw_param_dict["mass_1"]
    mass_2 = gw_param_dict["mass_2"]
    luminosity_distance = gw_param_dict["luminosity_distance"]
    theta_jn = gw_param_dict["theta_jn"]
    psi = gw_param_dict["psi"]
    phase = gw_param_dict["phase"]
    geocent_time = gw_param_dict["geocent_time"]
    ra = gw_param_dict["ra"]
    dec = gw_param_dict["dec"]
    # a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl exist in the dictionary
    # if exists, then use that, else pass
    if "a_1" and "a_2" in gw_param_dict:
        a_1 = gw_param_dict["a_1"]
        a_2 = gw_param_dict["a_2"]
    if "tilt_1" and "tilt_2" and "phi_12" and "phi_jl" in gw_param_dict:
        tilt_1 = gw_param_dict["tilt_1"]
        tilt_2 = gw_param_dict["tilt_2"]
        phi_12 = gw_param_dict["phi_12"]
        phi_jl = gw_param_dict["phi_jl"]

npool = snr_mine.npool
sampling_frequency = snr_mine.sampling_frequency
detectors = snr_mine.detector_list.copy()
detector_tensor = np.array(snr_mine.detector_tensor_list.copy())
approximant = snr_mine.waveform_approximant
f_min = snr_mine.f_min
num_det = np.arange(len(detectors), dtype=int)

# get the psds for the required detectors
psd_dict = {detectors[i]: snr_mine.psds_list[i] for i in num_det}

# reshape(-1) is so that either a float value is given or the input is an numpy array
# make sure all parameters are of same length
mass_1, mass_2 = np.array([mass_1]).reshape(-1), np.array([mass_2]).reshape(-1)
num = len(mass_1)
(
    mass_1,
    mass_2,
    luminosity_distance,
    theta_jn,
    psi,
    phase,
    geocent_time,
    ra,
    dec,
    a_1,
    a_2,
    tilt_1,
    tilt_2,
    phi_12,
    phi_jl,
) = np.broadcast_arrays(
    mass_1,
    mass_2,
    luminosity_distance,
    theta_jn,
    psi,
    phase,
    geocent_time,
    ra,
    dec,
    a_1,
    a_2,
    tilt_1,
    tilt_2,
    phi_12,
    phi_jl,
)

#############################################
# setting up parameters for multiprocessing #
#############################################
mtot = mass_1 + mass_2
idx = (mtot >= snr_mine.mtot_min) & (mtot <= snr_mine.mtot_max)
size1 = np.sum(idx)
iterations = np.arange(size1)  # to keep track of index

dectector_arr = np.array(detectors) * np.ones(
    (size1, len(detectors)), dtype=object
)
psds_dict_list = np.array([np.full(size1, psd_dict, dtype=object)]).T
# IMPORTANT: time duration calculation for each of the mass combination
safety = 1.2
approx_duration = safety * findchirp_chirptime(mass_1[idx], mass_2[idx], f_min)
duration = np.ceil(approx_duration + 2.0)
if snr_mine.duration_max:
    duration[duration > snr_mine.duration_max] = snr_mine.duration_max  # IMRPheonomXPHM has maximum duration of 371s


# for JAX input
Mc, eta = ms_to_Mc_eta(jnp.array([mass_1, mass_2]))

input_arguments = jnp.array(
    [
        Mc[idx],
        eta[idx],
        a_1[idx],
        a_2[idx],
        luminosity_distance[idx],
        geocent_time[idx],
        phase[idx],
        theta_jn[idx],
    ]).T

# len_ = len(duration)
f_l = float(snr_mine.f_min)
f_u = float(snr_mine.sampling_frequency/2.0)
del_f = np.array(1.0/duration[idx])
# fs = jnp.arange(f_l, f_u, del_f)
f_ref = f_l


In [7]:
del_f

array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])

In [8]:
from numba import njit

@njit
def create_fs(f_l, f_u, del_f):
  len_ = del_f.shape[0]
  fs = []
  for i in range(len_):
    del_f_ = del_f[i]
    fs.append(np.arange(f_l, f_u, del_f_))
  # fs = np.array(fs)

  return fs

# import jax

# @jax.jit
# def create_fs(f_l, f_u, del_f):
#   len_ = del_f.shape[0]
#   fs = []
#   for i in range(len_):
#     del_f_ = del_f[i]
#     fs.append(jnp.arange(f_l, f_u, del_f_))
#   # fs = np.array(fs)

#   return fs

# from numba import njit

# # @njit
# def create_fs(f_l, f_u, del_f):
#     len_ = len(del_f)
#     fs = []
#     for i in range(len_):
#         fs.append(np.arange(f_l, f_u, del_f[i]))
#     fs = np.array(fs)

#     return fs

# @njit
# def create_fs(f_lu, del_f):
#     len_ = len(del_f)
#     fs = []
#     for i in range(len_):
#         fs.append(np.arange(f_l[i], f_u[i], del_f[i]))
#     fs = np.array(fs)

#     return fs

In [18]:
fs  = create_fs(f_l, f_u, del_f)

In [10]:
# np.array(fs, dtype=object).shape, 
print(f_ref)

20.0


In [None]:
print(f_ref)

In [19]:
hp, hc = waveform(fs, input_arguments, f_ref)

In [38]:
hp[0]

Array([ 1.3929977e-22+2.1274368e-24j, -7.1213591e-23+1.1877783e-22j,
       -6.6714131e-23-1.2042913e-22j, ...,  2.0062255e-25-4.5555302e-25j,
        1.4770957e-25-4.7478057e-25j,  9.9274362e-27-4.9658153e-25j],      dtype=complex64)

In [39]:
p_array = jnp.array(psds_dict_list[0][0]['L1'].get_amplitude_spectral_density_array(fs[0]))
p_array

Array([4.3745991e-23, 4.5669871e-23, 4.7783558e-23, ..., 4.6202514e-24,
       4.6204916e-24, 4.6207318e-24], dtype=float32)

In [48]:
noise_weighted_inner_prod(hc[0], hc[0], p_array, duration[0])

Array(0.+0.j, dtype=complex64)

* repeat the process with numpy and cross check with jnp

In [42]:
signal1 = hp[0]
signal2 = hp[0]
psd = p_array
nwip_arr = jnp.conj(signal1) * signal2 / psd

In [44]:
jnp.sum(nwip_arr)

Array(0.+0.j, dtype=complex64)

In [46]:
import jax

@jax.jit
def waveform(fs, theta, f_ref):
    hp,hc = [], []
    for i in range(theta.shape[1]):
        hf = IMRPhenomD.gen_IMRPhenomD_hphc(fs[i], theta[i, :], f_ref)
        hp.append(hf[0])
        hc.append(hf[1])

    return jnp.array(hp), jnp.array(hc)

@jax.jit
def noise_weighted_inner_prod(signal1, signal2, psd, duration):
    nwip_arr = jnp.conj(signal1) * signal2 / psd
    return 4 / duration * jnp.sum(nwip_arr)

In [30]:
IMRPhenomD.gen_IMRPhenomD_hphc?

[0;31mSignature:[0m
[0mIMRPhenomD[0m[0;34m.[0m[0mgen_IMRPhenomD_hphc[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mf[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparams[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mf_ref[0m[0;34m:[0m [0mfloat[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Generate PhenomD frequency domain waveform following 1508.07253.
vars array contains both intrinsic and extrinsic variables
theta = [Mchirp, eta, chi1, chi2, D, tc, phic]
Mchirp: Chirp mass of the system [solar masses]
eta: Symmetric mass ratio [between 0.0 and 0.25]
chi1: Dimensionless aligned spin of the primary object [between -1 and 1]
chi2: Dimensionless aligned spin of the secondary object [between -1 and 1]
D: Luminosity distance to source [Mpc]
tc: Time of coalesence. This only appears as an overall linear in f contribution to the pha

In [None]:
input_arguments = np.array(
    [
        mass_1[idx],
        mass_2[idx],
        luminosity_distance[idx],
        theta_jn[idx],
        psi[idx],
        phase[idx],
        ra[idx],
        dec[idx],
        geocent_time[idx],
        a_1[idx],
        a_2[idx],
        tilt_1[idx],
        tilt_2[idx],
        phi_12[idx],
        phi_jl[idx],
        np.full(size1, approximant),
        np.full(size1, f_min),
        duration,
        np.full(size1, sampling_frequency),
        iterations,
    ],
    dtype=object,
).T

input_arguments = np.concatenate(
    (input_arguments, psds_dict_list, dectector_arr), axis=1
)

# np.shape(hp_inner_hp) = (len(num_det), size1)
hp_inner_hp = np.zeros((len(num_det), size1), dtype=np.complex128)
hc_inner_hc = np.zeros((len(num_det), size1), dtype=np.complex128)
with Pool(processes=npool) as pool:
    # call the same function with different data in parallel
    # imap->retain order in the list, while map->doesn't
    if snr_mine.multiprocessing_verbose:
        for result in tqdm(
            pool.imap_unordered(noise_weighted_inner_prod, input_arguments),
            total=len(input_arguments),
            ncols=100,
        ):
            # but, np.shape(hp_inner_hp_i) = (size1, len(num_det))
            hp_inner_hp_i, hc_inner_hc_i, iter_i = result
            hp_inner_hp[:, iter_i] = hp_inner_hp_i
            hc_inner_hc[:, iter_i] = hc_inner_hc_i
    else:
        # with map, without tqdm
        for result in pool.map(noise_weighted_inner_prod, input_arguments):
            hp_inner_hp_i, hc_inner_hc_i, iter_i = result
            hp_inner_hp[:, iter_i] = hp_inner_hp_i
            hc_inner_hc[:, iter_i] = hc_inner_hc_i

# get polarization tensor
# np.shape(Fp) = (size1, len(num_det))
Fp, Fc = antenna_response_array(
    ra[idx], dec[idx], geocent_time[idx], psi[idx], detector_tensor
)
snrs_sq = abs((Fp**2) * hp_inner_hp + (Fc**2) * hc_inner_hc)
snr = np.sqrt(snrs_sq)
snr_effective = np.sqrt(np.sum(snrs_sq, axis=0))

# organizing the snr dictionary
optimal_snr = dict()
for j, det in enumerate(detectors):
    snr_buffer = np.zeros(num)
    snr_buffer[idx] = snr[j]
    optimal_snr[det] = snr_buffer
snr_buffer = np.zeros(num)
snr_buffer[idx] = snr_effective
optimal_snr["optimal_snr_net"] = snr_buffer

# Save as JSON file
if output_jsonfile:
    output_filename = (
        output_jsonfile if isinstance(output_jsonfile, str) else "snr.json"
    )
    save_json(output_filename, optimal_snr)

return optimal_snr