# Active Inference Demo: Constructing a basic generative model from the "ground up"
This demo notebook provides a full walk-through of how to build a POMDP agent's generative model and perform active inference routine (inversion of the generative model) using the `Agent()` class of `pymdp`. We build a generative model from 'ground up', directly encoding our own A, B, and C matrices.

### Imports

First, import `pymdp` and the modules we'll need.

In [None]:
import os
import sys
import pathlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import copy

path = pathlib.Path(os.getcwd())
module_path = str(path.parent) + '/'
sys.path.append(module_path)

from pymdp.agent import Agent
from pymdp import utils
from pymdp.maths import softmax

## The world (as represented by the agent's generative model)

### Hidden states

We assume the agent's "represents" (this should make you think: generative _model_ , not _process_ ) its environment using two latent variables that are statistically independent of one another - we can thus represent them using two _hidden state factors._

We refer to these two hidden state factors are `GAME_STATE` and `PLAYING_VS_SAMPLING`. 

#### 1. `GAME_STATE`
The first factor is a binary variable representing whether some 'reward structure' that characterises the world. It has two possible values or levels: one level that will lead to rewards with high probability (`GAME_STATE = 0`, a state/level we will call `HIGH_REW`), and another level that will lead to "punishments" (e.g. losing money) with high probability (`GAME_STATE = 1`, a state/level we will call `LOW_REW`). You can think of this hidden state factor as describing the 'pay-off' structure of e.g. a two-armed bandit or slot-machine with two different settings - one where you're more likely to win (`HIGH_REW`), and one where you're more likely to lose (`LOW_REW`). Crucially, the agent doesn't _know_ what the `GAME_STATE` actually is. They will have to infer it by actively furnishing themselves with observations

#### 1. `PLAYING_VS_SAMPLING`

The second factor is a ternary (3-valued) variable representing the decision-state or 'sampling state' of the agent itself. The first state/level of this hidden state factor is just the starting or initial state of the agent (`PLAYING_VS_SAMPLING = 0`, a state that we can call `START`); the second state/level is the state the agent occupies when "playing" the multi-armed bandit or slot machine (`PLAYING_VS_SAMPLING = 1`, a state that we can call `PLAYING`); and the third state/level of this factor is a "sampling state" (`PLAYING_VS_SAMPLING = 2`, a state that we can call `SAMPLING`). This is a decision-state that the agent occupies when it is "sampling" data in order to _find out_ the level of the first hidden state factor - the `GAME_STATE`. 



In [None]:
factor_names = ["GAME_STATE", "PLAYING_VS_SAMPLING"]
num_factors = len(factor_names) # this is the total number of hidden state factors

HIGH_REW, LOW_REW = 0, 1 # let's assign the indices names so that when we build the A matrices, things will be more 'semantically' obvious
START, PLAYING, SAMPLING = 0, 1, 2 # let's assign the indices names so that when we build the A matrices, things will be more 'semantically' obvious

num_states = [len([HIGH_REW,LOW_REW]), len([START, PLAYING, SAMPLING])] # this is a list of the dimensionalities of each hidden state factor 

### Observations 

The observation modalities themselves are divided into 3 modalities. You can think of these as 3 independent sources of information that the agent has access to. You could think of thus in direct perceptual terms - e.g. 3 different sensory organs like eyes, ears, & nose, that give you qualitatively-different kinds of information. Or you can think of it more abstractly - like getting your news from 3 different media sources (online news articles, Twitter feed, and Instagram).

#### 1. Observations of the game state - `GAME_STATE_OBS`
The first observation modality is the `GAME_STATE_OBS` modality, and corresponds to observations that give the agent information about the `GAME_STATE`. There are three possible outcomes within this modality: `HIGH_REW_EVIDENCE` (`GAME_STATE_OBS = 0`), `LOW_REW_EVIDENCE` (`GAME_STATE_OBS = 1`), and `NO_EVIDENCE` (`GAME_STATE_OBS = 2`). So the first outcome can be described as lending evidence to the idea that the `GAME_STATE` is `HIGH_REW`; the second outcome can be described as lending evidence to the idea that the `GAME_STATE` is `LOW_REW`; and the third outcome within this modality doesn't tell the agent one way or another whether the `GAME_STATE` is `HIGH_REW` or `LOW_REW`. 

#### 2. Reward observations - `GAME_OUTCOME`
The second observation modality is the `GAME_OUTCOME` modality, and corresponds to arbitrary observations that are functions of the `GAME_STATE`. We call the first outcome level of this modality `REWARD` (`GAME_OUTCOME = 0`), which gives you a hint about how we'll set up the C matrix (the agent's "utility function" over outcomes). We call the second outcome level of this modality `PUN` (`GAME_OUTCOME = 1`), and the third outcome level `NEUTRAL` (`GAME_OUTCOME = 2`). By design, we will set up the `A` matrix such that the `REWARD` outcome is (expected to be) more likely when the `GAME_STATE` is `HIGH_REW` (`0`) and when the agent is in the `PLAYING` state, and that the `PUN` outcome is (expected to be) more likely when the `GAME_STATE` is `LOW_REW` (`1`) and the agent is in the `PLAYING` state. The `NEUTRAL` outcome is not expected to occur when the agent is playing the game, but will be expected to occur when the agent is in the `SAMPLING` state. This `NEUTRAL` outcome within the `GAME_OUTCOME` modality is thus a meaningless or 'null' observation that the agent gets when it's not actually playing the game (because an observation has to be sampled nonetheless from _all_ modalities).

#### 3. "Proprioceptive" or self-state observations - `ACTION_SELF_OBS`
The third observation modality is the `ACTION_SELF_OBS` modality, and corresponds to the agent observing what level of the `PLAYING_VS_SAMPLING` state it is currently in. These observations are direct, 'unambiguous' mappings to the true `PLAYING_VS_SAMPLING` state, and simply allow the agent to "know" whether it's playing the game, sampling information to learn about the game state, or where it's sitting at the `START` state. The levels of this outcome are simply thus `START_O`, `PLAY_O`, and `SAMPLE_O`, where the `_O` suffix simply distinguishes them from their corresponding hidden states, for which they provide direct evidence. 

#### Note about the arbitrariness of 'labelling' observations, before defining the `A` and `C` matrices.

There is a bit of a circularity here, in that that we're "pre-empting" what the A matrix (likelihood mapping) should look like, by giving these observations labels that imply particular roles or meanings. An observation per se doesn't mean _anything_, it's just some discrete index that distinguishes is from another observation. It's only through its probabilistic relationship to hidden states (encoded in the `A` matrix, as we'll see below) that we endow an observation with meaning. For example: by already labelling `GAME_STATE_OBS=0` as `HIGH_REW_EVIDENCE`, that's a hint about how we're going to structure the `A` matrix for the `GAME_STATE_OBS` modality.

In [None]:
modality_names = ["GAME_STATE_OBS", "GAME_OUTCOME", "ACTION_SELF_OBS"]
num_modalities = len(modality_names)

HIGH_REW_EVIDENCE, LOW_REW_EVIDENCE, NO_EVIDENCE  = 0, 1, 2
REWARD, PUN, NEUTRAL = 0, 1, 2
START_O, PLAY_O, SAMPLE_O = 0, 1, 2

num_obs = [len([HIGH_REW_EVIDENCE, LOW_REW_EVIDENCE, NO_EVIDENCE]), len([REWARD, PUN, NEUTRAL]), len([START_O, PLAY_O, SAMPLE_O])]

Setting up observation likelihood matrix - first main component of generative model

In [None]:
A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)])

Set up the **first** modality's likelihood mapping, correspond to how `"GAME_STATE_OBS"` i.e. `modality_names[0]` are related to hidden states.

In [None]:
A[0][NO_EVIDENCE,:,  START] = 1.0 # they always get the 'no evidence' observation in the START STATE
A[0][NO_EVIDENCE, :, PLAYING] = 1.0 # they always get the 'no evidence' observation in the PLAYING STATE

# the agent expects to see the HIGH_REW_EVIDENCE observation with 80% probability, if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
A[0][HIGH_REW_EVIDENCE, HIGH_REW, SAMPLING] = 0.8
# the agent expects to see the LOW_REW_EVIDENCE observation with 20% probability, if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
A[0][LOW_REW_EVIDENCE, HIGH_REW, SAMPLING] = 0.2

# the agent expects to see the LOW_REW_EVIDENCE observation with 80% probability, if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
A[0][LOW_REW_EVIDENCE, LOW_REW, SAMPLING] = 0.8
# the agent expects to see the HIGH_REW_EVIDENCE observation with 20% probability, if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
A[0][HIGH_REW_EVIDENCE, LOW_REW, SAMPLING] = 0.2

# quick way to do this
# A[0][:, :, 0] = 1.0
# A[0][:, :, 1] = 1.0
# A[0][:, :, 2] = np.array([[0.8, 0.2], [0.2, 0.8], [0.0, 0.0]])

Set up the **second** modality's likelihood mapping, correspond to how `"GAME_OUTCOME"` i.e. `modality_names[1]` are related to hidden states.

In [None]:
A[1][NEUTRAL, :, START] = 1.0 # regardless of the game state, if you're at the START, you see the 'neutral' outcome

A[1][NEUTRAL, :, SAMPLING] = 1.0 # regardless of the game state, if you're in the SAMPLING state, you see the 'neutral' outcome

# this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME" observation , in the case that "GAME_STATE" is `HIGH_REW`
HIGH_REW_MAPPING = softmax(np.array([1.0, 0])) 

# this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME" observation , in the case that "GAME_STATE" is `LOW_REW`
LOW_REW_MAPPING = softmax(np.array([0.0, 1.0]))

# fill out the A matrix using the reward probabilities
A[1][REWARD, HIGH_REW, PLAYING] = HIGH_REW_MAPPING[0]
A[1][PUN, HIGH_REW, PLAYING] = HIGH_REW_MAPPING[1]

A[1][REWARD, LOW_REW, PLAYING] = LOW_REW_MAPPING[0]
A[1][PUN, LOW_REW, PLAYING] = LOW_REW_MAPPING[1]


# quick way to do this
# A[1][2, :, 0] = np.ones(num_states[0])
# A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # relationship of game state to reward observations (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
# A[1][2, :, 2] = np.ones(num_states[0])

Set up the **third** modality's likelihood mapping, correspond to how `"ACTION_SELF_OBS"` i.e. `modality_names[2]` are related to hidden states.

In [None]:
A[2][START_O,:,START] = 1.0
A[2][PLAY_O,:,PLAYING] = 1.0
A[2][SAMPLE_O,:,SAMPLING] = 1.0

# quick way to do this
# modality_idx, factor_idx = 2, 2
# for sampling_state_i in num_states[factor_idx]:
#     A[modality_idx][sampling_state_i,:,sampling_state_i] = 1.0

### Control state factors

The 'control state' factors are the agent's representation of the control states (or actions) that _it believes_ can influence the dynamics of the hidden states - i.e. hidden state factors that are under the influence of control states are are 'controllable'. In practice, we often encode _every_ hidden state factor as being under the influence of control states, but the 'uncontrollable' hidden state factors are driven by a trivially-1-dimensional control state or action-affordance. This trivial action simply 'maintains the default environmental dynamics as they are' i.e. does nothing. This will become more clear when we set up the transition model (the `B` matrices) below.

#### 1. `NULL`
This reflects the agent's lack of ability to influence the `GAME_STATE` using policies or actions. The dimensionality of this control factor is 1, and there is only one action along this control factor: `NULL_ACTION` or "don't do anything to do the environment". This just means that the transition dynamics along the `GAME_STATE` hidden state factor have their own, uncontrollable dynamics that are not conditioned on this `NULL` control state - or rather, _always_ conditioned on an unchanging, 1-dimensional `NULL_ACTION`.

#### 1. `PLAYING_VS_SAMPLING_CONTROL`
This is a control factor that reflects the agent's ability to move itself between the `START`, `PLAYING` and `SAMPLING` states of the `PLAYING_VS_SAMPLING` hidden state factor. The levels/values of this control factor are `START_ACTION`, `PLAY_ACTION`, and `SAMPLE_ACTION`. When we describe the `B` matrices below, we will set up the transition dynamics of the `PLAYING_VS_SAMPLING` hidden state factor, such that they are totally determined by the value of the `PLAYING_VS_SAMPLING_CONTROL` factor. 

In [None]:
control_names = ["NULL", "PLAYING_VS_SAMPLING_CONTROL"]
num_control_factors = len(control_names) # this is the total number of controllable hidden state factors

NULL_ACTION = 0
START_ACTION, PLAY_ACTION, SAMPLE_ACTION = 0, 1, 2

num_control = [len([NULL_ACTION]), len([START_ACTION, PLAY_ACTION, SAMPLE_ACTION])] # this is a list of the dimensionalities of each hidden state factor 

### (Controllable-) Transition Dynamics

Importantly, some hidden state factors are _controllable_ by the agent, meaning that the probability of being in state $i$ at $t+1$ isn't merely a function of the state at $t$, but also of actions (or from the generative model's perspective, _control states_ ). So each transition likelihood or `B` matrix encodes conditional probability distributions over states at $t+1$, where the conditioning variables are both the states at $t-1$ _and_ the actions at $t-1$. This extra conditioning on control states is encoded by a third, lagging dimension on each factor-specific `B` matrix. So they are technically `B` "tensors" or an array of action-conditioned `B` matrices.

For example, in our case the 2nd hidden state factor (`PLAYING_VS_SAMPLING`) is under the control of the agent, which means the corresponding transition likelihoods `B[1]` are index-able by both previous state and action.

In [None]:
control_fac_idx = [1] # this is the (non-trivial) controllable factor, where there will be a >1-dimensional control state along this factor
B = utils.obj_array(num_factors)

p_stoch = 0.0

# we cannot influence factor zero, set up the 'default' stationary dynamics - 
# one state just maps to itself at the next timestep with very high probability, by default. So this means the reward state can
# change from one to another with some low probability (p_stoch)

B[0] = np.zeros((num_states[0], num_states[0], num_control[0])) 
B[0][HIGH_REW, HIGH_REW, NULL_ACTION] = 1.0 - p_stoch
B[0][LOW_REW, HIGH_REW, NULL_ACTION] = p_stoch

B[0][LOW_REW, LOW_REW, NULL_ACTION] = 1.0 - p_stoch
B[0][HIGH_REW, LOW_REW, NULL_ACTION] = p_stoch

# setup our controllable factor.
B[1] = np.zeros((num_states[1], num_states[1], num_control[1]))
B[1][START, :, START_ACTION] = 1.0 
B[1][PLAYING, :, PLAY_ACTION] = 1.0
B[1][SAMPLING, :, SAMPLE_ACTION] = 1.0

### Prior preferences

Now we parameterise the C vector, or the prior beliefs about observations. This will be used in the expression of the prior over actions, which is technically a softmax function of the negative expected free energy of each action. It is the equivalent of the exponentiated reward function in reinforcement learning treatments.


In [None]:
C = utils.obj_array_zeros([num_ob for num_ob in num_obs])
C[1][REWARD] = 1.0  # make the observation we've a priori named `REWARD` actually desirable, by building a high prior expectation of encountering it 
C[1][PUN] = -1.0    # make the observation we've a prior named `PUN` actually aversive,by building a low prior expectation of encountering it

# the above code implies the following for the `neutral' observation:
C[1][NEUTRAL] = 0.0 # we don't need to write this - but it's basically just saying that observing `NEUTRAL` is in between reward and punishment

### Initialise an instance of the `Agent()` class:

All you have to do is call `Agent(generative_model_params...)` where `generative_model_params` are your A, B, C's... and whatever parameters of the generative model you want to specify

In [None]:
agent = Agent(A=A, B=B, C=C, control_fac_idx=control_fac_idx)

### Generative process:
Important note how the generative process doesn't have to be described by A and B matrices - can just be the arbitrary 'rules of the game' that you 'write in' as a modeller. But here we just use the same transition/likelihood matrices to make the sampling process straightforward.

In [None]:
# transition/observation matrices characterising the generative process
A_gp = copy.deepcopy(A)
B_gp = copy.deepcopy(B)

Initialise the simulation

In [None]:
# initial state
T = 20 # number of timesteps in the simulation
observation = [NO_EVIDENCE, NEUTRAL, START_O] # initial observation
state = [HIGH_REW, START] # initial (true) state

Create some string names for the state, observation, and action indices to help with print statements

In [None]:
states_idx_names = [ ["HIGH_REW", "LOW_REW"], \
                 ["START", "PLAYING", "SAMPLING"]]

obs_idx_names = [ ["HIGH_REW_EV", "LOW_REW_EV", "NO_EV"], \
                 ["REWARD", "PUN", "NEUTRAL"], \
                 ["START", "PLAYING", "SAMPLING"] ]

action_idx_names = [ ["NULL"], ["MOVE TO START", "PLAY", "SAMPLE"] ]

Run simulation

In [None]:
for t in range(T):
    
    print(f"\nTime {t}:")
    
    print(f"State: {[(factor_names[f], states_idx_names[f][state[f]]) for f in range(num_factors)]}")
    print(f"Observations: {[(modality_names[g], obs_idx_names[g][observation[g]]) for g in range(num_modalities)]}")
    
    # update agent
    belief_state = agent.infer_states(observation)
    agent.infer_policies()
    action = agent.sample_action()
    
    # update environment
    for f, s in enumerate(state):
        state[f] = utils.sample(B_gp[f][:, s, int(action[f])])

    for g, _ in enumerate(observation):
        observation[g] = utils.sample(A_gp[g][:, state[0], state[1]])

    print(f"Beliefs: {[(factor_names[f], belief_state[f].round(3).T) for f in range(num_factors)]}")
    print(f"Action: {[(control_names[a], action_idx_names[a][int(action[a])]) for a in range(num_factors)]}")