In [None]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import jax
import jax.numpy as jnp

import numpy as np

jax.config.update("jax_enable_x64", True)

N_SAMPLES = 1000

This is the notebook comparing all inner functions of `SkyFrameToDetectorFrameSkyPositionTransform` in `Jim` and the corresponding functions in `bilby`. 

The following is the test comparing `theta_phi_to_ra_dec` in `jim` and that in `bilby`.
See [bilby's lines](<https://git.ligo.org/lscsoft/bilby/-/blob/c6bcb81649b7ebf97ae6e1fd689e8712fe028eb0/bilby/core/utils/conversion.py#:~:text=def-,theta_phi_to_ra_dec,-(theta%2C>)

In [None]:
# test theta_phi_to_ra_dec in inverse transform
from jimgw.single_event.utils import theta_phi_to_ra_dec as jim_theta_phi_to_ra_dec
from bilby.core.utils import theta_phi_to_ra_dec as bilby_theta_phi_to_ra_dec

key = jax.random.PRNGKey(42)

tol_diff_ra = 0
tol_diff_dec = 0

for _ in range(N_SAMPLES):
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, 3)
    theta = jax.random.uniform(subkeys[0], (1,), minval=0, maxval=jnp.pi)
    phi = jax.random.uniform(subkeys[1], (1,), minval=0, maxval=jnp.pi)
    gmst = jax.random.uniform(subkeys[2], (1,), minval=0, maxval=2 * jnp.pi)

    jim_ra, jim_dec = jim_theta_phi_to_ra_dec(theta, phi, gmst)
    bilby_ra, bilby_dec = bilby_theta_phi_to_ra_dec(theta, phi, gmst)
    bilby_ra = bilby_ra % (2 * jnp.pi)
    diff_ra = jnp.abs(jim_ra - bilby_ra)
    diff_dec = jnp.abs(jim_dec - bilby_dec)
    tol_diff_ra += diff_ra
    tol_diff_dec += diff_dec

    assert jnp.allclose(
        jim_ra, bilby_ra, atol=1e-5
    ), f"jim_ra: {jim_ra}, bilby_ra: {bilby_ra}"
    assert jnp.allclose(
        jim_dec, bilby_dec, atol=1e-5
    ), f"jim_dec: {jim_dec}, bilby_dec: {bilby_dec}"

mean_ra_diff = tol_diff_ra / N_SAMPLES
mean_dec_diff = tol_diff_dec / N_SAMPLES
print("Mean difference in RA: ", mean_ra_diff)
print("Mean difference in DEC: ", mean_dec_diff)

The following is the test comparing `angle_rotation` in `jim` and `zenith_azimuth_to_theta_phi` in `bilby`. See [bilby's lines](https://git.ligo.org/colm.talbot/bilby-cython/-/blob/main/bilby_cython/geometry.pyx?ref_type=heads#:~:text=zenith_azimuth_to_theta_phi)

In [None]:
# test angle rotation
from jimgw.single_event.utils import angle_rotation as jim_angle_rotation
from jimgw.single_event.utils import euler_rotation
from bilby_cython.geometry import zenith_azimuth_to_theta_phi as bibly_angle_rotation
from bilby_cython.geometry import rotation_matrix_from_delta

tol_diff_theta = 0
tol_diff_phi = 0

for _ in range(N_SAMPLES):
    zenith = np.random.uniform(0, np.pi)
    azimuth = np.random.uniform(0, 2 * np.pi)
    delta_x = np.random.uniform(0, 1, size=3)

    # Ensure rotation matrix are the same
    jim_rot = euler_rotation(delta_x)
    bilby_rot = rotation_matrix_from_delta(delta_x)

    assert jnp.allclose(
        jim_rot, bilby_rot
    ), f"jim_rot: {jim_rot}, bilby_rot: {bilby_rot}"

    jim_theta, jim_phi = jim_angle_rotation(zenith, azimuth, jim_rot)
    bilby_out = bibly_angle_rotation(zenith, azimuth, delta_x)

    diff_theta = jnp.abs(jim_theta - bilby_out[0])
    diff_phi = jnp.abs(jim_phi - bilby_out[1])

    tol_diff_theta += diff_theta
    tol_diff_phi += diff_phi

    assert jnp.allclose(
        jim_theta, bilby_out[0]
    ), f"jim_theta: {jim_theta}, bilby_theta: {bilby_out[0]}"
    assert jnp.allclose(
        jim_phi, bilby_out[1]
    ), f"jim_phi: {jim_phi}, bilby_phi: {bilby_out[1]}"

mean_diff_theta = tol_diff_theta / N_SAMPLES
mean_diff_phi = tol_diff_phi / N_SAMPLES
print("Mean difference in theta: ", mean_diff_theta)
print("Mean difference in phi: ", mean_diff_phi)

The following compares `delta_x` in `Jim` and that in `bilby`

In [None]:
# test delta_x
from itertools import combinations
from bilby.gw.detector import InterferometerList
from jimgw.single_event.detector import detector_preset

HLV = ["H1", "L1", "V1"]
for ifos in combinations(HLV, 2):
    jim_ifos = [detector_preset[ifo] for ifo in ifos]
    bilby_ifos = InterferometerList(ifos)

    delta_x_j = jim_ifos[0].vertex - jim_ifos[1].vertex
    delta_x_b = bilby_ifos[0].vertex - bilby_ifos[1].vertex

    print(f"Difference in delta_x for {ifos}: {jnp.abs(delta_x_j - delta_x_b)}")

The following compares the `gmst` in `Jim` and that in `bilby`. See [bilby's lines](https://git.ligo.org/colm.talbot/bilby-cython/-/blob/main/bilby_cython/time.pyx?ref_type=heads#:~:text=greenwich_mean_sidereal_time)

In [None]:
# Compare the greenwich mean sidereal time (GMST) and greenwich apparent sidereal time (GAST).
from jimgw.gps_times import greenwich_mean_sidereal_time as jim_gmst
from bilby_cython.time import greenwich_mean_sidereal_time
from astropy.time import Time

tol_diff_1 = 0
tol_diff_2 = 0
tol_diff_3 = 0
gps_times = jax.random.uniform(
    jax.random.PRNGKey(42), N_SAMPLES*2, minval=1, maxval=2e9+1234.5678
)
# for time in np.random.uniform(1, 10000000, N_SAMPLES):
for time in gps_times:
    gps_time = Time(time, format="gps")
    gmst_j = gps_time.sidereal_time("mean", "greenwich").rad % (2 * np.pi)
    gast_j = gps_time.sidereal_time("apparent", "greenwich").rad % (2 * np.pi)
    gmst_jim = jim_gmst(time) % (2 * np.pi)
    gmst_b = greenwich_mean_sidereal_time(time) % (2 * np.pi)
    tol_diff_1 += jnp.abs(gmst_j - gmst_b)
    tol_diff_2 += jnp.abs(gast_j - gmst_b)
    tol_diff_3 += jnp.abs(gmst_jim - gmst_b)

mean_diff = tol_diff_1 / N_SAMPLES
print("Mean difference in GMST: ", mean_diff)
mean_diff = tol_diff_2 / N_SAMPLES
print("Mean difference in GAST: ", mean_diff)
mean_diff = tol_diff_3 / N_SAMPLES
print("Mean difference in new Jim GMST: ", mean_diff)

The following compares the `SkyFrameToDetectorFrameSkyPositionTransform` in `Jim` and `zenith_azimuth_to_ra_dec` in `bilby`. See [bilby's lines](https://git.ligo.org/lscsoft/bilby/-/blob/c6bcb81649b7ebf97ae6e1fd689e8712fe028eb0/bilby/gw/utils.py#:~:text=zenith_azimuth_to_ra_dec)

In [None]:
# test the transform
from jimgw.single_event.transforms import SkyFrameToDetectorFrameSkyPositionTransform
from jimgw.single_event.detector import H1, L1

from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_zenith_azimuth_to_ra_dec
from bilby.gw.detector import InterferometerList

key = jax.random.PRNGKey(42)

gps_time = 1126259642.413
jim_ifos = [H1, L1]

ifo_names = ["H1", "L1"]
bilby_ifos = InterferometerList(ifo_names)

tol_diff_dec = 0
tol_diff_ra = 0

for _ in range(N_SAMPLES):
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, 2)
    zenith = jax.random.uniform(subkeys[0], (1,), minval=0, maxval=jnp.pi)
    azimuth = jax.random.uniform(subkeys[1], (1,), minval=0, maxval=2 * jnp.pi)

    jim_transform = SkyFrameToDetectorFrameSkyPositionTransform(
        gps_time=gps_time, ifos=jim_ifos
    )
    jim_outputs, _ = jim_transform.inverse(dict(zenith=zenith, azimuth=azimuth))
    bilby_ra, bilby_dec = bilby_zenith_azimuth_to_ra_dec(
        zenith[0], azimuth[0], gps_time, bilby_ifos
    )
    jim_ra = jim_outputs["ra"]
    jim_dec = jim_outputs["dec"]

    diff_ra = jnp.abs(jim_ra - bilby_ra)
    diff_dec = jnp.abs(jim_dec - bilby_dec)
    tol_diff_ra += diff_ra
    tol_diff_dec += diff_dec

mean_ra_diff = tol_diff_ra / N_SAMPLES
mean_dec_diff = tol_diff_dec / N_SAMPLES
print("Mean difference in RA: ", mean_ra_diff)
print("Mean difference in DEC: ", mean_dec_diff)

As seen in the above, the source of error in `SkyFrameToDetectorFrameSkyPositionTransform` would be the difference in calculating `gmst`. `Jim` and `bilby` use different algorithms for calculating `gmst`. This introduces an error of the order 1e-5 to `ra`. 

**Update on 2025/05/06:**
With the newly implemented GMST algorithm, the difference is now far below `1e-5`.
In fact, it can reach exactly zero now.

# Ensure new and old implementation of the angle rotation are equilvalent

In [None]:
from jimgw.single_event.utils import euler_rotation

key = jax.random.PRNGKey(123)
key, subkey = jax.random.split(key)
zenith, azimuth = jax.random.uniform(key, (2, N_SAMPLES), minval=0, maxval=jnp.pi)
azimuth *= 2.0
delta_x = jax.random.uniform(subkey, (N_SAMPLES, 3), minval=0, maxval=1)


def old_angle_rotation(zenith, azimuth, rotation):
    sin_azimuth = jnp.sin(azimuth)
    cos_azimuth = jnp.cos(azimuth)
    sin_zenith = jnp.sin(zenith)
    cos_zenith = jnp.cos(zenith)

    theta = jnp.acos(
        rotation[2][0] * sin_zenith * cos_azimuth
        + rotation[2][1] * sin_zenith * sin_azimuth
        + rotation[2][2] * cos_zenith
    )
    phi = jnp.fmod(
        jnp.atan2(
            rotation[1][0] * sin_zenith * cos_azimuth
            + rotation[1][1] * sin_zenith * sin_azimuth
            + rotation[1][2] * cos_zenith,
            rotation[0][0] * sin_zenith * cos_azimuth
            + rotation[0][1] * sin_zenith * sin_azimuth
            + rotation[0][2] * cos_zenith,
        )
        + 2 * jnp.pi,
        2 * jnp.pi,
    )
    return theta, phi


def new_angle_rotation(zenith, azimuth, rotation):
    sky_loc_vec = jnp.array(
        [
            jnp.sin(zenith) * jnp.cos(azimuth),
            jnp.sin(zenith) * jnp.sin(azimuth),
            jnp.cos(zenith),
        ]
    )
    rotated_vec = jnp.einsum("ij,j...->i...", rotation, sky_loc_vec)

    theta = jnp.acos(rotated_vec[2])
    phi = jnp.fmod(
        jnp.atan2(rotated_vec[1], rotated_vec[0]) + 2 * jnp.pi,
        2 * jnp.pi,
    )
    return theta, phi

In [None]:
# Use much stringent tolerance for this test to ensure equivalence.
atol = 1e-13  # Default: 1e-5
rtol = 5e-15  # Default: 1e-8

max_diff = []
frac_diff = []
# Use the first 100 samples to test the rotation
for delta_x_i in delta_x:
    rotation_mat = euler_rotation(delta_x_i)

    old_theta_phi = jnp.array(old_angle_rotation(zenith, azimuth, rotation_mat))
    new_theta_phi = jnp.array(new_angle_rotation(zenith, azimuth, rotation_mat))

    abs_diff = jnp.abs(old_theta_phi - new_theta_phi)
    threshold = atol + rtol * jnp.abs(new_theta_phi)

    max_diff.append(jnp.max(abs_diff))
    frac_diff.append(jnp.max(1 - new_theta_phi / old_theta_phi))

    assert jnp.allclose(old_theta_phi, new_theta_phi, rtol=rtol, atol=atol), (
        f"Max. abs. diff. - threshold: {jnp.max(abs_diff - threshold):.3e} > 0!; \n"
        + f"old_theta_phi: {old_theta_phi}, \nnew_theta_phi: {new_theta_phi}"
    )

print("Max absolute difference: ", jnp.max(jnp.array(max_diff)))
print("Max fractional difference: ", jnp.max(jnp.array(frac_diff)))

It is clear that the new implementation is same as the old one at machine precision.