In [1]:
# jax import 
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import bilby
import numpy as np

C = 299792458.0
G = 6.67408 * 1e-11
Mo = 1.989 * 1e30
Gamma = 0.5772156649015329
Pi = jnp.pi
MTSUN_SI = 4.925491025543576e-06

In [2]:
num = 5
luminosity_distance = 100
test = jnp.array([luminosity_distance]).reshape(-1) * jnp.ones(num)
print(test)

Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

[100. 100. 100. 100. 100.]




In [15]:
####################################################
#                                                  #
#            psd array finder from bilby           #
#                                                  #
####################################################
def power_spectral_density(psd):
    """
    psd array finder from bilby

    Parameters
    ----------
    psd : str
        name of the psd
        e.g. 'aLIGO_O4_high_psd.txt'

    Returns
    -------
    psd_array : bilby.gw.detector.psd.PowerSpectralDensity object
    """
    return bilby.gw.detector.PowerSpectralDensity(psd_file=psd)

####################################################
#                                                  #
#            asd array finder from bilby           #
#                                                  #
####################################################
def amplitude_spectral_density(asd):
    """
    asd array finder from bilby

    Parameters
    ----------
    asd : str
        name of the asd
        e.g. 'aLIGO_O4_high_asd.txt'

    Returns
    -------
    psd_array : bilby.gw.detector.psd.PowerSpectralDensity object
    """
    return bilby.gw.detector.PowerSpectralDensity(asd_file=asd)

####################################################
#                                                  #
#                   Chirp time                     #
#                                                  #
####################################################
@jit
def findchirp_chirptime(m1, m2, fmin):
    """
    Time taken from f_min to f_lso (last stable orbit). 3.5PN in fourier phase considered.
    -----------------
    Input parameters
    -----------------
    m1         : component mass of BBH, m1>m2, unit(Mo)
    m2         : component mass of BBH, m1>m2, unit(Mo)
    fmin       : minimum frequency cut-off for the analysis, unit(s)
    -----------------
    Return values
    -----------------
    chirp_time : Time taken from f_min to f_lso (frequency at last stable orbit), unit(s)
    """
    # variables used to compute chirp time
    m = m1 + m2
    eta = m1 * m2 / m / m
    c0T = c2T = c3T = c4T = c5T = c6T = c6LogT = c7T = 0.0

    c7T = Pi * (
        14809.0 * eta * eta / 378.0 - 75703.0 * eta / 756.0 - 15419335.0 / 127008.0
    )

    c6T = (
        Gamma * 6848.0 / 105.0
        - 10052469856691.0 / 23471078400.0
        + Pi * Pi * 128.0 / 3.0
        + eta * (3147553127.0 / 3048192.0 - Pi * Pi * 451.0 / 12.0)
        - eta * eta * 15211.0 / 1728.0
        + eta * eta * eta * 25565.0 / 1296.0
        + eta * eta * eta * 25565.0 / 1296.0
        + jnp.log(4.0) * 6848.0 / 105.0
    )
    c6LogT = 6848.0 / 105.0

    c5T = 13.0 * Pi * eta / 3.0 - 7729.0 * Pi / 252.0

    c4T = 3058673.0 / 508032.0 + eta * (5429.0 / 504.0 + eta * 617.0 / 72.0)
    c3T = -32.0 * Pi / 5.0
    c2T = 743.0 / 252.0 + eta * 11.0 / 3.0
    c0T = 5.0 * m * MTSUN_SI / (256.0 * eta)

    # This is the PN parameter v evaluated at the lower freq. cutoff
    xT = pow(Pi * m * MTSUN_SI * fmin, 1.0 / 3.0)
    x2T = xT * xT
    x3T = xT * x2T
    x4T = x2T * x2T
    x5T = x2T * x3T
    x6T = x3T * x3T
    x7T = x3T * x4T
    x8T = x4T * x4T

    # Computes the chirp time as tC = t(v_low)
    # tC = t(v_low) - t(v_upper) would be more
    # correct, but the difference is negligble.
    return (
        c0T
        * (
            1
            + c2T * x2T
            + c3T * x3T
            + c4T * x4T
            + c5T * x5T
            + (c6T + c6LogT * jnp.log(xT)) * x6T
            + c7T * x7T
        )
        / x8T
    )

In [16]:
npool = 4
sampling_frequency = 2048
approximant = "IMRPhenomXPHM"
f_min = 20.
duration=4
idx_tracker=0
detectors = ["L1", "H1", "V1"]
ifos = bilby.gw.detector.InterferometerList(detectors)
psd_file = [False, False, False]
psds = dict()
psds["L1"] = "aLIGO_O4_high_asd.txt"
psds["H1"] = "aLIGO_O4_high_asd.txt"
psds["V1"] = "AdV_asd.txt"

psds_arrays = dict()
i = 0  # iterator
for det in detectors:
    # either provided psd or what's available in bilby
    if type(psds[det]) == str and psds[det][-3:] == "txt":
        if psd_file[i]:
            psds_arrays[det] = power_spectral_density(psds[det])
        else:
            psds_arrays[det] = amplitude_spectral_density(psds[det])
    i += 1  # iterator wrt detectors

In [113]:
def buffer(
    mass_1,
    mass_2,
    luminosity_distance,
    theta_jn,
    psi,
    phase,
    geocent_time,
    ra,
    dec,
    a_1,
    a_2,
    tilt_1,
    tilt_2,
    phi_12,
    phi_jl,
):
    results=noise_weighted_inner_product(
        mass_1,
        mass_2,
        luminosity_distance,
        theta_jn,
        psi,
        phase,
        geocent_time,
        ra,
        dec,
        a_1,
        a_2,
        tilt_1,
        tilt_2,
        phi_12,
        phi_jl
    )
    return results

def polas_fn(
        mass_1,
        mass_2,
        luminosity_distance,
        theta_jn,
        psi,
        phase,
        geocent_time,
        ra,
        dec,
        a_1,
        a_2,
        tilt_1,
        tilt_2,
        phi_12,
        phi_jl
    ):

    parameters = {
        "mass_1": mass_1,
        "mass_2": mass_2,
        "luminosity_distance": luminosity_distance,
        "theta_jn": theta_jn,
        "psi": psi,
        "phase": phase,
        "geocent_time": geocent_time,
        "ra": ra,
        "dec": dec,
        "a_1": a_1,
        "a_2": a_2,
        "tilt_1": tilt_1,
        "tilt_2": tilt_2,
        "phi_12": phi_12,
        "phi_jl": phi_jl,
    }
    
    polas = waveform_generator.frequency_domain_strain(parameters=parameters)
    return polas

def noise_weighted_inner_product(
    mass_1,
    mass_2,
    luminosity_distance,
    theta_jn,
    psi,
    phase,
    geocent_time,
    ra,
    dec,
    a_1,
    a_2,
    tilt_1,
    tilt_2,
    phi_12,
    phi_jl,
    # list_of_detectors,
    # psds_arrays,
    # approximant,
    # f_min,
    # sampling_frequency,
    # duration,
    # idx_tracker,
    # ifos,
):

    list_of_detectors = detectors    


    # SNRs_list = []
    # NetSNR = 0.0
    # for i in range(len(list_of_detectors)):
    #     # make an ifo object to get the antenna pattern
    #     Fp = ifos[i].antenna_response(
    #         parameters["ra"],
    #         parameters["dec"],
    #         parameters["geocent_time"],
    #         parameters["psi"],
    #         "plus",
    #     )
    #     Fc = ifos[i].antenna_response(
    #         parameters["ra"],
    #         parameters["dec"],
    #         parameters["geocent_time"],
    #         parameters["psi"],
    #         "cross",
    #     )

    # return mass_1*mass_2


In [114]:
buffer(
    mass_1=30.0,
    mass_2=30.0,
    luminosity_distance=300.0,
    theta_jn=0.4,
    psi=0.1,
    phase=1.2,
    geocent_time=1249852157.0,
    ra=1.375,
    dec=-1.2108,
    a_1=0.4,
    a_2=0.3,
    tilt_1=0.0,
    tilt_2=0.0,
    phi_12=0.0,
    phi_jl=0.0,
    # list_of_detectors=detectors,
    # psds_arrays=psds_arrays,
    # approximant=approximant,
    # f_min=f_min,
    # sampling_frequency=sampling_frequency,
    # duration=4.0,
    # idx_tracker=0,
    # ifos=ifos,
)

30.0


In [115]:
# vmap test
mass_1_list = []
mass_2_list = []
luminosity_distance_list = []
theta_jn_list = []
psi_list = []
phase_list = []
geocent_time_list = []
ra_list = []
dec_list = []
a_1_list = []
a_2_list = []
tilt_1_list = []
tilt_2_list = []
phi_12_list = []
phi_jl_list = []
# detectors_list = []
# psds_arrays_list = []
# approximant_list = []
# f_min_list = []
# sampling_frequency_list = []
# duration_list = []
# idx_tracker_list = []
# ifos_list = []

# for i in range(3):
#     mass_1_list.append([30.0])
#     mass_2_list.append([30.0])
#     luminosity_distance_list.append([300.0])
#     theta_jn_list.append([0.4])
#     psi_list.append([0.1])
#     phase_list.append([1.2])
#     geocent_time_list.append([1249852157.0])
#     ra_list.append([1.375])
#     dec_list.append([-1.2108])
#     a_1_list.append([0.4])
#     a_2_list.append([0.3])
#     tilt_1_list.append([0.0])
#     tilt_2_list.append([0.0])
#     phi_12_list.append([0.0])
#     phi_jl_list.append([0.0])
    # detectors_list.append(detectors)
    # psds_arrays_list.append(psds_arrays)
    # approximant_list.append(approximant)
    # f_min_list.append(f_min)
    # sampling_frequency_list.append(sampling_frequency)
    # duration_list.append(4.0)
    # idx_tracker_list.append(0)
    # ifos_list.append(ifos)

for i in range(3):
    mass_1_list.append(30.0)
    mass_2_list.append(30.0)
    luminosity_distance_list.append(300.0)
    theta_jn_list.append(0.4)
    psi_list.append(0.1)
    phase_list.append(1.2)
    geocent_time_list.append(1249852157.0)
    ra_list.append(1.375)
    dec_list.append(-1.2108)
    a_1_list.append(0.4)
    a_2_list.append(0.3)
    tilt_1_list.append(0.0)
    tilt_2_list.append(0.0)
    phi_12_list.append(0.0)
    phi_jl_list.append(0.0)

In [116]:
mass_1_list = jnp.array(mass_1_list)
mass_2_list = jnp.array(mass_2_list)
luminosity_distance_list = jnp.array(luminosity_distance_list)
theta_jn_list = jnp.array(theta_jn_list)
psi_list = jnp.array(psi_list)
phase_list = jnp.array(phase_list)
geocent_time_list = jnp.array(geocent_time_list)
ra_list = jnp.array(ra_list)
dec_list = jnp.array(dec_list)
a_1_list = jnp.array(a_1_list)
a_2_list = jnp.array(a_2_list)
tilt_1_list = jnp.array(tilt_1_list)
tilt_2_list = jnp.array(tilt_2_list)
phi_12_list = jnp.array(phi_12_list)
phi_jl_list = jnp.array(phi_jl_list)

In [122]:
mass_1_list

Array([30., 30., 30.], dtype=float32)

In [118]:
noise_weighted_inner_product_vmap = vmap(buffer)
test = noise_weighted_inner_product_vmap(
    mass_1_list,
    mass_2_list,
    luminosity_distance_list,
    theta_jn_list,
    psi_list,
    phase_list,
    geocent_time_list,
    ra_list,
    dec_list,
    a_1_list,
    a_2_list,
    tilt_1_list,
    tilt_2_list,
    phi_12_list,
    phi_jl_list,
    )

Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([30., 30., 30.], dtype=float32)
  batch_dim = 0


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/0)> with
  val = Array([ True,  True,  True], dtype=bool)
  batch_dim = 0
The problem arose with the `bool` function. 
This BatchTracer with object id 11827945344 was created on line:
  /var/folders/ws/0948zvwd7g795j2l3fryghjw0000gp/T/ipykernel_20301/2723540240.py:73 (polas_fn)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [61]:
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  print(x)
  return jnp.array(output)

#print(convolve(x, w))

# Suppose we would like to apply this function to a batch of
# weights w to a batch of vectors x.
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

# use vmap instead of looping
auto_batch_convolve = vmap(convolve)

print(auto_batch_convolve(xs, ws))

Traced<ShapedArray(int32[5])>with<BatchTrace(level=1/0)> with
  val = Array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]], dtype=int32)
  batch_dim = 0
[[11. 20. 29.]
 [11. 20. 29.]]


In [110]:
bilby.core.utils.logger.disabled = True
mass_1=30.0
mass_2=30.0
luminosity_distance=300.0
theta_jn=0.4
psi=0.1
phase=1.2
geocent_time=1249852157.0
ra=1.375
dec=-1.2108
a_1=0.4
a_2=0.3
tilt_1=0.0
tilt_2=0.0
phi_12=0.0
phi_jl=0.0
parameters = {
    "mass_1": mass_1,
    "mass_2": mass_2,
    "luminosity_distance": luminosity_distance,
    "theta_jn": theta_jn,
    "psi": psi,
    "phase": phase,
    "geocent_time": geocent_time,
    "ra": ra,
    "dec": dec,
    "a_1": a_1,
    "a_2": a_2,
    "tilt_1": tilt_1,
    "tilt_2": tilt_2,
    "phi_12": phi_12,
    "phi_jl": phi_jl,
}

waveform_arguments = dict(
    waveform_approximant=approximant,
    reference_frequency=20.,
    minimum_frequency=f_min,
)

waveform_generator = bilby.gw.WaveformGenerator(
    duration=duration,
    sampling_frequency=sampling_frequency,
    frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
    waveform_arguments=waveform_arguments,
)
polas = waveform_generator.frequency_domain_strain(parameters=parameters)

In [124]:
def fp_fn(
        ra,
        dec,
        geocent_time,
        psi,):
    
    Fp = ifos[i].antenna_response(
            ra,
            dec,
            geocent_time,
            psi,
            "plus",
        )
    return Fp

In [128]:
vmap_fp_fn = vmap(fp_fn)
print(vmap_fp_fn(ra_list, dec_list, geocent_time_list, psi_list))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([1.375, 1.375, 1.375], dtype=float32)
  batch_dim = 0
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
This BatchTracer with object id 11850120256 was created on line:
  /var/folders/ws/0948zvwd7g795j2l3fryghjw0000gp/T/ipykernel_20301/2574265894.py:2 (<module>)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [9]:
# Vectorize the function
vmap_findchirp_chirptime = vmap(findchirp_chirptime)

# Example usage
size = 1000
# Create a random key
key = random.PRNGKey(0)
# Split the key for generating two independent sets of random numbers
key_m1, key_m2 = random.split(key, 2)
# Generate random values for m1 and m2 in the range [10, 100]
m1_batch = random.uniform(key_m1, (size,), minval=10, maxval=100)
m2_batch = random.uniform(key_m2, (size,), minval=10, maxval=100)
# generate fmin_batch with value 20 and size=size
fmin_batch = jnp.ones(size)*f_min

chirp_times = vmap_findchirp_chirptime(m1_batch, m2_batch, fmin_batch)

NameError: name 'findchirp_chirptime' is not defined

In [10]:
import numpy as np
dectectorList = np.array(detectors) * np.ones(
        (5, len(detectors)), dtype=object
    )

In [11]:
dectectorList

array([['L1', 'H1', 'V1'],
       ['L1', 'H1', 'V1'],
       ['L1', 'H1', 'V1'],
       ['L1', 'H1', 'V1'],
       ['L1', 'H1', 'V1']], dtype=object)

In [40]:
def compute_bilby_snr_jax(
    mass_1,
    mass_2,
    luminosity_distance=100.0,
    theta_jn=0.0,
    psi=0.0,
    phase=0.0,
    geocent_time=1249852157.0,
    ra=0.0,
    dec=0.0,
    a_1=0.0,
    a_2=0.0,
    tilt_1=0.0,
    tilt_2=0.0,
    phi_12=0.0,
    phi_jl=0.0,
    psd_with_time=False,
    verbose=True,
    jsonFile=False,
):

    ##################################  
    # get the psds for the detectors #
    ##################################
    psds_arrays = dict()
    i = 0  # iterator
    for det in detectors:
        # either provided psd or what's available in bilby
        if type(psds[det]) == str and psds[det][-3:] == "txt":
            if psd_file[i]:
                psds_arrays[det] = power_spectral_density(psds[det])
            else:
                psds_arrays[det] = amplitude_spectral_density(psds[det])
        i += 1  # iterator wrt detectors

    # reshape(-1) is so that either a float value is given or the input is an numpy array
    # jnp.ones is multipled to make sure everything is of same length
    mass_1, mass_2 = jnp.array([mass_1]).reshape(-1), jnp.array([mass_2]).reshape(-1)
    num = len(mass_1)
    # reshaping other parameters
    (
        luminosity_distance,
        theta_jn,
        psi,
        phase,
        geocent_time,
        ra,
        dec,
        a_1,
        a_2,
        tilt_1,
        tilt_2,
        phi_12,
        phi_jl,
    ) = (
        jnp.array([luminosity_distance]).reshape(-1) * jnp.ones(num),
        jnp.array([theta_jn]).reshape(-1) * jnp.ones(num),
        jnp.array([psi]).reshape(-1) * jnp.ones(num),
        jnp.array([phase]).reshape(-1) * jnp.ones(num),
        jnp.array([geocent_time]).reshape(-1) * jnp.ones(num),
        jnp.array([ra]).reshape(-1) * jnp.ones(num),
        jnp.array([dec]).reshape(-1) * jnp.ones(num),
        jnp.array([a_1]).reshape(-1) * jnp.ones(num),
        jnp.array([a_2]).reshape(-1) * jnp.ones(num),
        jnp.array([tilt_1]).reshape(-1) * jnp.ones(num),
        jnp.array([tilt_2]).reshape(-1) * jnp.ones(num),
        jnp.array([phi_12]).reshape(-1) * jnp.ones(num),
        jnp.array([phi_jl]).reshape(-1) * jnp.ones(num),
    )

    iter_ = []
    SNRs_list = []
    SNRs_dict = {}
    # time duration calculation for each of the mass combination
    safety = 1.2
    approx_duration = safety * findchirp_chirptime(mass_1, mass_2, f_min)
    duration = jnp.ceil(approx_duration + 2.0)

    size1 = len(mass_1)
    idx_tracker = jnp.arange(size1)  # to keep track of index

    # dectectorList is a list of size (size1, len(detectors)) is created with each row having same value. not an numpy array
    dectectorList = []
    psds_arrays_list = []
    approximant_list = []
    f_min_list = []
    sampling_frequency_list = []
    ifos_list = []
    for i in range(size1):
        dectectorList.append(detectors)
        psds_arrays_list.append(psds_arrays)
        approximant_list.append(approximant)
        f_min_list.append(f_min)
        sampling_frequency_list.append(sampling_frequency)
        ifos_list.append(ifos)

    noise_weighted_inner_product_vmap = vmap(noise_weighted_inner_product)
    # vmap 
    snr = noise_weighted_inner_product_vmap(
        mass_1,
        mass_2,
        luminosity_distance,
        theta_jn,
        psi,
        phase,
        geocent_time,
        ra,
        dec,
        a_1,
        a_2,
        tilt_1,
        tilt_2,
        phi_12,
        phi_jl,
        dectectorList,
        psds_arrays_list,
        approximant_list,
        f_min_list,
        sampling_frequency_list,
        duration,
        idx_tracker,
        ifos,
    )

    return duration

In [41]:
# Example usage
size = 1000
# Create a random key
key = random.PRNGKey(0)
# Split the key for generating two independent sets of random numbers
key_m1, key_m2 = random.split(key, 2)
# Generate random values for m1 and m2 in the range [10, 100]
m1_batch = random.uniform(key_m1, (size,), minval=10, maxval=100)
m2_batch = random.uniform(key_m2, (size,), minval=10, maxval=100)
# generate fmin_batch with value 20 and size=size
fmin_batch = jnp.ones(size)*f_min
test = compute_bilby_snr_jax(mass_1=m1_batch, mass_2=m2_batch)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [18]:
test

Array([3., 3., 5., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4., 3., 3., 3., 3., 3., 6.,
       3., 3., 3., 3., 4., 3., 3., 3., 3., 4., 4., 3., 3., 4., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 5., 3., 2., 3., 3., 3., 3.,
       4., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4., 3., 3., 3.,
       3., 3., 3., 6., 4., 4., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4.,
       3., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 6., 3., 6., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 4., 3., 3., 3., 3., 3.,
       3., 3., 3., 4., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4.,
       3., 3., 3., 6., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 5., 3.,
       3., 3., 3., 3., 3., 3., 3., 5., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       4., 3., 3., 3., 2., 3., 4., 3., 3., 3., 3., 5., 3., 3., 3., 3., 3.,
       3., 3., 3., 4., 3.