### Imports

In [1]:
import jax.tree_util as jtu
from jax import numpy as jnp, random as jr
from jax import nn, vmap
from equinox import tree_at

import numpy as np

from pymdp.envs import GridWorld, rollout
from pymdp import control
from pymdp.agent import Agent


### Grid world generative model

In [2]:
# size of the grid world
grid_shape = (7, 7)

# number of agents
batch_size = 5

# start in the middle of the grid
env = GridWorld(shape=grid_shape, initial_position=(3,3), include_stay=False, batch_size=batch_size)

desired_state = (6,6)  # bottom right corner
# get linear index of desired state
desired_state_id = env.coords_to_index(shape=grid_shape, coord=desired_state)

# create helpful num_obs and num_states lists (lists of observation dimensions per modality, and state dimensions per factor, respectively)
num_obs = [a.shape[1] for a in env.params['A']]
num_states = [b.shape[-2] for b in env.params['B']]
num_controls = [b.shape[-1] for b in env.params['B']]

### Planning and inductive inference parameters

In [3]:
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 7
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)

# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))

### Initialize an `Agent()`

In [4]:
# create agent, using generative process parameters from the environment to initialize the generative model
A = env.params['A']
B = env.params['B']
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = env.params['D']
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, 
            inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
            H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)

### Run active inference

In [5]:
T = 7 
last, info, env = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))

In [6]:
agent_id_to_track = 1
for t in range(T):
    state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env'].state[0][agent_id_to_track, t])
    print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")

Grid position for agent 2 at time 0: (3, 3)
Grid position for agent 2 at time 1: (3, 4)
Grid position for agent 2 at time 2: (3, 5)
Grid position for agent 2 at time 3: (3, 6)
Grid position for agent 2 at time 4: (4, 6)
Grid position for agent 2 at time 5: (5, 6)
Grid position for agent 2 at time 6: (6, 6)


### Now the agent starts further from the goal and thus need more timesteps to reach it

In [7]:
# size of the grid world
grid_shape = (7, 7)

# number of agents
batch_size = 5

upper_left_initial_state = env.coords_to_index(shape=grid_shape, coord=(0,0))
initial_state_prior = [jnp.repeat(nn.one_hot(upper_left_initial_state, num_states[0])[None, :], batch_size, axis=0)]
env = tree_at(lambda x: x.params['D'], env, initial_state_prior)

# start in the upper left corner this time
_, env = env.reset(jr.split(jr.PRNGKey(0), batch_size))

### Increase inductive planning depth in order to compute the needed inductive planning matrix 

In [8]:
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 14
policy_matrix = control.construct_policies(num_states, num_controls, policy_len=planning_horizon)

# inductive planning goal states
H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (batch_size, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))

In [9]:
# create agent, using generative process parameters from the environment to initialize the generative model
A = env.params['A']
B = env.params['B']
C = [jnp.repeat(nn.one_hot(desired_state_id, num_states[0])[None, :], batch_size, axis=0)] # preferred outcomes (shape of each is (n_batches, num_obs[f]))
D = [jnp.repeat(nn.one_hot(upper_left_initial_state, num_states[0])[None, :], batch_size, axis=0)] # need to do this since the D of the environment won't match the env state if you reset to a different state
agent = Agent(A, B, C, D, batch_size=batch_size, policies=policy_matrix, policy_len=planning_horizon, 
            inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,
            H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)

### Run active inference

In [10]:
T = 14
last, info, env = rollout(agent, env, num_timesteps=T, rng_key = jr.PRNGKey(0))

In [11]:
agent_id_to_track = 1
for t in range(T):
    state_time_t = env.index_to_coords(shape=grid_shape, idx=info['env'].state[0][agent_id_to_track, t])
    print(f"Grid position for agent {agent_id_to_track+1} at time {t}: {state_time_t}")

Grid position for agent 2 at time 0: (0, 0)
Grid position for agent 2 at time 1: (0, 1)
Grid position for agent 2 at time 2: (0, 2)
Grid position for agent 2 at time 3: (0, 3)
Grid position for agent 2 at time 4: (0, 4)
Grid position for agent 2 at time 5: (0, 5)
Grid position for agent 2 at time 6: (0, 6)
Grid position for agent 2 at time 7: (1, 6)
Grid position for agent 2 at time 8: (2, 6)
Grid position for agent 2 at time 9: (3, 6)
Grid position for agent 2 at time 10: (4, 6)
Grid position for agent 2 at time 11: (5, 6)
Grid position for agent 2 at time 12: (6, 6)
Grid position for agent 2 at time 13: (6, 6)
