In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import tinygp

from tinygp import GaussianProcess, kernels
from tinygp.kernels import quasisep

print("jax version: " + jax.__version__)
print("tinygp version: " + tinygp.__version__)

from jax.lib import xla_bridge
print("cpu/gpu: " + str(xla_bridge.get_backend().platform))

from jax._src import compilation_cache as cc
cc.initialize_cache("./cache_min_example")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 1)

import logging


t, y, yerr = np.arange(0, 1000, 1), np.sin(np.arange(0, 1000, 1)), np.ones(1000)

theta_init = {
    "log_drw_scale": np.log10(100)
}

def build_exp_gp_kernels(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp
    """
    exp_kernel = kernels.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

def build_exp_gp_quasi(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp, but quasiseparable
    """
    exp_kernel = quasisep.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

@jax.jit
def neg_log_likelihood_kernels(theta, t, y, yerr):
    gp = build_exp_gp_kernels(theta, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_quasi(theta, t, y, yerr):
    gp = build_exp_gp_quasi(theta, t, y, yerr)
    return -gp.log_probability(y)


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

neg_log_likelihood_kernels(theta_init, t, y, yerr)
neg_log_likelihood_quasi(theta_init, t, y, yerr)


jax version: 0.4.18
tinygp version: 0.2.4


Initialized persistent compilation cache at ./cache
DEBUG:root:test
DEBUG:jax._src.dispatch:Finished tracing + transforming _power for pjit in 0.0004904270172119141 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.00036597251892089844 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _mean for pjit in 0.0012814998626708984 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00031113624572753906 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming absolute for pjit in 0.0001735687255859375 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0003426074981689453 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00015974044799804688 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming true_divide for pjit in 0.00034308433532714844 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.0001838207244873047 s

cpu/gpu: gpu


DEBUG:jax._src.dispatch:Finished tracing + transforming ravel for pjit in 0.00022077560424804688 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming ravel for pjit in 0.00012683868408203125 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming append for pjit in 0.0023343563079833984 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00017714500427246094 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming true_divide for pjit in 0.00029850006103515625 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00018143653869628906 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming matmul for pjit in 0.0006918907165527344 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0002510547637939453 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.00033974647521972656 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming dot for pjit in 0.000

Array(1229.8038, dtype=float32)