### Imports

In [1]:
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import nn, vmap, random, lax
from typing import List, Optional
from jaxtyping import Array
from jax import random as jr
import matplotlib.pyplot as plt
import numpy as np

from pymdp.envs import GridWorldEnv
from pymdp.jax import control as j_control

### Grid world generative model

In [2]:
num_rows, num_columns = 7, 7
num_states = [num_rows*num_columns] # number of states equals the number of grid locations
num_obs = [num_rows*num_columns]    # number of observations equals the number of grid locations (fully observable)

# construct A arrays
A = [jnp.eye(num_states[0])]
A_dependencies = [[0]]

# construct B arrays
grid_world = GridWorldEnv(shape=[num_rows, num_columns])
B = [jnp.array(grid_world.get_transition_dist())]  # easy way to get the generative model parameters is to extract them from one of pre-made GridWorldEnv classes
B_dependencies = [[0]]
num_controls = [grid_world.n_control] # number of control states equals the number of actions
 
# create mapping from gridworld coordinates to linearly-index states
grid = np.arange(grid_world.n_states).reshape(grid_world.shape)
it = np.nditer(grid, flags=["multi_index"])
coord_to_idx_map = {}
while not it.finished:
    coord_to_idx_map[it.multi_index] = it.iterindex
    it.iternext()

# construct C arrays
desired_position = (6,6) # lower corner
desired_state_id = coord_to_idx_map[desired_position]
desired_obs_id = jnp.argmax(A[0][:, desired_state_id]) # throw this in there, in case there is some indeterminism between states and observations
C = [nn.one_hot(desired_obs_id, num_obs[0])]

# construct D arrays
starting_position = (3, 3) # middle
# starting_position = (0, 0) # upper left corner
starting_state_id = coord_to_idx_map[starting_position]
starting_obs_id = jnp.argmax(A[0][:, starting_state_id]) # throw this in there, in case there is some indeterminism between states and observations
D = [nn.one_hot(starting_state_id, num_states[0])]

### Planning parameters

In [3]:
planning_horizon, inductive_threshold = 1, 0.1
inductive_depth = 7
policy_matrix = j_control.construct_policies(num_states, num_controls, policy_len=planning_horizon)

# inductive planning goal states
H = [nn.one_hot(desired_state_id, num_states[0])]

# depth and threshold for inductive planning algorithm. I made policy-depth equal to inductive planning depth, out of ignorance -- need to ask Tim or Tommaso about this
I = j_control.generate_I_matrix(H, B, inductive_threshold, inductive_depth)

### Run active inference

In [4]:
# T = 14 # needed if you start further away from the goal (e.g. in upper left corner)
T = 7 # can get away with fewer timesteps if you start closer to the goal (e.g. in the middle)

qs_init = [nn.one_hot(starting_state_id, num_states[0])] # same as D
obs = nn.one_hot(starting_obs_id, num_obs[0])
state = starting_state_id

for t in range(T):

    print('Grid position at time {}: {}'.format(t, np.unravel_index(state, grid_world.shape)))

    # update posterior beliefs over states
    qs = [A[0][jnp.argmax(obs),:]] # trivial inference step

    # evaluate Q(pi) and negative EFE using the inductive planning algorithm
    q_pi, neg_efe = j_control.update_posterior_policies_inductive(policy_matrix, qs, A, B, C, A_dependencies, B_dependencies, I, gamma=16.0, use_utility=True, use_inductive=True, inductive_epsilon=1e-3)

    # select action
    action = jnp.argmax(q_pi)

    # use action to affect environment
    state = jnp.argmax(B[0][:,state,action])
    obs = nn.one_hot(jnp.argmax(A[0][:,state]), num_obs[0])


Grid position at time 0: (3, 3)
Grid position at time 1: (3, 4)
Grid position at time 2: (3, 5)
Grid position at time 3: (3, 6)
Grid position at time 4: (4, 6)
Grid position at time 5: (5, 6)
Grid position at time 6: (6, 6)
