In [None]:
import iklp

In [None]:
from iklp import gig, hyperparams, mercer_op, periodic, state, util, vi

In [None]:
import jax.random

hyperparams.random_periodic_kernel_hyperparams(jax.random.PRNGKey(0))

In [None]:
import jax.numpy as jnp

from bngif.iklp import build_Psi, build_X

from utils.plotting import iplot

In [None]:
# https://github.com/NeilGirdhar/efax

In [None]:
import jax
from numpy.random import randn

M, P = 2048, 30
a = randn(P)
x = jnp.linspace(1.0, 6.0, M)

Psi = build_Psi(M, a)
X = build_X(x, P)

Psi.shape, X.shape

In [None]:
# Infinite Kernel Linear Prediction (IKLP) Variational Inference Implementation (Yoshii & Goto 2013)
# pip install jax jaxlib
# pip install scipy

import jax
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
import scipy.special as sp


def build_X_matrix(x, p):
    """
    Build design matrix X for AR of order p (Eq. (3) structure).
    X is N×p where X[n,m] = x[n-m-1] for m < n (zero-padded for n <= m).
    """
    N = x.shape[0]
    if p == 0:
        return jnp.zeros((N, 0))
    # Use Toeplitz structure: each column is a shifted version of x
    cols = []
    for m in range(1, p + 1):
        # Pad m zeros at start, then take x[0:N-m] for remainder
        col = jnp.concatenate([jnp.zeros(m), x[:-m]])
        cols.append(col)
    X = jnp.stack(cols, axis=1)
    return X


def iklp_vi_naive(
    x, a, K_list, p, aw, bw, ae, be, alpha, max_iter=100, tol=1e-6
):
    """
    Naive IKLP Variational Inference.
    Uses explicit matrix operations as in the paper (Equations (15)–(27)).
    """
    N = x.shape[0]
    I = len(K_list)
    # Initialize variational parameters (posteriors) and AR coefficients
    # a = jnp.zeros(p)  # AR filter coefficients (point estimate, MAP)
    # Initialize expectations for theta (one per kernel) from prior (Eq. (16))
    E_theta = jnp.full((I,), 1.0 / I)  # E[θ_i] ≈ 1/I initially
    E_inv_theta = (
        1.0 / E_theta
    )  # E[1/θ_i] (approx reciprocal of mean to avoid infinities)
    # Initialize expectations for ν_w, ν_e from priors (Eq. (17))
    E_nu_w = aw / bw
    E_inv_nu_w = (
        (aw - 1) / bw if aw > 1 else 1.0 / E_nu_w
    )  # if shape<=1, use reciprocal of mean as approximation
    E_nu_e = ae / be
    E_inv_nu_e = (ae - 1) / be if ae > 1 else 1.0 / E_nu_e
    I_N = jnp.eye(N)
    elbo_history = []
    last_elbo = -jnp.inf
    for it in range(max_iter):
        # Eq. (24): Compute Omega = E[ν_w] * sum_i E[θ_i] K_i + E[ν_e] I (covariance of x after filtering)
        Omega = (
            E_nu_w * sum(E_theta[i] * K_list[i] for i in range(I))
            + E_nu_e * I_N
        )
        Omega_inv = jnp.linalg.inv(Omega)  # explicit inverse (naive)
        # Eq. (25): Compute S = ∑_i (E[1/(ν_w θ_i)])^{-1} K_i + (E[1/ν_e])^{-1} I
        # Note: E[1/(ν_w θ_i)] = E[1/ν_w]*E[1/θ_i] (due to factorization of q)
        S = (
            sum(
                (1.0 / (E_inv_nu_w * E_inv_theta[i])) * K_list[i]
                for i in range(I)
            )
            + (1.0 / E_inv_nu_e) * I_N
        )
        S_inv = jnp.linalg.inv(S)
        # Compute residual e = Ψ x (using lower-triangular Toeplitz Ψ, Eq. (3))
        # e = x + X a, where X is as defined for AR (past samples matrix)
        X = build_X_matrix(x, p)
        e = x + X @ a
        # Solve for u = S^{-1} e  (use explicit inverse for naive)
        u = S_inv @ e
        # Compute trace terms and quadratic terms for each kernel
        tr_vals = jnp.array(
            [jnp.trace(Omega_inv @ K) for K in K_list]
        )  # tr(Omega^{-1} K_i)
        uKu_vals = jnp.array([u @ (K @ u) for K in K_list])  # u^T K_i u
        tr_I = jnp.trace(Omega_inv)  # tr(Omega^{-1} I) for identity
        # Eq. (27): Update variational posterior parameters (GIG distributions) for θ and ν
        # Posterior q(θ_i): γ_i, ρ_i, τ_i
        gamma_theta = alpha / I
        rho_theta = 2 * alpha + E_nu_w * tr_vals
        # E[1/(ν_w θ_i)] = E[1/ν_w] * E[1/θ_i]
        E_inv_nu_theta = E_inv_nu_w * E_inv_theta  # vector of E[1/(ν_w θ_i)]
        tau_theta = E_inv_nu_w * (1.0 / (E_inv_nu_theta**2)) * uKu_vals
        # Posterior q(ν_w): γ_w, ρ_w, τ_w
        gamma_nu_w = aw
        rho_nu_w = 2 * bw + jnp.sum(E_theta * tr_vals)
        tau_nu_w = jnp.sum(E_inv_theta * (1.0 / (E_inv_nu_theta**2)) * uKu_vals)
        # Posterior q(ν_e): γ_e, ρ_e, τ_e
        gamma_nu_e = ae
        rho_nu_e = 2 * be + tr_I
        tau_nu_e = (1.0 / (E_inv_nu_e**2)) * (u @ u)
        # Update expectations E[...] and E[1/...] using moments of GIG distributions (via Bessel K functions)
        sqrt_rho_tau_theta = jnp.sqrt(rho_theta * tau_theta)
        # Use SciPy for Bessel K: K_ν(z)
        rho_tau_np = jnp.array(sqrt_rho_tau_theta)
        K_base = sp.kv(gamma_theta, rho_tau_np)
        K_p1 = sp.kv(gamma_theta + 1.0, rho_tau_np)
        K_m1 = sp.kv(gamma_theta - 1.0, rho_tau_np)
        # E[θ_i] = √(τ_i/ρ_i) * K_{γ_i+1}(√(ρ_i τ_i)) / K_{γ_i}(√(ρ_i τ_i))
        # E[1/θ_i] = √(ρ_i/τ_i) * K_{γ_i-1}(√(ρ_i τ_i)) / K_{γ_i}(√(ρ_i τ_i))
        E_theta = jnp.sqrt(tau_theta / rho_theta) * (
            jnp.array(K_p1) / jnp.array(K_base)
        )
        E_inv_theta = jnp.sqrt(rho_theta / tau_theta) * (
            jnp.array(K_m1) / jnp.array(K_base)
        )
        # E[ν_w] and E[1/ν_w]
        sqrt_rho_tau_nu_w = float(jnp.sqrt(rho_nu_w * tau_nu_w))
        K_base = float(sp.kv(gamma_nu_w, sqrt_rho_tau_nu_w))
        K_p1 = float(sp.kv(gamma_nu_w + 1.0, sqrt_rho_tau_nu_w))
        K_m1 = float(sp.kv(gamma_nu_w - 1.0, sqrt_rho_tau_nu_w))
        E_nu_w = jnp.sqrt(tau_nu_w / rho_nu_w) * (K_p1 / K_base)
        E_inv_nu_w = jnp.sqrt(rho_nu_w / tau_nu_w) * (K_m1 / K_base)
        # E[ν_e] and E[1/ν_e]
        sqrt_rho_tau_nu_e = float(jnp.sqrt(rho_nu_e * tau_nu_e))
        K_base = float(sp.kv(gamma_nu_e, sqrt_rho_tau_nu_e))
        K_p1 = float(sp.kv(gamma_nu_e + 1.0, sqrt_rho_tau_nu_e))
        K_m1 = float(sp.kv(gamma_nu_e - 1.0, sqrt_rho_tau_nu_e))
        E_nu_e = jnp.sqrt(tau_nu_e / rho_nu_e) * (K_p1 / K_base)
        E_inv_nu_e = jnp.sqrt(rho_nu_e / tau_nu_e) * (K_m1 / K_base)
        # Update AR coefficients a via MAP (solve regularized normal equation after Eq. (27))
        # (X^T Ω^{-1} X + λ I) a = X^T Ω^{-1} x, using current Omega
        lam = 1e-6  # small regularization (λ from prior a ~ N(0, λ I))
        A = X.T @ (Omega_inv @ X)
        b = X.T @ (Omega_inv @ x)
        A_reg = A + lam * jnp.eye(p)
        L_mat = jnp.linalg.cholesky(A_reg)
        y_tmp = solve_triangular(L_mat, b, lower=True)
        a = solve_triangular(L_mat.T, y_tmp, lower=False)
        # Compute evidence lower bound (ELBO) L for convergence tracking (Eq. (19) with approximations)
        M = N
        # First term: E[log p(x|θ,a,ν)] (use bound via Eq. (23) as computed above)
        logdet_Omega = jnp.log(jnp.linalg.det(Omega) + 1e-12)
        term1 = -0.5 * (
            M * jnp.log(2 * jnp.pi)
            + logdet_Omega
            + jnp.sum(E_nu_w * E_theta * tr_vals)
            + E_nu_e * tr_I
            + jnp.sum(E_inv_nu_w * E_inv_theta * uKu_vals)
            + E_inv_nu_e * (u @ u)
        )
        # Prior terms: E[log p(θ)] + E[log p(ν_w)] + E[log p(ν_e)] + E[log p(a)]
        E_log_theta = jnp.log(E_theta + 1e-12)
        term2 = jnp.sum(
            (alpha / I - 1.0) * E_log_theta - alpha * E_theta
        )  # θ prior (Gamma(α/I, α))
        E_log_nu_w = jnp.log(E_nu_w + 1e-12)
        term3 = (aw - 1.0) * E_log_nu_w - bw * E_nu_w  # ν_w prior
        E_log_nu_e = jnp.log(E_nu_e + 1e-12)
        term4 = (ae - 1.0) * E_log_nu_e - be * E_nu_e  # ν_e prior
        term5 = -0.5 * (
            lam * (a @ a) - p * jnp.log(lam + 1e-12)
        )  # a prior (Gaussian N(0, λ I))
        # Entropy terms: E[log q(θ,...)] (omitted for brevity or treated via L')
        L_elbo = term1 + term2 + term3 + term4 + term5
        elbo_history.append(L_elbo)
        # Check convergence
        if jnp.abs(L_elbo - last_elbo) < tol:
            break
        last_elbo = L_elbo
    return {
        "a": a,
        "E_theta": E_theta,
        "E_inv_theta": E_inv_theta,
        "E_nu_w": E_nu_w,
        "E_inv_nu_w": E_inv_nu_w,
        "E_nu_e": E_nu_e,
        "E_inv_nu_e": E_inv_nu_e,
        "ELBO_history": jnp.array(elbo_history),
    }


def iklp_vi_efficient(
    x, a, K_list, p, aw, bw, ae, be, alpha, max_iter=100, tol=1e-6
):
    """
    Efficient IKLP Variational Inference.
    Uses Cholesky decompositions and avoids explicit large-matrix inverses (Woodbury/solves for stability & speed).
    """
    N = x.shape[0]
    I = len(K_list)
    # Initialize parameters and AR coefficients as before
    # a = jnp.zeros(p)
    E_theta = jnp.full((I,), 1.0 / I)
    E_inv_theta = 1.0 / E_theta
    E_nu_w = aw / bw
    E_inv_nu_w = (aw - 1) / bw if aw > 1 else 1.0 / E_nu_w
    E_nu_e = ae / be
    E_inv_nu_e = (ae - 1) / be if ae > 1 else 1.0 / E_nu_e
    I_N = jnp.eye(N)
    elbo_history = []
    last_elbo = -jnp.inf
    for it in range(max_iter):
        # Compute Omega and its Cholesky (more stable than explicit inverse)
        Omega = (
            E_nu_w * sum(E_theta[i] * K_list[i] for i in range(I))
            + E_nu_e * I_N
        )
        L = jnp.linalg.cholesky(Omega)  # lower-triangular Cholesky of Omega
        # Compute S and solve S u = e by Cholesky (avoid forming S_inv explicitly)
        S = (
            sum(
                (1.0 / (E_inv_nu_w * E_inv_theta[i])) * K_list[i]
                for i in range(I)
            )
            + (1.0 / E_inv_nu_e) * I_N
        )
        Ls = jnp.linalg.cholesky(S)
        # Compute residual e = Ψ x (Toeplitz convolution, Eq. (3))
        X = build_X_matrix(x, p)
        e = x + X @ a
        # Solve S u = e via forward/backward substitution (Cholesky solves)
        y = solve_triangular(Ls, e, lower=True)
        u = solve_triangular(Ls.T, y, lower=False)
        # Compute trace and quadratic terms
        # For trace, we use explicit inverse or solve for stability (still O(N^3), but stable)
        Omega_inv = jnp.linalg.inv(
            Omega
        )  # could be replaced with iterative solves if needed
        tr_vals = jnp.array([jnp.trace(Omega_inv @ K) for K in K_list])
        uKu_vals = jnp.array([u @ (K @ u) for K in K_list])
        tr_I = jnp.trace(Omega_inv)
        # Variational parameter updates (same as naive)
        gamma_theta = alpha / I
        rho_theta = 2 * alpha + E_nu_w * tr_vals
        E_inv_nu_theta = E_inv_nu_w * E_inv_theta
        tau_theta = E_inv_nu_w * (1.0 / (E_inv_nu_theta**2)) * uKu_vals
        gamma_nu_w = aw
        rho_nu_w = 2 * bw + jnp.sum(E_theta * tr_vals)
        tau_nu_w = jnp.sum(E_inv_theta * (1.0 / (E_inv_nu_theta**2)) * uKu_vals)
        gamma_nu_e = ae
        rho_nu_e = 2 * be + tr_I
        tau_nu_e = (1.0 / (E_inv_nu_e**2)) * (u @ u)
        # Update expectations via Bessel K functions (GIG moments)
        sqrt_rho_tau_theta = jnp.sqrt(rho_theta * tau_theta)
        rho_tau_np = jnp.array(sqrt_rho_tau_theta)
        K_base = sp.kv(gamma_theta, rho_tau_np)
        K_p1 = sp.kv(gamma_theta + 1.0, rho_tau_np)
        K_m1 = sp.kv(gamma_theta - 1.0, rho_tau_np)
        E_theta = jnp.sqrt(tau_theta / rho_theta) * (
            jnp.array(K_p1) / jnp.array(K_base)
        )
        E_inv_theta = jnp.sqrt(rho_theta / tau_theta) * (
            jnp.array(K_m1) / jnp.array(K_base)
        )
        sqrt_rho_tau_nu_w = float(jnp.sqrt(rho_nu_w * tau_nu_w))
        K_base = float(sp.kv(gamma_nu_w, sqrt_rho_tau_nu_w))
        K_p1 = float(sp.kv(gamma_nu_w + 1.0, sqrt_rho_tau_nu_w))
        K_m1 = float(sp.kv(gamma_nu_w - 1.0, sqrt_rho_tau_nu_w))
        E_nu_w = jnp.sqrt(tau_nu_w / rho_nu_w) * (K_p1 / K_base)
        E_inv_nu_w = jnp.sqrt(rho_nu_w / tau_nu_w) * (K_m1 / K_base)
        sqrt_rho_tau_nu_e = float(jnp.sqrt(rho_nu_e * tau_nu_e))
        K_base = float(sp.kv(gamma_nu_e, sqrt_rho_tau_nu_e))
        K_p1 = float(sp.kv(gamma_nu_e + 1.0, sqrt_rho_tau_nu_e))
        K_m1 = float(sp.kv(gamma_nu_e - 1.0, sqrt_rho_tau_nu_e))
        E_nu_e = jnp.sqrt(tau_nu_e / rho_nu_e) * (K_p1 / K_base)
        E_inv_nu_e = jnp.sqrt(rho_nu_e / tau_nu_e) * (K_m1 / K_base)
        # Update AR coefficients a (solve weighted least squares via Cholesky, avoiding explicit Omega_inv)
        # Solve L z = X and L y_vec = x for z and y_vec
        z = solve_triangular(L, X, lower=True)
        y_vec = solve_triangular(L, x, lower=True)
        A_small = z.T @ z
        b_small = z.T @ y_vec
        lam = 1e-6
        A_reg = A_small + lam * jnp.eye(p)
        L_small = jnp.linalg.cholesky(A_reg)
        y_tmp = solve_triangular(L_small, b_small, lower=True)
        a = solve_triangular(L_small.T, y_tmp, lower=False)
        # Compute ELBO (L') for monitoring (similar to naive computation)
        M = N
        logdet_Omega = 2 * jnp.sum(jnp.log(jnp.diag(L) + 1e-12))  # log|Omega|
        term1 = -0.5 * (
            M * jnp.log(2 * jnp.pi)
            + logdet_Omega
            + jnp.sum(E_nu_w * E_theta * tr_vals)
            + E_nu_e * tr_I
            + jnp.sum(E_inv_nu_w * E_inv_theta * uKu_vals)
            + E_inv_nu_e * (u @ u)
        )
        E_log_theta = jnp.log(E_theta + 1e-12)
        term2 = jnp.sum((alpha / I - 1.0) * E_log_theta - alpha * E_theta)
        E_log_nu_w = jnp.log(E_nu_w + 1e-12)
        term3 = (aw - 1.0) * E_log_nu_w - bw * E_nu_w
        E_log_nu_e = jnp.log(E_nu_e + 1e-12)
        term4 = (ae - 1.0) * E_log_nu_e - be * E_nu_e
        term5 = -0.5 * (lam * (a @ a) - p * jnp.log(lam + 1e-12))
        L_elbo = term1 + term2 + term3 + term4 + term5
        elbo_history.append(L_elbo)
        if jnp.abs(L_elbo - last_elbo) < tol:
            break
        last_elbo = L_elbo
    return {
        "a": a,
        "E_theta": E_theta,
        "E_inv_theta": E_inv_theta,
        "E_nu_w": E_nu_w,
        "E_inv_nu_w": E_inv_nu_w,
        "E_nu_e": E_nu_e,
        "E_inv_nu_e": E_inv_nu_e,
        "ELBO_history": jnp.array(elbo_history),
    }

In [None]:
# iklp_vi_naive(x, K_list, p, aw, bw, ae, be, alpha, max_iter=100, tol=1e-6):

from math import sqrt
from numpy.random import randn

M = 20
P = 4
_lambda = 0.1


def randK():
    R = randn(M, M)
    R = R @ R.T
    return R


x = randn(M)
a = randn(P) * sqrt(_lambda)
K_list = [randK() for _ in range(4)]
p = P
aw = 1.0
bw = 1.0
ae = 1.0
be = 1.0
alpha = 1.0
max_iter = 20
tol = 1e-6

# Run the naive version
result_naive = iklp_vi_naive(
    x, a, K_list, p, aw, bw, ae, be, alpha, max_iter=max_iter, tol=tol
)

iplot(
    result_naive["ELBO_history"][:-2],
    title="Naive IKLP Variational Inference ELBO",
)


In [None]:
result_naive

In [None]:
sp.kv??