In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.random as jr
import jax.tree_util as jtu
import jax.numpy as jnp

from tom.models import OvercookedModel, OvercookedAgent
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer

from jaxmarl.environments.overcooked import overcooked_layouts, layout_grid_to_dict
from tom.envs.env_ocv1 import OvercookedV1Env

from tom.planning.rollout_tom import rollout as rollout_tom
from tom.planning.si_tom import si_policy_search_tom, ToMify
from tom.planning.si import si_policy_search as si_policy_search_nontom

# Overcooked_v1 with 1 ToM agent and 1 non-ToM agent

Initialise the environment

- W = wall
- O = onion pile
- P = pot
- A = agent
- B = plate pile
- X = delivery station
- empty spaces are empty cells

In [None]:
custom_layout_grid = """
WWPWW
OA  W
W  AB
WWXWW
"""
layout = layout_grid_to_dict(custom_layout_grid)

# # or if you want to use a pre-set layout
# layout = overcooked_layouts["cramped_room"] # options: cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit

num_agents = 2
timesteps = 3

In [None]:
env = OvercookedV1Env(num_agents, layout, timesteps, initiate_inventory=None)
# env = OvercookedV1Env(num_agents, layout, timesteps, initiate_inventory=["onion", "empty"])

Initialise pymdp agents

In [None]:
model = OvercookedModel(env_layout=layout)

In [None]:
agent0 = OvercookedAgent(model, batch_size = 1)
obs_mapping = jnp.array([[list(range(len(model.A.keys()))), list(range(len(model.A.keys())))]])

focal_agent = ToMify(
    agent0,
    self_states=[0, 1, 2],
    world_states=list(range(3,len(model.B.keys()))),
    observation_mappings=obs_mapping,
    batch_size=1
)

other_agent = OvercookedAgent(model, batch_size = 1)

In [None]:
focal_agent

Do the rollout using sophisticated active inference (2 ToM agents)

In [None]:
horizon=3
max_nodes = 20000
max_branching = len(model.B[0].batch["actions"])
policy_prune_threshold = 0.0
observation_prune_threshold = 1/64
entropy_stop_threshold = 0.5
efe_stop_threshold = 1000
kl_threshold=-1
prune_penalty = 512
gamma = 8.0
topk_obsspace = 1

# set up the policy search function
tree_search_nontom = si_policy_search_nontom(
        horizon=horizon,
        max_nodes=max_nodes,
        max_branching=max_branching,
        policy_prune_threshold=policy_prune_threshold,
        observation_prune_threshold=observation_prune_threshold,
        entropy_stop_threshold=entropy_stop_threshold,
        efe_stop_threshold=efe_stop_threshold,
        kl_threshold=kl_threshold,
        prune_penalty=prune_penalty,
        gamma=gamma,
        topk_obsspace=topk_obsspace
    )

In [None]:
horizon=5
max_nodes = 20000
max_branching = len(model.B[0].batch["actions"])
policy_prune_threshold = 0.0
observation_prune_threshold = 1/64
entropy_stop_threshold = 0.5
efe_stop_threshold = 1000
kl_threshold=-1
prune_penalty = 512
gamma = 8.0
topk_obsspace = 1

# set up the policy search function
tree_search_tom = si_policy_search_tom(
            horizon=horizon,
            max_nodes=max_nodes,
            max_branching=max_branching,
            policy_prune_threshold=policy_prune_threshold,
            observation_prune_threshold=observation_prune_threshold,
            entropy_stop_threshold=entropy_stop_threshold,
            efe_stop_threshold=efe_stop_threshold,
            kl_threshold=kl_threshold,
            prune_penalty=prune_penalty,
            gamma=gamma,
            other_agent_policy_search=tree_search_nontom
        )

In [None]:
key = jr.PRNGKey(1)
last, info_tom, env = rollout_tom(
    focal_agent,
    other_agent,
    env,
    timesteps,
    key,
    other_agent_policy_search=tree_search_nontom,
    focal_agent_tom_policy_search=tree_search_tom,
)

Adding visualisations and printing information

In [None]:
def convert_State_sequence(info_State):
    num_timesteps = info_State.time.shape[0]
    
    state_seq = []
    for t in range(num_timesteps):
        state_t = jtu.tree_map(
            lambda x: x[t] if x.ndim == 1 else x[:, t, ...],
            info_State
        )
        state_seq.append(state_t)
    
    return state_seq

In [None]:
state_seq = convert_State_sequence(info_tom["env_state"])
viz = OvercookedVisualizer()
viz.animate(state_seq, agent_view_size=5, filename='overcooked_v1_1tom_test.gif', fps=3.0)

In [None]:
import jax.numpy as jnp

def print_beliefs(model, qs, belief_idx):
    state_factor_names = list(model.B.keys())
    state_factor_name = state_factor_names[belief_idx]
    
    labels = model.B[state_factor_name].batch[state_factor_name]
    
    for agent_idx, agent_beliefs in enumerate(qs[belief_idx]):
        print(f"AGENT {agent_idx} {state_factor_name} BELIEFS:")
        
        for t in range(agent_beliefs.shape[0]):
            most_likely_idx = jnp.argmax(agent_beliefs[t])
            probability = agent_beliefs[t, 0, most_likely_idx]
            print(f"  time {t}: {labels[most_likely_idx]} (prob={probability:.3f})")
        print()

def print_actions(actions):
    action_names = model.B[list(model.B.keys())[0]].batch['actions']
    
    for agent_idx, agent_actions in enumerate(actions):
        print(f"AGENT {agent_idx} ACTIONS:")
        
        for t in range(agent_actions.shape[0]):
            action_idx = agent_actions[t][0] if agent_actions[t].ndim > 0 else agent_actions[t]
            print(f"  time {t}: {action_names[action_idx]}")
        print()


In [None]:
print_beliefs(model,info_tom["qs"], belief_idx=0)


In [None]:
print_actions(info_tom["action"])

In [None]:
# from tom.planning.visualize import visualize_plan_tree
# root_idx = None
# visualize_plan_tree(info_nontom, time_idx=0, agent_idx=1, model=model, min_prob=0.0, depth=4, fig_size = (8,10), root_idx=root_idx)