In [21]:
# Imports
# Python
import xarray as xr

# Maths
import numpy as np
import pandas as pd
import scipy.stats as stats

# Bayes
import pymc as pm
import arviz as az

from numpy.random import default_rng
rng = default_rng()

In [14]:
num_subjects = 20
beta = 1
sigma = 0.1
obs_sigma_sigma = 0.1 
xs = stats.norm.rvs(loc=0, scale=1, size=num_subjects)
ns = stats.randint.rvs(low=4, high=8, size=num_subjects)
x_sigma = obs_sigma_sigma*stats.truncnorm.rvs(a=0, b=np.infty, size=num_subjects)
data = pd.DataFrame()
for this_s, (this_xs, this_n, this_x_sigma) in enumerate(zip(xs,ns,x_sigma)):
    x = stats.norm.rvs(loc=this_xs, scale=this_x_sigma, size=this_n)
    data = pd.concat([data, pd.DataFrame({"s": [this_s for xi in x], "x": x})])

data.head()

Unnamed: 0,s,x
0,0,1.36717
1,0,1.054977
2,0,1.275224
3,0,1.294788
4,0,1.598036


In [15]:
def simple_model(x,s,use_mu0):
  s_idx, s_vals = pd.factorize(s, sort=True)
  coords = {"subject": s_vals, "points":np.arange(len(x))}
  with pm.Model(coords=coords) as model:
    # Data
    s_idx = pm.Data("s_idx", s_idx, dims="points", mutable=False)
    # Priors
    if use_mu0:
      mu0 = pm.Normal('mu0', mu=0, sigma=1)
    else:
      mu0 = pm.Data('mu0', 0.0, mutable=False)
    xs_sigma = pm.HalfNormal("xs_sigma", sigma=1.0)
    x_sigma = pm.HalfNormal("x_sigma", sigma=0.5, dims="subject")
    # Latent variables
    xs = pm.Normal('xs', mu=mu0, sigma=xs_sigma, dims="subject")
    # Likelihood 
    obs_x = pm.Normal('obs_x', mu=xs[s_idx], sigma=x_sigma[s_idx], observed=x, dims="points")
  return model

In [16]:
idata = dict()
mdl = dict()
for use_m0 in [False, True]:
    mdl[use_m0] = simple_model(data.x, data.s, use_m0)
    with mdl[use_m0]:
        idata[use_m0] = pm.sample(2000, tune=3000, cores=3, target_accept=0.95, return_inferencedata=True)


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [xs_sigma, x_sigma, xs]


Sampling 3 chains for 3_000 tune and 2_000 draw iterations (9_000 + 6_000 draws total) took 449 seconds.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [mu0, xs_sigma, x_sigma, xs]


Sampling 3 chains for 3_000 tune and 2_000 draw iterations (9_000 + 6_000 draws total) took 429 seconds.


In [17]:
az.compare(idata)



Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
True,0,121.7848,31.421257,0.0,1.0,9.97123,0.0,True,log
False,1,121.176977,31.798863,0.607823,0.0,9.855146,0.496719,True,log


In [70]:
def calc_ll(idata, mu0):
    x = idata.observed_data["obs_x"]
    s_idx = idata.constant_data["s_idx"]
    xs = idata.posterior["xs"]
    x_sigma = idata.posterior["x_sigma"]
    xs_sigma = idata.posterior["xs_sigma"]

    mu0 = mu0.broadcast_like(xs)
    xs_sigma = xs_sigma.broadcast_like(xs)
    ll_xs = stats.norm(mu0, xs_sigma).logpdf(xs)

    ll_x = np.zeros_like(ll_xs)

    x_subj = x.groupby(s_idx)
    coords = idata.posterior.coords
    for c in coords["chain"]:
        print(f'Chain: {c.item()}')
        for d in coords["draw"]:
            for s in range(len(x_subj)):
                ll_x[c,d,s] = stats.norm(xs[c, d, s], x_sigma[c, d, s]).logpdf(x_subj[s]).sum()
    
    ll = ll_xs + ll_x
    return xr.DataArray(ll, coords=xs.coords, dims=xs.dims)

In [71]:
mu0 = {
    True: idata[True].posterior.mu0,
    False: xr.full_like(idata[True].posterior.mu0, idata[False].constant_data.mu0.item())
}

ll_list = dict()
for use_mu0 in [False, True]:
    ll = calc_ll(idata[use_mu0], mu0[use_mu0])
    idata[use_mu0].log_likelihood["xs"] = ll

Chain: 0
Chain: 1
Chain: 2
Chain: 0
Chain: 1
Chain: 2


In [80]:
az.compare(idata, var_name="xs")



Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
True,0,52.398086,74.876416,0.0,0.664726,17.457878,0.0,True,log
False,1,49.132634,77.773201,3.265452,0.335274,16.878461,5.487503,True,log
