In [None]:
import glob
import importlib
from posteriordb import PosteriorDatabase
import os
from time import time
import pangolin as pg
import jax
import numpy as np
from jax import numpy as jnp
import numpyro
import warnings

niter = 1000

def aggregate_arrays(original_dict):
    new_dict = {}
    scalars = {}

    for key, value in original_dict.items():
        if '[' in key:
            prefix = key.split('[')[0]
            if prefix not in new_dict:
                new_dict[prefix] = []
            new_dict[prefix].append(value)
        else:
            scalars[key] = np.array(value)

    for key in new_dict:
        new_dict[key] = np.array(new_dict[key]).T

    new_dict.update(scalars)
    return new_dict

def flatten_samps(samps,nsamps):
    flat = jax.flatten_util.ravel_pytree(samps)[0]
    return np.reshape(flat,[nsamps,len(flat)//nsamps])

def wass1(samps,true_samps,nsampsA,nsampsB):
    assert samps.keys() == true_samps.keys()
    for key in samps.keys():
        assert samps[key].shape[1:] == true_samps[key].shape[1:]

    A = flatten_samps(samps,nsampsA)
    B = flatten_samps(true_samps,nsampsB)
    assert A.ndim==2
    assert B.ndim==2

    A1 = np.sort(A,axis=0)
    B1 = np.sort(B,axis=0)

    # subsample AFTER sorting
    if A1.shape[0] > B1.shape[0]:
        extra = A1.shape[0]//B1.shape[0]
        A1 = A1[::extra,:]
    elif B1.shape[0] > A1.shape[0]:
        extra = B1.shape[0]//A1.shape[0]
        B1 = B1[::extra,:]

    return np.mean(np.sum(np.abs(A1-B1),axis=1))

# edit to where wherever you put posteriordb
#pdb_path = os.path.join(os.getcwd(), "posteriordb-master/posterior_database") 
pdb_path = os.path.join(os.getcwd(), "posteriordb-old/posterior_database") 
my_pdb = PosteriorDatabase(pdb_path)

skip = ['dogs.py']

for file in glob.glob('models/*.py'):
    print(file)
    if any(file.endswith(s) for s in skip):
        continue
    
    module = importlib.import_module(file.replace('/','.')[:-3])
    posterior = my_pdb.posterior(module.posterior_name)
    values = posterior.data.values()
    
    ### get reference samples
    true_samps = posterior.reference_draws()
    # collect string array names ("beta[1]", "beta[2]", etc.) into single arrays
    true_samps = [aggregate_arrays(t) for t in true_samps]
    
    print(f"Tesing posterior {module.posterior_name}")
    #print(f"{module=}")
    #print(f"{values.keys()=}")
    
    vardict, given, vals = module.getmodel(values)

    ### OPTION A: do inference using pangolin
    samps_pangolin = pg.sample(vardict, given, vals, niter=niter)
    
    errors = [wass1(samps_pangolin, t, niter, 1000) for t in true_samps]
    print(f"pangolin errors by fold {errors}")

    ### OPTION B: alternatively, compile to a "plain" nupyro model ###
    # first, flatten everything
    flat_vars, vars_pytree = jax.tree_util.tree_flatten(vardict)
    flat_given, _ = jax.tree_util.tree_flatten(given)
    flat_vals, _ = jax.tree_util.tree_flatten(vals)
    # now get the numpyro model
    model, var_to_name = pg.inference.numpyro.model.get_model_flat(flat_vars, flat_given, flat_vals)
    # model is now a normal numpyro model, you can do normal numpyro stuff with it
    kernel = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(
        kernel,
        num_warmup=niter,
        num_samples=niter,
        progress_bar=False,
    )
    key = jax.random.PRNGKey(0)
    with warnings.catch_warnings(action="ignore", category=FutureWarning):  # type: ignore
        mcmc.run(key)
    latent_samples = mcmc.get_samples()
    latent_samples_flat = [latent_samples[var_to_name[v]] for v in flat_vars]
    samps_numpyro = jax.tree_util.tree_unflatten(vars_pytree, latent_samples_flat)
    
    errors = [wass1(samps_numpyro, t, niter, 1000) for t in true_samps]
    print(f"numpyro errors by fold {errors}")


In [1]:
import glob
import importlib
from posteriordb import PosteriorDatabase
import os
from time import time
import pangolin as pg
import jax
import numpy as np
from jax import numpy as jnp
import numpyro
import warnings


a = pg.makerv(np.random.randn(10))

with pg.Loop(5) as i:
   x[i] = a[i:(i+5)]

#pg.print_upstream(x)

ValueError: Slices with loop start/stop/step not currently supported

In [3]:
a

OperatorRV(Constant([ 0.70375101  0.53388627 ... -0.19841396
  0.62321776]))

In [7]:
tmp = np.add.outer(np.arange(5), np.arange(3))
print(tmp)

[[0 1 2]
 [1 2 3]
 [2 3 4]
 [3 4 5]
 [4 5 6]]


In [15]:
x = pg.slot()
with pg.Loop(3) as i:
    x[i] = i+3

In [17]:
pg.E(x)

Array([3., 4., 5.], dtype=float32)

In [18]:
pg.print_upstream(x)

shape | statement
----- | ---------
(3,)  | a = [0 1 2]
()    | b = 3
(3,)  | c = VMap(add,(0, '∅'),3)(a,b)
