# Parameter estimation with `cogwheel`

In this exercise we will learn to:

* Estimate the chirp mass of a signal from its frequency evolution
* Use the software `cogwheel` to:
    - Generate a synthetic event ("injection") in Gaussian noise
    - Plot whitened data as a spectrogram or time series
    - Find a good-fit waveform
    - Sample the posterior distribution using nested sampling
    - Corner-plot the posterior distribution

For reference, here are `cogwheel`'s [documentation](https://cogwheel.readthedocs.io/en/latest/index.html) and [source code (GitHub)](https://github.com/jroulet/cogwheel). You can find additional tutorials in the GitHub repository.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cogwheel.cosmology
import cogwheel.data
import cogwheel.posterior
import cogwheel.sampling
import cogwheel.gw_plotting
import cogwheel.gw_utils

import lal

%matplotlib widget

# Only needed on Apple M4 with numpy<2.3 https://github.com/numpy/numpy/issues/28687
# (still, ok on other systems)
import warnings
warnings.simplefilter("ignore", RuntimeWarning)

## Simulate a new detected event

To simulate a detection where we don't know the parameters in advance, we will generate source parameters randomly, generate the associated signal and inject it on Gaussian noise colored according to example LIGO-Virgo PSDs.

This is done in the cell below (it's not super important to follow every step).

In [None]:
def _generate_injection_dic(seed):
    aux_prior = cogwheel.gw_prior.IASPrior(
        f_ref=100.0,
        mchirp_range=(10, 50),
        detector_pair='HL',
        tgps=0,
        ref_det_name='H',
        f_avg=100.0,
        d_hat_max=100.0,
        dt0=10,
    )
    injection_dic = dict(
        aux_prior.generate_random_samples(1, seed=seed).loc[0, aux_prior.standard_params])
    return injection_dic


def get_event_data(seed=0):
    """Return an EventData with a secret injection in it."""
    injection_dic = _generate_injection_dic(seed)

    asd_funcs = list(cogwheel.data.ASDS)

    eventname = f'GW{seed}'
    event_data = cogwheel.data.EventData.gaussian_noise(
        eventname, duration=128.0, detector_names='HLV', fmax=512.0,
        asd_funcs=asd_funcs, tgps=0.0, seed=seed)

    event_data.inject_signal(par_dic=injection_dic, approximant='IMRPhenomXODE')

    return event_data


event_data = get_event_data()

Let's plot a spectrogram of the whitened data:

In [None]:
xlim = (-30, 30)  # Edit this
event_data.specgram(xlim=xlim);

A spectrogram is a time-frequency representation of the data.
The horizontal axis is time, the vertical axis frequency, and the color power.
You can think of the columns of a spectrogram as short-time PSDs.
The panels of the plot show different detectors.

* Find the event, estimate the merger time (hint: adjust `xlim` to best display the event)

In [None]:
t_merger_guess = 

`cogwheel` uses an algorithm called relative-binning to speed up likelihood computations.
This method requires a reference waveform that is a good fit to the signal.
We will obtain this well-fitting waveform by maximizing the likelihood, but this requires an initial guess for the mass.

Can we guess the mass of the signal by inspecting the data?
Recall that the time to merger depends on the chirp-mass and the frequency as
$$
    f^{-8/3} = \frac{(8\pi)^{8/3}}{5} \left(\frac{G \mathcal{M}}{c^3}\right)^{5/3} (t_\mathrm{merger} - t)
$$

* Which stays longer in the detector frequency band: a heavy or a light signal?
* From the spectrogram, estimate the time to merger as a function of frequency and from there guess the (detector frame, i.e. redshifted) chirp mass.

Another plot we can do is that of the whitened data:

In [None]:
whitened_td = event_data.get_whitened_td()

In [None]:
fig, axs = plt.subplots(len(event_data.detector_names), sharex=True, sharey=True)
for det_name, ax, wht_strain in zip(event_data.detector_names, axs, whitened_td):
    ax.plot(event_data.times - event_data.tcoarse, wht_strain)
    ax.grid()
    ax.set_title(det_name, loc='left')

ax.set_xlim(xlim)
fig.supxlabel('Time (s)')
fig.supylabel('Whitened data')
fig.suptitle(event_data.eventname);

In [None]:
# The same but just one detector
i_det = 0
det_name = event_data.detector_names[i_det]
wht_strain = whitened_td[i_det]
plt.figure()
plt.plot(event_data.times - event_data.tcoarse, wht_strain)
plt.grid()
plt.title(det_name, loc='left')

plt.xlim(xlim)
plt.xlabel('Time (s)')
plt.ylabel('Whitened data')
plt.title(event_data.eventname);

We can get another crude estimate of the chirp mass as follows (inspired by https://arxiv.org/pdf/1608.01940 Sec 2)
* Estimate the frequency of the wave as a function of time
* Estimate the slope of $f^{-8/3}(t)$.
* Estimate the chirp mass
> *Note:* in the code we may work assuming units of solar masses or Hertz, so actually
> \begin{align}
      \texttt{mchirp} &= \frac{\mathcal{M}}{M_\odot} \\
      \texttt{freq} &= \frac{f}{\rm Hz}
  \end{align}
> A handy conversion factor is
> $$
     \frac{G M_\odot}{c^3 \, \text{Hz}} = 4.9 \cdot 10^{-6},
  $$
> provided in LAL (LIGO-Virgo-KAGRA Algorithm Library) as `lal.MTSUN_SI`:


In [None]:
lal.MTSUN_SI

### Example solution (but try yourself, many solution strategies possible!)

Find zero-crossings of the whitened data

In [None]:
i_det = 0
xlim = 1.45, 1.63

det_name = event_data.detector_names[i_det]
wht_strain = whitened_td[i_det]

plt.figure()
plt.plot(event_data.times - event_data.tcoarse, wht_strain)
plt.grid()
plt.title(det_name, loc='left')

plt.xlim(xlim)
plt.xlabel('Time (s)')
plt.ylabel('Whitened data')
plt.title(event_data.eventname)

ascending_zero_crossings = 1.513, 1.53, 1.545, 1.555, 1.565, 1.573, 1.581, 1.588, 1.594
plt.scatter(ascending_zero_crossings, [0] * len(ascending_zero_crossings), c='r', marker='x')

Find frequencies from these

In [None]:
periods = np.diff(ascending_zero_crossings)
times = np.add(ascending_zero_crossings[:-1], ascending_zero_crossings[1:]) / 2
frequencies = 1 / periods

Linear fit $f^{-8/3}$ vs $t$:

In [None]:
y = frequencies**(-8/3)
x = times

# Perform linear fit
slope, const = np.polyfit(times, y, deg=1)

# Generate fitted line for plotting
x_fit = np.linspace(x.min(), x.max(), 500)
y_fit = slope * x_fit + const

# Plot
plt.figure()
plt.scatter(x, y, label='Data')
plt.plot(x_fit, y_fit, 'r--', label=rf'Fit: $y = {slope:.3g} x + {const:.3g}$')
plt.xlabel("Time")
plt.ylabel("Frequency$^{-8/3}$")
plt.legend()
plt.show()

In [None]:
# slope == (8*np.pi)**(8/3) / 5 * (lal.MTSUN_SI * mchirp)**(5/3)
mchirp_guess = (-slope / (8*np.pi)**(8/3) * 5)**(3/5) / lal.MTSUN_SI
print(mchirp_guess)

## Find a well-fitting waveform
Armed with our guess (which normally one would get from the search pipeline...), we can now do a fast likelihood maximization to find our reference waveform:

In [None]:
approximant = 'IMRPhenomXAS'
prior_class = 'IntrinsicAlignedSpinIASPrior'

posterior = cogwheel.posterior.Posterior.from_event(
    event_data,
    mchirp_guess,
    approximant,
    prior_class,
    ref_wf_finder_kwargs={
        'f_ref': 100.0,  # Just so it matches the injection and it makes sense to compare parameters
        'time_range': (t_merger_guess - 0.1, t_merger_guess + 0.1)  # Edit if needed
    }
)

Now we have a (crude) best-fit waveform, which is stored in `posterior.likelihood.par_dic_0`.
Only the most important parameters have been optimized and some approximations have been made, but if everything went well we should hve a decent fit to the data.

In [None]:
posterior.likelihood.par_dic_0

* How does this fit compare to your initial guess? Recall $\mathcal{M} = (m_1 m_2)^{5/6} / (m_1 + m_2)^{1/6}$

We can inspect whether we have a decent fit:

In [None]:
posterior.likelihood.plot_whitened_wf(posterior.likelihood.par_dic_0, trng=xlim);

## Sample the posterior

### `cogwheel` basics
The `posterior` object contains a `posterior.prior` and a `posterior.likelihood`. The posterior and the prior are distributions over the space of parameters.
We will use two sytems of coordinates: **sampled parameters** $\vartheta$ and **standard parameters** $\theta$.
* **Sampled parameters** are intended to remove correlations from the posterior. We express the posterior in these coordinates to ease sampling.
* **Standard parameters** are of astrophysical interest, and are understood by waveform modeling libraries (and fellow astrophysicists). The likelihood class uses standard parameters, so that we don't need to redefine it if we want to switch sampling coordinates.

The transformations between them are provided by the `posterior.prior` object:
* $\texttt{prior.sampled\_params} = \vartheta$ names
* $\texttt{prior.standard\_params} = \theta$ names
* $\texttt{prior.transform} : \vartheta \to \theta$
* $\texttt{prior.inverse\_transform} : \theta \to \vartheta$

The methods `posterior.lnposterior`, `posterior.prior.lnprior` and `posterior.likelihood.lnlike` are related by:
$$
    \mathcal{P}(\vartheta \mid d) = \pi(\vartheta) \mathcal{L}(d \mid \theta).
$$

### Extrinsic-parameter marginalization

Extrinsic parameters have a known functional form in the prior and likelihood, which allows to marginalize them semianalytically.
This reduces the dimensionality of the parameter space and speeds up sampling. To use this feature, we simply choose a prior for the intrinsic parameters (`'IntrinsicAlignedSpinIASPrior'`), and `cogwheel` understands that we want to use a marginalized likelihood $\overline{\mathcal{L}}(\theta_\mathrm{int} \mid d)$.

* Think of all the parameters that characterize a GW source. What do you expect the standard parameters to be in this case?
* Check in the `posterior.prior` object: what are the sampled parameters and standard parameters? Does this make sense?

In [None]:
sampler = cogwheel.sampling.Nautilus(posterior)
# These trade off quality and speed:
sampler.run_kwargs['n_live'] = 1000
sampler.run_kwargs['n_eff'] = 2000

In [None]:
%%time
rundir = sampler.get_rundir(parentdir='pe_runs')
sampler.run(rundir)  # Will take a bit

Let's load the samples. They are in the form of a dataframe (table) where the columns are parameters of the source, and the rows are samples from the posterior.

> **Note:** Some samplers like [nautilus](https://nautilus-sampler.readthedocs.io/en/latest/) produce *weighted* posterior samples. Notice that the dataframe contains a column `"weights"`. For example, to make a histogram you would do
> ```python
> plt.hist(samples['m1'], weights=samples['weights'])
> ```

In [None]:
samples = pd.read_feather(rundir/'samples.feather')
samples

In [None]:
list(samples)

DataFrames allow to easily add columns. Let's add some derived quantities:

In [None]:
def add_derived_quantities(samples):
    """
    Add columns inplace to a dataframe of samples.

    Includes redshift, mtot, m1_source, m2_source, mtot_source, chieff, q.
    """
    samples['redshift'] = cogwheel.cosmology.z_of_d_luminosity(samples['d_luminosity'])

    samples['mtot'] = samples['m1'] + samples['m2']

    for mass_key in 'm1', 'm2', 'mtot':
        samples[f'{mass_key}_source'] = samples[mass_key] / (1 + samples['redshift'])

    samples['chieff'] = cogwheel.gw_utils.chieff(**samples[['m1', 'm2', 's1z', 's2z']])

    samples['q'] = samples['m2'] / samples['m1']


add_derived_quantities(samples)

In [None]:
plot_params = [
    # 'mchirp',
    # 'lnq',
    # 'chieff',
    # 'cumchidiff',
    # 'weights',
    'm1',
    'm2',
    's1z',
    's2z',
    # 'l1',
    # 'l2',
    # 'f_ref',
    'd_luminosity',
    'ra',
    'dec',
    # 'lon',
    'phi_ref',
    'psi',
    'iota',
    't_geocenter',
    # 'lnl_marginalized',
    'lnl',
    # 'h_h',
    # 'n_effective',
    # 'n_qmc',
]

If you haven't before, it is good to spend some time staring at a corner plot of the posterior.

In [None]:
# Note: cogwheel.gw_plotting.CornerPlot understands the weights automatically
corner_plot = cogwheel.gw_plotting.CornerPlot(
    samples, params=plot_params, tail_probability=1e-4
)
corner_plot.plot(title=event_data.eventname)

# Let's also reveal the true (injected) value
corner_plot.scatter_points(event_data.injection['par_dic'], colors=['C3'], s=150,
                           zorder=2, marker='+', adjust_lims=True);

A corner plot shows 1d histograms in the diagonal, and 2d histograms of every parameter against every other off-diagonal.

* What parameters are correlated? Can you think of reasons why these correlations arise?
* Look at the log likelihood-ratio of the samples. Can you estimate the SNR of the signal?
* Are any parameters discrepant with the injection? Is this expected?

One subtlety is that the injection was made with the waveform model (approximant) `IMRPhenomXODE`, and the parameter inference with `IMRPhenomXAS` (which parameters should have the same physical meaning in both cases? which not?).


## Extra


`IMRPhenomXAS` only describes the quadrupolar radiation and assumes the spins are aligned.
This helped simplify some computations and reduced the dimensionality of the parameter space, but we can get more accurate results.

Let's now use `IMRPhenomXODE`, a more refined model that includes precession and higher-order harmonics, to do the inference.

In [None]:
posterior_xode = cogwheel.posterior.Posterior.from_event(
    event_data,
    mchirp_guess,
    approximant='IMRPhenomXODE',
    prior_class='IntrinsicIASPrior',
    ref_wf_finder_kwargs={
        'f_ref': 100.0,  # Just so it matches the injection and it makes sense to compare parameters
        'time_range': (t_merger_guess - 0.1, t_merger_guess + 0.1)  # Edit if needed
    }
)

* Note that we use a different prior class because now we have generic spins. Which are the sampled and standard parameters?

In [None]:
# Check that the fit makes sense
posterior_xode.likelihood.plot_whitened_wf(posterior_xode.likelihood.par_dic_0)
plt.xlim(xlim);

In [None]:
sampler_xode = cogwheel.sampling.Nautilus(posterior_xode)
sampler_xode.run_kwargs['n_live'] = 1000
sampler_xode.run_kwargs['n_eff'] = 2000

In [None]:
rundir_xode = sampler_xode.get_rundir(parentdir='pe_runs')

In [None]:
%%time
sampler_xode.run(rundir_xode)  # Will take a while

In [None]:
samples_xode = pd.read_feather(rundir_xode/'samples.feather')
add_derived_quantities(samples_xode)

In [None]:
corner_plot = cogwheel.gw_plotting.CornerPlot(samples_xode, params=plot_params, tail_probability=1e-4)
corner_plot.plot(title=event_data.eventname)

corner_plot.scatter_points(event_data.injection['par_dic'], colors=['C3'], s=150,
                           zorder=2, marker='+', adjust_lims=True);

Let's compare the two runs:

In [None]:
multi_corner_plot = cogwheel.gw_plotting.MultiCornerPlot(
    {'IMRPhenomXAS': samples,
     'IMRPhenomXODE': samples_xode}
    , params=plot_params, tail_probability=1e-4)
multi_corner_plot.plot(title=event_data.eventname)

multi_corner_plot.scatter_points(event_data.injection['par_dic'], colors=['k'], s=150,
                                 zorder=2, marker='+', adjust_lims=True);

The difference between these two runs is that IMRPhenomXODE includes the effects of misaligned spins (precession) and harmonic modes $(\ell, |m|)$ other than the quadrupole $(2, 2)$.
* Which parameters are affected the most, and the least? Does this make sense to you?
* Which model best fits the data?

In [None]:
sampler_xode.load_evidence()

In [None]:
sampler.load_evidence()

In [None]:
best_fit = samples_xode.iloc[samples_xode['lnl'].idxmax()]
fig = sampler_xode.posterior.likelihood.plot_whitened_wf(best_fit, trng=(-.4, .2))
fig.suptitle(event_data.eventname)
plt.xlim(xlim)

## Extra 2

Now you can infer the parameters of a real event!

This is how you would obtain the `EventData` object in `cogwheel`:

```python
eventname = 'GW190412'
filenames, detector_names, tgps = cogwheel.data.download_timeseries(eventname)
event_data = cogwheel.data.EventData.from_timeseries(
    filenames, eventname, detector_names, tgps)
```

GW190412 [https://arxiv.org/pdf/2004.08342] is a real-world example of the effects of adding higher modes on parameter inference. Can you reproduce the conclusion of Fig. 4 of that paper?