# Collision Avoidance Task

### Overview
This notebook demonstrates single- and multi-agent planning in a cooperative collision avoidance task using sophisticated active inference and Theory of Mind (ToM) capabilities.

### Task Description
The collision avoidance task occurs in a 3Ã—3 grid environment. Agents begin at opposing corners of the grid and must traverse to the opposite corner without occupying the same location as the other agent (i.e., they must avoid colliding with one another).

### Key Features Demonstrated
Note that all agents conduct planning via sophisticated inference and we use the `another_works_for_tom` branch of pymdp. We do not include single agent scenarios here as this task requires more than one agent. We do not include the experiment with the pymdp rollout function as it does not accommodate for multiple agents to be in one shared environment.

1. **Optimized Multiple non-ToM Agents** - uses our custom rollout function with planning tree recycling that supports multiple agents in one shared environment.

2. **Single ToM Agent with Multiple non-ToM Agents** - one agent planning with theory of mind capabilities alongside two agents without theory of mind capabilities. Uses our custom rollout function which accommodates both planning strategies, planning tree recycling, and multiple agents in one shared environment. 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import equinox as eqx

import matplotlib.pyplot as plt

from tom.envs import CollisionAvoidanceEnv
from tom.models import CollisionAvoidanceModel, CollisionAvoidanceAgent
from tom.planning.visualize import visualize_plan_tree, visualize_beliefs, visualize_env

from tom.planning.si import si_policy_search as si_policy_search_nontom
from tom.planning.si_tom import si_policy_search_tom, ToMify

from pymdp.envs import rollout as rollout_pymdp
from tom.planning.rollout_deprecated import rollout as rollout_optimized
from tom.planning.rollout_tom_deprecated import rollout as rollout_tom

# Optimized Mulitple non-ToM Agents

Initialize the environment

In [None]:
grid_size = 3
num_agents = 2
initial_positions = jnp.array([0, 8]) # can set to None and it will be initialised randomly

env = CollisionAvoidanceEnv(num_agents, grid_size, initial_positions) 

Initialize the agents

In [None]:
gamma = 1.0

model_agent0 = CollisionAvoidanceModel(agent_idx=0)
model_agent1 = CollisionAvoidanceModel(agent_idx=1)

In [None]:
agent0 = CollisionAvoidanceAgent(model_agent0, gamma=gamma, batch_size=1)
agent1 = CollisionAvoidanceAgent(model_agent1, gamma=gamma, batch_size=1)

In [None]:
agents = jtu.tree_map(lambda x,y: jnp.concatenate([x,y], axis=0), agent0, agent1)

# policies should remain of shape (9, 1, 2) - the same for all agents in the batch; the line above results in (18, 1, 2)
agents = eqx.tree_at(lambda x: x.policies, agents, agent0.policies)

# but now, we need to update the batch_size and since it's a static field, we need to use object.__setattr__ to bypass eqx's immutability
object.__setattr__(agents, 'batch_size', 2)

We first set up the non-tom planning algorithm. We then run the agent using the our custom rollout function function which, in addition to recycling the planning tree for better performance, allows multiple agents to be in a shared environment. 

In [None]:
horizon=3
max_nodes = 5000
max_branching = 10
policy_prune_threshold = 1 / 8
observation_prune_threshold = 1 / 8
entropy_stop_threshold = 0.5
efe_stop_threshold = 5
kl_threshold=1e-2
prune_penalty = 512

# set up the policy search function
tree_search_nontom = si_policy_search_nontom(
        horizon=horizon,
        max_nodes=max_nodes,
        max_branching=max_branching,
        policy_prune_threshold=policy_prune_threshold,
        observation_prune_threshold=observation_prune_threshold,
        entropy_stop_threshold=entropy_stop_threshold,
        efe_stop_threshold=efe_stop_threshold,
        kl_threshold=kl_threshold,
        prune_penalty=prune_penalty,
        gamma=gamma
    )

In [None]:
T = 3
key = jr.PRNGKey(1)
last, info_optimized_multi_nontom, env = rollout_optimized(agents, env, T, key, policy_search=tree_search_nontom)

In [None]:
visualize_env(info_optimized_multi_nontom, model=model_agent0, save_as_gif=False, gif_filename="collisionavoidance_optimized_multi_nontom.gif")

In [None]:
visualize_plan_tree(info_optimized_multi_nontom, time_idx=0, agent_idx=0, model=model_agent0, min_prob = 0.0, depth=4, fig_size = (5,5))

In [None]:
visualize_beliefs(info_optimized_multi_nontom, model=model_agent0)

# Single ToM Agent with Multiple non-ToM Agents

Initialize the agents

In [None]:
import equinox as eqx

agent0 = CollisionAvoidanceAgent(model_agent0, gamma=gamma, batch_size=1)
focal_agent = ToMify(agent0,
                     self_states=[0],
                     world_states=[],
                     # observation mappings is of size (focal_batch, num_agents, num_modalities)
                     # with for each obs modality of the agent, which actual received obs modality should be used
                     # e.g. you'll get 2 observation modalities,
                     # these map to obs modality [0, 1] for agent 0 (focal) and [1, nothing] for agent 1 (other)
                     observation_mappings=jnp.array([[[0, 1],[1, -1]]]), # note that observations from the environment include the locations of the other agents so we have observation mappings that tell us which observation is regarding which agent (-1s are invalid)
                     # state mappings is of size (focal_batch, num_other_agents, num_state_factors)
                     # and its meaning is, for each other agent, we specify whether a focal agent's state factor maps to one of this other agent's state factors
                     # -1 meaning that none of the other agent's state factors map to this one
                     # so here it means that for the other agent, the 0th focal state factor has no mapping (-1), but the 1st focal state factor maps to the 0th of the other
                     state_mappings=[[[-1, 0]]])

# we have to set the preferences for the focal agent's beliefs about the other agents correctly here otherwise it just copies its own model (and its own preferences)
focal_agent = eqx.tree_at(lambda x: x.agent_models.C, focal_agent, [c[None, ...] for c in agents.C])

other_agents = CollisionAvoidanceAgent(model_agent1, gamma=gamma, batch_size=1)

Running the agent using the sophisticated inference planning with theory of mind capabilities (`si_policy_search_tom`) and the custom rollout function which allows non-ToM and ToM agents to interact within the same environment.

In [None]:
max_nodes = 10000

tree_search_tom = si_policy_search_tom(
            horizon=horizon,
            max_nodes=max_nodes,
            max_branching=max_branching,
            policy_prune_threshold=policy_prune_threshold,
            observation_prune_threshold=observation_prune_threshold,
            entropy_stop_threshold=entropy_stop_threshold,
            efe_stop_threshold=efe_stop_threshold,
            kl_threshold=kl_threshold,
            prune_penalty=prune_penalty,
            gamma=gamma,
        )

In [None]:
key = jr.PRNGKey(1)
T=3
last, info_tom, env = rollout_tom(focal_agent,
            other_agents,
            env,
            T,
            key,
            other_agent_policy_search=tree_search_nontom,
            focal_agent_tom_policy_search=tree_search_tom,
)

In [None]:
visualize_env(info_tom, model=model_agent0, save_as_gif=False, gif_filename="collisionavoidance_tom.gif")

In [None]:
visualize_plan_tree(info_tom, time_idx=0, agent_idx=0, model=model_agent0, min_prob = 0.0, depth=4, fig_size = (5,5))

In [None]:
# helper function to print out the beliefs of focal agent at tree nodes

tree = jtu.tree_map(lambda x: x[0,0], info_tom["tree"])

def print_qs(tree, node_idx):
    print("focal agent:")
    print("  location focal \n", jnp.round(tree[node_idx]["qs"][0][0, 0, 0], 2))
    print("  location other: \n", jnp.round(tree[node_idx]["qs"][1][0, 0, 0], 2))
    print("other agent:")
    print("  location other (self): \n", jnp.round(tree[node_idx]["qs"][0][0, 1, 0], 2))
    print("  location focal: \n", jnp.round(tree[node_idx]["qs"][1][0, 1, 0], 2))