# Extrinsic marginalization with precession and higher modes

This tutorial illustrates details about the `cogwheel` objects that take care of computing a likelihood marginalized over extrinsic parameters.

If you just want to run parameter estimation you don't need to deal with these, see the [`factorized_phm.ipynb`](https://github.com/jroulet/cogwheel/blob/main/tutorials/factorized_phm.ipynb) tutorial instead.

The relevant subpackage is `cogwheel.likelihood.marginalization`.

Reference: https://arxiv.org/abs/2404.02435

In [None]:
# Ensure only one CPU is used:
import os
os.environ["OMP_NUM_THREADS"] = "1"

import sys
sys.path.append('..')

import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

import lal
lal.swig_redirect_standard_output_error(False)


from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from cogwheel import posterior
from cogwheel import sampling
from cogwheel import data
from cogwheel import likelihood
from cogwheel import gw_plotting

%matplotlib widget

In [None]:
eventname = 'GW190412'

In [None]:
# Automatically instantiate a Posterior object.
post = posterior.Posterior.from_event(
    eventname,
    data.EVENTS_METADATA['mchirp'][eventname],
    'IMRPhenomXPHM',
    'CartesianIntrinsicIASPrior')

## `MarginalizedExtrinsicLikelihood`

Takes care of computing matched-filtering time series `d_h` and covariances `h_h` that are the input to the marginalization.

In [None]:
# Since we used an "Intrinsic" prior, it figured out that we want a MarginalizedExtrinsicLikelihood:
like = post.likelihood
like.__class__

In [None]:
d_h, h_h, timeshift = like._get_dh_hh_timeshift(like.par_dic_0)

# We apply a small time shift to align the waveform to the relative binning reference
times = like._times - timeshift

In [None]:
d_h.shape  # modes, polarizations, times, detectors

In [None]:
h_h.shape  # mode pairs, polarizations, polarizations', detectors

In [None]:
times.shape

## `CoherentScoreHM`
Computes the marginalized likelihood given the timeseries and covariances.

Generates extrinsic-parameter samples.

In [None]:
coherent_score = post.likelihood.coherent_score
coherent_score.__class__

In [None]:
coherent_score.min_n_effective  # Convergence criterion

In [None]:
2**coherent_score.log2n_qmc  # Samples per proposal

In [None]:
# Compute the marginalized likelihood:
marg_info = coherent_score.get_marginalization_info(d_h, h_h, times)
marg_info.lnl_marginalized

In [None]:
# Generate extrinsic parameter samples
pd.DataFrame(coherent_score.gen_samples_from_marg_info(marg_info, num=10))

## `MarginalizationInfoHM`
Contains data products associated with a single likelihood marginalization.

In [None]:
marg_info.__class__

In [None]:
marg_info.lnl_marginalized  # Final answer

In [None]:
marg_info.n_effective  # Effective number of samples achieved

In [None]:
len(marg_info.proposals)  # The importance sampling integral required these many adaptations to converge

In [None]:
marg_info.proposals[0].shape  # detectors, upsampled times

# Note that times have been upsampled to match the resolution of the sky dictionary (see below)

_, upsampled_times = coherent_score.sky_dict.resample_timeseries(np.zeros_like(times), times)

In [None]:
# Plot a proposal
j_proposal = 0
plt.figure()
plt.plot(upsampled_times, marg_info.proposals[j_proposal].T, label=coherent_score.sky_dict.detector_names)
plt.legend()
plt.ylim(0)
plt.xlabel('Time (s)')
plt.ylabel(rf'Detector time of arrival proposal, $P_d^{{({j_proposal})}}(\tau)$')

## `SkyDictionary`
Maps discrete time delays to arrival directions.

Computes prior of time delays.

In [None]:
sky_dict = coherent_score.sky_dict
sky_dict.__class__

In [None]:
pd.DataFrame(sky_dict.sky_samples)  # Quasi Monte Carlo sequence of isotropic samples in the sky

In [None]:
sky_dict.f_sampling  # The time axis is discretized at this resolution (Hz) := 1/Delta

In [None]:
sky_dict.detector_names

In [None]:
list(sky_dict.delays2inds_map.keys())[:10]  # Keys are (HL delay, HV delay) in units of Delta 

In [None]:
sky_dict.delays2inds_map[10, 20]  # Indices of samples with these time delays

In [None]:
sky_dict.get_sky_inds_and_prior(delays=np.array(([10], [20])))  # Sky sample index, prior, and whether the delays requested are physical

## `LookupTable`
Marginalizes the likelihood over distance.

In [None]:
coherent_score.lookup_table.__class__

In [None]:
coherent_score.lookup_table.d_luminosity_prior_name

In [None]:
coherent_score.lookup_table.marginalized_params

In [None]:
coherent_score.lookup_table.d_luminosity_max  # Mpc

In [None]:
coherent_score.lookup_table.lnlike_marginalized(marg_info.d_h[0], marg_info.h_h[0])