# Foraging Task

### Overview
This notebook demonstrates multi-agent planning with and without Theory of Mind (ToM) in a cooperative foraging task using sophisticated active inference. These experiments are published in: ["Theory of Mind Using Active Inference: A Framework for Multi-Agent Cooperation"](https://arxiv.org/abs/2508.00401). Note that in the paper, locations are numbered 1–9, but in the codebase, they start at 0 instead of 1 (i.e., 0–8).

### Task Description
The foraging task takes place in a 3×3 grid environment where agents must collect apples that spawn at a rate of 0.25. Agents must coordinate their actions to efficiently gather resources whilst avoiding redundant efforts.

### Experiments
All agents use sophisticated inference planning with the `another_works_for_tom` branch of pymdp.

1. **Two non-ToM Agents** - Two agents planning without Theory of Mind capabilities. Uses our custom rollout function with planning tree recycling that supports multiple agents in a shared environment.

2. **One ToM Agent with One non-ToM Agent** - One agent with Theory of Mind capabilities cooperating with one agent without Theory of Mind capabilities. Uses our custom rollout function which accommodates both non-ToM and ToM planning strategies and multiple agents in one shared environment. 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu

import matplotlib.pyplot as plt

from tom.envs import ForagingEnv
from tom.models import ForagingModel, ForagingAgent
from tom.planning.visualize import visualize_plan_tree, visualize_env

from tom.planning.si import si_policy_search as si_policy_search_nontom
from tom.planning.si_tom import si_policy_search_tom, ToMify

from tom.planning.rollout import rollout as rollout_nontom
from tom.planning.rollout_tom import rollout as rollout_tom

# Two non-ToM Agents

Initialize the environment

In [None]:
grid_size = 3
apple_spawn_rate = 0.25
num_agents = 2
initial_positions = jnp.array([7, 5])

env = ForagingEnv(apple_spawn_rate, num_agents, grid_size, initial_positions) 

Initialize the agents

In [None]:
gamma = 1.0

model = ForagingModel(apple_spawn_rate=apple_spawn_rate)
agents = ForagingAgent(model, gamma=gamma, batch_size=num_agents)

Set up the non-ToM planning algorithm

In [None]:
horizon=3
max_nodes = 5000
max_branching = 6
policy_prune_threshold = 1 / 8
observation_prune_threshold = 1 / 8
entropy_stop_threshold = 0.5
efe_stop_threshold = 10
kl_threshold=1e-2
prune_penalty = 512

# set up the policy search function
tree_search_nontom = si_policy_search_nontom(
        horizon=horizon,
        max_nodes=max_nodes,
        max_branching=max_branching,
        policy_prune_threshold=policy_prune_threshold,
        observation_prune_threshold=observation_prune_threshold,
        entropy_stop_threshold=entropy_stop_threshold,
        efe_stop_threshold=efe_stop_threshold,
        kl_threshold=kl_threshold,
        prune_penalty=prune_penalty,
        gamma=gamma
    )

Running the agent using the sophisticated inference planning without theory of mind capabilities (`si_policy_search_nontom`) and the custom rollout function which allows multiple agents to interact within the same environment.

In [None]:
T = 2
key = jr.PRNGKey(1)
last, info_nontom, env = rollout_nontom(agents, env, T, key, policy_search=tree_search_nontom)

In [None]:
visualize_env(info_nontom, model=model, save_as_gif=False, gif_filename="foraging_nontom.gif")

Plan tree of the red non-ToM agent

In [None]:
visualize_plan_tree(info_nontom, time_idx=0, agent_idx=0, model=model, depth=4, fig_size=(5,5))

Plan tree of the purple non-ToM agent

In [None]:
visualize_plan_tree(info_nontom, time_idx=0, agent_idx=1, model=model, depth=4, fig_size=(5,5))

# One ToM Agent with One non-ToM Agent

Initialize the environment

In [None]:
grid_size = 3
apple_spawn_rate = 0.25
num_agents = 2
initial_positions = jnp.array([7, 5]) 

env = ForagingEnv(apple_spawn_rate, num_agents, grid_size, initial_positions) 

Initialize the agents

In [None]:
model = ForagingModel(apple_spawn_rate)

agent0 = ForagingAgent(model, gamma=gamma, batch_size=1)
focal_agent = ToMify(agent0,
                     self_states=[0, 1],
                     world_states=[2, 3, 4, 5, 6, 7],
                     # observation mappings is of size (focal_batch, num_agents, num_modalities)
                     # for each obs modality of the agent, we set which actual received obs modality should be used
                     # e.g. you'll get 3 observation modalities,
                     # these map to obs modality [0, 1, 2] for agent 0 (focal) and 
                     # [3, nothing, nothing] for agent 1 (other) - i.e., the 3rd observation maps onto the other agent's own location (0th modality)
                     observation_mappings=jnp.array([[[0,1,2],[0,-1,-1]]]))

other_agents = ForagingAgent(model, gamma=gamma, batch_size=1)

Running the agent using the sophisticated inference planning with theory of mind capabilities (`si_policy_search_tom`) and the custom rollout function which allows non-ToM and ToM agents to interact within the same environment.

In [None]:
tree_search_tom = si_policy_search_tom(
            horizon=horizon,
            max_nodes=max_nodes,
            max_branching=max_branching,
            policy_prune_threshold=policy_prune_threshold,
            observation_prune_threshold=observation_prune_threshold,
            entropy_stop_threshold=entropy_stop_threshold,
            efe_stop_threshold=efe_stop_threshold,
            kl_threshold=kl_threshold,
            prune_penalty=prune_penalty,
            gamma=gamma,
            other_agent_policy_search=tree_search_nontom
        )

In [None]:
key = jr.PRNGKey(1)
T=2
last, info_tom, env = rollout_tom(focal_agent,
            other_agents,
            env,
            T,
            key,
            other_agent_policy_search=tree_search_nontom,
            focal_agent_tom_policy_search=tree_search_tom,
)

In [None]:
visualize_env(info_tom, model=model, save_as_gif=False, gif_filename="foraging_tom.gif")

Plan tree of the red focal (ToM) agent

In [None]:
visualize_plan_tree(info_tom, time_idx=2, agent_idx=0, model=model, depth=8, root_idx=0, fig_size=(10,10))

Plan tree of the purple non-ToM agent

In [None]:
visualize_plan_tree(info_tom, time_idx=0, agent_idx=1, plotting_other_intom=True, model=model, depth=8, fig_size=(10,10))