# Single effective lengthscale

Given a sparse mixture { (k_j, ℓ_j) } with weights w_j = θ_j / Σ_i θ_i, define per-component correlation times

$$
\tau_j = c_{k_j}\,\ell_j, \quad
c_{\mathrm{m12}}=1,\; c_{\mathrm{m32}}=\frac{2}{\sqrt{3}},\;
c_{\mathrm{m52}}=\frac{8}{3\sqrt{5}},\; c_{\mathrm{rbf}}=\sqrt{\frac{\pi}{2}}.
$$

Use the geometric mean as a single scalar statistic:

$$
\ell_{\mathrm{eff}} \equiv \exp\!\Big(\sum_j w_j \log \tau_j\Big).
$$

With M points and m samples per lengthscale: Δt = ℓ_eff / m, T = (M-1)Δt, domain [-L, L] with L = T/2.

# Weighted rank r (shared across kernels; Hilbert on [-L, L])

Fix a spectral tolerance ε (e.g., ε=10^{-2}). For each kernel j, solve S_j(ω_c)/S_j(0)=ε. Closed forms (angular frequency):

$$
\begin{aligned}
\text{m12: } & \omega_c = \tfrac{1}{\ell}\sqrt{\varepsilon^{-1}-1},\\
\text{m32: } & \omega_c = \tfrac{\sqrt{3}}{\ell}\sqrt{\varepsilon^{-1/2}-1},\\
\text{m52: } & \omega_c = \tfrac{\sqrt{5}}{\ell}\sqrt{\varepsilon^{-1/3}-1},\\
\text{rbf: } & \omega_c = \sqrt{2\ln(1/\varepsilon)}\,/\,\ell.
\end{aligned}
$$

On [-L, L], Laplacian eigenfrequencies are \( s_n = n\pi/(2L) \). Per-kernel mode count:

$$
r_j = \Big\lceil \frac{2L}{\pi}\,\omega_{c,j} \Big\rceil.
$$

Aggregate to a single rank with mixture weights and a small safety factor γ (e.g., γ=1.1):

$$
r = \left\lceil \gamma \sum_j w_j\, r_j \right\rceil
\quad \text{(or use a weighted quantile if you prefer robustness).}
$$

# Hilbert kernel construction (for reference)

Use sine basis on [-L, L]:

$$
\phi_n(x)=\sqrt{\tfrac{1}{L}}\;\sin\!\Big(\tfrac{n\pi}{2L}\,(x+L)\Big),\quad
s_n=\tfrac{n\pi}{2L},\ n=1,\dots,r,
$$

with diagonal weights from the base PSD:

$$
\lambda_n = S\!\big(s_n\big),\qquad
k(x,x') \approx \sum_{n=1}^r \lambda_n\,\phi_n(x)\,\phi_n(x').
$$


In [123]:
seed = 0

noise_floor_db = -40.0

jax_enable_x64 = True
jax_platform_name = "gpu"
batch_size = 1
num_metrics_samples = 1

N_kernels = 2
N_ell = 100

alpha_scale = 1.0

P = 0  # Kernel identification: no_filter mode

assert N_kernels <= 4

In [125]:
import jax
import jax.numpy as jnp

# First to set config flags wins!
jax.config.update("jax_enable_x64", jax_enable_x64)
jax.config.update("jax_platform_name", jax_platform_name)

In [126]:
import gpjax as gpx
import jax
import jax.numpy as jnp
import numpy as np
import scipy
from matplotlib import pyplot as plt

from iklp.hyperparams import (
    active_components,
    pi_kappa_hyperparameters,
)
from iklp.mercer import psd_svd
from iklp.mercer_op import sample_parts
from iklp.metrics import (
    StateMetrics,
    compute_metrics_power_distribution,
    compute_power_distibution,
)
from iklp.run import print_progress, vi_run_criterion
from iklp.state import (
    compute_expectations,
    sample_x_from_z,
    sample_z_from_prior,
)
from utils.jax import maybe32, vk
from iklp.hyperparams import solve_for_alpha


In [128]:
master_key = jax.random.PRNGKey(seed)

I = N_kernels * N_ell

alpha = alpha_scale * solve_for_alpha(I)

pi = 0.95
kappa = 1.0

h = pi_kappa_hyperparameters(
    jnp.empty((I, 0, 0)), alpha=maybe32(alpha), pi=pi, kappa=kappa, P=P
)

In [129]:
master_key, key = jax.random.split(master_key)

z = sample_z_from_prior(key, h)

print("nu_w", z.nu_w)
print("nu_e", z.nu_e)
print("sum(theta) = ", z.theta.sum())
print("total power = ", z.nu_w * z.theta.sum() + z.nu_e)
print("pitchedness = ", z.nu_w / (z.nu_w + z.nu_e))
print("I_eff =", active_components(z.theta))

nu_w 0.29892501804909066
nu_e 0.0001801058656661295
sum(theta) =  0.4750795458991217
total power =  0.1421932676983149
pitchedness =  0.9993978509518364
I_eff = 5.409748869170837


In [130]:
# Get a representative lengthscale
z.theta

Array([4.98975922e-008, 4.09088452e-033, 2.20860400e-113, 3.07347916e-068,
       2.03314003e-121, 4.93529516e-084, 8.65480606e-038, 8.46471723e-085,
       1.89553860e-065, 5.48951229e-311, 2.12491478e-020, 3.23789189e-205,
       3.77557890e-063, 2.04030481e-312, 1.82051419e-133, 6.65943898e-038,
       4.59994264e-047, 1.24633107e-009, 5.87802473e-018, 8.54796692e-050,
       1.67506363e-106, 1.06229641e-223, 1.64179370e-017, 1.47369270e-135,
       2.49461070e-057, 1.77956840e-024, 2.17654496e-080, 2.69213199e-088,
       2.08882879e-099, 1.03764990e-196, 0.00000000e+000, 3.84533267e-014,
       9.21607979e-084, 1.56271709e-013, 3.84657627e-038, 9.25630247e-134,
       5.85199185e-045, 1.22022841e-056, 1.93275347e-028, 2.69605261e-043,
       2.18292691e-064, 1.30685310e-001, 2.83703679e-030, 7.30556584e-016,
       4.58960804e-203, 2.22486492e-019, 6.34348097e-066, 1.12461350e-106,
       6.52281570e-018, 1.29394221e-267, 5.93940950e-011, 6.86080163e-008,
       8.25326823e-093, 4

In [None]:
# Get some kernels
kernels = [
    gpx.kernels.Matern12(n_dims=1),
    gpx.kernels.Matern32(n_dims=1),
    gpx.kernels.Matern52(n_dims=1),
    gpx.kernels.RBF(n_dims=1),
]

print([k.name for k in kernels])

ells = jnp.logspace(-1, 1, num=100)


t = jnp.linspace(-3, 3.0, num=200).reshape(-1, 1)



['Matérn12', 'Matérn32', 'Matérn52', 'RBF']


['Matérn12', 'Matérn32', 'Matérn52', 'RBF']

In [None]:
I = len(kernels)
K = jnp.stack([k.gram(t).to_dense() for k in kernels], axis=0)
Phi = psd_svd(K)


Array([ 0.1       ,  0.10476158,  0.10974988,  0.1149757 ,  0.12045035,
        0.12618569,  0.13219411,  0.13848864,  0.14508288,  0.15199111,
        0.15922828,  0.16681005,  0.17475284,  0.18307383,  0.19179103,
        0.2009233 ,  0.21049041,  0.22051307,  0.23101297,  0.24201283,
        0.25353645,  0.26560878,  0.27825594,  0.29150531,  0.30538555,
        0.31992671,  0.33516027,  0.35111917,  0.36783798,  0.38535286,
        0.40370173,  0.42292429,  0.44306215,  0.46415888,  0.48626016,
        0.5094138 ,  0.53366992,  0.55908102,  0.58570208,  0.61359073,
        0.64280731,  0.67341507,  0.70548023,  0.7390722 ,  0.77426368,
        0.81113083,  0.84975344,  0.89021509,  0.93260335,  0.97700996,
        1.02353102,  1.07226722,  1.12332403,  1.17681195,  1.23284674,
        1.29154967,  1.35304777,  1.41747416,  1.48496826,  1.55567614,
        1.62975083,  1.70735265,  1.78864953,  1.87381742,  1.96304065,
        2.05651231,  2.15443469,  2.25701972,  2.36448941,  2.47

In [None]:
# Setup stuff
CMAP = plt.get_cmap("coolwarm")

# WARNING: need to jit any function used in on_metrics() callback, otherwise trigers recompilation at every iteration
compute_expectations = jax.jit(compute_expectations)
compute_metrics_power_distribution = jax.jit(compute_metrics_power_distribution)

collected_metrics = []


def on_metrics(metrics: StateMetrics):
    print_progress(metrics)

    global collected_metrics
    collected_metrics.append(metrics)


In [None]:
# Define hyperparameters and sample from prior
alpha = 0.1  # solve_for_alpha(I) => ensure one component dominates
pi = 0.95
kappa = 1.0

h = pi_kappa_hyperparameters(Phi, alpha=maybe32(alpha), pi=pi, kappa=kappa, P=0)
z = sample_z_from_prior(vk(), h)
x = sample_x_from_z(vk(), z, h)

print("nu_w", z.nu_w)
print("nu_e", z.nu_e)
print("sum(theta) = ", z.theta.sum())
print("pitchedness = ", z.nu_w / (z.nu_w + z.nu_e))
print("I_eff =", active_components(z.theta))
power = jnp.mean(x**2)
print("power(x)/(nu_w + nu_e) = ", power / (z.nu_w + z.nu_e))

# Show sampled timeseries x
plt.figure()
plt.plot(x)
plt.xlabel("x")
plt.title("Sampled x")
plt.show()


# Show power waterfall plot
plt.figure()

# This plots on plt via on_metrics() function
state, metrics = vi_run_criterion(vk(), x, h, callback=on_metrics)

for i, ms in enumerate(collected_metrics):
    power_distribution = compute_metrics_power_distribution(ms)
    plt.plot(
        power_distribution,
        linewidth=1.0,
        color=CMAP(i / len(collected_metrics)),
    )

inferred_power_distribution = compute_metrics_power_distribution(metrics)

true_power_distribution = compute_power_distibution(z)

plt.stem(
    true_power_distribution,
    linefmt="r-",
    markerfmt="ro",
    label="true power distribution",
    basefmt=" ",
)

labels = ["noise"] + [str(i) for i in range(I)]

plt.xticks(ticks=np.arange(I + 1), labels=labels)
plt.xlabel("kernel index $i$")
plt.ylabel("relative power")
plt.title("Power distribution waterfall through VI and ground truth")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(loc="best")
plt.show()

# Calculate performance scores
score = np.exp(
    scipy.stats.entropy(true_power_distribution, inferred_power_distribution)
)

print(
    "Score(DKL) of inferred power distribution (lower is better, 1.0 is perfect): ",
    score,
)

wasserstein = scipy.stats.wasserstein_distance(
    np.arange(I + 1),
    np.arange(I + 1),
    true_power_distribution,
    inferred_power_distribution,
)  # symmetric

print(
    "Score(Wasserstein) between true and inferred power distribution (lower is better, 0.0 is perfect): ",
    wasserstein,
)


In [None]:
from iklp.state import compute_auxiliaries

aux = compute_auxiliaries(state)

op = aux.Omega

signal, noise = sample_parts(op, vk())

plt.plot(x, label="x")
plt.plot(signal, label="signal")
plt.plot(noise, label="noise")
plt.title("Sampled signal and noise parts | (E(z),)")
plt.legend()


In [None]:
plt.plot(x, label="x")
plt.plot(metrics.signals.T, label="signal")

noises = x - metrics.signals

plt.title("Sampled signal and noise parts | (E(z), x)")
plt.legend()