In [1]:
from jax import config, numpy as jnp
# Set JAX to use 64-bit precision
config.update("jax_enable_x64", True)

# Import modules from jimgw
from jimgw.core.single_event.detector import H1, L1, V1
from jimgw.core.single_event.waveform import RippleIMRPhenomPv2
from jimgw.core.single_event.gps_times import greenwich_mean_sidereal_time as compute_gmst
from jimgw.core.prior import (
    CombinePrior,
    UniformPrior,
    CosinePrior,
    SinePrior,
    PowerLawPrior,
    UniformSpherePrior,
    RayleighPrior,
    # SimpleConstrainedPrior,
)
from jimgw.core.transforms import PeriodicTransform, BoundToUnbound
from jimgw.core.single_event.likelihood import (
    TransientLikelihoodFD,
    HeterodynedTransientLikelihoodFD,
)
from jimgw.core.single_event.transforms import (
    SkyFrameToDetectorFrameSkyPositionTransform,
    MassRatioToSymmetricMassRatioTransform,
    DistanceToSNRWeightedDistanceTransform,
    GeocentricArrivalTimeToDetectorArrivalTimeTransform,
    GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
    SphereSpinToCartesianSpinTransform,
    SpinAnglesToCartesianSpinTransform,
)

##################################################
################# Configuration ##################
##################################################

gps_time = 1187008882.4

f_min = 20.0
f_max = 1024.0
duration = 4.0
sampling_frequency = 2048.0

# initialize waveform
PhenomPv2 = RippleIMRPhenomPv2(f_ref=50.0)

##################################################
################### Injection ####################
##################################################

injection_parameters = {
    "M_c": 30.0,
    "eta": 0.2394761120859283,
    "s1_x": 0.2382837295002566,
    "s1_y": -0.22687038916583377,
    "s1_z": 0.1326624241517266,
    "s2_x": 0.2786812656038679,
    "s2_y": -0.2852388855713494,
    "s2_z": -0.15189958676694681,
    "ra": 3.44616,
    "dec": -0.408084,
    "psi": 2.8466991142089473,
    "d_L": 1300.46265427215251,
    "iota": 2.641220540787931,
    "phase_c": 5.18747203625618,
    "t_c": 1187008882.4298131 - gps_time,
}
injection_parameters.update({"gmst": compute_gmst(gps_time), "trigger_time": gps_time})
injection_parameters = {key: jnp.array(value) for key, value in injection_parameters.items()}

ifos = [H1, L1, V1]
for ifo in ifos:
    ifo.load_and_set_psd()
    ifo.frequency_bounds = (f_min, f_max)
    ifo.inject_signal(
        duration=duration,
        sampling_frequency=sampling_frequency,
        epoch=0.0,
        waveform_model=PhenomPv2,
        parameters=injection_parameters,
    )




SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(False)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  from lal import LIGOTimeGPS


Grabbing GWTC-2 PSD for H1
For detector H1, the injected signal has:
  - Optimal SNR: 14.6734
  - Match filtered SNR: 14.9123-1.3346j
Grabbing GWTC-2 PSD for L1
For detector L1, the injected signal has:
  - Optimal SNR: 15.3082
  - Match filtered SNR: 15.8173+0.1457j
Grabbing GWTC-2 PSD for V1
For detector V1, the injected signal has:
  - Optimal SNR: 2.2786
  - Match filtered SNR: 1.4483+1.7300j


In [2]:
import jimgw

jimgw.__file__

'/users/hin-wai.leong/src/kaze_jim/src/jimgw/__init__.py'

In [3]:
DistanceToSNRWeightedDistanceTransform?

[31mInit signature:[39m
DistanceToSNRWeightedDistanceTransform(
    gps_time: jaxtyping.Float,
    ifos: Sequence[jimgw.core.single_event.detector.GroundBased2G],
    dL_min: jaxtyping.Float,
    dL_max: jaxtyping.Float,
)
[31mDocstring:[39m     
Transform the luminosity distance to network SNR weighted distance

Parameters
----------
name_mapping : tuple[list[str], list[str]]
        The name mapping between the input and output dictionary.
[31mFile:[39m           ~/src/kaze_jim/src/jimgw/core/single_event/transforms.py
[31mType:[39m           ABCMeta
[31mSubclasses:[39m     

In [None]:
# Mass priors
Mc_lower, Mc_upper = 20.0, 40.0
q_min, q_max = 0.125, 1.0
dL_upper = 5000.0
priors = CombinePrior([
    UniformPrior(Mc_lower, Mc_upper, parameter_names=["M_c"]),
    UniformPrior(q_min, q_max, parameter_names=["q"]),
    UniformSpherePrior(parameter_names=["s1"], max_mag=0.99),
    UniformSpherePrior(parameter_names=["s2"], max_mag=0.99),
    SinePrior(parameter_names=["iota"]),
    PowerLawPrior(1.0, dL_upper, 2.0, parameter_names=["d_L"]),
    UniformPrior(-0.1, 0.1, parameter_names=["t_c"]),
    UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]),
    UniformPrior(0.0, jnp.pi, parameter_names=["psi"]),
    UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]),
    CosinePrior(parameter_names=["dec"]),
    RayleighPrior(1.0, parameter_names=["periodic_1"]),
    RayleighPrior(1.0, parameter_names=["periodic_2"]),
    RayleighPrior(1.0, parameter_names=["periodic_3"]),
    RayleighPrior(1.0, parameter_names=["periodic_4"]),
    RayleighPrior(1.0, parameter_names=["periodic_5"]),
])

# -------------------------------
# Define sample and likelihood transforms
# -------------------------------
sample_transforms = [
    # Transformations for luminosity distance
    DistanceToSNRWeightedDistanceTransform(dL_min=1.0, dL_max=dL_upper, gps_time=gps_time, ifos=ifos),
    # BoundToUnbound(name_mapping=(["d_L"], ["d_L_unbounded"]), original_lower_bound=0.0, original_upper_bound=dL_upper),

    # Transformations for phase
    GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps_time, ifo=ifos[0]),
    PeriodicTransform(name_mapping=(["periodic_4", "phase_det"], ["phase_det_x", "phase_det_y"]), xmin=0.0, xmax=2 * jnp.pi),
    # PeriodicTransform(name_mapping=(["periodic_4", "phase_c"], ["phase_c_x", "phase_c_y"]), xmin=0.0, xmax=2 * jnp.pi),

    # Transformations for time
    GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=-0.1, tc_max=0.1, gps_time=gps_time, ifo=ifos[0]),
    # BoundToUnbound(name_mapping=(["t_c"], ["t_c_unbounded"]), original_lower_bound=-0.1, original_upper_bound=0.1),

    # Transformations for sky position
    SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps_time, ifos=ifos),
    BoundToUnbound(name_mapping=(["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    PeriodicTransform(name_mapping=(["periodic_3", "azimuth"], ["azimuth_x", "azimuth_y"]), xmin=0.0, xmax=2 * jnp.pi),

    # Transformations for polarization angle
    PeriodicTransform(name_mapping=(["periodic_5", "psi"], ["psi_base_x", "psi_base_y"]), xmin=0.0, xmax=jnp.pi),

    # Transformations for masses
    BoundToUnbound(name_mapping=(["M_c"], ["M_c_unbounded"]), original_lower_bound=Mc_lower, original_upper_bound=Mc_upper),
    BoundToUnbound(name_mapping=(["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),

    # Transformations for spins
    BoundToUnbound(name_mapping=(["iota"], ["iota_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    PeriodicTransform(name_mapping=(["periodic_1", "s1_phi"], ["s1_phi_base_x", "s1_phi_base_y"]), xmin=0.0, xmax=2 * jnp.pi),
    PeriodicTransform(name_mapping=(["periodic_2", "s2_phi"], ["s2_phi_base_x", "s2_phi_base_y"]), xmin=0.0, xmax=2 * jnp.pi),
    BoundToUnbound(name_mapping=(["s1_theta"], ["s1_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    BoundToUnbound(name_mapping=(["s2_theta"], ["s2_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    BoundToUnbound(name_mapping=(["s1_mag"], ["s1_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
    BoundToUnbound(name_mapping=(["s2_mag"], ["s2_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
    # BoundToUnbound(name_mapping=(["theta_jn"], ["theta_jn_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    # PeriodicTransform(name_mapping=(["periodic_1", "phi_jl"], ["phi_jl_x", "phi_jl_y"]), xmin=0.0, xmax=2 * jnp.pi),
    # BoundToUnbound(name_mapping=(["tilt_1"], ["tilt_1_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    # BoundToUnbound(name_mapping=(["tilt_2"], ["tilt_2_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
    # PeriodicTransform(name_mapping=(["periodic_2", "phi_12"], ["phi_12_x", "phi_12_y"]), xmin=0.0, xmax=2 * jnp.pi),
    # BoundToUnbound(name_mapping=(["a_1"], ["a_1_unbounded"]), original_lower_bound=0.0, original_upper_bound=1.0),
    # BoundToUnbound(name_mapping=(["a_2"], ["a_2_unbounded"]), original_lower_bound=0.0, original_upper_bound=1.0),
]

# Likelihood transforms
likelihood_transforms = [
    SphereSpinToCartesianSpinTransform("s1"),
    SphereSpinToCartesianSpinTransform("s2"),
    # SpinAnglesToCartesianSpinTransform(freq_ref=ref_freq),
    MassRatioToSymmetricMassRatioTransform,
]

likelihood_original = TransientLikelihoodFD(
    detectors=ifos,
    waveform=PhenomPv2,
    trigger_time=gps_time,
    f_min=f_min,
    f_max=f_max,
)

print("Original maximum likelihood:", likelihood_original.evaluate(injection_parameters))

likelihood_heterodyned = HeterodynedTransientLikelihoodFD(
    ifos,
    waveform=PhenomPv2,
    n_bins=1_000,
    trigger_time=gps_time,
    f_min=f_min,
    f_max=f_max,
    n_steps=100,
    popsize =2,
    prior=priors,
    sample_transforms=sample_transforms,
    likelihood_transforms=likelihood_transforms
)

Original maximum likelihood: 236.82753416900218
Starting Stage 1 optimization
Stage 1 parameters: {'eta', 'dec', 'd_L', 't_c', 'iota', 'M_c', 'ra'}
Using Adam optimization
Stage 1 optimization completed.
Optimised parameters: {'periodic_2': Array(2.95550434, dtype=float64), 'periodic_1': Array(1.2408085, dtype=float64), 'iota': Array(1.67043364, dtype=float64), 'M_c': Array(37.63893566, dtype=float64), 'psi': Array(1.40257024, dtype=float64), 'periodic_5': Array(3.46782923, dtype=float64), 'periodic_3': Array(0.66135783, dtype=float64), 'ra': Array(0.08690178, dtype=float64), 'dec': Array(-1.32952173, dtype=float64), 't_c': Array(0.05441922, dtype=float64), 'periodic_4': Array(1.23660465, dtype=float64), 'phase_c': Array(0.2798947, dtype=float64), 'd_L': Array(2684.42480481, dtype=float64), 's1_x': Array(-0.08142322, dtype=float64), 's1_y': Array(0.07743395, dtype=float64), 's1_z': Array(0.21448354, dtype=float64), 's2_x': Array(-0.15366876, dtype=float64), 's2_y': Array(-0.08708085, d

In [5]:
params_2, params_1 = likelihood_heterodyned.maximize_likelihood(
    prior=priors,
    likelihood_transforms=likelihood_transforms,
    sample_transforms=sample_transforms,
    popsize=10,
    n_steps=500,
    return_stage1=True,
)

for params in [params_1, params_2]:
    print("Optimized parameters:")
    for key, value in params.items():
        inj_value = injection_parameters.get(key, None)
        if inj_value is not None:
            frac_diff = 1 - value / inj_value
            print(f"{key}: {frac_diff:.3f}")

Starting Stage 1 optimization
Stage 1 parameters: {'eta', 'dec', 'd_L', 't_c', 'iota', 'M_c', 'ra'}
Using Adam optimization
Stage 1 optimization completed.
Optimised parameters: {'periodic_2': Array(3.97338101, dtype=float64), 'periodic_1': Array(2.02971166, dtype=float64), 'iota': Array(2.39944902, dtype=float64), 'M_c': Array(33.52709524, dtype=float64), 'psi': Array(2.77034538, dtype=float64), 'periodic_5': Array(0.97630761, dtype=float64), 'periodic_3': Array(0.4156608, dtype=float64), 'ra': Array(4.38663398, dtype=float64), 'dec': Array(-0.7212903, dtype=float64), 't_c': Array(-0.01090023, dtype=float64), 'periodic_4': Array(2.02148848, dtype=float64), 'phase_c': Array(1.16063998, dtype=float64), 'd_L': Array(3420.53642654, dtype=float64), 's1_x': Array(-0.77583622, dtype=float64), 's1_y': Array(0.10838007, dtype=float64), 's1_z': Array(-0.33153537, dtype=float64), 's2_x': Array(-0.5177537, dtype=float64), 's2_y': Array(0.10910674, dtype=float64), 's2_z': Array(0.17817189, dtype=f

In [6]:
injection_parameters

{'M_c': Array(30., dtype=float64, weak_type=True),
 'eta': Array(0.23947611, dtype=float64, weak_type=True),
 's1_x': Array(0.23828373, dtype=float64, weak_type=True),
 's1_y': Array(-0.22687039, dtype=float64, weak_type=True),
 's1_z': Array(0.13266242, dtype=float64, weak_type=True),
 's2_x': Array(0.27868127, dtype=float64, weak_type=True),
 's2_y': Array(-0.28523889, dtype=float64, weak_type=True),
 's2_z': Array(-0.15189959, dtype=float64, weak_type=True),
 'ra': Array(3.44616, dtype=float64, weak_type=True),
 'dec': Array(-0.408084, dtype=float64, weak_type=True),
 'psi': Array(2.84669911, dtype=float64, weak_type=True),
 'd_L': Array(1300.46265427, dtype=float64, weak_type=True),
 'iota': Array(2.64122054, dtype=float64, weak_type=True),
 'phase_c': Array(5.18747204, dtype=float64, weak_type=True),
 't_c': Array(0.02981305, dtype=float64, weak_type=True),
 'gmst': Array(40566.97324959, dtype=float64),
 'trigger_time': Array(1.18700888e+09, dtype=float64, weak_type=True)}

In [7]:
for params in [params_1, params_2]:
    print("Optimized parameters:")
    for key, value in params.items():
        inj_value = injection_parameters.get(key, None)
        if inj_value is not None:
            frac_diff = 1 - value / inj_value
            print(f"{key}: {frac_diff:.3f}")

Optimized parameters:
iota: 0.092
M_c: -0.118
psi: 0.027
ra: -0.273
dec: -0.768
t_c: 1.366
phase_c: 0.776
d_L: -1.630
s1_x: 4.256
s1_y: 1.478
s1_z: 3.499
s2_x: 2.858
s2_y: 1.383
s2_z: 2.173
eta: -0.010
Optimized parameters:
iota: 0.160
M_c: -0.128
psi: 0.036
ra: -0.279
dec: -0.400
t_c: 1.300
phase_c: 0.375
d_L: -1.633
s1_x: -1.984
s1_y: 0.305
s1_z: 0.881
s2_x: 0.773
s2_y: -0.193
s2_z: 0.746
eta: 0.034
