In [1]:
from diag.util import DummyProgress
from diag.model import CVFold, InfParams, ModelParams
import unittest

import jax
from jax import numpy as jnp
from jax.scipy import stats as st

from diag import Model, LogTransform, compare


class _GaussianVarianceModel(Model):
    r"""Test model: Gaussian with unknown variance and a single obs

    The prior for :math:`\sigma^2` is Gamma(a, b).

    Because :math:`\sigma^2` has only positive support, we need to transform
    it to cover the real line. One way is with the logarithmic transform.
    """
    name = "Gaussian variance model"
    def __init__(
        self,
        y: jnp.DeviceArray,
        mean: float = 0.0,
        prior_shape: float = 2.0,
        prior_rate: float = 2.0,
    ) -> None:
        self.y = y
        self.mean = mean
        self.prior_shape = prior_shape
        self.prior_rate = prior_rate
        self.sigma_sq_transform = LogTransform()

    def log_likelihood(self, cv_fold, model_params: ModelParams):
        return jnp.sum(
            st.norm.logpdf(
                self.y, loc=self.mean, scale=jnp.sqrt(model_params["sigma_sq"])
            )
        )

    def log_prior(self, model_params: ModelParams):
        return st.gamma.logpdf(
            model_params["sigma_sq"], a=self.prior_shape, scale=1.0 / self.prior_rate
        )

    @classmethod
    def generate(
        cls, N: int, mean: float, sigma_sq: float, rng_key: jnp.DeviceArray
    ) -> jnp.DeviceArray:
        return mean + jnp.sqrt(sigma_sq) * jax.random.normal(shape=(N,), key=rng_key)

    def initial_value(self) -> ModelParams:
        return {"sigma_sq": 1.0}

    def log_cond_pred(self, model_params: ModelParams, cv_fold: CVFold):
        sigma_sq = model_params["sigma_sq"]
        return st.norm.logpdf(self.y[cv_fold], loc=self.mean, scale=jnp.sqrt(sigma_sq))

    def to_inference_params(self, model_params: ModelParams) -> InfParams:
        unconstrained = {"sigma_sq": self.sigma_sq_transform(model_params["sigma_sq"])}
        return unconstrained

    def to_model_params(self, inf_params: InfParams) -> ModelParams:
        constrained = {
            "sigma_sq": self.sigma_sq_transform.to_constrained(inf_params["sigma_sq"])
        }
        return constrained

    def log_det(self, model_params: ModelParams) -> jnp.DeviceArray:
        return self.sigma_sq_transform.log_det(model_params["sigma_sq"])

    def cv_folds(self):
        return len(self.y)  # will yield nonsense CV values of course


In [2]:
gen_key = jax.random.PRNGKey(seed=42)
y = _GaussianVarianceModel.generate(N=100, mean=0, sigma_sq=10, rng_key=gen_key)
m1 = _GaussianVarianceModel(y, mean=0.0)  # good
m2 = _GaussianVarianceModel(y, mean=-10.0)  # bad
m3 = _GaussianVarianceModel(y, mean=50.0)  # awful
m1_post = m1.inference(out=DummyProgress())
m2_post = m2.inference(out=DummyProgress())
m3_post = m3.inference(out=DummyProgress())



In [3]:
print(m1_post)

Gaussian variance model inference summary

16,000 draws from 2,000 iterations on 8 chains with seed 42

Parameter      Mean  (SE)      1%    5%    25%    Median    75%    95%    99%
-----------  ------  ------  ----  ----  -----  --------  -----  -----  -----
sigma_sq       8.17  (0.93)   6.3  6.78   7.52       8.1   8.76    9.8  10.65


In [4]:
print(m2_post)

Gaussian variance model inference summary

16,000 draws from 2,000 iterations on 8 chains with seed 42

Parameter      Mean  (SE)       1%     5%    25%    Median    75%    95%    99%
-----------  ------  ------  -----  -----  -----  --------  -----  -----  -----
sigma_sq      42.12  (2.86)  35.99  37.59  40.15     42.02  44.04  47.03  49.22


In [5]:
print(m3_post)

Gaussian variance model inference summary

16,000 draws from 2,000 iterations on 8 chains with seed 42

Parameter      Mean  (SE)        1%      5%     25%    Median     75%     95%     99%
-----------  ------  ------  ------  ------  ------  --------  ------  ------  ------
sigma_sq     238.71  (7.55)  221.65  226.57  233.59    238.54  243.79  251.42  256.56


In [6]:
m1_cv = m1_post.cross_validate()
m2_cv = m2_post.cross_validate()
m3_cv = m3_post.cross_validate()

In [7]:
print(m1_cv)
print(m2_cv)
print(m3_cv)

Cross-validation summary

    elpd = -2.6101 (se 0.0167)

Calculated from 100 folds (8 chains per fold, 800 total)

Average acceptance rate 84.0% (min 81.4%, max 86.5%)

Divergent chain count: 0
Cross-validation summary

    elpd = -4.1115 (se 0.0179)

Calculated from 100 folds (8 chains per fold, 800 total)

Average acceptance rate 76.6% (min 73.6%, max 80.2%)

Divergent chain count: 0
Cross-validation summary

    elpd = -8.9120 (se 0.0151)

Calculated from 100 folds (8 chains per fold, 800 total)

Average acceptance rate 80.1% (min 77.5%, max 82.6%)

Divergent chain count: 0


In [8]:
cmp_res = compare(m1_cv, m2_cv, m3_cv)
cmp_res

LOO Cross Validation Comparison
Model        elpd    elpd diff    elpd se
-------  --------  -----------  ---------
model0   -2.61008      0        0.016704
model1   -4.11147      1.50139  0.0178733
model2   -8.91202      6.30194  0.0150503

In [9]:
cmp_res = compare(m1_cv, bad_model=m2_cv, awful_model=m3_cv)
cmp_res

LOO Cross Validation Comparison
Model            elpd    elpd diff    elpd se
-----------  --------  -----------  ---------
model0       -2.61008      0        0.016704
bad_model    -4.11147      1.50139  0.0178733
awful_model  -8.91202      6.30194  0.0150503