In [1]:
import numpy as np 
import jax.numpy as jnp
import jax.tree_util as jtu

from jax import nn, vmap
from pymdp.jax.algos import update_variational_filtering
from pymdp import utils

In [2]:
num_states_list = [ 
                    [2, 2, 5],
                    [2, 2, 2],
                    [4, 4]
]

num_controls_list = [
                    [2, 1, 3],
                    [2, 1, 2],
                    [1, 3]
]

num_obs_list = [
                [5, 10],
                [4, 3, 2],
                [5, 2, 6, 3]
]

A_dependencies_list = [
                    [[0, 1], [1, 2]],
                    [[0], [1], [2]],
                    [[0,1], [1], [0], [1]]
]

batch_dim, T = 13, 4 # batch dimension (e.g. number of agents, parallel realizations, etc.) and time steps
n_policies = 3

In [3]:
for (num_states, A_dependencies, num_controls, num_obs) in zip(num_states_list, A_dependencies_list, num_controls_list, num_obs_list):
            
    A_reduced_numpy = utils.random_A_matrix(num_obs, num_states, A_factor_list=A_dependencies)
    A_reduced = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_dim,) + x.shape), list(A_reduced_numpy))
          
    A_full_numpy = []
    for m, no in enumerate(num_obs):
        other_factors = list(set(range(len(num_states))) - set(A_dependencies[m])) # list of the factors that modality `m` does not depend on

        # broadcast or tile the reduced A matrix (`A_reduced`) along the dimensions of corresponding to `other_factors`
        expanded_dims = [no] + [1 if f in other_factors else ns for (f, ns) in enumerate(num_states)]
        tile_dims = [1] + [ns if f in other_factors else 1 for (f, ns) in enumerate(num_states)]
        A_full_numpy.append(np.tile(A_reduced_numpy[m].reshape(expanded_dims), tile_dims))
    
    A_full = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_dim,) + x.shape), list(A_full_numpy))

    B_numpy = utils.random_B_matrix(num_states, num_controls)
    B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_dim,) + x.shape), list(B_numpy))

    prior_numpy = utils.random_single_categorical(num_states)
    prior = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_dim,) + x.shape), list(prior_numpy))
          
    # initialization observation sequences in jax
    obs_seq = []
    for n_obs in num_obs:
        obs_ints = np.random.randint(0, high=n_obs, size=(T,1))
        obs_array_mod_i = jnp.broadcast_to(nn.one_hot(obs_ints, num_classes=n_obs), (T, batch_dim, n_obs))
        obs_seq.append(obs_array_mod_i)

    # create random policies
    policies = []
    for n_controls in num_controls:
        policies.append(jnp.array(np.random.randint(0, high=n_controls, size=(n_policies, T-1))))

    def test_sparse(action_sequence):
        B_policy = jtu.tree_map(lambda b, a_idx: b[..., a_idx].transpose(3, 0, 1, 2), B, action_sequence)
        qs, ps, qss = update_variational_filtering(obs_seq, A_reduced, B_policy, prior, A_dependencies)
        return qs, ps, qss

    qs_pi, ps_pi, qss_pi = vmap(test_sparse)(policies)

    for qs, ps, qss in zip(qs_pi, ps_pi, qss_pi):
        print(qs.shape, ps.shape, qss.shape)

#Note: qs is of dimension [num_actions x num_agents x dim_state_f] * num_factors
#Note: qss is of dimension [num_actions x time_steps x num_agents x dim_state_f x dim_state_f]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(3, 13, 2) (3, 13, 2) (3, 4, 13, 2, 2)
(3, 13, 2) (3, 13, 2) (3, 4, 13, 2, 2)
(3, 13, 5) (3, 13, 5) (3, 4, 13, 5, 5)
(3, 13, 2) (3, 13, 2) (3, 4, 13, 2, 2)
(3, 13, 2) (3, 13, 2) (3, 4, 13, 2, 2)
(3, 13, 2) (3, 13, 2) (3, 4, 13, 2, 2)
(3, 13, 4) (3, 13, 4) (3, 4, 13, 4, 4)
(3, 13, 4) (3, 13, 4) (3, 4, 13, 4, 4)


In [4]:
qs_pi[0][0, 0]

Array([0.10571534, 0.03540028, 0.4963476 , 0.36253685], dtype=float32)

In [5]:
qss_pi[0][0, 0, 0]

Array([[0.6461534 , 0.09652454, 0.22822888, 0.02909314],
       [0.33564523, 0.26329154, 0.30335486, 0.09770837],
       [0.4735609 , 0.18010727, 0.23158638, 0.11474543],
       [0.50991637, 0.20105273, 0.18791321, 0.10111766]], dtype=float32)