In [None]:
# parameters, export
Rd = 0.3
centered = True
iteration = 1
kernel = "tack:2"
normalized = False
seed = 4283955834

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_log_compiles", False)
jax.config.update("jax_platform_name", "cpu")

from gfm.lf import lf_relaxation_open_phase

from dynesty import NestedSampler

from gfm.kernel import instantiate_kernel, build_theta
from utils import time_this


In [None]:
tc = 6.0
N = 256

d = lf_relaxation_open_phase(Rd, tc, N)
t, du, u = d["t"], d["du"], d["u"]

plt.title(f"Data to fit: Rd={Rd}")
plt.plot(t, du, label="du")
plt.plot(t, u, label="u")
plt.legend()


In [None]:
# build generative model
from scipy.special import ndtri
from tinygp.gp import GaussianProcess

hyper = {
    "T": tc,
    "normalized": normalized,
    "centered": centered,
    "center": t.mean(),
}


def build_gp(theta):
    k = instantiate_kernel(kernel, theta, hyper)
    gp = GaussianProcess(kernel=k, X=t, diag=theta["sigma_noise"] ** 2)
    return gp


@jax.jit
def loglikelihood(x):
    theta = build_theta(x, kernel)
    gp = build_gp(theta)
    return gp.log_probability(du)


def ptform(u):
    z = ndtri(u)
    return 10.0**z


In [None]:
# smoke test
rng = np.random.default_rng(seed)

x = ptform(rng.uniform(size=100))
theta = build_theta(x, kernel)
ndim = len(theta)

theta_noiseless = theta.copy()
theta_noiseless["sigma_noise"] = 1e-6

s = build_gp(theta_noiseless).sample(jax.random.PRNGKey(seed), shape=(3,))

plt.title(f"kernel: {kernel}, centered: {centered}, normalized: {normalized}")
plt.plot(t, s.T, label="sample from GP prior")
plt.legend()

loglikelihood(x)

In [None]:
# initialize our nested sampler
nlive = 500

sampler = NestedSampler(
    loglikelihood, ptform, ndim, nlive=nlive, rstate=rng, sample="rwalk"
)

with time_this() as elapsed:
    sampler.run_nested(maxcall=1_000_000, print_progress=False)


In [None]:
res = sampler.results

res.summary()


In [None]:
xs = res.samples_equal()[:5]

for x in xs:
    theta = build_theta(x, kernel)

    print(theta)

    gp = build_gp(theta)
    mu, var = gp.predict(du, t, return_var=True)
    std = jnp.sqrt(var) + theta["sigma_noise"]

    plt.fill_between(
        t,
        mu - 1.96 * std,
        mu + 1.96 * std,
        alpha=0.2,
    )
    plt.plot(t, mu, label="GP posterior mean")

plt.plot(t, du, label="data")

plt.title(f"kernel: {kernel}, centered: {centered}, normalized: {normalized}")
plt.legend()


In [None]:
from dynesty import plotting as dyplot

try:
    fig, ax = dyplot.cornerplot(
        res,
        labels=[str(k) for k in theta.keys()],
        verbose=True,
        quantiles=[0.05, 0.5, 0.95],
    )
except Exception as e:
    print(f"Could not make corner plot: {e}")


In [None]:
# export
logz = res.logz[-1]
logzerr = res.logzerr[-1]

ndim = res.samples.shape[1]
information = res.information[-1]

niter = res.niter
ncall = res.ncall.sum()
walltime = elapsed.walltime