In [None]:
path_to_cogwheel = '..'
import sys
sys.path.append(path_to_cogwheel)

import lal
lal.swig_redirect_standard_output_error(False)  # Or LAL may run slowly on notebooks

# Making your own prior

`cogwheel` comes with a set of default priors but also allows you to define new ones. In this tutorial we will make a prior that is flat in detector-frame chirp mass and mass ratio, and otherwise follows the "LVC prior".

For distributions with many parameters, a prior can be defined modularly as a combination of "subpriors". Schematically, for a prior
$$p(x, y) = p(x) p(y|x)$$
we can define $p(x)$ and $p(y|x)$ separately and generate $p(x, y)$ automatically from these. This makes them easier to recycle, say if we want to use different priors for $x$. Since in this example we just want to change the prior for the masses, we will define a subprior using $m_1, m_2$ as standard parameters, and then combine it with other pre-defined priors for the remaining variables.

Here is the full code for the subprior, and a detailed explanation follows:

In [None]:
import numpy as np
from cogwheel import prior


class UniformMchirpQPrior(prior.Prior):
    """
    Prior for the detector-frame masses that is uniform in
    detector-frame chirp mass and mass ratio q < 1, with an
    optional lower bound in the secondary detector-frame mass m2.
    Sampled parameters are (mchirp, q) and standard parameters are
    (m1, m2).
    """
    standard_params = ['m1', 'm2']
    range_dic = {'mchirp': None,
                 'q': None}
    
    def __init__(self, *, mchirp_range, q_min=0.05, m2_min=0, **kwargs):
        """
        Parameters
        ----------
        mchirp_range: (float, float)
            Range of detector-frame chirp mass to explore (Msun).
        
        q_min: float
            Minimum mass ratio to explore.
        
        m2_min: float
            Minimum detector-frame secondary mass allowed (Msun).
        """
        self.range_dic = {'mchirp': mchirp_range,
                          'q': (q_min, 1)}
        self.m2_min = m2_min

        super().__init__(mchirp_range=mchirp_range, q_min=q_min, m2_min=m2_min,
                         **kwargs)
    
    def transform(self, mchirp, q):
        """Sampled params to standard params."""
        m1 = mchirp * (1 + q)**.2 / q**.6
        return {'m1': m1,
                'm2': q * m1}
    
    def inverse_transform(self, m1, m2):
        """Standard params to sampled params."""
        return {'mchirp': (m1*m2)**.6 / (m1+m2)**.2,
                'q': m2 / m1}
    
    def get_init_dict(self):
        """Return dict with keyword arguments to reproduce instance."""
        return {'mchirp_range': self.range_dic['mchirp'],
                'q_min': self.range_dic['q'][0],
                'm2_min': self.m2_min}
    
    def lnprior(self, mchirp, q):
        """Natural log of prior density in the space of sampled params."""
        m2 = self.transform(mchirp, q)['m2']
        if m2 < self.m2_min:
            return - np.inf
        return 0

## Explanation
Priors take the form of classes that inherit from `prior.Prior`:

    class UniformMchirpQPrior(prior.Prior):

A prior class can use any coordinate system for sampling but should also provide a transformation to "standard" coordinates (`.standard_params`). In this way, likelihood objects can later be used in combination with priors without a need to redefine them every time we want to change coordinate system.
The convention for which are the "standard" parameters in `coghweel` is defined in `waveform.WaveformGenerator.params`. Namely:
`d_luminosity, dec, f_ref, iota, l1, l2, m1, m2, phi_ref, psi, ra, s1x_n, s1y_n, s1z, s2x_n, s2y_n, s2z, t_geocenter` (or a subset of them, in the case of marginalized likelihoods).

In the present example, the relevant standard parameters that our prior should take care of are the detector frame masses `m1, m2`:

        standard_params = ['m1', 'm2']

However, we will use a more convenient coordinate system to parameterize these variables, namely, chirp mass and mass ratio ("sampled parameters").
To implement this, we define an attribute `range_dic` with the sampled parameter names and their ranges. In this case, the range is not known in advance, since different events may require different ranges in order to sample efficiently around the region of posterior support. Thus, we leave it as `None`:

        range_dic = {'mchirp': None,
                     'q': None}

The `range_dic` will have to be overwritten with the actual ranges when instantiating the class (a new instance of the class is created every time we analyze an event). We'll do this in the `__init__`, see below.

> *Note:* if the range had been known in advance, say $\psi \in (0, \pi)$, we would have done `range_dic = {'psi': (0, np.pi)}` instead, and `range_dic` would have remained a class attribute (shared by all instances).

Once we have defined the standard and sampled parameters, we should define methods `.transform` and `.inverse_transform` to go between them. To avoid confusion when combining priors, these must return dictionaries where the keys are parameter names:

        def transform(self, mchirp, q):
            """Sampled params to standard params."""
            m1 = mchirp * (1 + q)**.2 / q**.6
            return {'m1': m1,
                    'm2': q * m1}

        def inverse_transform(self, m1, m2):
            """Standard params to sampled params."""
            return {'mchirp': (m1*m2)**.6 / (m1+m2)**.2,
                    'q': m2 / m1}

Additionally, we define a method `lnprior` that takes sampled parameters and returns the prior density in that space. In this example, we take the density uniform in $\mathcal{M}, q$ provided that $m_2$ exceeds some value $m_2^{\rm min}$:

        def lnprior(self, mchirp, q):
            """Natural log of prior density in the space of sampled params."""
            m2 = self.transform(mchirp, q)['m2']
            if m2 < self.m2_min:
                return - np.inf
            return 0

> *Note:* Even if the prior did not depend on all sampled parameters, the method `lnprior` must have them in the signature.

Finally, let us implement the sampled parameter ranges and the cutoff `m2_min`. Since in this case these vary event by event, they will need to be passed when instantiating the class, so we must override the `__init__` method. 

        def __init__(self, *, mchirp_range, q_min=0.05, m2_min=0, **kwargs):
            self.range_dic = {'mchirp': mchirp_range,
                              'q': (q_min, 1)}
            self.m2_min = m2_min

            super().__init__(mchirp_range=mchirp_range, q_min=q_min, m2_min=m2_min,
                             **kwargs)

A few comments about overriding `Prior.__init__` are in order:

* It is recommended that the parameters to `__init__` are enforced to be keyword-only (note the `*` after `self` in the signature), because we will later combine priors and likely lose track of the overall order of the parameters.

* The `__init__` method of `Prior` subclasses must accept `**kwargs`, this allows to easily combine them later.

* It is important that we have overriden the `range_dic` attribute. Note that we did
    ```
        self.range_dic = {'mchirp': mchirp_range,
                          'q': (q_min, 1)}
    ```
    and not 
    ```
        self.range_dic['mchirp'] = mchirp_range
        self.range_dic['q'] = (q_min, 1)
    ```
    In the latter case, `range_dic` would have remained a class attribute, and if we constructed multiple instances simultaneously, e.g. corresponding to different events, they would wrongly have shared the `range_dic` attribute. By overriding `range_dic`, it becomes an instance attribute as desired.

* `Prior` subclasses must call `super().__init__` forwarding all arguments by keyword. This facilitates multiple inheritance. It is important that `super().__init__` is called after setting `range_dic`.

* Whenever we override `Prior.__init__`, we must also override the method `.get_init_dict`, that returns a dictionary with keyword arguments to reproduce the class instance:

        def get_init_dict(self):
            """Return dict with keyword arguments to reproduce instance."""
            return {'mchirp_range': self.range_dic['mchirp'],
                    'q_min': self.range_dic['q'][0],
                    'm2_min': self.m2_min}

This concludes the definition of our subprior.


## Summary

In order to define a new prior, make a subclass of `prior.Prior` and define the following attributes and methods:

* `standard_params`
* `range_dic`
* `transform`
* `inverse_transform`
* `lnprior`

Additionally, if you require event-dependent information in order to specify your prior, override the following methods making sure to follow the guidelines above:

* `__init__`
* `get_init_dict`


## Combining priors

Now that we have defined our custom prior for the masses, we can combine it with priors for other parameters in order to make a full prior for all 18 standard parameters. This is easily done by subclassing `prior.CombinedPrior` (you can take a look at `cogwheel/gw_prior/combined.py` to find examples).

In [None]:
from cogwheel.gw_prior.combined import *


class UniformMchirpQIsotropicSpinPrior(RegisteredPriorMixin,
                                       prior.CombinedPrior):
    """
    Prior for the full parameter space that is flat in detector
    frame chirp mass and mass ratio, and otherwise follows the
    LVC prior.
    """
    default_likelihood_class = RelativeBinningLikelihood

    prior_classes = [
        UniformMchirpQPrior,
        FixedReferenceFrequencyPrior,
        IsotropicSpinsAlignedComponentsPrior,
        UniformPolarizationPrior,
        IsotropicSpinsInplaneComponentsIsotropicInclinationSkyLocationPrior,
        UniformTimePrior,
        UniformPhasePrior,
        UniformLuminosityVolumePrior,
        ZeroTidalDeformabilityPrior]

We just defined two attributes, `prior_classes` and (optionally) `default_likelihood_class`.
Based on the `prior_classes`, the `CombinedPrior` automatically constructs all the attributes and methods necessary for a prior (`standard_params`, `range_dic`, `transform`, `inverse_transform`, `lnprior`). For example:

In [None]:
UniformMchirpQIsotropicSpinPrior.range_dic

In [None]:
UniformMchirpQIsotropicSpinPrior.standard_params

The `default_likelihood_class` attribute should be a subclass of `likelihood.CBCLikelihood` whose `.params` attribute matches `.standard_params`. This is useful since marginalized likelihood classes have different standard parameters (e.g., if the distance is marginalized then it is not a parameter) and therefore they have to be used with a different prior class.

Finally, note that we inherited from `RegisteredPriorMixin`. This accomplishes two things that allow to use this prior with `posterior.Posterior.from_event`:

* It registers the prior in `cogwheel.gw_prior.combined.prior_registry`

    ```
    assert 'UniformMchirpQIsotropicSpinPrior' in prior_registry
    ```

* It adds a method `from_reference_waveform_finder`.

Indeed, we can now do

In [None]:
from cogwheel import data
from cogwheel.posterior import Posterior


eventname = 'GW150914'
post = Posterior.from_event(eventname,
                            data.EVENTS_METADATA['mchirp'][eventname],
                            'IMRPhenomXPHM',
                            'UniformMchirpQIsotropicSpinPrior',
                            prior_kwargs={'m2_min': 5})

and use `post` to sample the posterior distribution.

Note that the following should always hold:

In [None]:
assert set(post.prior.standard_params) == set(post.likelihood.params)

In [None]:
standard_par_dic = post.likelihood.par_dic_0  # For example
sampled_par_dic = post.prior.inverse_transform(**standard_par_dic)
standard_par_dic_2 = post.prior.transform(**sampled_par_dic)

for par in post.prior.standard_params:
    assert np.isclose(standard_par_dic[par],
                      standard_par_dic_2[par])

For quick-and-dirty usage, you can define your custom prior classes in a notebook like we just did.
For a tidier experience, you can write the bits of code in a new module, e.g. `gw_prior/custom.py` and import it.

## Advanced usage

### Conditioned-on parameters

Our `UniformMchirpQPrior` example was of the form $p(\mathcal{M}, q)$, not conditioned on any other parameters. If either the prior or the transformation to standard parameters is conditioned on other parameters, you can do this by defining a `.conditioned_on` attribute, which must be a list of standard parameters (it is an empty list by default).

For example, take a look at the code for `cogwheel.gw_prior.spin.UniformEffectiveSpinPrior`. It is conditioned on `m1, m2`, since the transform from sampled parameters `chieff, cumchidiff` to standard parameters `s1z, s2z` depends on `m1, m2`.

In that case, the signature of the `lnprior`, `transform` and `inverse_transform` methods should be modified by adding the conditioned-on parameters at the end. Note that the signature of `lnprior` and `transform` is always of the form `(*sampled_params, *conditioned_on)`, and that of `inverse_transform` is of the form `(*standard_params, *conditioned_on)`.

The `CombinedPrior` class can handle this case as long as all subpriors in `prior_classes` are only conditioned on standard parameters already covered by the preceding subpriors (as in an autoregressive flow).

### Common cases

Some common use cases are implemented for convenience:

#### Fixing a parameter
Use `prior.FixedPrior` if you would like to fix a standard parameter rather than sample it.

Usage: subclass `prior.FixedPrior` instead of `prior.Prior`, and define an attribute `standard_par_dic`. This will automatically generate `standard_params`, `range_dic`, `lnprior`, `transform` and `inverse_transform`.

See e.g. the code for `gw_prior.miscellaneous.ZeroTidalDeformabilityPrior` or `gw_prior.miscellaneous.FixedReferenceFrequencyPrior` for examples in which you know or don't know the value of the fixed parameters in advance, respectively.

#### Uniform prior
Use `prior.UniformPriorMixin` if your prior has uniform density in the space of sampled parameters.

Usage: subclass `prior.UniformPriorMixin` and `prior.Prior` and define `standard_params`, `range_dic`, `transform` and `inverse_transform`. This will automatically generate `lnprior`.

See e.g. `gw_prior.extrinsic.IsotropicInclinationPrior` for an example.

#### Identity transform

Use `prior.IdentityTransformMixin` if a sampled parameter is also a standard parameter.

Usage: subclass `prior.IdentityTransformMixin` and `prior.Prior` and define `range_dic` and `lnprior`. This will automatically generate `standard_params`, `transform` and `inverse_transform`.

`IdentityTransformMixin` can be used together with `UniformPriorMixin`, see `gw_prior.extrinsic.UniformPolarizationPrior` for an example.

### Folded parameters

If you know in advance that a sampled parameter may exhibit an approximate discrete symmetry, you can ease the sampling by [folding](https://arxiv.org/pdf/2207.03508.pdf#section*.15) the distribution along that parameter(s).
Define an attribute `folded_reflected_params` and/or `folded_shifted_params` that is a list of sampled parameters along which you want to fold the distribution. `folded_reflected_params` will be reflected with respect to the center of the parameter range, i.e. $(0, 0.5, 1)$ gets mapped to $(0, 0.5, 0)$. `folded_shifted_params` will be rigidly shifted with respect to the center of the parameter range, i.e. $(0, 0.5^-, 0.5^+, 1) \to (0, 0.5, 0, 0.5)$.

### Periodic/reflected parameters

Some samplers allow periodic/reflecting boundary conditions. Define an attribute `periodic_params` and/or `reflected_params` with a list of sampled parameters in which you want to apply those boundary conditions.