# Distributions API Tutorial

In this tutorial, we introduce a feature which makes it easier to build and work with complex active inference models by allowing you to use meaningful names to set up a generative model instead of using numerical indices.

The API allows you to:
- Give semantic names to tensor dimensions and values (e.g., "left", "right" instead of 0, 1)
- Avoid indexing errors by working with named elements
- Build complex models more intuitively using descriptive labels and structures

## Tutorial Structure

1. Basic Example: A simple grid navigation task built from a structured description using labels, in which the agent has to travel to a goal location.
2. A More Advanced Example: A simple foraging task built from a structured description using labels, in which the agent has to search for apples and eat them while they spawn at a set rate.

---

## Example 1: Grid World Navigation

Let's start with the simple example: an agent moving in a 1D grid world. Our agent can be in one of four positions: "left", "centre_left", "centre_right", or "right", and can take actions "move_right" or "move_left". We give an example of setting up one agent as well as three batched agents.

In [1]:
import jax.tree_util as jtu
from jax import numpy as jnp
from jax import random as jr
from pymdp.agent import Agent
from pymdp.distribution import compile_model
from pymdp.envs.env import Env
from pymdp.envs import rollout

In [2]:
positions = ["left", "centre_left", "centre_right", "right"]
actions = ["move_left", "move_right"]

model_description = {
    "observations": {
        "position_obs": {
            "elements": positions, 
            "depends_on": ["position"] # we specify that the observation depends on the "position" state factor
        },
    },
    "controls": {
        "movement": {"elements": actions} # we specify the available actions
    },
    "states": {
        "position": {
            "elements": positions, 
            "depends_on": ["position"],  # our current position depends on previous position...
            "controlled_by": ["movement"]  # ...and the movement action taken
        },
    },
}

# compile the model structure from the description
model = compile_model(model_description)

We have built a generative model structure using the model description, however the model is currently empty. So now, we fill it in by indexing using the labels we provided. 

In [3]:
# fill in the likelihood (A) tensor
# the observations have an identical mapping to the states (i.e., the agent will perfectly observe its position)
model.A["position_obs"]["left", "left"] = 1.0
model.A["position_obs"]["centre_left", "centre_left"] = 1.0
model.A["position_obs"]["centre_right", "centre_right"] = 1.0
model.A["position_obs"]["right", "right"] = 1.0
# model.A["position_obs"].data = jnp.eye(len(positions)) # you could also use the .data attribute to set the identity mapping directly

# fill in the transition model (B) tensor
# note that it's specified as ["to", "from", "action"]
# moving right
model.B["position"]["centre_left", "left", "move_right"] = 1.0     
model.B["position"]["centre_right", "centre_left", "move_right"] = 1.0  
model.B["position"]["right", "centre_right", "move_right"] = 1.0    
model.B["position"]["right", "right", "move_right"] = 1.0           

# moving left  
model.B["position"]["left", "left", "move_left"] = 1.0              
model.B["position"]["left", "centre_left", "move_left"] = 1.0       
model.B["position"]["centre_left", "centre_right", "move_left"] = 1.0  
model.B["position"]["centre_right", "right", "move_left"] = 1.0    

# set preferences (C) tensor - prefer to be at "centre_right"
model.C["position_obs"]["centre_left"] = 1.0

Now, let's create the agent (via the Agent object) and have it infer which state it is in via an observation and select an action according to it's goal.

In [4]:
batch_size = 1 
gamma = 10 # deterministic behaviour; make it smaller for stochastic behaviour

# create agent
agent = Agent(**model, batch_size = batch_size, gamma = gamma)

# set up initial observation to be "left"
observation = jnp.broadcast_to(jnp.zeros(1), (batch_size, 1)) # broadcast to batch size and add a time dimension

# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too

# print initial beliefs, goal, and action chosen
qs = agent.infer_states([observation], qs_init)
print(f"Current belief about position: {positions[jnp.argmax(qs[0][0])]}")
qs = [jnp.squeeze(q, 1) for q in qs]

print(f"Goal position: {positions[jnp.argmax(agent.C[0])]}")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(f"Action chosen: {actions[action[0][0]]}")

Current belief about position: left
Goal position: centre_left
Action chosen: move_right


We can also run multiple trials in parallel, each with a different initial observation (i.e., different initial position), via the batching feature.

In [5]:
batch_size = 3 # running 3 trials in parallel
gamma = 10 # deterministic behaviour; make it smaller for stochastic behaviour

# create agent
agent = Agent(**model, batch_size = batch_size, gamma = gamma)

# set up initial observations to be "left", "centre_right", and "right"
observation = jnp.array([[0], [2], [3]])

# get the prior
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 1), agent.D) # qs needs a time dimension too

# print goal and initial beliefs
qs = agent.infer_states([observation], qs_init)
for a in range(batch_size): 
    print(f"Agent {a}'s current belief about position: {positions[jnp.argmax(qs[0][a])]}")
qs = [jnp.squeeze(q, 1) for q in qs]

print(f"\nGoal position for all agents: {positions[jnp.argmax(agent.C[0])]}\n")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
for a in range(batch_size): 
    print(f"Agent {a}'s action chosen: {actions[action[a][0]]}")

Agent 0's current belief about position: left
Agent 1's current belief about position: centre_right
Agent 2's current belief about position: right

Goal position for all agents: centre_left

Agent 0's action chosen: move_right
Agent 1's action chosen: move_left
Agent 2's action chosen: move_left


---

## Example 2: Apple Foraging Task

Now let's look at a slightly more complex example: an apple foraging task. Here, we have a 1x3 grid, with a "left", "centre", and "right" cell. These are orchard cells where an apple can grow at a set rate (1/3). The agent's objective is to find apples and eat them as they get a reward to eat apples. The agent can stay, move_left, move_right, or eat.


In [6]:
num_locations = 3 # these correspond to "left", "centre", and "right" and you can just specify ["left", "centre", "right"] but we use numbers to show how to use the 'size' key instead of 'elements'
item_list = ["orchard", "apple"]

model_description = {
    "observations": {
        "location_obs": {"size": num_locations, # if you want to use numbers instead of strings, you can use the 'size' key instead of 'elements'
                         "depends_on": ["location_state"],
        },
        "item_obs": {"elements": item_list, # "elements" key for strings
                     "depends_on": ["location_state", "left_state", "centre_state", "right_state"],
        },
        "reward_obs": {"elements": ["no_reward", "reward"],
                       "depends_on": ["reward_state"],
        },
    },
    "controls": {
        "move": {"elements": ["stay", "move_left", "move_right"],
        },
        "eat": {"elements": ["noop", "eat"], # noop = no-operation
        # note that if you cannot control a state, you still need to add 
        # an action for it (e.g., with elements: ["null"]) for the model to be initialised 
        # with the correct dimensions
        },
    },
    "states": {
        "location_state": {"size": num_locations,
                           "depends_on": ["location_state"],
                           "controlled_by": ["move"],
        },
        "reward_state": {"elements": ["no_reward", "reward"],
                            # if you have more than one dependency,the first dependency is its own state factor (at the previous timestep), 
                            # then add the other dependencies in the order they are specified (you can skip over some state factors)
                            "depends_on": ["reward_state", "location_state", 
                                           "left_state", "centre_state", "right_state"],
                            "controlled_by": ["eat"],
        },
        "left_state": {"elements": item_list,
                        "depends_on": ["left_state", "location_state"], 
                        "controlled_by": ["eat"],
        },
        "centre_state": {"elements": item_list,
                        "depends_on": ["centre_state", "location_state"],
                        "controlled_by": ["eat"],
        },
        "right_state": {"elements": item_list,
                        "depends_on": ["right_state", "location_state"],
                        "controlled_by": ["eat"],
        },
    },
}

model = compile_model(model_description)


We have built a generative model structure using the model description, however the model is currently empty. So now, we fill it in by indexing using the labels we provided. 

In [7]:
'''
SPECIFY THE A TENSOR
'''
# identity mapping for the observations regarding location and reward
model.A["location_obs"].data = jnp.eye(len(model.A["location_obs"].data))
model.A["reward_obs"].data = jnp.eye(len(model.A["reward_obs"].data))

# in any of the locations, the agent may observe apple or orchard
model.A["item_obs"]["apple", 0, "apple", :, :] = 1.0
model.A["item_obs"]["apple", 1, :, "apple", :] = 1.0
model.A["item_obs"]["apple", 2, :, :, "apple"] = 1.0
model.A["item_obs"]["orchard", 0, "orchard", :, :] = 1.0
model.A["item_obs"]["orchard", 1, :, "orchard", :] = 1.0
model.A["item_obs"]["orchard", 2, :, :, "orchard"] = 1.0
model.A["item_obs"].data = model.A["item_obs"].data + 1e-3 # add a small amount of noise to the observations

'''
SPECIFY THE B TENSOR
'''

# for moving between locations
# (to, from, action)
valid_transitions = [
    # from 0 (left)
    (0, 0, "stay"), # from left to left, stay
    (1, 0, "move_right"), # from left to centre, move right
    (2, 0, "move_left"), # from left to right, move left

    # from 1 (centre)
    (0, 1, "move_left"), # from centre to left, move left
    (1, 1, "stay"), # from centre to centre, stay
    (2, 1, "move_right"), # from centre to right, move right

    # from 2 (right)
    (0, 2, "move_right"), # from right to left, move right
    (1, 2, "move_left"), # from right to centre, move left
    (2, 2, "stay"), # from right to right, stay
]

for to_state, from_state, action in valid_transitions:
    model.B["location_state"][to_state, from_state, action] = 1.0

# again, remember the reward states will be set as ["to", "from", ...dependencies..., "action"]
# if the agent sees an apple and does not eat the apple (i.e., noop), it does not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "noop"] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "noop"] = 1.0

# if the agent sees an orchard, it does not get a reward regardless of its actions
model.B["reward_state"]["no_reward", "no_reward", 0, "orchard", :, :, :] = 1.0 
model.B["reward_state"]["no_reward", "no_reward", 1, :, "orchard", :, :] = 1.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "orchard", :] = 1.0

# from a reward state, there will always be no reward in the next timestep regardless of the action
model.B["reward_state"]["no_reward", "reward", 0, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 1, :, :, :, :] = 1.0
model.B["reward_state"]["no_reward", "reward", 2, :, :, :, :] = 1.0

# if the agent sees an orchard and eats, it will not get a reward
model.B["reward_state"]["reward", "no_reward", 0, "orchard", :, :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 1, :, "orchard", :, "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "orchard", "eat"] = 0.0

# if the agent sees an apple and eats the apple, it gets a reward and never not get a reward
model.B["reward_state"]["no_reward", "no_reward", 0, "apple", :, :, "eat"] = 0.0 
model.B["reward_state"]["no_reward", "no_reward", 1, :, "apple", :, "eat"] = 0.0
model.B["reward_state"]["no_reward", "no_reward", 2, :, :, "apple", "eat"] = 0.0
model.B["reward_state"]["reward", "no_reward", 0, "apple", :, :, "eat"] = 1.0 
model.B["reward_state"]["reward", "no_reward", 1, :, "apple", :, "eat"] = 1.0
model.B["reward_state"]["reward", "no_reward", 2, :, :, "apple", "eat"] = 1.0

apple_spawn_locations = ["left_state", "centre_state", "right_state"]
apple_spawn_rate = 1/3
for i, state in enumerate(apple_spawn_locations):
    model.B[state]["orchard", "orchard", :, :] = 1.0 - apple_spawn_rate # no spawn
    model.B[state]["apple", "orchard", :, :] = apple_spawn_rate # spawn
    for agent_location in range(num_locations):
        if i == agent_location:
            # if the agent does not eat the apple (noop), the apple will stay in the cell
            model.B[state]["apple", "apple", agent_location, "noop"] = 1.0
            # if the agent eats the apple, it will become an orchard cell
            model.B[state]["orchard", "apple", agent_location, "eat"] = 1.0
    model.B[state].data = model.B[state].data + 1e-3 # add a small amount of noise to the observations

'''
SPECIFY THE C TENSOR. 
'''
model.C["reward_obs"]["reward"] = 1.0

'''
NORMALISE THE TENSORS
'''

model.A["location_obs"].normalize()
model.A["item_obs"].normalize()
model.A["reward_obs"].normalize()

model.B["location_state"].normalize()
model.B["reward_state"].normalize()
model.B["left_state"].normalize()
model.B["centre_state"].normalize()
model.B["right_state"].normalize()

In [8]:
batch_size = 1
gamma = 1.0

agent = Agent(**model, batch_size=batch_size, learn_A=False, learn_B=False, gamma=gamma, sampling_mode="full")