# JaxMARL Example Usage for Overcooked v1 and v2

I'm running overcooked v1 and v2 from jaxmarl here with random actions and visualising the results. Examples for more JaxMARL environments can be found [here](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/tutorials).

* `actions`, `obs`, `rewards`, `dones` are dictionaries keyed by agent name, this allows for differing action and observation spaces. As agents can terminate asychronously, `dones` contains a special `"__all__"` which signifies whether an episode has terminated.
* `state` represents the internal state of the environment and contains all the information needed to transistion the environment given a set of actions. These variables are not held within the environment class due to JAX transformations requiring pure functions. This contains the `grid` with shape (timesteps, height, width, 3 channels). The 3 channels are (1) static items like wall, agent, ingredient, button etc, (2) dynamic items like plate (plate goes from plate to things on it) or cooked (whether something is cooked or uncooked), and (3) extra info.
* `info` is a dictionary containing pertinent information, the exact content varies environment to environment.

This is what the symbols in the visualisation mean: 
* triangles - agents, with direction
* yellow circles - onions 
* pot - pot - lid position indicates if its actively cooking or not; green bars on pots to show cooking progress
* white circles - plates 
* green rectangle - delivery area where completed dishes must be placed (ultimate goal)


## Overcooked v1


In [None]:
import jax 
from jaxmarl import make
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer
from jaxmarl.environments.overcooked import overcooked_layouts, layout_grid_to_dict

# Parameters + random keys
max_steps = 10
key = jax.random.PRNGKey(0)
key, key_r, key_a = jax.random.split(key, 3)

# Get one of the classic layouts (cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit)
layout = overcooked_layouts["cramped_room"]

# Or make your own!
# custom_layout_grid = """
# WWOWW
# WA  W
# B P X
# W  AW
# WWOWW
# """
# layout = layout_grid_to_dict(custom_layout_grid)

# Instantiate environment
env = make('overcooked', layout=layout, max_steps=max_steps)

obs, state = env.reset(key_r)
print('list of agents in environment', env.agents)

state_seq = []
for _ in range(max_steps):
    state_seq.append(state)
    # Iterate random keys and sample actions
    key, key_s, key_a = jax.random.split(key, 3)
    key_a = jax.random.split(key_a, env.num_agents)

    actions = {agent: env.action_space(agent).sample(key_a[i]) for i, agent in enumerate(env.agents)}

    # Step environment
    obs, state, rewards, dones, infos = env.step(key_s, state, actions)

# viz = OvercookedVisualizer()

# # Or save an animation
# viz.animate(state_seq, agent_view_size=5, filename='overcooked_v1.gif')

In [None]:
observation_labels = [
    "Agent 0 Position",                    
    "Agent 1 Position",                    
    "Agent 0 Orientation North",           
    "Agent 0 Orientation South",           
    "Agent 0 Orientation East",            
    "Agent 0 Orientation West",            
    "Agent 1 Orientation North",           
    "Agent 1 Orientation South",           
    "Agent 1 Orientation East",            
    "Agent 1 Orientation West",            
    "Pot Locations",                       
    "Counter Locations",             
    "Onion Pile Locations",                
    "Tomato Pile Locations (not used in this environment)",      
    "Plate Pile Locations",                
    "Delivery Locations",             
    "Onions in Pot (0-3)",                 
    "Tomatoes in Pot (not used in this environment)",            
    "Onions in Soup (0 or 3)",             
    "Tomatoes in Soup (not used in this environment)",           
    "Pot Cooking Time Remaining (19-1)",   
    "Soup Ready",                     
    "Plate Locations (Variable)",          
    "Onion Locations (Variable)",          
    "Tomato Locations (not used in this environment)",           
    "Urgency (â‰¤40 steps remaining)"        
]

for i in range(obs['agent_0'].shape[2]):
    print(f"\n{observation_labels[i]} (Channel {i}):")
    print(obs['agent_0'][:,:,i])


## Overcooked v2

[note: i dont see a button though?]

In [None]:
import jax 
from jaxmarl import make
from jaxmarl.viz.overcooked_v2_visualizer import OvercookedV2Visualizer
from jaxmarl.environments.overcooked_v2 import overcooked_v2_layouts
import time
import jax.tree_util as tree_util
import jax.numpy as jnp

# Parameters + random keys
max_steps = 10
key = jax.random.PRNGKey(0)
key, key_r, key_a = jax.random.split(key, 3)

# Get one of the classic layouts
layout = overcooked_v2_layouts["cramped_room_v2"] 

# Instantiate environment
env = make('overcooked_v2', layout=layout, max_steps=max_steps)

obs, state = env.reset(key_r)
print('list of agents in environment', env.agents)

state_seq = []
for _ in range(max_steps):
    state_seq.append(state)
    # Iterate random keys and sample actions
    key, key_s, key_a = jax.random.split(key, 3)
    key_a = jax.random.split(key_a, env.num_agents)

    actions = {agent: env.action_space(agent).sample(key_a[i]) for i, agent in enumerate(env.agents)}

    # Step environment
    obs, state, rewards, dones, infos = env.step(key_s, state, actions)

    print(f"\n{_}th step")
    print(f"obs agent 0 shape: {obs['agent_0'].shape}")
    print(f"obs agent 1 shape: {obs['agent_1'].shape}")
    print("obs agent 0:")
    for o in obs['agent_0']:
        print(f"\n {o}")
    print("obs agent 1:")
    for o in obs['agent_1']:
        print(f"\n {o}")
    

viz2 = OvercookedV2Visualizer()

stacked_state_seq = tree_util.tree_map(lambda *args: jnp.stack(args), *state_seq) # need to stack the state sequence for JAX vmap for v2 visualisation


In [None]:
# Save an animation
# viz2.animate(stacked_state_seq, agent_view_size=5, filename='overcooked_v2.gif')

In [None]:
observation_labels = [
    # Agent Layer (8 layers: 0-7)
    "Agent Position",
    "Agent Direction: UP", 
    "Agent Direction: DOWN",
    "Agent Direction: RIGHT", 
    "Agent Direction: LEFT",
    "Agent Inventory: Plate Bit",
    "Agent Inventory: Cooked Bit", 
    "Agent Inventory: Ingredient 0",
    "Agent Inventory: Ingredient 1",
    
    # Other Agents Layer (8 layers: 8-15)
    "Other Agents Position",
    "Other Agents Direction: UP",
    "Other Agents Direction: DOWN", 
    "Other Agents Direction: RIGHT",
    "Other Agents Direction: LEFT",
    "Other Agents Inventory: Plate Bit",
    "Other Agents Inventory: Cooked Bit",
    "Other Agents Inventory: Ingredient 0",
    "Other Agents Inventory: Ingredient 1",
    
    # Static Objects Layer (6 layers: 16-21)
    "Static: Walls",
    "Static: Goals (Delivery Areas)",
    "Static: Pots", 
    "Static: Recipe Indicators",
    "Static: Button Recipe Indicators",
    "Static: Plate Piles",
    
    # Ingredient Piles Layer (1 layer: 22)
    "Ingredient Pile: Ingredient 0",
    "Ingredient Pile: Ingredient 1",
    
    # Ingredients on Grid Layer (3 layers: 23-25)
    "Grid Items: Plate Bit",
    "Grid Items: Cooked Bit",
    "Grid Items: Ingredient 0",
    "Grid Items: Ingredient 1",
    
    # Recipe Layer (3 layers: 26-28) 
    "Recipe: Plate Bit",
    "Recipe: Cooked Bit", 
    "Recipe Type",
    
    # Extra Layers (6 layers: 29-34)
    "Pot Cooking Timers",
    "Successful Delivery Indicator"
]

for i in range(obs['agent_0'].shape[2]):
    print(f"\n{observation_labels[i]} (Channel {i}):")
    print(obs['agent_0'][:,:,i])
