In [2]:
import jax.numpy as jnp
import jax.random as jr
from gpjax.kernel import RBF
from gpjax.parameters import Parameter
from gpjax.transforms import Identity
from jax.scipy.linalg import solve_triangular

key = jr.PRNGKey(123)



In [3]:
N = 100
M = 20

In [9]:
X = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(N,)).sort().reshape(-1, 1)
y = jnp.sin(X)

In [37]:
features = Parameter(jr.normal(key, shape=(20, 1)), transform=Identity())
denom = jnp.ones_like(jnp.array([1.0]))

In [38]:
omega = features.value / denom

In [39]:
assert omega.shape == (M, 1)

In [40]:
cos_freqs = jnp.cos(X.dot(omega.T))  # TODO: Can possible do away with the tranpose
sin_freqs = jnp.sin(X.dot(omega.T))
phi = jnp.hstack((cos_freqs, sin_freqs))

In [41]:
assert phi.shape == (N, 2*M)

In [42]:
gram = jnp.matmul(phi, phi.T)

In [43]:
assert gram.shape == (N, N)

In [44]:
l_var = 1.0
k_var = 1.0

In [46]:
A = (k_var / M) * jnp.matmul(phi.T, phi) + l_var * jnp.eye(M * 2)

In [48]:
assert A.shape == (2*M, 2*M)

In [49]:
Rt = jnp.linalg.cholesky(A)

In [52]:
RtiPhit = solve_triangular(Rt, phi.T)

In [57]:
RtiPhity = jnp.matmul(RtiPhit, y.ravel())

In [60]:
RtiPhity.shape

(40,)

In [62]:
term1 = (jnp.sum(y**2) - jnp.sum(RtiPhity**2) * k_var / M) * 0.5 / l_var
term2 = jnp.sum(jnp.log(jnp.diag(Rt.T))) + (N * 0.5 - M) * jnp.log(l_var) + (N * 0.5 * jnp.log(2 * jnp.pi))
tot = term1 + term2

In [64]:
tot.reshape()

DeviceArray(-415.56983189, dtype=float64)

# Kernel

In [4]:
spectral_density = jr.normal

In [6]:
features = spectral_density(key, (M, 1))

In [8]:
lengthscales = 1.
omega = features/lengthscales

In [13]:
cos_freqs = jnp.cos(X.dot(omega.T)) 
sin_freqs = jnp.sin(X.dot(omega.T))
phi = jnp.hstack((cos_freqs, sin_freqs))

In [15]:
phi.shape

(100, 40)

In [18]:
(1/M)*jnp.matmul(phi, phi.T)

DeviceArray([[ 1.        ,  0.99977234,  0.97923822, ...,  0.00318022,
               0.00375161,  0.00488347],
             [ 0.99977234,  1.        ,  0.9832944 , ...,  0.00232103,
               0.00290595,  0.00407733],
             [ 0.97923822,  0.9832944 ,  1.        , ..., -0.00479182,
              -0.00431083, -0.00325729],
             ...,
             [ 0.00318022,  0.00232103, -0.00479182, ...,  1.        ,
               0.99989517,  0.99902319],
             [ 0.00375161,  0.00290595, -0.00431083, ...,  0.99989517,
               1.        ,  0.9995581 ],
             [ 0.00488347,  0.00407733, -0.00325729, ...,  0.99902319,
               0.9995581 ,  1.        ]], dtype=float64)