In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

%matplotlib notebook

import sys
sys.path.append('..')

from cogwheel.fast_marginalization import adaptive, marginalized_extrinsic, intrinsic_prior

import numpy as np
import scipy
import pandas as pd
from cProfile import Profile
from pstats import Stats

import cogwheel
from cogwheel import gw_prior
from cogwheel.posterior import Posterior
from cogwheel import likelihood
from cogwheel import utils
from cogwheel import sampling
from cogwheel import gw_plotting
from cogwheel import data

from importlib import reload

import lal
import lalsimulation

import matplotlib.pyplot as plt
import multiprocessing
from pathlib import Path

In [None]:
parentdir = NotImplemented  # SET ME

In [None]:
metadata = pd.read_csv(data.DATADIR/'events_metadata.csv', index_col=0)  # Chirp mass guesses

In [None]:
def run_pymultinest(eventname, log2nfast=10, n_live_points=512,
                    evidence_tolerance=.5):
    mchirp_guess = metadata['mchirp'][eventname]

    lookup_table = likelihood.LookupTable()
    aux_posterior = Posterior.from_event(eventname, mchirp_guess,
                                         approximant='IMRPhenomXPHM',
                                         prior_class='LinearFreeIASPrior',
                                         likelihood_class=likelihood.MarginalizedDistanceLikelihood,
                                         likelihood_kwargs={'lookup_table': lookup_table})
    like = marginalized_extrinsic.MarginalizedExtrinsicLikelihood.from_aux_posterior(
        aux_posterior, log2nfast)
    prior = gw_prior.prior_registry['IntrinsicIASPrior'](
        **aux_posterior.prior.get_init_dict()
#         | dict(symmetrize_lnq=True)
        )
    post = Posterior(prior, like)

    sampler = sampling.PyMultiNest(post)

    parentdir = 'pe'

    sampler.run_kwargs['n_live_points'] = n_live_points
    sampler.run_kwargs['evidence_tolerance'] = evidence_tolerance

    rundir = sampler.get_rundir(parentdir)

    sampler.run(rundir)

    samples_fname = rundir/'samples.feather'
    samples = pd.read_feather(samples_fname)

    sampler.posterior.prior.transform_samples(samples)
    sampler.posterior.likelihood.postprocess_samples(samples)

    aux_posterior.prior.inverse_transform_samples(samples)

    samples.to_feather(samples_fname)

In [None]:
eventnames = metadata.index.intersection(path.name.removesuffix('.npz')
                                         for path in data.DATADIR.glob('*.npz'))

In [None]:
ncores = multiprocessing.cpu_count() - 1

with multiprocessing.Pool(ncores) as pool:
    pool.map(run_pymultinest, eventnames)

In [None]:
eventname = eventnames[0]

In [None]:
def plot_samples(eventname):
    rundir = sorted(utils.get_eventdir(parentdir,
                                       'IntrinsicIASPrior',
                                       eventname
                                      ).glob('run_*')
                   )[-1]
    stats = Stats(str(rundir/'profiling')).strip_dirs().sort_stats('tottime')

    print(f'{stats.total_tt / 60:.2f} minutes')
    samples = pd.read_feather(rundir/'samples.feather')
    print(len(samples), 'samples')
    gw_plotting.CornerPlot(samples[adaptive.LinearFreeIASPrior.sampled_params]).plot(title=eventname)

In [None]:
for eventname in eventnames:
    plot_samples(eventname)

Timing

In [None]:
stats = Stats(str(rundir/'profiling')).strip_dirs().sort_stats('tottime')

print(f'{stats.total_tt / 60:.2f} minutes')
stats.print_stats()

In [None]:
sampler = utils.read_json(rundir/'Sampler.json')

In [None]:
sampler.posterior.likelihood.fast_parameter_slice = slice(2**11)

In [None]:
def gen_random(prior):
    return prior.transform(*(prior.cubemin + np.random.uniform(0, prior.cubesize)))

with Profile() as p:
    for _ in range(100):
        par_dic = gen_random(sampler.posterior.prior)
        sampler.posterior.likelihood.lnlike(par_dic)

p.print_stats('tottime')

In [None]:
prior = sampler.posterior.prior

In [None]:
with Profile() as p:
    for _ in range(100):
        sampler._get_lnprobs(*(prior.cubemin + np.random.uniform(0, prior.cubesize)))

p.print_stats('tottime')