# Synthetic Data Generation with GWKokab

Welcome to this Python notebook, where we'll learn how to leverage GWKokab to generate synthetic data. Let's get started and discover how GWKokab can help us generate high-quality synthetic data!

In this tutorial we are generating data from `Wysocki2019MassModel`, see eq (7) of [Wysocki et al](https://journals.aps.org/prd/abstract/10.1103/PhysRevD.100.043012), for the primary and secondary mass of the binary system, and Truncated Normal distribution for the eccentricity of the orbit. The models are defined as follows:

$$
    p(m_1,m_2\mid\alpha_m,m_{\text{min}},m_{\text{max}})\propto
    \frac{m_1^{-\alpha}}{m_1-m_{\text{min}}} \qquad \mathcal{N}_{[a,b]}(\epsilon\mid\mu,\sigma^2)\propto\exp\left(-\frac{1}{2}\left(\frac{\epsilon-\mu}{\sigma}\right)^2\right)\mathbb{1}_{[a,b]}(\epsilon)
$$

where the chosen values are $\alpha_m=-1$, $m_{\text{min}}=10M_\odot$, $m_{\text{max}}=50M_\odot$, $a=0$, $b=0.05$, $\mu=0$ and $\sigma=0.05$.

## Environment Variables

GWKokab at its core used JAX for fast computation. To setup the behavior of JAX based on the accelerator you are using, you have to set the environment variables. Here are the environment variables that you can set:

In [1]:
import os


os.environ["NPROC"] = "4"
os.environ["intra_op_parallelism_threads"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Necessary Imports

Let's start by importing the necessary libraries.

In [2]:
from jax import vmap
from jaxtyping import Array, Bool
from numpyro.distributions import TruncatedNormal

from gwkokab.errors import banana_error_m1_m2
from gwkokab.models import Wysocki2019MassModel
from gwkokab.parameters import ECCENTRICITY, PRIMARY_MASS_SOURCE, SECONDARY_MASS_SOURCE
from gwkokab.population import error_magazine, popfactory, popmodel_magazine
from gwkokab.vts.neuralvt import load_model


SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(False)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal
[Loading lalsimutils.py : MonteCarloMarginalization version]
  scipy :  1.14.0
  numpy :  1.26.4


## Constant Parameters

People make typos all the time, and to avoid that, we have predefined the physical parameters in the `gwkokab.parameters` module. They are more than the predefined parameters names, we will see this later.

In [3]:
m1_source = PRIMARY_MASS_SOURCE().name
m2_source = SECONDARY_MASS_SOURCE().name
ecc = ECCENTRICITY().name

print(m1_source, m2_source, ecc)

mass_1_source mass_2_source ecc


## How to define models and errors?

After the release of version 0.0.2, GWKokab has adopted the style of using decorators. This makes it easier to define models and errors. One method is to directly pass parameters and model to the decorator, and the other is to define a function that will return the model and put the decorator on top of it. This style is preferred for complex models.

In [4]:
popmodel_magazine.register(
    (m1_source, m2_source),
    Wysocki2019MassModel(alpha_m=-1.0, mmin=10.0, mmax=50.0),
)

popmodel_magazine.register(
    ecc,
    TruncatedNormal(scale=0.05, loc=0.0, low=0.0, high=0.05, validate_args=True),
)

<numpyro.distributions.truncated.TwoSidedTruncatedDistribution at 0x7f345e765090>

Error functions are also defined in same way. Note that the error functions takes the row of data, size of error and `PRNGKey` in order.

In [5]:
@error_magazine.register((m1_source, m2_source))
def m1m2_error_fn(x, size, key):
    return banana_error_m1_m2(x, size, key, scale_Mc=1.0, scale_eta=1.0)


@error_magazine.register(ecc)
def ecc_error_fn(x, size, key):
    return x + TruncatedNormal(loc=0, scale=0.06, low=0.0, high=0.06).sample(
        key=key, sample_shape=(size,)
    )

## VT Sensitivity and flexibility for users

We have allowed users to define the `logVT` function. This gives them the flexibility to chose any parameter for data model and VT sensitivity, only thing they have to change is the `logVT` function.

In [6]:
!wget -c https://raw.githubusercontent.com/gwkokab/asset-store/main/neural_vts/neural_vt_1_200_1000.eqx

--2024-07-16 02:12:57--  https://raw.githubusercontent.com/gwkokab/asset-store/main/neural_vts/neural_vt_1_200_1000.eqx
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 401749 (392K) [application/octet-stream]
Saving to: ‘neural_vt_1_200_1000.eqx’


2024-07-16 02:12:58 (1.04 MB/s) - ‘neural_vt_1_200_1000.eqx’ saved [401749/401749]



In [7]:
_, logVT = load_model(r"neural_vt_1_200_1000.eqx")
logVT = vmap(logVT)

## Population Generation

Before generating the population, we have to pass the essential parameters to the `gwkokab.population.popfactory` to generate the population. These parameters are passed below.

In [8]:
popfactory.analysis_time = 248
popfactory.rate = 100
popfactory.log_VT_fn = logVT
popfactory.VT_params = [m1_source, m2_source]
popfactory.error_size = 2000

Often we want to put an overall constraint on the population, for example, the primary mass should be greater than the secondary mass. We can do this by defining a function that takes the population as input and returns a boolean value.

In [9]:
def constraint(x: Array) -> Bool:
    m1 = x[..., 0]
    m2 = x[..., 1]
    ecc = x[..., 2]
    mask = m2 <= m1
    mask &= m2 > 0.0
    mask &= m1 > 0.0
    mask &= ecc >= 0.0
    mask &= ecc <= 1.0
    return mask


popfactory.constraint = constraint

With all set up, we can now generate the synthetic data.

In [10]:
popfactory.produce()