In [None]:
import numpy as np
import os

import torch
import numpy as np
import scipy
from chainconsumer import ChainConsumer


from dingo.gw.ASD_dataset.noise_dataset import ASDDataset
from dingo.core.models import PosteriorModel
from dingo.gw.inference import injection
from dingo.gw.inference.gw_samplers import GWSamplerGNPE, GWSampler


In [None]:
# if you have a custom LAL_DATA_PATH otherwise you don't need this
os.environ["LAL_DATA_PATH"] = "/home/local/nihargupte/dingo-devel/venv/lib/python3.9/site-packages/lalsimulation/"

# Insert your time network here
time_network_path = "/home/local/nihargupte/dingo-devel/tutorials/03_aligned_spin/train_dir_SEOBNRv4HM_ROM_O1_2048_lr_time/model_latest.pt"
# Insert your main network here
gnpe_network_path = "/home/local/nihargupte/dingo-devel/tutorials/03_aligned_spin/train_dir_SEOBNRv4HM_ROM_O1_2048_lr/model_latest.pt"

torch.cuda.set_device(5)
time_pm = main_pm = PosteriorModel(
    device="cuda",
    **{"model_filename": time_network_path},
    load_training_info=False,
)

main_pm = PosteriorModel(
    device="cuda",
    **{"model_filename": gnpe_network_path},
    load_training_info=False,
)

injection_generator = injection.Injection.from_posterior_model(main_pm)

# Opening up a asd
# If you have a specific hdf5 asd generated through the dingo.gw.ASD_Dataset.generate_dataset you can put it here. Otherwise you can always use the asd dataset that
# the network was trained with which you can access through
# file_name = main_pm.metadata["train_settings"]["training"]["stage_0"]["asd_dataset_path"]
file_name = "/home/local/nihargupte/dingo-devel/tutorials/03_aligned_spin/datasets/ASDs_new/1024_1/asds_O1.hdf5"
asd_dataset = ASDDataset(file_name=file_name)
asd = asd_dataset.sample_random_asds()
injection_generator.asd = asd
# Don't want to pass whitened data to the GW sampler
injection_generator.whiten = False

# You can insert whichever parameters you want here, I am using aligned spins so it is chi_1, chi_2 but if you are using precession this would be the a_1, a_2, etc.
intrinsic_parameters = {
    "mass_1": 40.0,
    "mass_2": 31.6,
    "chi_1": 0,
    "chi_2": 0,
}

extrinsic_parameters = {
    "phase": 0,  # rad
    "theta_jn": np.pi/3,  # rad
    "psi": 0,  # rad
    "ra": 0,  # rad
    "dec": 0,  # rad
    "geocent_time": 0.0,  # s
    "luminosity_distance": 600,  # Mpc
}


theta = {**intrinsic_parameters, **extrinsic_parameters}

# This is the object we want to do inference on 
strain_data = injection_generator.injection(theta)


In [None]:
# Doing the inference
theta = strain_data["parameters"].copy()
init_sampler = GWSampler(model=time_pm)
sampler = GWSamplerGNPE(model=main_pm, init_sampler=init_sampler, num_iterations=30)
sampler.context = strain_data
sampler.run_sampler(
    num_samples=50_000,
    batch_size=10_000,
)

In [None]:
%matplotlib inline

N = 1
c = ChainConsumer()
c.add_chain(sampler.samples, name='dingo')
c.configure(
    linestyles=["-"] * N,
    linewidths=[1.5] * N,
    sigmas=[np.sqrt(2) * scipy.special.erfinv(x) for x in [0.5, 0.9]],
    shade=[False] + [True] * (N - 1),
    shade_alpha=0.3,
    bar_shade=False,
    label_font_size=10,
    tick_font_size=10,
    usetex=False,
    legend_kwargs={"fontsize": 30},
    kde=0.7 # NOTE something to beware of is the KDE, sometimes because it doesn't understand the periodic boundary condition of ra and dec it can give weird results
            # in this case just remove the KDE by removing `kde=0.7`
)

c.plotter.plot(truth=theta)