# Active Inference Demo: T-Maze Environment
This demo notebook provides a full walk-through of how to build a POMDP agent's generative model and perform active inference routine (inversion of the generative model) using the `Agent()` class of `pymdp`. We build a generative model from 'ground up', directly encoding our own A, B, and C matrices.

### Imports

First, import `pymdp` and the modules we'll need.

In [8]:
import os
import sys
import pathlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

path = pathlib.Path(os.getcwd())
module_path = str(path.parent) + '/'
sys.path.append(module_path)

from pymdp.agent import Agent
from pymdp.core import utils
from pymdp.core.maths import softmax
from pymdp.distributions import Categorical, Dirichlet
import copy

### Define model dimensions and intuitive labels for the different model components

Define the observation modalities, the hidden state factors, and the actions. We define them in terms of their dimensionalities and give each one an intuitive label that we can use later on to validate that our agent is acting in agreement with how we designed their generative model.

We assume the agent's "represents" (read: generative model) its environment using two latent variables, that are factorized - in practice this means we can neatly segregate them out in the generative model as well as in the agent's variational posterior (that comes later).

These two hidden state factors are `GAME STATE` and `PLAYING_VS_SAMPLING`. 

The first factor is a binary variable representing whether the state of the world is one that yields rewards with high probability (`GAME_STATE = 0`) or one that yields punishments (inverse rewards) with high probability (`GAME_STATE = 1`). You can think of this as basically the 'pay-off' structure of e.g. a slot machine or some game that the agent is currently playing. Crucially, the agent doesn't _know_ what the `GAME_STATE` is. They will have to infer it.

The second factor is a ternary (3-valued) variable representing ....

The observation modalities themselves are divided into 3 modalities:

The first obs modality...

the second obs modality ...

the third obs modality ...


N.B. Useful define the total number of modalities and factors

In [9]:
obs_names = ["GAME_STATE_OBS", "REWARD", "ACTION_SELF_OBS"]
state_names = ["GAME_STATE", "PLAYING_VS_SAMPLING"]
action_names = ["NULL", "PLAY_SAMPLE_INIT"]

num_obs = [3, 3, 3]
num_states = [2, 3]
num_modalities = len(num_obs)
num_factors = len(num_states)

Setting up observation likelihood matrix - first main component of generative model

In [10]:
A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)])

First modality : describe it

In [11]:
A[0][:, :, 0] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]
A[0][:, :, 1] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]
A[0][:, :, 2] = np.array([[0.8, 0.2], [0.0, 0.0], [0.2, 0.8]])

Second modality: describe it

In [12]:
A[1][2, :, 0] = np.ones(num_states[0])
A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
A[1][2, :, 2] = np.ones(num_states[0])

Third modality: describe it

In [13]:
A[2][0,:,0] = 1.0
A[2][1,:,1] = 1.0
A[2][2,:,2] = 1.0

### (Controllable-) Transition Dynamics

Importantly, some hidden state factors are _controllable_ by the agent, meaning that the probability of being in state $i$ at $t+1$ isn't merely a function of the state at $t$, but also of actions (or from the agent's perspective, _control states_ ). So now each transition likelihood encodes conditional probability distributions over states at $t+1$, where the conditioning variables are both the states at $t-1$ _and_ the actions at $t-1$. This extra conditioning on actions is encoded via an optional third dimension to each factor-specific `B` matrix.

For example, in our case the first hidden state factor (`Location`) is under the control of the agent, which means the corresponding transition likelihoods `B[0]` are index-able by both previous state and action.

In [14]:
control_fac_idx = [1]
B = utils.obj_array(num_factors)
for f, ns in enumerate(num_states):
    B[f] = np.eye(ns)
    if f in control_fac_idx:
        # maybe do this bit more transparently to help with understand - i.e. loop over actions and fill out
        # each action using B[f][action,:,action] = 1.0
        B[f] = B[f].reshape(ns, ns, 1)
        B[f] = np.tile(B[f], (1, 1, ns))
        B[f] = B[f].transpose(1, 2, 0)
    else:
        B[f] = B[f].reshape(ns, ns, 1)


C matrix / utility

In [15]:
C = utils.obj_array_zeros([num_ob for num_ob in num_obs])
C[1][0] = 1.0  # put a 'reward' over first observation
C[1][1] = -5.0  # put a 'punishment' over first observation
# this implies that C[1][2] is 'neutral'

Set up our agent

In [16]:
agent = Agent(A=A, B=B, C=C, control_fac_idx=control_fac_idx)

Set up our simulation

In [17]:
# initial state
T = 5
o = [2, 2, 0]
s = [0, 0]

Generative process - important note how the generative process doesn't have to be described by A and B matrices - can just be the arbitrary 'rules of the game' that you 'write in' as a modeller. But here we just use the same transition/likelihood matrices to make the sampling process straightforward

In [18]:
# transition/observation matrices characterising the generative process
A_gp = copy.deepcopy(A)
B_gp = copy.deepcopy(B)

Run simulation

In [20]:
for t in range(T):

    for g in range(num_modalities):
        print(f"{t}: Observation {obs_names[g]}: {o[g]}")

    qx = agent.infer_states(o)

    for f in range(num_factors):
        print(f"{t}: Beliefs about {state_names[f]}: {qx[f].values.round(3)}")

    agent.infer_policies()
    action = agent.sample_action()

    for f, s_i in enumerate(s):
        s[f] = utils.sample(B_gp[f][:, s_i, action[f]])

    for g, _ in enumerate(o):
        o[g] = utils.sample(A_gp[g][:, s[0], s[1]])
    
    print(np.argmax(s))
    print(f"{t}: Action: {action} / State: {s}")


0: Observation GAME_STATE_OBS: 0
0: Observation REWARD: 2
0: Observation ACTION_SELF_OBS: 2
0: Beliefs about GAME_STATE: [[1.]
 [0.]]
0: Beliefs about PLAYING_VS_SAMPLING: [[0.]
 [0.]
 [1.]]
0
0: Action: [0 0] / State: [0, 0]
1: Observation GAME_STATE_OBS: 1
1: Observation REWARD: 2
1: Observation ACTION_SELF_OBS: 0
1: Beliefs about GAME_STATE: [[1.]
 [0.]]
1: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
1: Action: [0 0] / State: [0, 0]
2: Observation GAME_STATE_OBS: 0
2: Observation REWARD: 2
2: Observation ACTION_SELF_OBS: 0
2: Beliefs about GAME_STATE: [[1.]
 [0.]]
2: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
2: Action: [0 0] / State: [0, 0]
3: Observation GAME_STATE_OBS: 2
3: Observation REWARD: 2
3: Observation ACTION_SELF_OBS: 0
3: Beliefs about GAME_STATE: [[1.]
 [0.]]
3: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
3: Action: [0 0] / State: [0, 0]
4: Observation GAME_STATE_OBS: 2
4: Observation REWARD: 2
4: Observation ACTION_SELF_OBS: 0
4: B