In [1]:
import jax.numpy as jnp
from pymdp.jax.agent import Agent

### Set up generative model and a sequence of observations. The A tensors, B tensors and observations are specified in such a way that  only later observations ($o_{t > 1}$) help disambiguate hidden states at earlier time points. This will demonstrate the importance of "smoothing" or retrospective inference

In [19]:
num_states = [3, 2]
num_obs = [3]

A_tensor = jnp.stack([jnp.array([[0.5, 0.5, 0.], 
                                [0.0,  0.0,  1.], 
                                [0.5, 0.5, 0.]]
                            ), jnp.array([[1./3, 1./3, 1./3], 
                                            [1./3, 1./3, 1./3], 
                                            [1./3, 1./3, 1./3]]
                            )], axis=-1)

A = [ jnp.broadcast_to(A_tensor, (2, 3, 3, 2)) ]

# create two B matrices, one for each action
B_1 = jnp.broadcast_to(jnp.array([[0.0, 0.75, 0.0],
                                [0.0, 0.25, 1.0],
                                [1.0, 0.0, 0.0]]
            ), (2, 3, 3))

B_2 = jnp.broadcast_to(jnp.array([[0.0, 0.25, 0.0],
                                [0.0, 0.75, 0.0],
                                [1.0, 0.0, 1.0]]
            ), (2, 3, 3))

B_uncontrollable = jnp.expand_dims(
    jnp.broadcast_to(
        jnp.array([[1.0, 0.0], [0.0, 1.0]]), (2, 2, 2)
    ), 
    -1
)

B = [jnp.stack([B_1, B_2], axis=-1), B_uncontrollable]

# create a policy-dependent sequence of B matrices

policy_1 = jnp.array([ [0, 0],
                        [1, 0],
                        [1, 0] ]
                    )

policy_2 = jnp.array([ [1, 0],
                        [1, 0],
                        [1, 0] ]
                    )

policy_3 = jnp.array([ [1, 0],
                        [0, 0],
                        [1, 0] ]
                    )

all_policies = [policy_1, policy_2, policy_3]
n_policies = len(all_policies)
all_policies = list(jnp.stack(all_policies).transpose(2, 0, 1)) # `n_factors` lists, each with matrix of shape `(n_policies, n_time_steps)`

# for the single modality, a sequence over time of observations (one hot vectors)
obs = [jnp.broadcast_to(jnp.array([[1., 0., 0.], # observation 0 is ambiguous with respect to hidden state_1 and hidden_state 2
                                    [0., 1., 0.],  # observation 1 yields certain inference over hidden_state_1 = 2
                                    [0., 0., 1.], # observation 2 is ambiguous with respect to hidden state_1 and hidden_state 2
                                    [1., 0., 0.]])[:, None], (4, 2, 3) )] # observation 0 is ambiguous with respect to hidden state_1 and hidden_state 2

C = [jnp.ones((2,3))] # flat preferences
D = [jnp.ones((2, 3)) / 3., jnp.ones((2, 2)) / 2.] # flat prior
E = jnp.ones((2,n_policies))/n_policies


### Construct the `Agent`

In [23]:
pA = None
pB = None

agents = Agent(
        A=A,
        B=B,
        C=C,
        D=D,
        E=E,
        pA=None,
        pB=None,
        policy_len=3,
        control_fac_idx=None,
        policies=None,
        gamma=16.0,
        alpha=16.0,
        use_utility=True,
        action_selection="deterministic",
        sampling_mode="full",
        inference_algo="ovf",
        num_iter=16,
        learn_A=False,
        learn_B=False)


### Using `obs` and `policies`, pass in the arguments `outcomes`, `past_actions`, `empirical_prior` and `qs_hist` to `agent.infer_states(...)`

In [25]:
beliefs = agents.infer_states(outcomes, past_actions, empirical_prior, qs_hist, mask=None)


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (13 of them) had size 2, e.g. axis 0 of args[0].A[0] of type float32[2,3,3,2];
  * one axis had size 4: axis 0 of args[1][0] of type float32[4,2,3]