In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import tinygp

from tinygp import GaussianProcess, kernels
from tinygp.kernels import quasisep

print("jax version: " + jax.__version__)
print("tinygp version: " + tinygp.__version__)

from jax.lib import xla_bridge
print("cpu/gpu: " + str(xla_bridge.get_backend().platform))

from jax._src import compilation_cache as cc
cc.initialize_cache("./cache_min_example")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.1)

import logging


t, y, yerr = np.arange(0, 1000, 1), np.sin(np.arange(0, 1000, 1)), np.ones(1000)

theta_init = {
    "log_drw_scale": np.log10(100)
}

def build_exp_gp_kernels(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp
    """
    exp_kernel = kernels.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

def build_exp_gp_quasi(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp, but quasiseparable
    """
    exp_kernel = quasisep.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

@jax.jit
def neg_log_likelihood_kernels(theta, t, y, yerr):
    gp = build_exp_gp_kernels(theta, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_quasi(theta, t, y, yerr):
    gp = build_exp_gp_quasi(theta, t, y, yerr)
    return -gp.log_probability(y)


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

neg_log_likelihood_kernels(theta_init, t, y, yerr)
neg_log_likelihood_quasi(theta_init, t, y, yerr)


jax version: 0.4.16
tinygp version: 0.2.4


Initialized persistent compilation cache at ./cache_min_example
DEBUG:root:test
DEBUG:jax._src.dispatch:Finished tracing + transforming _power for pjit in 0.0006608963012695312 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0005075931549072266 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _mean for pjit in 0.0019757747650146484 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.0003695487976074219 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming absolute for pjit in 0.00025582313537597656 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.0004565715789794922 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00023412704467773438 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming true_divide for pjit in 0.0004050731658935547 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00026249885

cpu/gpu: gpu


DEBUG:jax._src.dispatch:Finished tracing + transforming ravel for pjit in 0.00021576881408691406 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming ravel for pjit in 0.0001246929168701172 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming append for pjit in 0.0027832984924316406 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00017523765563964844 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming true_divide for pjit in 0.0002715587615966797 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming <lambda> for pjit in 0.00017786026000976562 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming matmul for pjit in 0.0011882781982421875 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.00024080276489257812 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.00032711029052734375 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming dot for pjit in 0.0003

Array(1229.8038, dtype=float32)

In [2]:
%pip install -q tinygp

Note: you may need to restart the kernel to use updated packages.
