In [None]:
import gpjax as gpx
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt

from iklp.hyperparams import pi_kappa_hyperparameters, solve_for_alpha
from iklp.mercer import psd_svd
from iklp.run import CriterionState, print_progress, vi_run_criterion
from iklp.state import (
    LatentVars,
    compute_expectations,
    sample_x_from_z,
    sample_z_from_prior,
)
from utils.jax import maybe32, vk


In [None]:

kernels = [
    gpx.kernels.Matern12(),
    gpx.kernels.Matern32(),
    gpx.kernels.Matern52(),
    gpx.kernels.RBF(),
]

I = len(kernels)

fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 6), tight_layout=True)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)

meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel(), strict=False):
    prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
    rv = prior(x)
    y = rv.sample(key=vk(), sample_shape=(10,))
    ax.plot(x, y.T, alpha=0.7)
    ax.set_title(k.name)



In [None]:

K = jnp.stack([k.gram(x).to_dense() for k in kernels], axis=0)


Phi = psd_svd(K)

print("Phi shape", Phi.shape)
print(
    "SVD decomposition allclose()?", jnp.allclose(K, Phi @ Phi.swapaxes(1, 2))
)



In [None]:
alpha = solve_for_alpha(I)

pi = 0.95
kappa = 1.0


h = pi_kappa_hyperparameters(Phi, alpha=maybe32(alpha), pi=pi, kappa=kappa, P=0)

z = sample_z_from_prior(vk(), h)

print("nu_w", z.nu_w)
print("nu_e", z.nu_e)
print("sum = ", z.nu_w + z.nu_e)
print("pitchedness = ", z.nu_w / (z.nu_w + z.nu_e))

x = sample_x_from_z(vk(), z, h)
plt.plot(x)
plt.xlabel("x")
plt.title("Sampled x")
plt.show()

power = jnp.mean(x**2)
print("power(x)/(nu_w + nu_e) = ", power / (z.nu_w + z.nu_e))


def compute_power_distibution(z):
    """Return normalized power distribution for (noise, kernel_1, ..., kernel_I)"""
    power = jnp.concatenate((jnp.array([z.nu_e]), z.nu_w * z.theta))
    return power / jnp.sum(power)  # (I+1,)


def compute_state_power_distribution(state):
    E = compute_expectations(state)
    z = LatentVars(E.theta, E.nu_w, E.nu_e, None)
    return compute_power_distibution(z)


true_power_distribution = compute_power_distibution(z)

plt.stem(
    true_power_distribution,
    linefmt="r-",
    markerfmt="ro",
    label="true power distribution",
    basefmt=" ",
)
plt.legend()
plt.xlabel("kernel index $i$")
plt.ylabel("relative power")
plt.title("True power distribution across kernels and noise")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)


In [None]:

CMAP = plt.get_cmap("coolwarm")


def onstate(cs: CriterionState):
    print_progress(cs)

    i = cs.i
    state = cs.state

    E = compute_expectations(state)
    z = LatentVars(E.theta, E.nu_w, E.nu_e, None)

    print(f"argmax_i(theta) = {np.argmax(E.theta)}")

    cmap = CMAP

    color = cmap(i / 100)
    if i % 1 == 0:
        power_distribution = compute_state_power_distribution(state)
        plt.plot(power_distribution, linewidth=1.0, color=color)



In [None]:


def vi_run_criterion_callback(*args, **kwargs):
    assert "callback" not in kwargs
    return vi_run_criterion(*args, **kwargs, callback=onstate)


vi_run_criterion_callback = jax.jit(vi_run_criterion_callback)

In [None]:

# Use global plot
plt.figure()

cs = vi_run_criterion_callback(vk(), x, h)

true_power_distribution = compute_power_distibution(z)

plt.stem(
    true_power_distribution,
    linefmt="r-",
    markerfmt="ro",
    label="true power distribution",
    basefmt=" ",
)

I = len(kernels)
labels = ["noise"] + [str(i) for i in range(I)]

plt.xticks(ticks=np.arange(I + 1), labels=labels)
plt.xlabel("kernel index $i$")
plt.ylabel("relative power")
plt.title("True power distribution across kernels and noise")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.show()



In [None]:
from iklp.mercer_op import *
from iklp.state import compute_auxiliaries

aux = compute_auxiliaries(state)

op = aux.Omega


In [None]:
signal, noise = sample_parts(op, vk())

plt.plot(x, label="x")
plt.plot(signal, label="signal")
plt.plot(noise, label="noise")
plt.legend()


In [None]:
signal, noise = sample_parts_given_observation(op, x, vk())

plt.plot(x, label="x")
plt.plot(signal, label="signal")
plt.plot(noise, label="noise")
plt.legend()