In [18]:
import jax
import jax.numpy as jnp
from jax import export
from jax._src import compilation_cache as cc
from jax._src.lib import xla_client
from jax.lib import xla_bridge

import numpy as np
import tinygp
import logging
import pickle
from pathlib import Path

from tinygp import GaussianProcess, kernels
from tinygp.kernels import quasisep

print("jax version: " + jax.__version__)
# print if you are using cpu or gpu
print("cpu/gpu: " + str(jax.default_backend()))
print("tinygp version: " + tinygp.__version__)

# cc.initialize_cache("./cache_min_example")
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 1)


t, y, yerr = np.arange(0, 1000, 1), np.sin(np.arange(0, 1000, 1)), np.ones(1000)

theta_init = {
    "log_drw_scale": np.log10(100)
}

# Convert the scalar to a JAX array
theta_init = {key: jnp.array(value, dtype=jnp.float32) for key, value in theta_init.items()}

# Convert the dictionary into ShapeDtypeStruct objects
theta_shapes = {key: jax.ShapeDtypeStruct(shape=value.shape, dtype=value.dtype) for key, value in theta_init.items()}


def build_exp_gp_kernels(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp
    """
    exp_kernel = kernels.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

def build_exp_gp_quasi(theta, t, y, yerr):
    """Build an Gaussian proccess that is only exp, but quasiseparable
    """
    exp_kernel = quasisep.Exp(scale=10**theta["log_drw_scale"])
    return GaussianProcess(exp_kernel, t, diag=yerr, mean=np.mean(y))

@jax.jit
def neg_log_likelihood_kernels(theta, t, y, yerr):
    gp = build_exp_gp_kernels(theta, t, y, yerr)
    return -gp.log_probability(y)

@jax.jit
def neg_log_likelihood_quasi(theta, t, y, yerr):
    gp = build_exp_gp_quasi(theta, t, y, yerr)
    return -gp.log_probability(y)


logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

neg_log_likelihood_kernels(theta_init, t, y, yerr)
neg_log_likelihood_quasi(theta_init, t, y, yerr)


out_file_name = f'/epyc/users/ncaplar/github/JaxPeriodDrwFit/Dev/test.pkl'

if Path(out_file_name).exists():
    with open(out_file_name, 'rb') as f:
        serialized = pickle.load(f)
else:

    exported = export.export(jax.jit(neg_log_likelihood_kernels))(
        theta_shapes,  # Shape and dtype for theta
        jax.ShapeDtypeStruct(shape=t.shape, dtype=t.dtype),                    # Shape and dtype for t
        jax.ShapeDtypeStruct(shape=y.shape, dtype=y.dtype),                    # Shape and dtype for y
        jax.ShapeDtypeStruct(shape=yerr.shape, dtype=yerr.dtype)               # Shape and dtype for yerr
)
    serialized = exported.serialize()

if not Path(out_file_name).exists():
    with open(out_file_name, 'wb') as f:
        pickle.dump(serialized, f)

rehydrated_exp = export.deserialize(serialized)

%timeit
res1 = jax.jit(neg_log_likelihood_kernels)(theta_init, t, y, yerr)

%timeit
res2 = rehydrated_exp.call(theta_init, t, y, yerr)


jax version: 0.4.34
cpu/gpu: cpu
tinygp version: 0.3.0


In [16]:
res1

Array(1229.804, dtype=float32)

In [17]:
res2

Array(1229.804, dtype=float32)

In [None]:
# now, lets try over multiple cores 