In [1]:
from tinygp import GaussianProcess
from tinygp.kernels import quasisep

import jax
from jax import export

import numpy as np

In [2]:
def build_gp_drw(theta, t, y, yerr):
    """Build a Gaussian Process model with the damped random walk kernel.

    Parameters
    ----------
    theta : array-like, (2,)
        Array of float values representing the parameters for the kernel.
    t : array-like
        Time domain data for the Gaussian Process.
    y : array-like
        Observations corresponding to the time domain data.
    yerr : array-like
        Uncertainties (errors) associated with the observations.

    Returns
    -------
    gp : GaussianProcess
        Gaussian Process model with the damped random walk kernel.

    Example
    -------
    theta = [0.5, -1.0, 0.8, -2.0]
    gp = build_gp(theta, t, y, yerr)
    """

    # assert len(theta) == 2

    log_drw_scale = theta[0]
    log_drw_amp = theta[1]

    exp_kernel = quasisep.Exp(
        scale=10**log_drw_scale, sigma=10**log_drw_amp)

    kernel = exp_kernel
 
    return GaussianProcess(kernel, t, diag=yerr, mean=np.mean(y), assume_sorted=True)

In [3]:
theta_init = np.array([np.log10(100), 2.0])
t_single = np.linspace(0, 10, 100)
y_single = np.sin(t_single)
yerr_single = 0.1 * np.ones_like(y_single)

In [4]:
jaxpr = jax.make_jaxpr(jax.jit(build_gp_drw))(theta_init, t_single, y_single, yerr_single)
print(jaxpr.pretty_print(source_info=True))

{ lambda ; a:f32[2] b:f32[100] c:f32[100] d:f32[100]. let
    e:f32[] f:f32[] g:f32[] h:f32[100] i:f32[100] j:f32[100,1] k:f32[100,1] l:f32[100,1,1]
      m:f32[100] n:f32[100,1] o:f32[100,1] p:f32[100,1,1] = pjit[                            # /tmp/ipykernel_64954/3436479300.py:1:8 (<module>)
      name=build_gp_drw
      jaxpr={ lambda ; q:f32[2] r:f32[100] s:f32[100] t:f32[100]. let
          u:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] q             # /tmp/ipykernel_64954/1532116811.py:28:20 (build_gp_drw)
          v:f32[] = squeeze[dimensions=(0,)] u                                               # /tmp/ipykernel_64954/1532116811.py:28:20 (build_gp_drw)
          w:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] q             # /tmp/ipykernel_64954/1532116811.py:29:18 (build_gp_drw)
          x:f32[] = squeeze[dimensions=(0,)] w                                               # /tmp/ipykernel_64954/1532116811.py:29:18 (build_gp_drw)
        