# Testing ToM planning algorithms with the Englishmen Scenario in the Overcooked (v1) task

Karl described a scenario where two polite Englishmen approaching a door at the same time would each defer to the other, resulting in neither actually going through. It's a coordination failure caused by mutual courtesy.

We're generalising this to the Overcooked environment where two agents go for the same resource (an onion) at a central location. There is an environmental constraint that agents can't occupy the same location. By varying the agents' ToM capabilities, we test three conditions:

2 non-ToM Agents: Both agents select the action to move to the centre location to access the onion pile, resulting in a deadlock and task failure.

1 non-ToM + 1 ToM Agent: The ToM agent predicts the non-ToM agent will select the action to move to the centre, so it selects the action to stay instead, resulting in task completion.

2 ToM Agents: Both agents predict the other will select the action to move to the centre, so they both select the action to stay, mirroring the Englishmen scenario and resulting in neither agent moving and task failure.

Note: Using clipped model for simpler scenario to experiment with

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.random as jr
import jax.tree_util as jtu
import jax.numpy as jnp

# from tom.models import OvercookedModel, OvercookedAgent
from tom.models.model_ocv1_clipped import OvercookedModel, OvercookedAgent

from jaxmarl.environments.overcooked import overcooked_layouts, layout_grid_to_dict
from tom.envs.env_ocv1_clipped import OvercookedV1Env

from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer

from tom.planning.si import si_policy_search as si_policy_search_nontom
from tom.planning.rollout import rollout as rollout_nontom

from tom.planning.si_tom import si_policy_search_tom, ToMify
from tom.planning.rollout_2tom import rollout as rollout_tom

from tom.planning.visualize import visualize_plan_tree


Initialise the environment

- W = wall
- O = onion pile
- P = pot
- A = agent
- B = plate pile
- X = delivery station
- empty spaces are empty cells

In [None]:
custom_layout_grid = """
WWOWW
WA AW
P   B
WWXWW
"""
layout = layout_grid_to_dict(custom_layout_grid)

# # or if you want to use a pre-set layout
# layout = overcooked_layouts["cramped_room"] # options: cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit

num_agents = 2
timesteps = 3

In [None]:
env = OvercookedV1Env(
    num_agents, layout, timesteps, 
    initiate_inventory=None, initiate_facingdir=[2, 3]
    )
# env = OvercookedV1Env(num_agents, layout, timesteps, initiate_inventory=["onion", "empty"])

In [None]:
key = jr.PRNGKey(1)
obs, state = env.reset(key)

Initialise pymdp agents' generative model

In [None]:
model = OvercookedModel(env_layout=layout)

Initialise the non-ToM and ToM planning algorithms

In [None]:
horizon=2
max_nodes = 20000
max_branching = len(model.B[0].batch["actions"])
policy_prune_threshold = 0.0
observation_prune_threshold = 0.0
entropy_stop_threshold = 0.5
efe_stop_threshold = 10
kl_threshold=-1
prune_penalty = 512
gamma = 8.0
topk_obsspace = 1

In [None]:
# set up the policy search function
tree_search_nontom = si_policy_search_nontom(
        horizon=horizon,
        max_nodes=max_nodes,
        max_branching=max_branching,
        policy_prune_threshold=policy_prune_threshold,
        observation_prune_threshold=observation_prune_threshold,
        entropy_stop_threshold=entropy_stop_threshold,
        efe_stop_threshold=efe_stop_threshold,
        kl_threshold=kl_threshold,
        prune_penalty=prune_penalty,
        gamma=gamma,
        topk_obsspace=topk_obsspace
    )

tree_search_tom = si_policy_search_tom(
            horizon=horizon,
            max_nodes=max_nodes,
            max_branching=max_branching,
            policy_prune_threshold=policy_prune_threshold,
            observation_prune_threshold=observation_prune_threshold,
            entropy_stop_threshold=entropy_stop_threshold,
            efe_stop_threshold=efe_stop_threshold,
            kl_threshold=kl_threshold,
            prune_penalty=prune_penalty,
            gamma=gamma,
            other_agent_policy_search=tree_search_nontom
        )

Helper functions to print actions and convert environment sequence for visualisation


In [None]:
def print_actions(actions):
    action_names = model.B[list(model.B.keys())[0]].batch['actions']
    
    for agent_idx, agent_actions in enumerate(actions):
        print(f"AGENT {agent_idx} ACTIONS:")
        
        for t in range(agent_actions.shape[0]):
            action_idx = agent_actions[t][0] if agent_actions[t].ndim > 0 else agent_actions[t]
            print(f"  time {t}: {action_names[action_idx]}")
        print()

def convert_State_sequence(info_State):
    num_timesteps = info_State.time.shape[0]
    
    state_seq = []
    for t in range(num_timesteps):
        state_t = jtu.tree_map(
            lambda x: x[t] if x.ndim == 1 else x[:, t, ...],
            info_State
        )
        state_seq.append(state_t)
    
    return state_seq

## 2 non-ToM agents

In [None]:
agents_2nontoms = OvercookedAgent(model, batch_size = num_agents)

In [None]:
key = jr.PRNGKey(1)

last, info_nontom, env_final = rollout_nontom(
    agents_2nontoms, env, timesteps, key, 
    policy_search=tree_search_nontom
    )

In [None]:
print_actions(info_nontom["action"])

In [None]:
state_seq = convert_State_sequence(info_nontom["env_state"])
viz = OvercookedVisualizer()
viz.animate(state_seq, agent_view_size=5, filename='ovc1_englishmen_nontom.gif', fps=1.0)

In [None]:
info_nontom["observation"]

In [None]:
root_idx = 0
visualize_plan_tree(info_nontom, time_idx=0, agent_idx=0, model=model, min_prob=0.1, depth=4, fig_size = (8,10), root_idx=root_idx)

In [None]:
root_idx = 0
visualize_plan_tree(info_nontom, time_idx=0, agent_idx=1, model=model, min_prob=0.1, depth=2, fig_size = (8,5), root_idx=root_idx)

## 1 non-ToM agent, 1 ToM agent

In [None]:
agent_1nontom = OvercookedAgent(model, batch_size = 1)

In [None]:
obs_mapping = jnp.array([[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]])# TODO add the remaining self state factors obs for the other agent: facing location and carrying state

In [None]:
state_mapping = [[[-1, -1, -1, -1, 0]]]
# state_mapping = None

In [None]:
agent_1nontom = OvercookedAgent(model, batch_size = 1)

focal_agent = ToMify(
    agent_1nontom,
    self_states=[0, 1, 2],
    # world_states=list(range(3,len(model.B.keys()))),
    world_states=[],
    observation_mappings=obs_mapping,
    state_mappings=state_mapping,
    batch_size=1
)

other_agent = OvercookedAgent(model, batch_size = 1)

In [None]:
key = jr.PRNGKey(1)
last, info_1tom, env = rollout_tom(
    focal_agent,
    other_agent,
    env,
    timesteps,
    key,
    other_agent_policy_search=tree_search_nontom,
    focal_agent_tom_policy_search=tree_search_tom,
)

In [None]:
state_seq = convert_State_sequence(info_1tom["env_state"])
viz = OvercookedVisualizer()
viz.animate(state_seq, agent_view_size=5, filename='ovc1_englishmen_1tom.gif', fps=1.0)

In [None]:
print_actions(info_1tom["action"])

In [None]:
root_idx = 0
visualize_plan_tree(info_1tom, time_idx=0, agent_idx=0, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

In [None]:
root_idx = 4
visualize_plan_tree(info_1tom, time_idx=0, agent_idx=0, model=model, min_prob=0.0, depth=4, fig_size = (8,10), root_idx=root_idx)

In [None]:
root_idx = 99
visualize_plan_tree(info_1tom, time_idx=0, agent_idx=0, model=model, min_prob=0.0, depth=4, fig_size = (8,10), root_idx=root_idx)

In [None]:
root_idx = 0
visualize_plan_tree(info_1tom, time_idx=0, agent_idx=1, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

## 2 ToM agents

In [None]:
obs_mapping = jnp.array([list(range(len(model.A.keys()))), list(range(len(model.A.keys())))])
obs_mapping = jnp.repeat(obs_mapping[None, ...], 2, axis=0)

In [None]:
state_mapping = [[[-1, -1, -1, -1, 0]], [[-1, -1, -1, -1, 0]]]
# state_mapping = jnp.repeat(state_mapping[None, ...], 2, axis=0)


In [None]:
focal_agents = ToMify(
    agent_1nontom,
    self_states=[0, 1, 2],
    # world_states=list(range(3,len(model.B.keys()))),
    world_states=[],
    observation_mappings=obs_mapping,
    state_mappings=state_mapping,
    batch_size=num_agents
)


In [None]:
key = jr.PRNGKey(1)
last, info_2toms, env_final = rollout_tom(
    focal_agents, 
    other_agents=None,
    env=env, 
    num_timesteps=timesteps, 
    rng_key=key, 
    focal_agent_tom_policy_search=tree_search_tom,
    other_agent_policy_search=tree_search_nontom)


In [None]:
state_seq = convert_State_sequence(info_2toms["env_state"])
viz = OvercookedVisualizer()
viz.animate(state_seq, agent_view_size=5, filename='ovc1_englishmen_2toms.gif', fps=1.0)

In [None]:
print_actions(info_2toms["action"])

In [None]:
root_idx = 0
visualize_plan_tree(info_2toms, time_idx=0, agent_idx=0, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

In [None]:
root_idx = 4
visualize_plan_tree(info_2toms, time_idx=0, agent_idx=0, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

In [None]:
root_idx = 0
visualize_plan_tree(info_2toms, time_idx=0, agent_idx=1, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

In [None]:
root_idx = 4
visualize_plan_tree(info_2toms, time_idx=0, agent_idx=1, model=model, min_prob=0.0, depth=2, fig_size = (8,5), root_idx=root_idx)

# DEBUG ZONE

In [None]:
import jax.numpy as jnp

def print_beliefs(model, qs, belief_idx):
    state_factor_names = list(model.B.keys())
    state_factor_name = state_factor_names[belief_idx]
    
    labels = model.B[state_factor_name].batch[state_factor_name]
    
    for agent_idx, agent_beliefs in enumerate(qs[belief_idx]):
        print(f"AGENT {agent_idx} {state_factor_name} BELIEFS:")
        
        for t in range(agent_beliefs.shape[0]):
            most_likely_idx = jnp.argmax(agent_beliefs[t])
            probability = agent_beliefs[t, 0, most_likely_idx]
            print(f"  time {t}: {labels[most_likely_idx]} (prob={probability:.3f})")
        print()




In [None]:
# print_beliefs(model, info_tom["qs"], belief_idx=0)
