# Proof of concept - online rhat

In [1]:
import blackjax
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from collections import namedtuple
from blackjax.diagnostics import potential_scale_reduction
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

In [None]:
Theta = namedtuple('Theta', ['beta', 'sigsq'])

In [None]:
WelfordState = namedtuple("WelfordState", ["K", "Ex", "Ex2", "n"])

def welford_init(K) -> WelfordState:
  return WelfordState(K, 0., 0., 0)

def welford_add(state: WelfordState, x: float) -> WelfordState:
  return WelfordState(state.K, state.Ex + x - state.K, state.Ex2 + (x - state.K)**2, state.n + 1)

def welford_mean(state: WelfordState) -> WelfordState:
  return state.K + state.Ex / state.n

def welford_var(state: WelfordState) -> WelfordState:
  return (state.Ex2 - state.Ex**2 / state.n) / (state.n - 1)

In [None]:
ExtendedState = namedtuple("ExtendedState", ['state', 'welford_state'])

In [None]:
y_key, X_key = jax.random.split(jax.random.PRNGKey(0))
N = 100
beta0 = jnp.array([1.0, 2.0, 3.0, 4.0])
p = len(beta0)
sigsq0 = jnp.array(2.0)
X = tfd.Normal(loc=0, scale=1).sample(sample_shape=(N, p), seed=X_key)
y = X@beta0 + tfd.Normal(loc=0, scale=jnp.sqrt(sigsq0)).sample(sample_shape=(N,), seed=y_key)

In [None]:
# use exp to transform sigsq to unconstrained space
sigsq_t = tfb.Exp()

beta_prior = tfd.MultivariateNormalDiag(loc=jnp.zeros(p), scale_diag=jnp.ones(p))
sigsq_prior = tfd.Gamma(concentration=1.0, rate=1.0)

def logprob_fn(theta: Theta):
  sigsq = sigsq_t.forward(theta.sigsq)
  sigsq_ldj = sigsq_t.forward_log_det_jacobian(theta.sigsq)
  lprior = beta_prior.log_prob(theta.beta) + sigsq_prior.log_prob(theta.sigsq)
  lhood = tfd.Normal(loc=X@theta.beta, scale=jnp.sqrt(sigsq)).log_prob(y).sum()
  return lprior + lhood

In [None]:
num_samples = 500
warmup_iter = 500
num_chains = 5
warmup_key, sampling_key, init_key, subs_key = jax.random.split(jax.random.PRNGKey(0), 4)

# random initialization in the constrained parameter space
def make_initial_pos(key):
  k1, k2 = jax.random.split(key)
  return Theta(beta=jax.random.normal(key=k1, shape=(p,)), sigsq=jax.random.normal(key=k2))

In [None]:
%%time

warmup = blackjax.window_adaptation(blackjax.nuts, logprob_fn, num_steps=warmup_iter, progress_bar=True)
final_warmup_state, kernel, info = warmup.run(warmup_key, make_initial_pos(init_key))

In [None]:
# sample initial positions from second half of warmup trajectory
idxs = jax.random.choice(subs_key, a=jnp.arange(warmup_iter//2, warmup_iter), shape=(num_chains,))
initial_positions = Theta(
    beta = info[0].position.beta[idxs,],
    sigsq = info[0].position.sigsq[idxs]
)
initial_states = jax.vmap(lambda p: blackjax.nuts.init(p, logprob_fn))(initial_positions)

In [None]:
# sense check logprob_fn
[logprob_fn(final_warmup_state.position), jax.vmap(logprob_fn)(initial_positions)]

In [None]:
plt.subplot(1,2,1)
plt.plot(info[0].position.beta)
plt.title('beta warmup')
plt.subplot(1,2,2)
plt.plot(info[0].position.sigsq)
plt.title('sigsq warmup')
plt.tight_layout()

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):

    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        e = X @ state.position.beta - y
        elpd = -0.5 * (
            jnp.log(2 * jnp.pi)
            + jnp.log(state.position.sigsq)
            + jnp.dot(e, e)/state.position.sigsq
        )
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)
    return states

In [None]:
%%time

sampling_keys = jax.random.split(sampling_key, num_chains)

states = jax.vmap(inference_loop, in_axes=(0, None, 0, None))(
    sampling_keys, kernel, initial_states, num_samples)
_ = states.position.sigsq[0,0].block_until_ready()

In [None]:
# results
sigsq = sigsq_t.forward(states.position.sigsq)
print(f"E[sigsq|y] = {jnp.mean(sigsq, axis=(0,1)).round(3)}")
print(f"E[beta|y]  = {jnp.mean(states.position.beta, axis=(0,1)).round(3)}")
# diagnostics
print(f"beta rhat  = {potential_scale_reduction(states.position.beta).round(3)}")
print(f"sigsq rhat = {potential_scale_reduction(states.position.sigsq).round(3)}")

In [None]:
# replicate rhat for sigsq, easier because univariate

# notation from Vats et al

n = states.position.sigsq.shape[1]
m = states.position.sigsq.shape[0]
print(f"n = {n}, m = {m}")
# mean of individual chain sample variances
chain_sample_variances = jnp.var(states.position.sigsq, ddof=1, axis=1)
print(f"s_m^2 = {chain_sample_variances}")
W = jnp.mean(chain_sample_variances)
print(f"W = {W}")

In [None]:
chain_means = jnp.mean(states.position.sigsq, axis=1)
print(f"chain_means = {chain_means}")
B = n*jnp.var(chain_means, ddof=1)
print(f"B = {B}")