In [None]:
from bngif.gig import moments
from jax import grad, jit

moments = jit(moments)
moments(1.0, 3.0, 5.0)

In [None]:
moments(1.0, 3.0, 5.0)

In [None]:
# test_gig_moments.py
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
import scipy.stats as sp

tfm = tfp.math  # --- our helpers --------------------------------------


def _log_Kv(v, z):
    return tfm.log_bessel_kve(v, z) - jnp.abs(z)


def _Kv_ratio(v, z):
    return jnp.exp(tfm.log_bessel_kve(v + 1, z) - tfm.log_bessel_kve(v, z))


@jax.jit
def _dlogK_dv(v, z, eps=1e-4):
    return (_log_Kv(v + eps, z) - _log_Kv(v - eps, z)) / (2.0 * eps)


@jax.jit
def gig_moments(p, a, b):
    z = jnp.sqrt(a * b)
    r = _Kv_ratio(p, z)
    mean_x = jnp.sqrt(b / a) * r
    mean_invx = jnp.sqrt(a / b) * r - 2.0 * p / b
    mean_logx = 0.5 * (jnp.log(b) - jnp.log(a)) + _dlogK_dv(p, z)
    return mean_x, mean_invx, mean_logx


# -----------------------------------------------------------------------------


# ---------------------------------------------------------------------
# 1)   Compare against SciPy’s implementation   (a == b  ⇔ SciPy case)
# ---------------------------------------------------------------------
def test_against_scipy_geninvgauss():
    rng = np.random.default_rng(0)
    for _ in range(20):
        p = rng.uniform(1.2, 5.0)  # mean exists for p>1
        b = rng.uniform(0.1, 6.0)

        mean_jax, inv_jax, log_jax = gig_moments(p, b, b)

        # SciPy gives mean directly;  ⟨1/X⟩ and ⟨log X⟩ via .expect
        scipy_rv = sp.geninvgauss(p, b)
        mean_ref = scipy_rv.mean()
        inv_ref = scipy_rv.expect(lambda x: 1.0 / x)
        log_ref = scipy_rv.expect(np.log)

        print(
            f"p={p:.2f}, b={b:.2f}  |  "
            f"mean={mean_jax:.4f}, inv={inv_jax:.4f}, log={log_jax:.4f}  |  "
            f"ref_mean={mean_ref:.4f}, ref_inv={inv_ref:.4f}, ref_log={log_ref:.4f}"
        )

        assert np.allclose(mean_jax, mean_ref, rtol=1e-11, atol=1e-13)
        assert np.allclose(inv_jax, inv_ref, rtol=1e-8, atol=1e-13)
        assert np.allclose(log_jax, log_ref, rtol=1e-7, atol=1e-11)


# ---------------------------------------------------------------------
# 2)   Inverse-Gaussian special case  (p = −½,  b = 1/μ,  scale = μ)
#       E[X]  should be  μ
# ---------------------------------------------------------------------
def test_inverse_gaussian_mean():
    mu = 2.0
    p, b = -0.5, 1.0 / mu

    # our implementation corresponds to "scale = 1"
    mean_y, *_ = gig_moments(p, b, b)  # Y  ~ GIG scale=1
    mean_x = mu * float(mean_y)  # X = μ Y

    scipy_mean = sp.invgauss(mu).mean()

    print(
        f"p={p:.2f}, b={b:.2f}  |  "
        f"mean_y={mean_y:.4f}, mean_x={mean_x:.4f}  |  "
        f"ref_mean={scipy_mean:.4f}"
    )

    assert np.allclose(mean_x, scipy_mean, rtol=1e-11, atol=1e-13)


# ---------------------------------------------------------------------
# 3)   Gamma limit:  b → 0⁺  ⇒  GIG(p,a,b) →  Gamma(k=p, θ=2/a)
# ---------------------------------------------------------------------
def test_gamma_limit_b_to_zero():
    a, p = 3.4, 4.0  # shape k = p  (k>0 ensures mean exists)
    b_small = 1e-7
    mean_gig, *_ = gig_moments(p, a, b_small)

    mean_gamma = sp.gamma.mean(a=p, scale=2.0 / a)

    # First-order limit error is O(b), so 1e-4 on the relative scale is fine
    print(
        f"p={p:.2f}, a={a:.2f}  |  mean_gig={mean_gig:.4f}, mean_gamma={mean_gamma:.4f}"
    )

    assert np.allclose(mean_gig, mean_gamma, rtol=1e-4)


# ---------------------------------------------------------------------
# 4)   Reciprocal property:   If X~GIG(p,a,b)  ⇒  1/X ~ GIG(−p,b,a)
# ---------------------------------------------------------------------
def test_reciprocal_identity():
    p, a, b = 1.8, 2.0, 0.7
    mean_x, mean_invx, _ = gig_moments(p, a, b)

    mean_x_recip, *_ = gig_moments(-p, b, a)

    print(
        f"p={p:.2f}, a={a:.2f}, b={b:.2f}  |  "
        f"mean_x={mean_x:.4f}, mean_invx={mean_invx:.4f}  |  "
        f"mean_x_recip={mean_x_recip:.4f}"
    )

    assert np.allclose(mean_invx, mean_x_recip, rtol=1e-12, atol=1e-14)


test_against_scipy_geninvgauss()
test_inverse_gaussian_mean()
test_gamma_limit_b_to_zero()
test_reciprocal_identity()

In [None]:
# gig.py  ──────────────────────────────────────────────────────────────
from __future__ import annotations
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

tfm = tfp.math  # for bessel_kve, log_bessel_kve


# ---------------------------------------------------------------------
#   Core special-function helpers (all JAX, all GPU/TPU-safe)
# ---------------------------------------------------------------------
def _log_Kv(v: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray:
    """log K_v(z) via the scaled TFP kernel:  log K = log kve − |z|."""
    return tfm.log_bessel_kve(v, z) - jnp.abs(z)


def _Kv_ratio(v: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray:
    """K_{v+1}(z) / K_v(z) in log-space for stability."""
    return jnp.exp(tfm.log_bessel_kve(v + 1, z) - tfm.log_bessel_kve(v, z))


def _dlogK_dv(v: jnp.ndarray, z: jnp.ndarray, eps: float = 1e-4) -> jnp.ndarray:
    """∂/∂v log K_v(z) (central finite difference, AD-compatible)."""
    return (_log_Kv(v + eps, z) - _log_Kv(v - eps, z)) / (2 * eps)


# ---------------------------------------------------------------------
#   Public: closed-form expectations for any broadcast-shape p,a,b
# ---------------------------------------------------------------------
@jax.jit
def gig_moments(
    p: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    r"""E[X], E[1/X], E[log X]  for  GIG(p, a, b)."""
    z = jnp.sqrt(a * b)
    r = _Kv_ratio(p, z)

    mean_x = jnp.sqrt(b / a) * r
    mean_invx = jnp.sqrt(a / b) * r - 2.0 * p / b
    mean_logx = 0.5 * (jnp.log(b) - jnp.log(a)) + _dlogK_dv(p, z)
    return mean_x, mean_invx, mean_logx


# ---------------------------------------------------------------------
#   Differential entropy  H(X)  =  −E[log f(X)]
# ---------------------------------------------------------------------
@jax.jit
def gig_entropy(p: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    r"""Differential entropy of  X ∼ GIG(p,a,b)."""
    z = jnp.sqrt(a * b)
    # log normalizer   log Z = log 2 + log K_p(z) − (p/2)[log a − log b]
    log_Z = jnp.log(2.0) + _log_Kv(p, z) - 0.5 * p * (jnp.log(a) - jnp.log(b))

    mean_x, mean_invx, mean_logx = gig_moments(p, a, b)
    H = log_Z - (p - 1.0) * mean_logx + 0.5 * (a * mean_x + b * mean_invx)
    return H


# ---------------------------------------------------------------------
#   Optional convenience: make each GIG a small PyTree leaf
# ---------------------------------------------------------------------
@jax.tree_util.register_pytree_node_class
@dataclass
class GIG:
    p: jnp.ndarray  # shape-broadcastable
    a: jnp.ndarray  # >0
    b: jnp.ndarray  # >0

    # — PyTree interface —
    def tree_flatten(self):
        return ((self.p, self.a, self.b), None)

    @classmethod
    def tree_unflatten(cls, aux, children):
        return cls(*children)

    # — lightweight methods —
    def moments(self):
        return gig_moments(self.p, self.a, self.b)

    def entropy(self):
        return gig_entropy(self.p, self.a, self.b)

    def to_scipy(self):
        p_val = self.p
        b_scipy = jnp.sqrt(self.a * self.b)
        scale_val = jnp.sqrt(self.b / self.a)

        return sp.geninvgauss(p_val, b_scipy, scale=scale_val)


In [None]:
import jax
import jax.numpy as jnp

# A nested pytree of independent GIG variables
z = {
    "nu_w": GIG(p=2.0, a=3.0, b=4.0),
    "nu_e": GIG(p=2.0, a=3.0, b=4.0),
    "theta": [GIG(p=jnp.linspace(-0.4, 0.4, 30), a=1.0, b=3.0)],
}


In [None]:
import jax
import jax.numpy as jnp

# A nested pytree of independent GIG variables
z = {
    "nu_w": GIG(p=2.0, a=3.0, b=4.0),
    "nu_e": GIG(p=2.0, a=3.0, b=4.0),
    "theta": [
        GIG(p=jnp.linspace(-0.4, 0.4, 20), a=1.0, b=jnp.array([2.0, 3.0])[:, None])
    ],
}

is_gig = lambda x: isinstance(x, GIG)

func = jax.jit(
    lambda t: jax.tree_util.tree_map(
        lambda d: d.entropy(),  # Runs inside JIT
        t,
        is_leaf=is_gig,
    )  # <-- Don't unflatten GIG
)

tree_entropy = func(z)

tree_entropy["theta"]

In [None]:
# append right after the `entropy` method inside the GIG class
import scipy.stats as sp  # put at top of file with other imports

# -----------------------------------------------------------------
#  convert to a frozen scipy.stats.geninvgauss distribution
# -----------------------------------------------------------------


gig_scalar = GIG(p=1.0, a=100.0, b=40.0)

# Get the frozen SciPy object
rv = gig_scalar.to_scipy()

print(rv.mean(), rv.var())
# consistency check
mx, _, _ = gig_scalar.moments()
print(mx)

print(rv.entropy())
print(gig_scalar.entropy())

In [None]:
def scope():
    x = rv.rvs(100000)
    return x.mean(), (1 / x).mean(), jnp.log(x).mean()


print(scope())

gig_scalar.moments()

In [None]:
rv.entropy()

In [None]:
%debug