Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
!pip install meltingpot

In [None]:
#@title Environment configuration & build function
"""Configuration for a Cordination in the matrix (grouped) environment."""

import copy
import random
import re
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from meltingpot.utils.substrates import builder
from meltingpot.utils.substrates import colors
from meltingpot.utils.substrates import game_object_utils
from meltingpot.utils.substrates import shapes
from ml_collections import config_dict

PrefabConfig = game_object_utils.PrefabConfig

# The number of resources must match the (square) size of the matrix.
NUM_RESOURCES = 3

# This color is red.
RESOURCE1_COLOR = (150, 0, 0, 255)
RESOURCE1_HIGHLIGHT_COLOR = (200, 0, 0, 255)
RESOURCE1_COLOR_DATA = (RESOURCE1_COLOR, RESOURCE1_HIGHLIGHT_COLOR)
# This color is green.
RESOURCE2_COLOR = (0, 150, 0, 255)
RESOURCE2_HIGHLIGHT_COLOR = (0, 200, 0, 255)
RESOURCE2_COLOR_DATA = (RESOURCE2_COLOR, RESOURCE2_HIGHLIGHT_COLOR)
# This color is blue.
RESOURCE3_COLOR = (0, 0, 150, 255)
RESOURCE3_HIGHLIGHT_COLOR = (0, 0, 200, 255)
RESOURCE3_COLOR_DATA = (RESOURCE3_COLOR, RESOURCE3_HIGHLIGHT_COLOR)

MAP = "normal"  # Can also be "walled" and "celled"

if MAP == "normal":
  ASCII_MAP = """
  WWWWWWWWWWWWWWWWWWWWW
  WPPPP  aa   aa  PPPPW
  WPPPP  aa   aa  PPPPW
  WP                 PW
  WP11 22 33 11 22 33PW
  WP11 22 33 11 22 33PW
  WP                 PW
  WP22 33 11 22 33 11PW
  WP22 33 11 22 33 11PW
  WP                 PW
  WP33 22 11 33 22 11PW
  WP33 22 11 33 22 11PW
  WP                 PW
  WPPPP  aa   aa  PPPPW
  WPPPP  aa   aa  PPPPW
  WWWWWWWWWWWWWWWWWWWWW
  """
  NORMAL_MAP = True

elif MAP == "walled":
  ASCII_MAP = """
  WWWWWWWWWWWWWWWWWWWWW
  WRRR   aa   aa   RRRW
  WRRR   aa   aa   RRRW
  WR11 22 33 11 22 33RW
  WR11 22 33 11 22 33RW
  WWWWWWWWWWWWWWWWWWWWW
  WB33 22 11 33 22 11BW
  WB33 22 11 33 22 11BW
  WBBB   aa   aa   BBBW
  WBBB   aa   aa   BBBW
  WWWWWWWWWWWWWWWWWWWWW
  """
  NORMAL_MAP = False

elif MAP == "celled":
  ASCII_MAP = """
    WWWWWWWWWWWWWWWWWWWWW
    W R  W  B W  R W BR W
    W a1 W 3a W 2a W a1 W
    W a2 W 2a W 3a W a3 W
    W a3 W 1a W 1a W a2 W
    W  B W R  W B  W BR W
    WWWWWWWWWWWWWWWWWWWWW
    W B  W  R W  B W R  W
    W 3a W 1a W a1 W a2 W
    W 2a W 2a W a3 W a3 W
    W 1a W 3a W a2 W a1 W
    W  R W B  W R  W  B W
    WWWWWWWWWWWWWWWWWWWWW
    """
  NORMAL_MAP = False


HEAD_BAND = """
xxxxxxxx
xxxxxxxx
xx####xx
x@xxxx@x
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
"""

HEAD_BAND_PALETTE = {
    "#": (204, 203, 200, 55),
    "@": (171, 170, 167, 50),
    "x": (0, 0, 0, 0),
}


# Custom avatars with individual 'faces'.
def alter_avatar(avatar: str, target: int) -> str:
  """Replace a part of the avatar string.

  Args:
    avatar: string definining avatar sprite.
    target: location of the pixel we want to add.

  Returns:
    altered avatar string.
  """
  string_rows = [m.start() for m in re.finditer("\n", avatar)]
  # O is dark grey
  # & is darker shade of the agent color
  # * is the agent color
  # # is white
  new_avatar = (
      avatar[:string_rows[target]] +
      "\nO" +
      avatar[string_rows[target]+2:string_rows[target]+8] +
      "x" +
      avatar[string_rows[target+1]:]
  )
  return new_avatar


_resource_names = [
    "resource_class1",
    "resource_class2",
    "resource_class3",
]

# `prefab` determines which prefab game object to use for each `char` in the
# ascii map.
CHAR_PREFAB_MAP = {
    # A randomly chosen resource from `_resource_names` will be placed at each
    # location in the ascii map there is an 'a'.
    "a": {"type": "choice", "list": _resource_names},
    "P": "spawn_point",
    "R": "red_spawn_point",
    "B": "blue_spawn_point",
    "W": "wall",
    "1": _resource_names[0],
    "2": _resource_names[1],
    "3": _resource_names[2],
}

_COMPASS = ["N", "E", "S", "W"]

WALL = {
    "name": "wall",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "wall",
                "stateConfigs": [{
                    "state": "wall",
                    "layer": "upperPhysical",
                    "sprite": "Wall",
                }],
            }
        },
        {
            "component": "Transform",
        },
        {
            "component": "Appearance",
            "kwargs": {
                "renderMode": "ascii_shape",
                "spriteNames": ["Wall"],
                "spriteShapes": [shapes.WALL],
                "palettes": [{"*": (95, 95, 95, 255),
                              "&": (100, 100, 100, 255),
                              "@": (109, 109, 109, 255),
                              "#": (152, 152, 152, 255)}],
                "noRotates": [False]
            }
        },
        {
            "component": "BeamBlocker",
            "kwargs": {
                "beamType": "gameInteraction"
            }
        },
    ]
}

SPAWN_POINT = {
    "name": "spawnPoint",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "spawnPoint",
                "stateConfigs": [{
                    "state": "spawnPoint",
                    "layer": "alternateLogic",
                    "groups": ["spawnPoints"]
                }],
            }
        },
        {
            "component": "Transform",
        },
    ]
}

DUMMY_SPAWN_POINT = {
    "name": "dummySpawnPoint",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "dummySpawnPoint",
                "stateConfigs": [{
                    "state": "dummySpawnPoint",
                    "layer": "alternateLogic",
                    "groups": ["dummySpawnPoints"]
                }],
            }
        },
        {
            "component": "Transform",
        },
    ]
}

# PLAYER_COLOR_PALETTES is a list with each entry specifying the color to use
# for the player at the corresponding index.
NUM_PLAYERS_UPPER_BOUND = 32
PLAYER_COLOR_PALETTES = []
for idx in range(NUM_PLAYERS_UPPER_BOUND):
  PLAYER_COLOR_PALETTES.append(shapes.get_palette(colors.palette[idx]))

# Primitive action components.
NOOP       = {"move": 0, "turn":  0, "interact": 0}
FORWARD    = {"move": 1, "turn":  0, "interact": 0}
STEP_RIGHT = {"move": 2, "turn":  0, "interact": 0}
BACKWARD   = {"move": 3, "turn":  0, "interact": 0}
STEP_LEFT  = {"move": 4, "turn":  0, "interact": 0}
TURN_LEFT  = {"move": 0, "turn": -1, "interact": 0}
TURN_RIGHT = {"move": 0, "turn":  1, "interact": 0}
INTERACT   = {"move": 0, "turn":  0, "interact": 1}

ACTION_SET = (
    NOOP,
    FORWARD,
    BACKWARD,
    STEP_LEFT,
    STEP_RIGHT,
    TURN_LEFT,
    TURN_RIGHT,
    INTERACT,
)


def get_sprite_data(color: str, face: Optional[int] = None) -> dict[str, Any]:
  """Put together an individual avatar sprite.

  Args:
    color: colors of avatar.
    face: how they deviate from the basic sprite.

  Returns:
    the resulting sprite specifications.
  """
  if face is not None:
    shape = [
        alter_avatar(shapes.CUTE_AVATAR_N, face),
        alter_avatar(shapes.CUTE_AVATAR_E, face),
        alter_avatar(shapes.CUTE_AVATAR_S, face),
        alter_avatar(shapes.CUTE_AVATAR_W, face)
    ]
  else:
    shape = shapes.CUTE_AVATAR

  if color == "red":
    sprite_data = {
        # Sprite for group 1 avatars (red).
        "shape": shape,
        "palette": shapes.get_palette((150, 0, 0)),
        "noRotate": True,
    }
  elif color == "blue":
    sprite_data = {
        # Sprite for group 2 avatars (blue).
        "shape": shape,
        "palette": shapes.get_palette((0, 0, 150)),
        "noRotate": True,
    }
  elif color == "green":
    sprite_data = {
        # Sprite for group 2 avatars (green).
        "shape": shape,
        "palette": shapes.get_palette((0, 150, 0)),
        "noRotate": True,
    }
  elif color == "red_closer":  # Correlated red and blue.
    sprite_data = {
        "shape": shape,
        "palette": shapes.get_palette((150, 0, 50)),
        "noRotate": True,
    }
  elif color == "blue_closer":
    sprite_data = {
        "shape": shape,
        "palette": shapes.get_palette((50, 0, 150)),
        "noRotate": True,
    }
  elif color == "red_closest":  # Two shades of purple.
    sprite_data = {
        "shape": shape,
        "palette": shapes.get_palette((150, 0, 100)),
        "noRotate": True,
    }
  elif color == "blue_closest":
    sprite_data = {
        "shape": shape,
        "palette": shapes.get_palette((100, 0, 150)),
        "noRotate": True,
    }
  else:
    raise ValueError(f"Unknown color: {color}")
  return sprite_data


def create_scene() -> dict[str, Any]:
  """Creates the global scene.

  Returns:
    Dict of the scene specifications.
  """
  scene = {
      "name": "scene",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "scene",
                  "stateConfigs": [{
                      "state": "scene",
                  }],
              }
          },
          {
              "component": "Transform",
          },
          {
              "component": "TheMatrix",
              "kwargs": {
                  # Prevent interaction before both interactors have collected
                  # at least one resource.
                  "disallowUnreadyInteractions": True,
                  "matrix": [
                      # 1  2  3
                      [1, 0, 0],  # 1
                      [0, 1, 0],  # 2
                      [0, 0, 1]   # 3
                  ],
                  "resultIndicatorColorIntervals": [
                      (0.0, 0.2),  # red
                      (0.2, 0.4),  # yellow
                      (0.4, 0.6),  # green
                      (0.6, 0.8),  # blue
                      (0.8, 1.0),  # violet
                  ],
              }
          },
          {
              "component": "StochasticIntervalEpisodeEnding",
              "kwargs": {
                  "minimumFramesPerEpisode": 1000,
                  "intervalLength": 100,  # Set equal to unroll length.
                  "probabilityTerminationPerInterval": 0.2
              }
          }
      ]
  }
  return scene


def create_resource_prefab(resource_id, color_data) -> dict[str, Any]:
  """Creates resource prefab with provided `resource_id` (num) and color.

  Args:
    resource_id: The id of the resource.
    color_data: The color data for the resource.
  Returns:
    Dict of the resource prefab specifications.
  """
  resource_name = "resource_class{}".format(resource_id)
  resource_prefab = {
      "name": resource_name,
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": resource_name,
                  "stateConfigs": [
                      {"state": resource_name + "_wait",
                       "groups": ["resourceWaits"]},
                      {"state": resource_name,
                       "layer": "lowerPhysical",
                       "sprite": resource_name + "_sprite"},
                  ]
              },
          },
          {
              "component": "Transform",
          },
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [resource_name + "_sprite"],
                  "spriteShapes": [shapes.BUTTON],
                  "palettes": [{"*": color_data[0],
                                "#": color_data[1],
                                "x": (0, 0, 0, 0)}],
                  "noRotates": [False]
              },
          },
          {
              "component": "Resource",
              "kwargs": {
                  "resourceClass": resource_id,
                  "visibleType": resource_name,
                  "waitState": resource_name + "_wait",
                  "regenerationRate": 0.04,
                  "regenerationDelay": 10,
              },
          },
          {
              "component": "Destroyable",
              "kwargs": {
                  "waitState": resource_name + "_wait",
                  # It is possible to destroy resources but takes concerted
                  # effort to do so by zapping them `initialHealth` times.
                  "initialHealth": 3,
              },
          },
      ]
  }
  return resource_prefab


def create_spawn_point_prefab(team: str) -> dict[str, Any]:
  """Return a team-specific spawn-point prefab.

  Args:
    team: The team name.
  Returns:
    Dict of the spawn-point prefab specifications.
  """
  prefab = {
      "name": f"{team}_spawn_point",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "playerSpawnPoint",
                  "stateConfigs": [{
                      "state": "playerSpawnPoint",
                      "layer": "alternateLogic",
                      "groups": ["spawnPoints", "{}SpawnPoints".format(team)],
                  }],
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "invisible",
                  "spriteNames": [],
                  "spriteRGBColors": []
              }
          },
      ]
  }
  return prefab


def create_prefabs() -> PrefabConfig:
  """Create the prefabs.

  Prefabs are a dictionary mapping names to template game objects that can
  be cloned and placed in multiple locations accoring to an ascii map.

  Returns:
    The prefab config.
  """
  prefabs = {
      "wall": WALL,
      "spawn_point": SPAWN_POINT,
      "dummy_spawn_point": DUMMY_SPAWN_POINT,
      "red_spawn_point": create_spawn_point_prefab("red"),
      "blue_spawn_point": create_spawn_point_prefab("blue"),
  }
  prefabs["resource_class1"] = create_resource_prefab(1, RESOURCE1_COLOR_DATA)
  prefabs["resource_class2"] = create_resource_prefab(2, RESOURCE2_COLOR_DATA)
  prefabs["resource_class3"] = create_resource_prefab(3, RESOURCE3_COLOR_DATA)
  return prefabs


def create_avatar_object(player_idx: int,
                         sprite_data: Dict[str, Any],
                         dummy: bool = False,
                         dummy_without_hat: bool = False,
                         color: str = "red",
                         evaluation: bool = False,
                         bad_apple: bool = False) -> Dict[str, Any]:
  """Create an avatar object given sprite data.

  Args:
    player_idx: The index of the player.
    sprite_data: The sprite specification.
    dummy: Whether the avator is a dummy (i.e. cannot act).
    dummy_without_hat: Whether to use a dummy sprite without a hat.
    color: The color of the sprite.
    evaluation: Whether to use the evaluation logic.
    bad_apple: Whether to use the 'bad apple' logic (i.e. longer freeze).

  Returns:
    The avatar specification.
  """
  # Lua is 1-indexed.
  lua_index = player_idx + 1

  # This will be the standard thats called during evalution.
  spawn_group = "spawnPoints"
  if color == "red" and not evaluation and not NORMAL_MAP:
    spawn_group = "redSpawnPoints"
  elif color == "blue" and not evaluation and not NORMAL_MAP:
    spawn_group = "blueSpawnPoints"
  if dummy:
    spawn_group = "dummySpawnPoints"

  if bad_apple:
    freeze = 16  # Change this to vary how bad it is to interact with the
    # less desirable agent.
    beam_color = [200, 200, 200]
    unready_reward = 0
    reward_multiplier = 1
  else:
    freeze = 16
    beam_color = [200, 200, 200]
    unready_reward = 0
    reward_multiplier = 1

  # Setup the self vs other sprite mapping.
  source_sprite_self = "Avatar" + str(lua_index)

  live_state_name = "player{}".format(lua_index)
  avatar_object = {
      "name":
          "avatar",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState":
                      live_state_name,
                  "stateConfigs": [
                      {
                          "state": live_state_name,
                          "layer": "upperPhysical",
                          "sprite": source_sprite_self,
                          "contact": "avatar",
                          "groups": ["players"]
                      },
                      {
                          "state": "playerWait",
                          "groups": ["playerWaits"]
                      },
                  ]
              }
          },
          {
              "component": "Transform",
          },
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [source_sprite_self],
                  "spriteShapes": [sprite_data["shape"]],
                  "palettes": [sprite_data["palette"]],
                  "noRotates": [sprite_data["noRotate"]]
              }
          },
          {
              "component": "Avatar",
              "kwargs": {
                  "index": lua_index,
                  "aliveState": live_state_name,
                  "waitState": "playerWait",
                  "speed": 1.0,
                  "spawnGroup": spawn_group,
                  "actionOrder": ["move", "turn", "interact"],
                  "actionSpec": {
                      "move": {"default": 0, "min": 0, "max": len(_COMPASS)},
                      "turn": {"default": 0, "min": -1, "max": 1},
                      "interact": {"default": 0, "min": 0, "max": 1},
                  },
                  "view": {
                      "left": 5,
                      "right": 5,
                      "forward": 9,
                      "backward": 1,
                      "centered": False
                  },
                  # The following kwarg makes it possible to get rewarded even
                  # on frames when an avatar is "dead". It is needed for in the
                  # matrix games in order to correctly handle the case of two
                  # players getting hit simultaneously by the same beam.
                  "skipWaitStateRewards": False,
              }
          },
          {
              "component": "GameInteractionZapper",
              "kwargs": {
                  "cooldownTime": 2,
                  "beamLength": 2,
                  "beamRadius": 0,
                  "beamColor": beam_color,
                  "framesTillRespawn": 50,
                  "numResources": NUM_RESOURCES,
                  "endEpisodeOnFirstInteraction": False,
                  # Reset both players' inventories after each interaction.
                  "reset_winner_inventory": True,
                  "reset_loser_inventory": True,
                  # Both players get removed after each interaction.
                  "losingPlayerDies": True,
                  "winningPlayerDies": True,
                  # `freezeOnInteraction` is the number of frames to display the
                  # interaction result indicator, freeze, and delay delivering
                  # all results of interacting.
                  "freezeOnInteraction": freeze,
                  "rewardFromZappingUnreadyPlayer": unready_reward,
                  "rewardMultiplier": reward_multiplier,
              }
          },
          {
              "component": "ReadyToShootObservation",
              "kwargs": {
                  "zapperComponent": "GameInteractionZapper",
              }
          },
          {
              "component": "InventoryObserver",
              "kwargs": {}
          },
          {
              "component": "Taste",
              "kwargs": {
                  "mostTastyResourceClass": -1,  # -1 indicates no preference.
                  # No resource is most tasty when mostTastyResourceClass == -1.
                  "mostTastyReward": 0.1,
              }
          },
          {
              "component": "InteractionTaste",
              "kwargs": {
                  "mostTastyResourceClass": -1,  # -1 indicates no preference.
                  "zeroDefaultInteractionReward": False,
                  "extraReward": 1.0,
              }
          },
      ]
  }

  # Dummy avatars have some extra components.
  if dummy and not dummy_without_hat:
    avatar_object["components"].extend([
        {
            "component": "DisallowMovement",
            "kwargs": {}
        },
        {
            "component": "InitializeAsReadyToInteract",
            "kwargs": {
                "playerIndex": lua_index,
            }
        },
    ])
  elif dummy and dummy_without_hat:
    avatar_object["components"].extend([
        {
            "component": "DisallowMovement",
            "kwargs": {}
        },
    ])
  return avatar_object


def get_indicator_color_palette(
    color_rgba: Dict[str, Tuple[int, ...]]) -> Dict[str, Tuple[int, ...]]:
  """Create a color palette for the indicator color.

  Args:
    color_rgba: rgba of original color.
  Returns:
    A color palette for the indicator color.
  """
  indicator_palette = copy.deepcopy(shapes.GOLD_CROWN_PALETTE)
  indicator_palette["#"] = color_rgba
  slightly_darker_color = [round(value * 0.9) for value in color_rgba[:-1]]
  slightly_darker_color.append(150)  # Add a half transparent alpha channel.
  indicator_palette["@"] = slightly_darker_color
  return indicator_palette


def create_ready_to_interact_marker(
    player_idx: int,
    shape: str = shapes.BRONZE_CAP,
    palette: Dict[str, Tuple[int, ...]] = shapes.SILVER_CROWN_PALETTE,
    color_cycle: Tuple[Tuple[int, ...], ...] = (
        (139, 0, 0, 255),  # red
        (253, 184, 1, 255),  # yellow
        (0, 102, 0, 255),  # green
        (2, 71, 254, 255),  # blue
        (127, 0, 255, 255),  # violet
    ),
) -> Dict[str, Any]:
  """Create a ready-to-interact marker overlay specification.

  Args:
    player_idx: index of the player.
    shape: shape of the marker.
    palette: color palette.
    color_cycle: color cycle.
  Returns:
    A ready-to-interact marker overlay specification.
  """
  # Lua is 1-indexed.
  lua_idx = player_idx + 1

  marking_object = {
      "name": "avatarReadyToInteractMarker",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "avatarMarkingWait",
                  "stateConfigs": [
                      # Use `overlay` layer for ready and nonready states, both
                      # are used for live avatars and are always connected.
                      {"state": "ready",
                       "layer": "overlay",
                       "sprite": "Ready"},
                      {"state": "notReady",
                       "layer": "overlay"},

                      # Result indication colors.
                      {"state": "resultIndicatorColor1",
                       "layer": "overlay",
                       "sprite": "ResultIndicatorColor1"},
                      {"state": "resultIndicatorColor2",
                       "layer": "overlay",
                       "sprite": "ResultIndicatorColor2"},
                      {"state": "resultIndicatorColor3",
                       "layer": "overlay",
                       "sprite": "ResultIndicatorColor3"},
                      {"state": "resultIndicatorColor4",
                       "layer": "overlay",
                       "sprite": "ResultIndicatorColor4"},
                      {"state": "resultIndicatorColor5",
                       "layer": "overlay",
                       "sprite": "ResultIndicatorColor5"},

                      # Invisible inactive overlay type.
                      {"state": "avatarMarkingWait",
                       "groups": ["avatarMarkingWaits"]},
                  ]
              }
          },
          {
              "component": "Transform",
          },
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [
                      "Ready",
                      "ResultIndicatorColor1",
                      "ResultIndicatorColor2",
                      "ResultIndicatorColor3",
                      "ResultIndicatorColor4",
                      "ResultIndicatorColor5",
                  ],
                  "spriteShapes": [shape,] * 6,
                  "palettes": [
                      palette,
                      # Colors are in rainbow order (more or less).
                      get_indicator_color_palette(color_cycle[0]),
                      get_indicator_color_palette(color_cycle[1]),
                      get_indicator_color_palette(color_cycle[2]),
                      get_indicator_color_palette(color_cycle[3]),
                      get_indicator_color_palette(color_cycle[4]),
                  ],
                  "noRotates": [True,] * 6,
              }
          },
          {
              "component": "AvatarConnector",
              "kwargs": {
                  "playerIndex": lua_idx,
                  "aliveState": "notReady",  # state `notReady` is invisible.
                  "waitState": "avatarMarkingWait"
              }
          },
          {
              "component": "ReadyToInteractMarker",
              "kwargs": {
                  "playerIndex": lua_idx,
              }
          },
      ]
  }
  return marking_object


def create_avatar_objects(num_players: int,
                          preferences: Sequence[Mapping],
                          evaluation: bool = False) -> List[Dict[str, Any]]:
  """Returns list of avatar objects of length 'num_players'.

  Args:
    num_players: number of players.
    preferences: preferences.
      evaluation: whether to use evaluation or not.
  Returns:
    List of avatar specifications.
  """
  avatar_objects = []
  for player_idx in range(0, num_players):
    color = preferences[player_idx].get("color")
    dummy = preferences[player_idx].get("dummy", False)
    dummy_without_hat = preferences[player_idx].get("dummy_without_hat", False)
    face = preferences[player_idx].get("face", None)
    bad_apple = preferences[player_idx].get("bad_apple", False)
    sprite_data = get_sprite_data(color, face)
    game_object = create_avatar_object(
        player_idx,
        sprite_data=sprite_data,
        dummy=dummy,
        dummy_without_hat=dummy_without_hat,
        color=color,
        evaluation=evaluation,
        bad_apple=bad_apple,
    )
    avatar_objects.append(game_object)
    readiness_marker = create_ready_to_interact_marker(
        player_idx,
        shape=HEAD_BAND,
        palette=HEAD_BAND_PALETTE,
        color_cycle=(
            (5, 5, 5, 255),
            (100, 100, 100, 255),
            (150, 150, 150, 255),
            (200, 200, 200, 255),
            (255, 255, 255, 255),
        )
    )
    avatar_objects.append(readiness_marker)

  return avatar_objects


def create_random_char_prefab_map(
    resource_chars: str,
    resource_names: List[str],
    randomization: str = "full",
    spawn_point: str = "P",
    red_spawn_point: str = "R",
    blue_spawn_point: str = "B",
) -> Dict[str, Any]:
  """Creates a randomized character prefab map.

  Args:
    resource_chars: string of characters which are resources.
    resource_names: list of resource names.
    randomization: randomization type. If "full", then each character is fully
      random. If "grouped", then each resource char corresponds to a specific
      resource name, which is randomized.
    spawn_point: the general spawn point.
    red_spawn_point: the red spawn point.
    blue_spawn_point: the blue spawn point.

  Returns:
    The char prefab map.
  """
  assert len(resource_chars) == len(resource_names), (
      "Must have as many names as chars.")

  if randomization == "full":
    randomized_char_prefab_map = {c: {"type": "choice", "list": resource_names
                                      } for c in resource_chars}

  elif randomization == "grouped":
    resource_names_shuffled = resource_names[:]
    random.shuffle(resource_names_shuffled)
    randomized_char_prefab_map = {
        c: resource_names_shuffled[i] for i, c in enumerate(resource_chars)
    }
  else:
    raise ValueError(f"Unknown randomization {randomization}.")

  return {
      # 'a' is always random.
      "a": {"type": "choice", "list": resource_names},
      spawn_point: "spawn_point",
      red_spawn_point: "red_spawn_point",
      blue_spawn_point: "blue_spawn_point",
      "W": "wall",
      **randomized_char_prefab_map,
  }


def get_config() -> config_dict.ConfigDict:
  """Default configuration.

  Returns:
    Default config.
  """
  config = config_dict.ConfigDict()

  config.evaluation = False

  # Specify the number of players to particate in each episode (optional).
  config.recommended_num_players = 8

  config.randomization_mode = "grouped"  # choose from {"full", "grouped"}.

  # Action set configuration.
  config.action_set = ACTION_SET
  # Observation format configuration.
  config.individual_observation_names = [
      "RGB",
      "INVENTORY",
      "READY_TO_SHOOT",
  ]
  config.global_observation_names = [
      "WORLD.RGB",
  ]

  # Allow overriding of the layout and episode length.
  config.layout = config_dict.ConfigDict()
  config.layout.ascii_map = ASCII_MAP
  config.layout.char_prefab_map = CHAR_PREFAB_MAP
  config.layout.maxEpisodeLengthFrames = 5000

  return config


def build_environment(
    preferences: Sequence[Mapping],
    config: config_dict.ConfigDict,
) -> config_dict.ConfigDict:
  """Build substrate definition given player preferences.

  Args:
    preferences: player preferences.
    config: config.
  Returns:
    Environment config.
  """

  num_players = len(preferences)
  char_prefab_map = config.layout.char_prefab_map
  if not config.evaluation:
    char_prefab_map = create_random_char_prefab_map(
        resource_chars="123",
        resource_names=_resource_names,
        randomization=config.randomization_mode,
        red_spawn_point="R",
        blue_spawn_point="B",
        spawn_point="P",
    )
  # Build the rest of the substrate definition.
  config.lab2d_settings = {
      "levelName": "the_matrix",
      "levelDirectory": "meltingpot/lua/levels",
      "numPlayers": num_players,
      # Define upper bound of episode length since episodes end stochastically.
      "maxEpisodeLengthFrames": config.layout.maxEpisodeLengthFrames,
      "spriteSize": 8,
      "topology": "BOUNDED",  # Choose from ["BOUNDED", "TORUS"],
      "simulation": {
          "map": config.layout.ascii_map,
          "gameObjects": create_avatar_objects(num_players=num_players,
                                               preferences=preferences,
                                               evaluation=config.evaluation),
          "scene": create_scene(),
          "prefabs": create_prefabs(),
          "charPrefabMap": char_prefab_map,
      }
  }

  return builder.builder(config.lab2d_settings, {}, env_seed=3)


In [None]:
# Define a laboratory style evaluation map here.
EVAL_MAP = """
  aa     P
D        P
  aa 11  P
  aa 22  P
  aa 33  P
D        P
  aa     P
"""

_resource_names = [
    "resource_class1",
    "resource_class2",
    "resource_class3",
]

# `prefab` determines which prefab game object to use for each `char` in the
# ascii map.
CHAR_PREFAB_MAP = {
    # A randomly chosen resource from `_resource_names` will be placed at each
    # location in the ascii map there is an 'a'.
    "a": {"type": "choice", "list": _resource_names},
    "P": "spawn_point",
    "D": "dummy_spawn_point",
    "W": "wall",
    "1": _resource_names[0],
    "2": _resource_names[1],
    "3": _resource_names[2],
}


def get_eval_config() -> config_dict.ConfigDict:
  """Default configuration.

  Returns:
    Config with the evaluation specification.
  """
  config = get_config()
  config.evaluation = True
  # Override the map layout settings.
  config.layout = config_dict.ConfigDict()
  config.layout.ascii_map = EVAL_MAP
  config.layout.char_prefab_map = CHAR_PREFAB_MAP
  config.layout.maxEpisodeLengthFrames = 1000
  return config


In [None]:
# Specify population parameters.
# The length of the preferences map determines the population size.
default_preferences = [
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},

    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
]

validation_preferences = [
    {"color": "red", "dummy": True},
    {"color": "green", "dummy": True},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},
    {"color": "red"},

    {"color": "blue", "dummy": True},
    {"color": "green", "dummy": True},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
    {"color": "blue"},
]

preferences_individuation = [
    {"color": "red", "bad_apple": True, "face": 0},
    {"color": "red", "face": 1},
    {"color": "red", "face": 2},
    {"color": "red", "face": 3},
    {"color": "red", "face": 4},
    {"color": "red", "face": 5},
    {"color": "red", "face": 6},
    {"color": "red", "face": 7},

    {"color": "blue", "bad_apple": True, "face": 7},
    {"color": "blue", "face": 6},
    {"color": "blue", "face": 5},
    {"color": "blue", "face": 4},
    {"color": "blue", "face": 3},
    {"color": "blue", "face": 2},
    {"color": "blue", "face": 1},
    {"color": "blue", "face": 0},
]

preferences_individuation_eval = [
    {"color": "red", "dummy": True, "face": 0},
    {"color": "red", "dummy": True, "face": None},
    {"color": "red", "dummy": True, "face": 2},
    {"color": "red", "dummy": True, "face": 3},
    {"color": "red", "face": 4},
    {"color": "red", "face": 5},
    {"color": "red", "face": 6},
    {"color": "red", "face": 7},

    {"color": "blue", "dummy": True, "face": 7},
    {"color": "blue", "dummy": True, "face": None},
    {"color": "blue", "dummy": True, "face": 5},
    {"color": "blue", "dummy": True, "face": 4},
    {"color": "blue", "face": 3},
    {"color": "blue", "face": 2},
    {"color": "blue", "face": 1},
    {"color": "blue", "face": 0},
]

In [None]:
#@title Example of a default configuration

env = build_environment(preferences=default_preferences, config=get_config())

env.reset().observation["WORLD.RGB"]