In [1]:
%load_ext jupyter_black
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt
import arviz as az

import jax.random as jr
import jax.numpy as jnp
import jax.tree_util as jtu

from jax import nn, lax, vmap

from pymdp.jax.agent import Agent as AIFAgent

from pybefit.inference import run_nuts, run_svi, default_dict_nuts

from pybefit.inference import NumpyroModel, NumpyroGuide
from pybefit.inference import RegularisedHorseshoe, RegularisedHorseshoePosterior
from pybefit.inference.numpyro.likelihoods import pymdp_likelihood as likelihood

from numpyro.infer.autoguide import AutoNormal, AutoMultivariateNormal
from numpyro.infer import Predictive

from pymdp.utils import random_A_matrix, random_B_matrix
from equinox import Module, field

seed_key = jr.PRNGKey(101)

In [3]:
# define an agent and environment here
batch_size = 10
num_obs = [3, 3]
num_states = [3, 3]
num_controls = [2, 2]
num_blocks = 25
num_timesteps = 5

A_np = random_A_matrix(num_obs=num_obs, num_states=num_states)
B_np = random_B_matrix(num_states=num_states, num_controls=num_controls)
A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))
B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))
C = [jnp.zeros((batch_size, no)) for no in num_obs]
D = [jnp.ones((batch_size, ns)) / ns for ns in num_states]
E = jnp.ones((batch_size, 4)) / 4

pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))
pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))


class TestEnv(Module):
    num_obs: int = field(static=True)

    def __init__(self, num_obs):
        self.num_obs = num_obs

    @vmap
    def step(self, *args, **kwargs):
        # return a list of random observations for each agent or parallel realization (each entry in batch_dim)
        key = kwargs["key"]
        obs = [jr.randint(key, (), 0, no) for no in self.num_obs]
        return obs


agents = AIFAgent(A, B, C, D, E, pA, pB, use_param_info_gain=True, inference_algo="mmp")
task = TestEnv(num_obs)

In [4]:
num_params = 3
num_agents = batch_size
prior = RegularisedHorseshoe(num_params, num_agents, backend="numpyro")


def transform(z):
    # map z variables to model parameters

    na, np = z.shape

    assert np == 3  # test that we have only 3 parameters

    a = nn.sigmoid(z[..., 0])  # element of the likelihood matrix
    lam = jnp.exp(z[..., 1])  # outcome preference
    d = nn.sigmoid(z[..., 2])  # prior state probability

    C = [
        jnp.zeros((na, 3)),
        jnp.expand_dims(lam, -1) * jnp.array([0.0, 1.0, -1.0]),
    ]

    D = [jnp.ones((na, 3)) / 3, jnp.stack([d, 1 - d, jnp.zeros(na)], -1)]

    # replace first component of A with parameterised likelihood matrix

    a1 = jnp.stack([a, 1 - a, jnp.zeros(na)], -1)

    A[0] = jnp.broadcast_to(
        jnp.expand_dims(a1, (-1, -2)), (na, num_obs[0], *num_states)
    )

    # return the aif agent class
    agent = AIFAgent(
        A, B, C, D, E, pA, pB, use_param_info_gain=True, inference_algo="mmp"
    )
    return agent


key, _key = jr.split(seed_key)
z = jr.normal(_key, shape=(num_agents, num_params)) / 10

transform(z);

In [5]:
opts_task = {
    "task": task,
    "num_blocks": num_blocks,
    "num_trials": num_timesteps,
    "num_agents": num_agents,
}
opts_model = {"prior": {}, "transform": {}, "likelihood": opts_task}

model = NumpyroModel(prior, transform, likelihood, opts=opts_model)

pred = Predictive(model, num_samples=1)
key, _key = jr.split(key)
samples = pred(_key)

In [6]:
# perform inference using no-u-turn sampler
# opts sampling dictionary can be used to specify various parameters
# either for the NUTS kernel or MCMC sampler
model = NumpyroModel(prior, transform, likelihood, opts=opts_model)

experimental_data = {
    "multiactions": samples["multiactions"][0],
    "outcomes": jtu.tree_map(lambda x: x[0], samples["outcomes"]),
}

opts_sampling = default_dict_nuts
opts_sampling["num_warmup"] = 1
opts_sampling["num_samples"] = 1
print(opts_sampling)

nuts_samples, mcmc = run_nuts(model, experimental_data, opts=opts_sampling)

{'seed': 0, 'num_samples': 1, 'num_warmup': 1, 'sampler_kwargs': {'kernel': {}, 'mcmc': {}}}


sample: 100%|██████████| 2/2 [00:29<00:00, 14.77s/it, 1 steps of size 2.34e+00. acc. prob=0.00]


In [13]:
from pybefit.inference.methods import default_dict_numpyro_svi
from numpyro.infer.autoguide import AutoNormal

model = NumpyroModel(prior, transform, likelihood, opts=opts_model)

test_guide = AutoNormal(model)

posterior = NumpyroGuide(test_guide)

experimental_data = {
    "multiactions": samples["multiactions"][0],
    "outcomes": jtu.tree_map(lambda x: x[0], samples["outcomes"]),
}

from numpyro import handlers

with handlers.seed(rng_seed=0):
    posterior(data=experimental_data)


# # perform inference using stochastic variational inference
# opts_svi = default_dict_numpyro_svi
# print(opts_svi)

# samples, svi, results = run_svi(model, posterior, experimental_data, opts=opts_svi)

KeyError: 'num_agents'

In [None]:
# we can use arviz for evaluating the quality of the inference or model evidence

idata = az.from_numpyro(mcmc)
print(az.loo(idata))

az.plot_trace(idata, var_names=("mu", "sigma", "z"));

NameError: name 'mcmc' is not defined

In [None]:
plt.figure(figsize=(16, 5))
labels = [r"$f^{-1}(a)$", r"$f^{-1}(\lambda)$", r"$f^{-1}(d)$"]
for i in range(3):
    plt.scatter(z[:, i], samples["z"].mean(0)[:, i], label=labels[i])

plt.plot((z.min(), z.max()), (z.min(), z.max()), "k--")
plt.ylabel("posterior mean")
plt.xlabel("true value")
plt.legend(title="parameter")

: 

In [None]:
from pybefit.inference.methods import default_dict_numpyro_svi

posterior = NumpyroGuide(
    RegularisedHorseshoePosterior(num_params, num_agents, backend="numpyro")
)

# perform inference using stochastic variational inference
opts_svi = default_dict_numpyro_svi
print(opts_svi)

samples, svi, results = run_svi(model, posterior, measurments, opts=opts_svi)

: 

In [None]:
plt.figure(figsize=(16, 5))
for i in range(3):
    plt.scatter(z[:, i], samples["z"].mean(0)[:, i], label=i)

plt.plot((z.min(), z.max()), (z.min(), z.max()), "k--")
plt.ylabel("posterior mean")
plt.xlabel("true value")
plt.legend(title="parameter id")

: 