# VI

Possible optimizations:

- Use CG solver warmstarted from previous result for $a$
- Use COLA annotations


In [None]:
import jax
import jax.numpy as jnp

from bngif.iklp import build_Psi, build_X
from bngif.gig import GIG, Gamma

from utils.plotting import iplot


In [None]:
M = 20
P = 4
I = 5
_lambda = 0.1

key = jax.random.PRNGKey(0)
key, k1, k2, k3 = jax.random.split(key, 4)

x = jax.random.normal(k1, shape=(M,))
a = jax.random.normal(k2, shape=(P,)) * jnp.sqrt(_lambda)

# K
K_root = jax.random.normal(k3, shape=(I, M, M))
K = jnp.matmul(
    jnp.transpose(K_root, (0, 2, 1)),  # (I, M, M)  = K_rootᵀ
    K_root,
)

a_w = b_w = 1.0
a_e = b_e = 1.0
alpha = 1.0

In [None]:
assert not jnp.any(jnp.isnan(jnp.linalg.cholesky(K)))

In [None]:
zp = {
    "nu_w": Gamma(a_w, b_w),
    "nu_e": Gamma(a_e, b_e),
    "theta": Gamma(jnp.ones(I) * alpha / I, alpha),
}

zq = {
    "nu_w": Gamma(a_w, b_w),
    "nu_e": Gamma(a_e, b_e),
    "theta": Gamma(jnp.ones(I) * alpha / I, alpha),
}

is_gig = lambda x: isinstance(x, GIG)

tree = jax.tree.map(
    lambda d: d.entropy(),
    zp,
    is_leaf=is_gig,
)

zp

In [None]:
jax.tree.map(
    lambda d: d.moments(),
    zp,
    is_leaf=is_gig,
)

In [None]:
from jax.scipy.special import gammaln  # already imported as jss.gammaln


def E_log_p_under_q(alpha, beta, q: GIG):
    """E_q[ log Gamma(alpha,beta) ]  where  q ~ GIG."""
    mean_z, _, mean_logz = q.moments()  # re-uses your `moments` function

    const = alpha * jnp.log(beta) - gammaln(alpha)
    return const + (alpha - 1) * mean_logz - beta * mean_z


In [None]:
-E_log_p_under_q(a_w, b_w, zq["nu_w"]) - zq["nu_w"].entropy()

In [None]:
def KL_test(p, a, b):
    testq = GIG(p, a, b)
    a_w = 1 / 2.0
    b_w = 1 / 3.0
    return testq.KL_from_gamma(a_w, b_w)


KL_test = jax.jit(KL_test)

# generate random p, a, b values many times and test for each combination if nonnegative KL divergence
# use broadcasting
p = 1.0
a = jax.random.uniform(k2, shape=(1000,), minval=0.01, maxval=10)
b = jax.random.uniform(k3, shape=(1000,), minval=0.01, maxval=10)
kl = KL_test(p, a, b)
iplot(kl.sort())

assert jnp.all(kl >= 0.0)

In [None]:
q = {
    "nu_w": Gamma(a_w, b_w),
    "nu_e": Gamma(a_e, b_e),
    "theta": Gamma(jnp.ones(I) * alpha / I, alpha),
}

q["nu_w"].KL_from_gamma(a_w, b_w), q["nu_e"].KL_from_gamma(a_e, b_e)


In [None]:
q["theta"].KL_from_gamma(alpha / I, alpha)

In [None]:
[
    GIG(t, q["theta"].a, q["theta"].b).KL_from_gamma(alpha / I, alpha)
    for t in q["theta"].p
]  # ok

In [None]:
I = 1 / 50

Gamma(alpha / I, alpha).KL_from_gamma(alpha / I, alpha)

In [None]:
I = 400

print(Gamma(alpha / I, alpha).entropy(), Gamma(alpha / I, alpha).to_scipy().entropy())

problems:

E(1/x) does not exist if Gamma(a,b): a <= 1.0

Our entropy() function of GIG is off wrt to scipy implementation

Problems are caused by b ~ 0 when parametrizing Gamma as a GIG

We have to init the q()s differently from Gamma alone it seems


In [None]:
Gamma(a_w, b_w).moments()

In [None]:
I = jnp.arange(100) + 1

q = {
    "nu_w": Gamma(a_w, b_w),
    "nu_e": Gamma(a_e, b_e),
    "theta": Gamma(alpha / I, alpha),
}


def KL_from_gammas(q):
    """Calculate KL[q|p] = E_q[log q] - E_q[log p]"""
    return (
        # q["nu_w"].KL_from_gamma(a_w, b_w)
        # + q["nu_e"].KL_from_gamma(a_e, b_e)
        +q["theta"].KL_from_gamma(alpha / I, alpha)
    )


iplot(KL_from_gammas(q))
