In [1]:
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import nn, vmap, random, lax
from typing import List, Optional
from jaxtyping import Array
from jax import random as jr
import matplotlib.pyplot as plt
import numpy as np

from pymdp.envs import GridWorldEnv
from pymdp.jax.control import construct_policies
from pymdp.jax.agent import Agent as AIFAgent


### Grid world generative model

In [2]:
num_rows, num_columns = 7, 7
num_states = [num_rows*num_columns] # number of states equals the number of grid locations
num_obs = [num_rows*num_columns]    # number of observations equals the number of grid locations (fully observable)

# number of agents
n_batches = 3

# construct A arrays
A = [jnp.broadcast_to(jnp.eye(num_states[0]), (n_batches,) + (num_obs[0], num_states[0]))] # fully observable (identity observation matrix

# construct B arrays
grid_world = GridWorldEnv(shape=[num_rows, num_columns])
B = [jnp.broadcast_to(jnp.array(grid_world.get_transition_dist()), (n_batches,) + (num_states[0], num_states[0], grid_world.n_control))]  # easy way to get the generative model parameters is to extract them from one of pre-made GridWorldEnv classes
num_controls = [grid_world.n_control] # number of control states equals the number of actions
 
# create mapping from gridworld coordinates to linearly-index states
grid = np.arange(grid_world.n_states).reshape(grid_world.shape)
it = np.nditer(grid, flags=["multi_index"])
coord_to_idx_map = {}
while not it.finished:
    coord_to_idx_map[it.multi_index] = it.iterindex
    it.iternext()

# construct C arrays
desired_position = (6, 6) # lower corner
desired_state_id = coord_to_idx_map[desired_position]
desired_obs_id = jnp.argmax(A[0][:, desired_state_id]) # throw this in there, in case there is some indeterminism between states and observations
C = [jnp.broadcast_to(nn.one_hot(desired_obs_id, num_obs[0]), (n_batches, num_obs[0]))]

# construct D arrays
starting_position = (3, 3) # middle
# starting_position = (0, 0) # upper left corner
starting_state_id = coord_to_idx_map[starting_position]
starting_obs_id = jnp.argmax(A[0][:, starting_state_id]) # throw this in there, in case there is some indeterminism between states and observations
D = [jnp.broadcast_to(nn.one_hot(starting_state_id, num_states[0]), (n_batches, num_states[0]))]

2024-06-12 23:14:09.315056: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


### Planning parameters

In [3]:
planning_horizon = 1
max_depth = 12
policy_matrix = construct_policies(num_states, num_controls, policy_len=planning_horizon)

### Initialize an `Agent()`

In [4]:
# create agent
agent = AIFAgent(
    A,
    B,
    C,
    D,
    E=None,
    pA=None,
    pB=None,
    policies=policy_matrix,
    policy_len=planning_horizon,
    use_utility=True,
    use_states_info_gain=False
)

### MCTS based policy search

In [9]:
import mctx
from tmp_mcts import make_aif_recurrent_fn

def si_policy(rng_key, agent, beliefs):
    root = mctx.RootFnOutput(
        prior_logits=jnp.log(agent.E),
        value=jnp.zeros((agent.batch_size)),
        embedding=beliefs,
    )

    recurrent_fn = make_aif_recurrent_fn()

    policy_output = mctx.gumbel_muzero_policy(
        agent,
        rng_key,
        root,
        recurrent_fn,
        num_simulations=4096,
        max_depth=max_depth
    )

    return policy_output.action_weights

### Run active inference

In [10]:
T = 8 # needed if you start further away from the goal (e.g. in upper left corner)

obs_idx = [jnp.broadcast_to(starting_obs_id, (n_batches, 1))] # list of len (num_modalities), each list element of shape (n_batches, 1)

state = jnp.broadcast_to(starting_state_id, (n_batches,))

prior = agent.D
key = jr.PRNGKey(101)
batch_to_track = 1
actions=None
for t in range(T):

    print('Grid position for agent {} at time {}: {}'.format(batch_to_track+1, t, np.unravel_index(state[batch_to_track], grid_world.shape)))

    beliefs = agent.infer_states(obs_idx, prior)
    embedings = jtu.tree_map(lambda x: x.squeeze(1), beliefs)
    key, _key = jr.split(key)
    q_pi = si_policy(_key, agent, embedings)
    print(q_pi)

    keys = jr.split(key, n_batches + 1)
    batch_keys, key = keys[:-1], keys[-1]
    actions = agent.sample_action(q_pi, rng_key=batch_keys)
    (prior, _) = agent.update_empirical_prior(actions, beliefs)

    # get next state and observation from the grid world (need to vmap everything over batches)
    state = vmap(lambda b, s, a: jnp.argmax(b[:, s, a]), in_axes=(0,0,0))(B[0], state, actions)
    next_obs = vmap(lambda a, s: jnp.argmax(a[:, s]), in_axes=(0,0))(A[0], state)
    obs_idx = [next_obs]
    obs_idx = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), obs_idx) # add a trivial time dimension to the observation to enable indexing during agent.infer_states

Grid position for agent 2 at time 0: (3, 3)
[[0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 1.9621882e-31 1.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00]]
Grid position for agent 2 at time 1: (4, 3)
[[0.2 0.2 0.2 0.2 0.2]
 [0.  0.  0.  0.  1. ]
 [0.  0.  1.  0.  0. ]]
Grid position for agent 2 at time 2: (4, 3)
[[0.        0.        1.        0.        0.       ]
 [0.        1.        0.        0.        0.       ]
 [0.        0.4717898 0.5282102 0.        0.       ]]
Grid position for agent 2 at time 3: (4, 4)
[[2.0000000e-01 2.0000000e-01 2.0000000e-01 2.0000000e-01 2.0000000e-01]
 [0.0000000e+00 3.6696893e-01 6.3303107e-01 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 1.0000000e+00 4.7584586e-10 0.0000000e+00 3.0332685e-25]]
Grid position for agent 2 at time 4: (5, 4)
[[2.0000000e-01 2.0000000e-01 2.0000000e-01 2.0000000e-01 2.0000000e-01]
 [0.0000000e+00 4.7349367e-01 5.26506