# GW150914 Parameter Estimation with Jim
This notebook demonstrates parameter estimation for the GW150914 event using the Jim library and the IMRPhenomPv2 waveform model. The workflow includes data fetching, PSD estimation, prior and transform setup, likelihood construction, and sampling.

In [None]:
# Import required libraries
import time
import jax
import jax.numpy as jnp
from jimgw.core.jim import Jim
from jimgw.core.prior import (
    CombinePrior,
    UniformPrior,
    CosinePrior,
    SinePrior,
    PowerLawPrior,
    UniformSpherePrior,
)
from jimgw.core.single_event.detector import get_H1, get_L1
from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD
from jimgw.core.single_event.data import Data
from jimgw.core.single_event.waveform import RippleIMRPhenomPv2
from jimgw.core.transforms import BoundToUnbound
from jimgw.core.single_event.transforms import (
    SkyFrameToDetectorFrameSkyPositionTransform,
    SphereSpinToCartesianSpinTransform,
    MassRatioToSymmetricMassRatioTransform,
    DistanceToSNRWeightedDistanceTransform,
    GeocentricArrivalTimeToDetectorArrivalTimeTransform,
    GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
)
jax.config.update("jax_enable_x64", True)

## 1. Fetch Data and Estimate PSD
We fetch a 4s segment centered on GW150914 for analysis and a 4096s segment for PSD estimation. Frequency integration bounds are set to avoid antialiasing artifacts.

In [None]:
total_time_start = time.time()
# GW150914 GPS time
gps = 1126259462.4
start = gps - 2
end = gps + 2
psd_start = gps - 2048
psd_end = gps + 2048
fmin = 20.0
fmax = 1024
# Initialize detectors
ifos = [get_H1(), get_L1()]
for ifo in ifos:
    data = Data.from_gwosc(ifo.name, start, end)
    ifo.set_data(data)
    psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end)
    psd_fftlength = data.duration * data.sampling_frequency
    ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength))

## 2. Set Up Waveform Model
We use the RippleIMRPhenomPv2 waveform model with a reference frequency of 20 Hz.

In [None]:
waveform = RippleIMRPhenomPv2(f_ref=20)

## 3. Define Priors
We define priors for mass, spin, inclination, distance, time, phase, polarization, and sky location.

In [None]:
prior = []
# Mass prior
M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
q_prior = UniformPrior(q_min, q_max, parameter_names=["q"])
prior = prior + [Mc_prior, q_prior]
# Spin prior
s1_prior = UniformSpherePrior(parameter_names=["s1"])
s2_prior = UniformSpherePrior(parameter_names=["s2"])
iota_prior = SinePrior(parameter_names=["iota"])
prior = prior + [s1_prior, s2_prior, iota_prior]
# Extrinsic prior
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])
prior = prior + [dL_prior, t_c_prior, phase_c_prior, psi_prior, ra_prior, dec_prior]
prior = CombinePrior(prior)

## 4. Define Transforms
We set up sample and likelihood transforms for parameter mapping and reparameterization.

In [None]:
sample_transforms = [
    DistanceToSNRWeightedDistanceTransform(
        gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax
    ),
    GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
    GeocentricArrivalTimeToDetectorArrivalTimeTransform(
        tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]
    ),
    SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
    BoundToUnbound(
        name_mapping=(["M_c"], ["M_c_unbounded"]),
        original_lower_bound=M_c_min,
        original_upper_bound=M_c_max,
    ),
    BoundToUnbound(
        name_mapping=(["q"], ["q_unbounded"]),
        original_lower_bound=q_min,
        original_upper_bound=q_max,
    ),
    BoundToUnbound(
        name_mapping=(["s1_phi"], ["s1_phi_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=2 * jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["s2_phi"], ["s2_phi_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=2 * jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["iota"], ["iota_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["s1_theta"], ["s1_theta_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["s2_theta"], ["s2_theta_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["s1_mag"], ["s1_mag_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=0.99,
    ),
    BoundToUnbound(
        name_mapping=(["s2_mag"], ["s2_mag_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=0.99,
    ),
    BoundToUnbound(
        name_mapping=(["phase_det"], ["phase_det_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=2 * jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["psi"], ["psi_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["zenith"], ["zenith_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=jnp.pi,
    ),
    BoundToUnbound(
        name_mapping=(["azimuth"], ["azimuth_unbounded"]),
        original_lower_bound=0.0,
        original_upper_bound=2 * jnp.pi,
    ),
]
likelihood_transforms = [
    MassRatioToSymmetricMassRatioTransform,
    SphereSpinToCartesianSpinTransform("s1"),
    SphereSpinToCartesianSpinTransform("s2"),
]

## 5. Construct Likelihood and Run Sampler
We construct the frequency-domain likelihood and run the Jim sampler.

In [None]:
likelihood = BaseTransientLikelihoodFD(
    ifos,
    waveform=waveform,
    trigger_time=gps,
    f_min=fmin,
    f_max=fmax,
)
jim = Jim(
    likelihood,
    prior,
    sample_transforms=sample_transforms,
    likelihood_transforms=likelihood_transforms,
    n_chains=500,
    n_local_steps=100,
    n_global_steps=1000,
    n_training_loops=20,
    n_production_loops=10,
    n_epochs=20,
    mala_step_size=2e-3,
    rq_spline_hidden_units=[128, 128],
    rq_spline_n_bins=10,
    rq_spline_n_layers=8,
    learning_rate=1e-3,
    batch_size=10000,
    n_max_examples=30000,
    n_NFproposal_batch_size=100,
    local_thinning=1,
    global_thinning=100,
    history_window=200,
    n_temperatures=0,
    max_temperature=20.0,
    n_tempered_steps=10,
    verbose=True,
)
jim.sample()
print("Done!")

## 6. Analyze Results
We extract the log-probabilities and samples, and plot the posterior distributions using corner plots.

In [None]:
logprob = jim.sampler.resources["log_prob_production"].data
print(jnp.mean(logprob))
chains = jim.get_samples()
import numpy as np
import corner
fig = corner.corner(
    np.stack([chains[key] for key in jim.prior.parameter_names]).T[::10]
)
fig.savefig("test")