# Active Inference Demo: T-Maze Environment
This demo notebook provides a full walk-through of how to build a POMDP agent's generative model and perform active inference routine (inversion of the generative model) using the `Agent()` class of `pymdp`. We build a generative model from 'ground up', directly encoding our own A, B, and C matrices.

### Imports

First, import `pymdp` and the modules we'll need.

In [1]:
import os
import sys
import pathlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

path = pathlib.Path(os.getcwd())
module_path = str(path.parent) + '/'
sys.path.append(module_path)

from pymdp.agent import Agent
from pymdp.core import utils
from pymdp.core.maths import softmax
from pymdp.distributions import Categorical, Dirichlet
import copy

## The world (as represented by the agent's generative model)

### Hidden states

We assume the agent's "represents" (this should make you think: generative _model_ , not _process_ ) its environment using two latent variables that are statistically independent of one another - we can thus represent them using two _hidden state factors._

We refer to these two hidden state factors are `GAME_STATE` and `PLAYING_VS_SAMPLING`. 

The first factor is a binary variable representing whether some 'reward structure' that characterises the world. It has two possible values or levels: one level that will lead to rewards with high probability (`GAME_STATE = 0`, a state/level we will call `HIGH_REW`), and another level that will lead to "punishments" (e.g. losing money) with high probability (`GAME_STATE = 1`, a state/level we will call `LOW_REW`). You can think of this hidden state factor as describing the 'pay-off' structure of e.g. a two-armed bandit or slot-machine with two different settings - one where you're more likely to win (`HIGH_REW`), and one where you're more likely to lose (`LOW_REW`). Crucially, the agent doesn't _know_ what the `GAME_STATE` actually is. They will have to infer it by actively furnishing themselves with observations

The second factor is a ternary (3-valued) variable representing the decision-state or 'sampling state' of the agent itself. The first state/level of this hidden state factor is just the starting or initial state of the agent (`PLAYING_VS_SAMPLING = 0`, a state that we can call `START`); the second state/level is the state the agent occupies when "playing" the multi-armed bandit or slot machine (`PLAYING_VS_SAMPLING = 1`, a state that we can call `PLAYING`); and the third state/level of this factor is a "sampling state" (`PLAYING_VS_SAMPLING = 2`, a state that we can call `SAMPLING`). This is a decision-state that the agent occupies when it is "sampling" data in order to _find out_ the level of the first hidden state factor - the `GAME_STATE`. 



In [2]:
factor_names = ["GAME_STATE", "PLAYING_VS_SAMPLING"]
num_factors = len(factor_names) # this is the total number of hidden state factors

HIGH_REW, LOW_REW = 0, 1 # let's assign the indices names so that when we build the A matrices, things will be more 'semantically' obvious
START, PLAYING, SAMPLING = 0, 1, 2 # let's assign the indices names so that when we build the A matrices, things will be more 'semantically' obvious

num_states = [len([HIGH_REW,LOW_REW]), len([START, PLAYING, SAMPLING])] # this is a list of the dimensionalities of each hidden state factor 

### Observations 

The observation modalities themselves are divided into 3 modalities. You can think of these as 3 independent sources of information that the agent has access to. You could think of thus in direct perceptual terms - e.g. 3 different sensory organs like eyes, ears, & nose, that give you qualitatively-different kinds of information. Or you can think of it more abstractly - like getting your news from 3 different media sources (online news articles, Twitter feed, and Instagram).

#### 1. Observations of the game state
The first observation modality is the `GAME_STATE_OBS` modality, and corresponds to observations that give the agent information about the `GAME_STATE`. There are three possible outcomes within this modality: `HIGH_REW_EVIDENCE` (`GAME_STATE_OBS = 0`), `LOW_REW_EVIDENCE` (`GAME_STATE_OBS = 1`), and `NEUTRAL` (`GAME_STATE_OBS = 2`). So the first outcome can be described as lending evidence to the idea that the `GAME_STATE` is `HIGH_REW`; the second outcome can be described as lending evidence to the idea that the `GAME_STATE` is `LOW_REW`; and the third outcome within this modality doesn't tell the agent one way or another whether the `GAME_STATE` is `HIGH_REW` or `LOW_REW`. There is a bit of a circularity here, in that that we're "pre-empting" what the A matrix (likelihood mapping) will look like, by giving these observations these labels. Of course, an observation per se doesn't do _anything_ - it's only through it's probabilistic relationship to hidden states (encoded in the `A` matrix, as we'll see below) that give an observation its semantic 'essence.' By already labelling `GAME_STATE_OBS=0` as `HIGH_REW_EVIDENCE`, that's a hint about how we're going to structure the `A` matrix corresponding to this `GAME_STATE_OBS` modality.

#### 2. Reward observations
The second observation modality is the `GAME_OUTCOME` modality, and corresponds to arbitrary observations that are functions of the `GAME_STATE`. We call the first outcome level of this modality `REWARD` (`GAME_OUTCOME = 0`), which once again 'pre-empts' the way we're going to set up the `A` matrix (the probabilistic mapping between hidden states and observations) and the C matrix (the agent's "utility function" over outcomes). We call the second outcome level of this modality `PUN` (`GAME_OUTCOME = 1`), and the third outcome level `NEUTRAL` (`GAME_OUTCOME = 2`). By design, we will set up the `A` matrix such that the `REWARD` outcome is (expected to be) more likely when the `GAME_STATE` is `HIGH_REW` (`0`) and when the agent is in the `PLAYING` state, and that the `PUN` outcome is (expected to be) more likely when the `GAME_STATE` is `LOW_REW` (`1`) and the agent is in the `PLAYING` state. The `NEUTRAL` outcome is not expected to occur when the agent is playing the game, but will be expected to occur when the agent is in the `SAMPLING` state. This `NEUTRAL` outcome within the `GAME_OUTCOME` modality is thus a meaningless or 'null' observation that the agent gets when it's not actually playing the game (because an observation has to be sampled nonetheless from _all_ modalities).

#### 3. "Proprioceptive" or self-state observations
The third observation modality is the `ACTION_SELF_OBS` modality, and corresponds to the agent observing what level of the `PLAYING_VS_SAMPLING` state it's in. These observations are direct, 'unambiguous' mappings from the true `PLAYING_VS_SAMPLING` state, and simply allow the agent to unambiguously "know" whether it's playing the game or sampling (or at the initial starting state). The levels of this outcome are simply thus `START_O`, `PLAY_O`, and `SAMPLE_O`, with the `_O` suffix appended to this semantic labels to simply differentiate them from their respective hidden states, that they are direct observations of. 

In [3]:
modality_names = ["GAME_STATE_OBS", "GAME_OUTCOME", "ACTION_SELF_OBS"]
num_modalities = len(modality_names)

HIGH_REW_EVIDENCE, EQUAL_EVIDENCE, LOW_REW_EVIDENCE = 0, 1, 2
REWARD, PUN, NEUTRAL = 0, 1, 2
START_O, PLAY_O, SAMPLE_O = 0, 1, 2

num_obs = [len([HIGH_REW_EVIDENCE, EQUAL_EVIDENCE, LOW_REW_EVIDENCE]), len([REWARD, PUN, NEUTRAL]), len([START_O, PLAY_O, SAMPLE_O])]

### Control state factors

In [None]:
control_names = ["NULL", "PLAYING_VS_SAMPLING"]
num_control_factors = len(control_names) # this is the total number of controllable hidden state factors

NULL_ACTION = 0
START_ACTION, PLAY_ACTION, SAMPLE_ACTION = 0, 1, 2

num_control = [len([NULL_ACTION]), len([START_ACTION, PLAY_ACTION, SAMPLE_ACTION])] # this is a list of the dimensionalities of each hidden state factor 

Setting up observation likelihood matrix - first main component of generative model

In [10]:
A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)])

First modality : describe it

In [None]:
A[0][:,:,  START] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]
A[0][:, :, PLAYING] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]

# the agent expects to see the HIGH_REW_EVIDENCE observation with 80% probability, if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
A[0][HIGH_REW_EVIDENCE, HIGH_REW, SAMPLING] = 0.8 
# the agent expects to see the LOW_REW_EVIDENCE observation with 20% probability, if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
A[0][LOW_REW_EVIDENCE, HIGH_REW, SAMPLING] = 0.2

# the agent expects to see the LOW_REW_EVIDENCE observation with 80% probability, if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
A[0][LOW_REW_EVIDENCE, LOW_REW, SAMPLING] = 0.8
# the agent expects to see the HIGH_REW_EVIDENCE observation with 20% probability, if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
A[0][HIGH_REW_EVIDENCE, LOW_REW, SAMPLING] = 0.2

# quick way to do it
# A[0][:, :, 0] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]
# A[0][:, :, 1] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0]
# A[0][:, :, 2] = np.array([[0.8, 0.2], [0.0, 0.0], [0.2, 0.8]])

Second modality: describe it

In [12]:
A[1][2, :, 0] = np.ones(num_states[0])
A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
A[1][2, :, 2] = np.ones(num_states[0])

Third modality: describe it

In [13]:
A[2][0,:,0] = 1.0
A[2][1,:,1] = 1.0
A[2][2,:,2] = 1.0

### (Controllable-) Transition Dynamics

Importantly, some hidden state factors are _controllable_ by the agent, meaning that the probability of being in state $i$ at $t+1$ isn't merely a function of the state at $t$, but also of actions (or from the agent's perspective, _control states_ ). So now each transition likelihood encodes conditional probability distributions over states at $t+1$, where the conditioning variables are both the states at $t-1$ _and_ the actions at $t-1$. This extra conditioning on actions is encoded via an optional third dimension to each factor-specific `B` matrix.

For example, in our case the first hidden state factor (`Location`) is under the control of the agent, which means the corresponding transition likelihoods `B[0]` are index-able by both previous state and action.

In [14]:
control_fac_idx = [1]
B = utils.obj_array(num_factors)
for f, ns in enumerate(num_states):
    B[f] = np.eye(ns)
    if f in control_fac_idx:
        # maybe do this bit more transparently to help with understand - i.e. loop over actions and fill out
        # each action using B[f][action,:,action] = 1.0
        B[f] = B[f].reshape(ns, ns, 1)
        B[f] = np.tile(B[f], (1, 1, ns))
        B[f] = B[f].transpose(1, 2, 0)
    else:
        B[f] = B[f].reshape(ns, ns, 1)


C matrix / utility

In [15]:
C = utils.obj_array_zeros([num_ob for num_ob in num_obs])
C[1][0] = 1.0  # put a 'reward' over first observation
C[1][1] = -5.0  # put a 'punishment' over first observation
# this implies that C[1][2] is 'neutral'

Set up our agent

In [16]:
agent = Agent(A=A, B=B, C=C, control_fac_idx=control_fac_idx)

Set up our simulation

In [17]:
# initial state
T = 5
o = [2, 2, 0]
s = [0, 0]

Generative process - important note how the generative process doesn't have to be described by A and B matrices - can just be the arbitrary 'rules of the game' that you 'write in' as a modeller. But here we just use the same transition/likelihood matrices to make the sampling process straightforward

In [18]:
# transition/observation matrices characterising the generative process
A_gp = copy.deepcopy(A)
B_gp = copy.deepcopy(B)

Run simulation

In [20]:
for t in range(T):

    for g in range(num_modalities):
        print(f"{t}: Observation {obs_names[g]}: {o[g]}")

    qx = agent.infer_states(o)

    for f in range(num_factors):
        print(f"{t}: Beliefs about {state_names[f]}: {qx[f].values.round(3)}")

    agent.infer_policies()
    action = agent.sample_action()

    for f, s_i in enumerate(s):
        s[f] = utils.sample(B_gp[f][:, s_i, action[f]])

    for g, _ in enumerate(o):
        o[g] = utils.sample(A_gp[g][:, s[0], s[1]])
    
    print(np.argmax(s))
    print(f"{t}: Action: {action} / State: {s}")


0: Observation GAME_STATE_OBS: 0
0: Observation REWARD: 2
0: Observation ACTION_SELF_OBS: 2
0: Beliefs about GAME_STATE: [[1.]
 [0.]]
0: Beliefs about PLAYING_VS_SAMPLING: [[0.]
 [0.]
 [1.]]
0
0: Action: [0 0] / State: [0, 0]
1: Observation GAME_STATE_OBS: 1
1: Observation REWARD: 2
1: Observation ACTION_SELF_OBS: 0
1: Beliefs about GAME_STATE: [[1.]
 [0.]]
1: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
1: Action: [0 0] / State: [0, 0]
2: Observation GAME_STATE_OBS: 0
2: Observation REWARD: 2
2: Observation ACTION_SELF_OBS: 0
2: Beliefs about GAME_STATE: [[1.]
 [0.]]
2: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
2: Action: [0 0] / State: [0, 0]
3: Observation GAME_STATE_OBS: 2
3: Observation REWARD: 2
3: Observation ACTION_SELF_OBS: 0
3: Beliefs about GAME_STATE: [[1.]
 [0.]]
3: Beliefs about PLAYING_VS_SAMPLING: [[1.]
 [0.]
 [0.]]
0
3: Action: [0 0] / State: [0, 0]
4: Observation GAME_STATE_OBS: 2
4: Observation REWARD: 2
4: Observation ACTION_SELF_OBS: 0
4: B