# Jax GIG implementation

Yoshii+ (2013) uses same parametrization as [Wikipedia](https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution).

There is a reference implementation at Scipy: [`scipy.stats.geninvgauss.entropy`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.geninvgauss.html) which we test against.

- Note: its differential `entropy()` actually uses quadrature; we use analytical formula from Wikipedia


In [None]:
from bngif.gig import GIG
import numpy as np
import jax
import jax.numpy as jnp
import scipy.stats as sp
from numpy.testing import assert_allclose

In [None]:
# ---------------------------------------------------------------------
# 1)   Compare our closed-form moments to SciPy
# ---------------------------------------------------------------------
def test_against_scipy_geninvgauss():
    rng = np.random.default_rng(0)

    for _ in range(20):
        p = float(rng.uniform(1.2, 5.0))  # mean exists for p>1
        a = float(rng.uniform(0.1, 6.0))
        b = float(rng.uniform(0.1, 6.0))
        g = GIG(p=p, a=a, b=b)

        mean_jax, inv_jax, log_jax = g.moments()

        # SciPy reference via .to_scipy()
        rv = g.to_scipy()
        mean_ref = rv.mean()
        inv_ref = rv.expect(lambda x: 1.0 / x)
        log_ref = rv.expect(np.log)

        assert_allclose(mean_jax, mean_ref, rtol=1e-11, atol=1e-13)
        assert_allclose(inv_jax, inv_ref, rtol=1e-8, atol=1e-13)
        assert_allclose(log_jax, log_ref, rtol=1e-7, atol=1e-11)


test_against_scipy_geninvgauss()


In [None]:
gamma = np.array(1 / 4)
rho = np.array(0.984)
tau = np.array(1.6484)

g = GIG(gamma, rho, tau)

g.moments() # OK

In [None]:
rv = g.to_scipy()

a = np.array(1)
b = np.array(1)

gamma = sp.gamma(a, scale=1/b)

# -Eq( log q )
k1 = rv.entropy()

# Eq( log p)
k2 = rv.expect(lambda x: gamma.logpdf(x))

d_kl = -k1 - k2

d_kl

In [None]:
# ---------------------------------------------------------------------
# 2)   Inverse-Gaussian special case   (p = −½,  b = 1/μ,  scale = μ)
#       E[X] should equal μ
# ---------------------------------------------------------------------
def test_inverse_gaussian_mean():
    mu = 2.0
    p, b = -0.5, 1.0 / mu
    g = GIG(p=p, a=b, b=b)  # Y ~ GIG scale=1
    mean_y, *_ = g.moments()
    mean_x = mu * float(mean_y)  # X = μ·Y

    scipy_mean = sp.invgauss(mu).mean()
    assert_allclose(mean_x, scipy_mean, rtol=1e-11, atol=1e-13)


test_inverse_gaussian_mean()

In [None]:
# ---------------------------------------------------------------------
# 3)   Gamma limit:  b → 0⁺  ⇒  GIG(p,a,b) →  Gamma(k=p, θ=2/a)
# ---------------------------------------------------------------------
from bngif.gig import Gamma


def test_gamma_subclass():
    rng = np.random.default_rng(321)

    for _ in range(10):
        k = rng.uniform(1.2, 6.0)  # shape (k>1 ⇒ finite entropy)
        rate = rng.uniform(0.4, 5.0)  # λ
        g = Gamma(shape=k, rate=rate, eps=1e-10)

        mean, inv, log = g.moments()
        H = g.entropy()

        rv = g.to_scipy()
        assert_allclose(mean, rv.mean(), rtol=1e-8)
        assert_allclose(H, rv.entropy(), rtol=1e-8)

        # gradients wrt shape & rate are finite
        dH_dshape = jax.grad(lambda s: Gamma(s, rate).entropy())(k)
        dH_drate = jax.grad(lambda r: Gamma(k, r).entropy())(rate)
        assert np.isfinite(dH_dshape) and np.isfinite(dH_drate)


test_gamma_subclass()

In [None]:
# ---------------------------------------------------------------------
# 4)   Reciprocal property:  X~GIG(p,a,b)  ⇒  1/X ~ GIG(−p,b,a)
# ---------------------------------------------------------------------
def test_reciprocal_identity():
    p, a, b = 1.8, 2.0, 0.7
    mean_x, mean_invx, _ = GIG(p=p, a=a, b=b).moments()
    mean_x_recip, *_ = GIG(p=-p, a=b, b=a).moments()

    assert_allclose(mean_invx, mean_x_recip, rtol=1e-12, atol=1e-14)


test_reciprocal_identity()

In [None]:
# ---------------------------------------------------------------------
# 5)   Entropy: closed-form vs SciPy numeric integration
#      + jit- and grad- friendliness
# ---------------------------------------------------------------------
from bngif.gig import entropy as gig_entropy


def test_entropy_and_autodiff():
    rng = np.random.default_rng(42)

    # use a handful of random parameter triples that keep SciPy quadrature cheap
    for _ in range(5):
        p = float(rng.uniform(0.4, 3.0))  # finite entropy, away from tails
        a = float(rng.uniform(0.2, 4.0))
        b = float(rng.uniform(0.2, 4.0))
        g = GIG(p=p, a=a, b=b)

        # --- reference via SciPy numeric integration -----------------
        entropy_ref = g.to_scipy().entropy()

        # --- our closed form ----------------------------------------
        entropy_val = g.entropy()
        assert_allclose(entropy_val, entropy_ref, rtol=1e-6, atol=1e-7)

        # --- jit works ---------------------------------------------
        entropy_jit = jax.jit(gig_entropy)
        assert_allclose(entropy_jit(p, a, b), entropy_val, rtol=1e-9, atol=1e-9)

        # --- gradients wrt a and b exist and are finite ------------
        grad_a = jax.grad(lambda aa: gig_entropy(p, aa, b))(a)
        grad_b = jax.grad(lambda bb: gig_entropy(p, a, bb))(b)

        for gval in (grad_a, grad_b):
            assert np.isfinite(gval), "NaN/Inf gradient detected"


test_entropy_and_autodiff()

Plot the PDF on the three scales $x,1/x,\log x$. Note $1/x \sim GIG(-p,a,b)$ -- simple negation of order $p$.


In [None]:
from utils.plotting import iplot

gig = GIG(p=1.0, a=1.0, b=1.0).to_scipy()

x = np.linspace(gig.ppf(0.01), gig.ppf(0.99), 100)
y = 1 / x
z = np.log(x)


def pdfx(x):
    return gig.pdf(x)


def pdfy(y):
    return gig.pdf(1 / y) * (1 / y**2)  # Jacobian for change of variables


def pdfz(z):
    return gig.pdf(np.exp(z)) * np.exp(z)


iplot(x, pdfx(x), title="GIG PDF", xlabel="x", ylabel="PDF")
iplot(y, pdfy(y), title="GIG PDF (reciprocal scale)", xlabel="y = 1/x", ylabel="PDF")
iplot(z, pdfz(z), title="GIG PDF (log scale)", xlabel="z = log(x)", ylabel="PDF")


In [None]:
import jax
import jax.numpy as jnp

I = 400
alpha = 1.0

iplot(
    Gamma(alpha / I, alpha).to_scipy().rvs(I),
    title="Gamma process prior",
    xlabel="index i",
    ylabel="theta_i",
)

In [None]:
a_w = b_w = 1.0
a_e = b_e = 1.0

z = {
    "nu_w": Gamma(a_w, b_w),
    "nu_e": Gamma(a_e, b_e),
    "theta": Gamma(jnp.ones(I) * alpha / I, alpha),
}

In [None]:
is_gig = lambda x: isinstance(x, GIG)


entropies = jax.jit(
    lambda t: jax.tree_util.tree_map(
        lambda d: d.entropy(),
        t,
        is_leaf=is_gig,
    )  # <-- Don't unflatten GIG
)

entropies(z)