In [1]:
import numpy as np
import matplotlib.pyplot as plt
from corner import corner

import jax
# jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpyro

from ripplegw.waveforms.IMRPhenomD import gen_IMRPhenomD

In [None]:
# @jax.jit
def signal_model(f, chirp_mass):
    eta, chi1, chi2, dl, tc, phic = 0.25, 0, 0, 1_000, 0, 0
    fref = 20.0
    theta = jnp.array([chirp_mass, eta, chi1, chi2, dl, tc, phic])
    return jnp.real(gen_IMRPhenomD(f, theta, fref)) * 1e24

In [3]:
def noise_model(mean = 0, variance = 1):
    return numpyro.distributions.Normal(loc = mean, scale = variance**0.5)

In [None]:
fmin = 20
fmax = 2_000
N = 1_000
chirp_mass = 10
mean = 0
variance = 1

f = jnp.linspace(20, 2_000, 1_000)
signal = signal_model(f, 10)
noise = noise_model(mean, variance).sample(jax.random.PRNGKey(0), (N,))

plt.plot(f, noise, label = 'noise')
plt.plot(f, signal + noise, label = 'signal + noise')
plt.plot(f, signal, label = 'signal')
plt.legend()
plt.semilogx()

In [6]:
def model(f, data):
    chirp_mass = numpyro.sample(
        'chirp_mass', numpyro.distributions.Uniform(9, 11),
    )
    signal = signal_model(f, chirp_mass)
    # variance = 
    print(data.shape, signal.shape)
    with numpyro.plate('N', f.size):
        numpyro.sample('noise', noise_model(), obs = data - signal)
    noise = data - signal
    log_lkl = noise_model().log_prob(noise).sum()
    print(log_lkl)
    numpyro.factor('log_lkl', log_lkl)

In [None]:
nuts = numpyro.infer.NUTS(
    model,
    init_strategy = numpyro.infer.init_to_value(
        values = dict(chirp_mass = chirp_mass),
    ),
)
mcmc = numpyro.infer.MCMC(nuts, num_warmup = 10_000, num_samples = 10_000)
mcmc.run(
    jax.random.PRNGKey(1),
    f = f, data = signal + noise,
    # init_params = jnp.array([chirp_mass]),
)
mcmc.print_summary()

(1000,) (1000,)
-1396.1462
(1000,) (1000,)
Traced<ConcreteArray(-1396.14599609375, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array(-1396.146, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fe606cf6270>, in_tracers=(Traced<ShapedArray(float32[1000]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x7fe5f6719f80; to 'JaxprTracer' at 0x7fe5f671a020>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1000]. let b:f32[] = reduce_sum[axes=(0,)] a in (b,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_reduce_sum', 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x555e9444a720>, nam