# Foraging Task

### Overview
This notebook demonstrates single- and multi-agent planning in a cooperative foraging task using sophisticated active inference and Theory of Mind (ToM) capabilities.

### Task Description
The foraging task occurs in a 3Ã—3 grid environment where agents must eat apples that spawn in the orchard at a set rate.

### Key Features Demonstrated
Note that all agents conduct planning via sophisticated inference and we use the `another_works_for_tom` branch of pymdp.

1. **Single non-ToM Agent** - uses the `pymdp` rollout function.

2. **Optimized Single non-ToM Agent** - uses our custom rollout function with planning tree recycling for improved performance.

3. **Multiple non-ToM Agents** - uses our custom rollout function with planning tree recycling that also supports multiple agents in one shared environment.

4. **Single ToM Agent with Multiple non-ToM Agents** - one agent planning with theory of mind capabilities alongside two agents without theory of mind capabilities. Uses our custom rollout function which accommodates both planning strategies and multiple agents in one shared environment. 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu

import matplotlib.pyplot as plt

from tom.envs import ForagingEnv
from tom.models import ForagingModel, ForagingAgent
from tom.planning.visualize import visualize_plan_tree, visualize_beliefs, visualize_env, plot_plan_tree

from tom.planning.si import si_policy_search as si_policy_search_nontom
from tom.planning.si_tom import si_policy_search_tom, ToMify

from pymdp.envs import rollout as rollout_pymdp
from tom.planning.rollout_deprecated import rollout as rollout_optimized
from tom.planning.rollout_tom_deprecated import rollout as rollout_tom

# Single Non-ToM Agent

Initialize the environment

In [None]:
grid_size = 3
num_agents = 1
apple_spawn_rate = 0.1
initial_positions = jnp.array([4]) # can set to None and it will be initialised randomly

env = ForagingEnv(apple_spawn_rate, num_agents, grid_size, initial_positions) 

Initialize the agent

In [None]:
gamma = 1.0

model = ForagingModel(apple_spawn_rate=apple_spawn_rate)
agents = ForagingAgent(model, gamma=gamma, batch_size=num_agents)

Running the agent using the sophisticated inference non-tom planning (`si_policy_search_nontom`) and the `pymdp` rollout function

In [None]:
horizon=3
max_nodes = 5000
max_branching = 6
policy_prune_threshold = 1 / 32
observation_prune_threshold = 1 / 32
entropy_stop_threshold = 0.5
efe_stop_threshold = 5
kl_threshold=1e-2
prune_penalty = 512

# set up the policy search function
tree_search_nontom = si_policy_search_nontom(
        horizon=horizon,
        max_nodes=max_nodes,
        max_branching=max_branching,
        policy_prune_threshold=policy_prune_threshold,
        observation_prune_threshold=observation_prune_threshold,
        entropy_stop_threshold=entropy_stop_threshold,
        efe_stop_threshold=efe_stop_threshold,
        kl_threshold=kl_threshold,
        prune_penalty=prune_penalty,
        gamma=gamma
    )

In [None]:
T = 10
key = jr.PRNGKey(1)
last, info_single_nontom, env = rollout_pymdp(agents, env, T, key, policy_search=tree_search_nontom)

In [None]:
visualize_env(info_single_nontom, model=model, save_as_gif=False, gif_filename="foraging_single_nontom.gif")

In [None]:
visualize_plan_tree(info_single_nontom, time_idx=4, agent_idx=0, model=model, min_prob=0.0, depth=4, fig_size = (5,5))

In [None]:
# visualize_beliefs(info_single_nontom, model=model)

# Optimized Single Non-ToM Agent

The environment, agent, and planning algorithm set up are the same as above. We now just run the agent using the our custom rollout function function which recycles the planning tree for better performance. 

In [None]:
T = 10
key = jr.PRNGKey(1)
last, info_optimized_single_nontom, env = rollout_optimized(agents, env, T, key, policy_search=tree_search_nontom)

In [None]:
visualize_env(info_optimized_single_nontom, model=model, save_as_gif=False, gif_filename="foraging_optimized_single_nontom.gif")

In [None]:
visualize_plan_tree(info_optimized_single_nontom, time_idx=0, agent_idx=0, model=model, fig_size = (5,5))

# Mulitple non-ToM Agents

Initialize the environment

In [None]:
grid_size = 3
apple_spawn_rate = 0.1
num_agents = 3
initial_positions = jnp.array([7, 5, 1]) # can set to None and it will be initialised randomly

env = ForagingEnv(apple_spawn_rate, num_agents, grid_size, initial_positions) 

Initialize the agents

In [None]:
gamma = 1.0

model = ForagingModel(apple_spawn_rate=apple_spawn_rate)
agents = ForagingAgent(model, gamma=gamma, batch_size=num_agents)

The non-tom planning algorithm set up is the same as above. We now just run the agent using the our custom rollout function function which, in addition to recycling the planning tree for better performance, allows multiple agents to be in a shared environment. 

In [None]:
T = 10
key = jr.PRNGKey(1)
last, info_optimized_multi_nontom, env = rollout_optimized(agents, env, T, key, policy_search=tree_search_nontom)

In [None]:
visualize_env(info_optimized_multi_nontom, model=model, save_as_gif=False, gif_filename="foraging_optimized_multi_nontom.gif")

In [None]:
visualize_plan_tree(info_optimized_multi_nontom, time_idx=0, agent_idx=0, model=model, depth=4, fig_size = (5,5))

# Single ToM Agent with Multiple non-ToM Agents

Initialize the environment

In [None]:
grid_size = 3
apple_spawn_rate = 0.1
num_agents = 3
initial_positions = jnp.array([7, 5, 1]) # can set to None and it will be initialised randomly

env = ForagingEnv(apple_spawn_rate, num_agents, grid_size, initial_positions) 

Initialize the agents

In [None]:
agent0 = ForagingAgent(model, gamma=gamma, batch_size=1)
focal_agent = ToMify(agent0,
                     self_states=[0, 1],
                     world_states=[2, 3, 4, 5, 6, 7],
                     observation_mappings=jnp.array([[[0,1,2],[3,-1,-1],[4,-1,-1]]]))

# note that observations from the environment include the locations of the other agents so we have observation mappings that tell us which observation is regarding which agent (-1s are invalid)

other_agents = ForagingAgent(model, gamma=1.0, batch_size=2)

Running the agent using the sophisticated inference planning with theory of mind capabilities (`si_policy_search_tom`) and the custom rollout function which allows non-ToM and ToM agents to interact within the same environment.

In [None]:
tree_search_tom = si_policy_search_tom(
            horizon=horizon,
            max_nodes=10_000,
            max_branching=max_branching,
            policy_prune_threshold=policy_prune_threshold,
            observation_prune_threshold=observation_prune_threshold,
            entropy_stop_threshold=entropy_stop_threshold,
            efe_stop_threshold=efe_stop_threshold,
            kl_threshold=kl_threshold,
            prune_penalty=prune_penalty,
            gamma=gamma,
            other_agent_policy_search=tree_search_nontom
        )

In [None]:
key = jr.PRNGKey(1)
T=10
last, info_tom, env = rollout_tom(focal_agent,
            other_agents,
            env,
            T,
            key,
            other_agent_policy_search=tree_search_nontom,
            focal_agent_tom_policy_search=tree_search_tom,
)

In [None]:
visualize_env(info_tom, model=model, save_as_gif=False, gif_filename="foraging_tom.gif")

Plan tree of the focal (ToM) agent

In [None]:
visualize_plan_tree(info_tom, time_idx=0, agent_idx=0, model=model, depth=6, min_prob=0.0, root_idx=0, fig_size = (5,5))

In [None]:
# helper function to print out the beliefs of focal agent at tree nodes

tree = jtu.tree_map(lambda x: x[0,0], info_tom["tree"])

def print_qs(tree, node_idx):
    states = jtu.tree_map(lambda x : [i[0].item() for i in jnp.argmax(x, axis=-1)[0]], tree[node_idx]["qs"])
    print("location", states[0])
    print("rewards:", states[1])
    print("apples:")
    print( " focal:  ", jnp.round(tree[node_idx]["qs"][2][0, 0, 0], 2), jnp.round(tree[node_idx]["qs"][3][0, 0, 0], 2), jnp.round(tree[node_idx]["qs"][4][0, 0, 0], 2))
    print( " other 1:", jnp.round(tree[node_idx]["qs"][2][0, 1, 0], 2), jnp.round(tree[node_idx]["qs"][3][0, 1, 0], 2), jnp.round(tree[node_idx]["qs"][4][0, 1, 0], 2))
    print( " other 2:", jnp.round(tree[node_idx]["qs"][2][0, 2, 0], 2), jnp.round(tree[node_idx]["qs"][3][0, 2, 0], 2), jnp.round(tree[node_idx]["qs"][4][0, 2, 0], 2))
    print()
    print( " focal:  ", jnp.round(tree[node_idx]["qs"][5][0, 0, 0], 2), jnp.round(tree[node_idx]["qs"][6][0, 0, 0], 2), jnp.round(tree[node_idx]["qs"][7][0, 0, 0], 2))
    print( " other 1:", jnp.round(tree[node_idx]["qs"][5][0, 1, 0], 2), jnp.round(tree[node_idx]["qs"][6][0, 1, 0], 2), jnp.round(tree[node_idx]["qs"][7][0, 1, 0], 2))
    print( " other 2:", jnp.round(tree[node_idx]["qs"][5][0, 2, 0], 2), jnp.round(tree[node_idx]["qs"][6][0, 2, 0], 2), jnp.round(tree[node_idx]["qs"][7][0, 2, 0], 2))

Plan tree of the other (non-ToM) agent 1

In [None]:
other_tree1 = jtu.tree_map(lambda x: x[0, 0], info_tom["other_tree"])
_ = plot_plan_tree(other_tree1, model=model, max_depth=6)

Plan tree that the focal agent imagined for other agent 1 (note that focal_other_tree has dims (num_tom_agents, num_timesteps, num_other_agents, ...))

In [None]:
focal_other_tree1 = jtu.tree_map(lambda x: x[0, 0, 0], info_tom["focal_other_tree"])
_ = plot_plan_tree(focal_other_tree1, model=model, max_depth=6)