Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
!pip install meltingpot

In [None]:
#@title Environment configuration & build function
"""Configuration for creating a Boat Race environment."""
from collections.abc import Mapping, MutableMapping, Sequence

import enum

from meltingpot.utils.substrates import builder
from meltingpot.utils.substrates import colors
from meltingpot.utils.substrates import game_object_utils
from meltingpot.utils.substrates import shapes
from ml_collections import config_dict as configdict

import numpy as np

from typing import Any


# The number of races
NUM_RACES = 8 # @param {type: "integer"}
# if non zero, additional discount on every step
CROWN_BETA = 0.0 # @param {type: "number"}

# These should be treated as constants. These are the values used for the
# published article.
NUM_PLAYERS = 6
PARTNER_DURATION = 75
RACE_DURATION = 225
# after 4 straight rows, you get a crown,
# but need to roughly keep a ratio of 2:1 to maintain it.
CROWN_TURN_OFF_THRESHOLD = 0.3
CROWN_TURN_ON_THRESHOLD = 0.6
# higher alpha discounts older observations faster.
CROWN_ALPHA = 0.2
# For efficient learning steps, this value should align with the unroll length
# of your agent.
EARLY_EXIT_CHECK_INTERVAL = 100

ASCII_MAP = r"""
WWWWWWWWWWWWWWWWWWWWWWWWWW
W                        W
W                        W
W                        W
W      RRRRRRRRRRRR      W
W      RRRRRRRRRRRR      W
W      RRRRRRRRRRRR      W
W      RRRRRRRRRRRR      W
W                        W
W      S  SS  SS  S      W
W      S%%SS%%SS%%S      W
W      S  SS  SS  S      W
~~~~~~~~gg~~gg~~gg~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~AA~~AA~~AA~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~AA~~AA~~AA~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~AA~~AA~~AA~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~{{~~{{~~{{~~~~~~~~
~~~~~~~~AA~~AA~~AA~~~~~~~~
~~~~~~~~/Y~~/Y~~/Y~~~~~~~~
~~~~~~~p;:qp;:qp;:q~~~~~~~
W      SLJSSLJSSLJS      W
W      S--SS--SS--S      W
W      S  SS  SS  S      W
W                        W
W      OOOOOOOOOOOO      W
W      OOOOOOOOOOOO      W
W      OOOOOOOOOOOO      W
W      OOOOOOOOOOOO      W
W                        W
W    ________________    W
W    ________________    W
WWWWWWWWWWWWWWWWWWWWWWWWWW
"""

# This mapping determines which prefab game object to use for each `char` in the
# ASCII map.
CHAR_PREFAB_MAP = {
    "_": "spawn_point",
    "W": "wall",
    "S": "semaphore",
    "A": {"type": "all", "list": ["water_background", "single_apple"]},
    "R": "respawning_apple_north",
    "O": "respawning_apple_south",
    "%": "barrier_north",
    "-": "barrier_south",
    "~": "water_blocking",
    "{": "water_background",
    "g": {"type": "all", "list": ["goal_north", "water_background"]},
    "/": {"type": "all", "list": ["boat_FL", "water_background"]},
    "Y": {"type": "all", "list": ["boat_FR", "water_background"]},
    "L": "boat_RL",
    "J": "boat_RR",
    "p": {"type": "all", "list": ["oar_L", "water_blocking"]},
    "q": {"type": "all", "list": ["oar_R", "water_blocking"]},
    ";": {"type": "all", "list": ["seat_L", "goal_south", "water_background"]},
    ":": {"type": "all", "list": ["seat_R", "goal_south", "water_background"]},
}

_COMPASS = ["N", "E", "S", "W"]

Prefab = Mapping[str, Any]

class BankSide(enum.Enum):
  NORTH = "N"
  SOUTH = "S"


def get_scene(early_exit_check_interval: int,
              early_exit_on_any: bool) -> Prefab:
  """Returns a scene object managing metrics and early exit.

  Args:
    early_exit_check_interval: The interval between checking for early exit.
    early_exit_on_any: Whether to early exit on any timestep. Ignores
      `early_exit_check_interval` if True.

  Returns:
    The scene object.
  """
  return {
      "name": "scene",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  # Set to "ForceEmbark" to start the boats half-filled.
                  "initialState": "partnerChoice",
                  "stateConfigs": [
                      {
                          "state": "ForceEmbark",
                      }, {
                          "state": "partnerChoice",
                      }, {
                          "state": "semaphore_yellow",
                      }, {
                          "state": "semaphore_green",
                      }, {
                          "state": "boatRace",
                      }, {
                          "state": "semaphore_red",
                          # A temporary state at end game.
                      }],
              }
          },
          {"component": "Transform",},
          {
              "component": "RaceManager",
              "kwargs": {
                  "raceStartTime": PARTNER_DURATION,
                  "raceDuration": RACE_DURATION,
              },
          },
          {
              "component": "EpisodeManager",
              "kwargs": {
                  "checkInterval": early_exit_check_interval,
                  "earlyExitOnAny": early_exit_on_any,
              },
          },
          {
              "component": "GlobalMetricReporter",
              "kwargs": {
                  "metrics": [
                      {
                          "name": "RACE_START",
                          "type": "tensor.Int32Tensor",
                          "shape": (NUM_PLAYERS // 2, 2),
                          "component": "GlobalRaceTracker",
                          "variable": "raceStart"
                      },
                      {
                          "name": "STROKES",
                          "type": "tensor.Int32Tensor",
                          "shape": (NUM_PLAYERS,),
                          "component": "GlobalRaceTracker",
                          "variable": "strokes"
                      },
                  ]
              }
          },
          {
              "component": "GlobalRaceTracker",
              "kwargs": {
                  "numPlayers": NUM_PLAYERS,
              },
          },
      ]
  }


WALL = {
    "name": "wall",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "wall",
                "stateConfigs": [{
                    "state": "wall",
                    "layer": "upperPhysical",
                    "sprite": "Wall",
                }],
            }
        },
        {"component": "Transform",},
        {
            "component": "Appearance",
            "kwargs": {
                "renderMode": "ascii_shape",
                "spriteNames": ["Wall",],
                "spriteShapes": [shapes.WALL],
                "palettes": [{"*": (95, 95, 95, 255),
                              "&": (100, 100, 100, 255),
                              "@": (109, 109, 109, 255),
                              "#": (152, 152, 152, 255)}],
                "noRotates": [True]
            }
        },
        {
            "component": "BeamBlocker",
            "kwargs": {
                "beamType": "gift"
            }
        },
        {
            "component": "BeamBlocker",
            "kwargs": {
                "beamType": "zap"
            }
        },
    ]
}

SPAWN_POINT = {
    "name": "spawnPoint",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "spawnPoint",
                "stateConfigs": [{
                    "state": "spawnPoint",
                    "layer": "logic",
                    "groups": ["spawnPoints"]
                }],
            }
        },
        {"component": "Transform",},
    ]
}


SINGLE_APPLE = {
    "name": "single_apple",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "apple",
                "stateConfigs": [
                    {"state": "apple",
                     "layer": "singleAppleLayer",
                     "sprite": "apple",
                    },
                    {"state": "appleWait",
                     "layer": "logic",
                    },
                ]
            }
        },
        {"component": "Transform",},
        {
            "component": "Appearance",
            "kwargs": {
                "renderMode": "ascii_shape",
                "spriteNames": ["apple"],
                "spriteShapes": [shapes.HD_APPLE],
                "palettes": [shapes.get_palette((40, 180, 40, 255))],
                "noRotates": [False],
            }
        },
        {
            "component": "Edible",
            "kwargs": {
                "liveState": "apple",
                "waitState": "appleWait",
                "rewardForEating": 1.0,
            }
        },
    ]
}


def get_respawning_apple(bank_side: BankSide) -> Prefab:
  """Returns a respawning apple for the given bank side.

  Args:
    bank_side: The side of the river the apple is on. If north, it will start
      active, otherwise, it will be deactivated.

  Returns:
    A prefab config for the respawning apple.
  """
  initial_state = "apple" if bank_side == BankSide.NORTH else "applePause"
  return {
      "name": "apple",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": initial_state,
                  "stateConfigs": [
                      {"state": "apple",
                       "layer": "superOverlay",
                       "sprite": "apple",
                      },
                      {"state": "appleWait",
                       "layer": "logic",
                      },
                      {"state": "applePause",
                       "layer": "logic",
                      },
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": ["apple"],
                  "spriteShapes": [shapes.HD_APPLE],
                  "palettes": [shapes.get_palette((40, 180, 40, 255))],
                  "noRotates": [False],
              }
          },
          {
              "component": "Edible",
              "kwargs": {
                  "liveState": "apple",
                  "waitState": "appleWait",
                  "rewardForEating": 1.0,
              }
          },
          {
              "component": "FixedRateRegrow",
              "kwargs": {
                  "liveState": "apple",
                  "waitState": "appleWait",
                  "regrowRate": 0.1,
              }
          },
      ]
  }


SEMAPHORE = {
    "name": "semaphore",
    "components": [
        {
            "component": "StateManager",
            "kwargs": {
                "initialState": "red",
                "stateConfigs": [
                    {"state": "red",
                     "layer": "upperPhysical",
                     "sprite": "red",
                     "groups": ["semaphore"]},
                    {"state": "yellow",
                     "layer": "upperPhysical",
                     "sprite": "yellow",
                     "groups": ["semaphore"]},
                    {"state": "green",
                     "layer": "upperPhysical",
                     "sprite": "green",
                     "groups": ["semaphore"]},
                ]
            }
        },
        {"component": "Transform",},
        {
            "component": "Appearance",
            "kwargs": {
                "renderMode": "ascii_shape",
                "spriteNames": ["red", "yellow", "green"],
                "spriteShapes": [shapes.COIN] * 3,
                "palettes": [shapes.RED_COIN_PALETTE, shapes.COIN_PALETTE,
                             shapes.GREEN_COIN_PALETTE],
                "noRotates": [False] * 3,
            }
        },
    ]
}


def get_barrier(bank_side: BankSide = BankSide.NORTH) -> Prefab:
  """Get a barrier prefab at the specified side.

  Args:
    bank_side: The side where the barrier will be placed. Barriers on the north
      start off, the ones on the south start on.

  Returns:
    A prefab config for the barrier prefab.
  """
  initial_state = "off" if bank_side == BankSide.NORTH else "on"
  return {
      "name": "barrier",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": initial_state,
                  "stateConfigs": [
                      {"state": "on",
                       "layer": "upperPhysical",
                       "sprite": "barrierOn",
                       "groups": ["barrier"]},
                      {"state": "off",
                       "layer": "superOverlay",
                       "sprite": "barrierOff",
                       "groups": ["barrier"]},
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": ["barrierOn", "barrierOff"],
                  "spriteShapes": [shapes.BARRIER_ON, shapes.BARRIER_OFF],
                  "palettes": [shapes.GRAY_PALETTE] * 2,
                  "noRotates": [False] * 2,
              }
          },
      ]
  }


def get_water(layer: str) -> Prefab:
  """Get a water game object at the specified layer, possibly with a goal.

  Args:
    layer: The layer to place the water.

  Returns:
    A prefab config for the water prefab.
  """
  return {
      "name": "water_{}".format(layer),
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "water_1",
                  "stateConfigs": [
                      {"state": "water_1",
                       "layer": layer,
                       "sprite": "water_1",
                       "groups": ["water"]},
                      {"state": "water_2",
                       "layer": layer,
                       "sprite": "water_2",
                       "groups": ["water"]},
                      {"state": "water_3",
                       "layer": layer,
                       "sprite": "water_3",
                       "groups": ["water"]},
                      {"state": "water_4",
                       "layer": layer,
                       "sprite": "water_4",
                       "groups": ["water"]},
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": ["water_1", "water_2", "water_3", "water_4"],
                  "spriteShapes": [shapes.WATER_1, shapes.WATER_2,
                                   shapes.WATER_3, shapes.WATER_4],
                  "palettes": [shapes.WATER_PALETTE] * 4,
              }
          },
          {
              "component": "Animation",
              "kwargs": {
                  "states": ["water_1", "water_2", "water_3", "water_4"],
                  "gameFramesPerAnimationFrame": 2,
                  "loop": True,
                  "randomStartFrame": True,
                  "group": "water",
              }
          },
      ]
  }


def get_goal(bank_side: BankSide = BankSide.NORTH) -> Prefab:
  """Get a water goal prefab.

  Args:
    bank_side: The side of the water goal.

  Returns:
    A prefab config for the water goal prefab.
  """
  return {
      "name": "water_goal",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "goalNonBlocking",
                  "stateConfigs": [{
                      "state": "goalNonBlocking",
                      "layer": "logic",
                  }, {
                      "state": "goalBlocking",
                      "layer": "upperPhysical",
                  }],
              }
          },
          {"component": "Transform",},
          {
              "component": "WaterGoal",
              "kwargs": {
                  "bank_side": bank_side.value,
              },
          }
      ]
  }


def get_boat(front: bool, left: bool) -> Prefab:
  """Get a boat prefab corresponding to a piece of the full boat.

  Args:
    front: Whether the piece is front or rear.
    left: Whether the piece is left or right.

  Returns:
    A prefab config for the boat piece prefab.
  """
  suffix = "{}{}".format("F" if front else "R", "L" if left else "R")
  shape = {
      "FL": shapes.BOAT_FRONT_L,
      "FR": shapes.BOAT_FRONT_R,
      "RL": shapes.BOAT_REAR_L,
      "RR": shapes.BOAT_REAR_R,
  }
  return {
      "name": f"boat_{suffix}",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "boat",
                  "stateConfigs": [
                      {"state": "boat",
                       "layer": "lowerPhysical",
                       "sprite": f"Boat{suffix}",
                       "groups": ["boat"]},
                      {"state": "boatFull",
                       "layer": "overlay",
                       "sprite": f"Boat{suffix}",
                       "groups": ["boat"]},
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [f"Boat{suffix}"],
                  "spriteShapes": [shape[suffix]],
                  "palettes": [shapes.BOAT_PALETTE],
                  "noRotates": [False]
              }
          },
      ]
  }


def get_seat(left: bool) -> Prefab:
  """Get a seat prefab. Left seats contain the BoatManager component.

  Args:
    left: Whether the seat is left or right.

  Returns:
    A prefab config for the seat prefab.
  """
  suffix = "L" if left else "R"
  shape = {
      "L": shapes.BOAT_SEAT_L,
      "R": shapes.BOAT_SEAT_R,
  }
  seat = {
      "name": f"seat_{suffix}",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "seat",
                  "stateConfigs": [
                      {"state": "seat",
                       "layer": "lowerPhysical",
                       "sprite": f"Seat{suffix}",
                       "groups": ["seat", "boat"]},
                      {"state": "seatTaken",
                       "layer": "overlay",
                       "sprite": f"Seat{suffix}",
                       "contact": "boat"},
                      {"state": "seatUsed",
                       "layer": "lowerPhysical",
                       "sprite": f"Seat{suffix}"},
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [f"Seat{suffix}"],
                  "spriteShapes": [shape[suffix]],
                  "palettes": [shapes.BOAT_PALETTE],
                  "noRotates": [False]
              }
          },
          {
              "component": "Seat",
              "kwargs": {
              },
          },
      ]
  }
  if left:
    seat["components"] += [
        {
            "component": "BoatManager",
            "kwargs": {
                "flailEffectiveness": 0.1,
            }
        }
    ]
  return seat


def get_oar(left: bool) -> Prefab:
  """Get an oar prefab.

  Args:
    left: Whether the oar is left or right.

  Returns:
    A prefab config for the oar prefab.
  """
  suffix = "L" if left else "R"
  shape = {
      "L": [shapes.OAR_DOWN_L, shapes.OAR_UP_L, shapes.OAR_UP_L],
      "R": [shapes.OAR_DOWN_R, shapes.OAR_UP_R, shapes.OAR_UP_R],
  }
  return {
      "name": f"oar_{suffix}",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "oarDown",
                  "stateConfigs": [
                      {"state": "oarDown",
                       "layer": "overlay",
                       "sprite": f"OarDown{suffix}",
                       "groups": ["oar", "boat"]},

                      {"state": "oarUp_row",
                       "layer": "overlay",
                       "sprite": f"OarUp{suffix}Row",
                       "groups": ["oar", "boat"]},

                      {"state": "oarUp_flail",
                       "layer": "overlay",
                       "sprite": f"OarUp{suffix}Flail",
                       "groups": ["oar", "boat"]},
                  ]
              }
          },
          {"component": "Transform",},
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [
                      f"OarDown{suffix}",
                      f"OarUp{suffix}Row",
                      f"OarUp{suffix}Flail",
                  ],
                  "spriteShapes": shape[suffix],
                  "palettes": [shapes.GRAY_PALETTE] * 3,
                  "noRotates": [False] * 3
              }
          },
          {
              "component": "AdditionalSprites",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "customSpriteNames": [
                      "OarUpL_green",
                      "OarUpR_green",
                      "OarUpL_red",
                      "OarUpR_red",
                  ],
                  "customSpriteShapes": [
                      shapes.OAR_UP_L,
                      shapes.OAR_UP_R,
                      shapes.OAR_UP_L,
                      shapes.OAR_UP_R,
                  ],
                  "customPalettes": [
                      shapes.get_palette((0, 255, 0)),
                      shapes.get_palette((0, 255, 0)),
                      shapes.get_palette((255, 0, 0)),
                      shapes.get_palette((255, 0, 0)),
                  ],
                  "customNoRotates": [False] * 2,
              }
          },
      ]
  }


def get_avatar(custom_sprite_map: Mapping[str, str],
               crown_alpha: float,
               crown_beta: float):
  """Get avatar prefab with custom sprite map.

  Args:
    custom_sprite_map: A custom sprite map. Used for implementing perceptual
      interventions.
    crown_alpha: The exponential decay parameter for crown.
    crown_beta: The linear decay parameter for crown.

  Returns:
    A prefab config for the avatar prefab.
  """
  avatar = {
      "name":
          "avatar",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState":
                      "player",
                  "stateConfigs": [
                      {
                          "state": "player",
                          "layer": "upperPhysical",
                          "sprite": "Avatar",
                          "contact": "avatar",
                          "groups": ["players"]
                      },
                      {
                          "state": "playerWait",
                          "groups": ["playerWaits"]
                      },
                      {
                          "state": "rowing",
                          "layer": "superOverlay",
                          "sprite": "Avatar",
                          "contact": "avatar",
                          "groups": ["players"]
                      },
                      {
                          "state": "landed",
                          "layer": "upperPhysical",
                          "sprite": "Avatar",
                          "contact": "avatar",
                          "groups": ["players"]
                      },
                  ]
              }
          },
          {
              "component": "Transform",
          },
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": ["Avatar"],
                  "spriteShapes": [shapes.HD_AVATAR_W_BADGE],
                  "palettes": [shapes.get_palette(colors.palette[0])],
                  "noRotates": [True]
              }
          },
          {
              "component": "Avatar",
              "kwargs": {
                  "index": -1,  # player index to be overwritten.
                  "aliveState": "player",
                  "waitState": "playerWait",
                  "spawnGroup": "spawnPoints",
                  "actionOrder": ["move", "turn", "row", "flail"],
                  "actionSpec": {
                      "move": {
                          "default": 0,
                          "min": 0,
                          "max": len(_COMPASS)
                      },
                      "turn": {
                          "default": 0,
                          "min": -1,
                          "max": 1
                      },
                      "row": {
                          "default": 0,
                          "min": 0,
                          "max": 1
                      },
                      "flail": {
                          "default": 0,
                          "min": 0,
                          "max": 1
                      },
                  },
                  "view": {
                      "left": 6,
                      "right": 6,
                      "forward": 10,
                      "backward": 2,
                      "centered": False
                  },
                  "spriteMap": custom_sprite_map,
              }
          },
          {
              "component": "Rowing",
              "kwargs": {
                  "cooldownTime": 2,
              },
          },
          {
              "component": "Crown",
              "kwargs": {
                  "turnOnThreshold": CROWN_TURN_ON_THRESHOLD,
                  "turnOffThreshold": CROWN_TURN_OFF_THRESHOLD,
                  "alpha": crown_alpha,
                  "beta": crown_beta,
              },
          },
          {
              "component": "StrokesTracker",
          },
      ]
  }
  additional_sprite_names = []
  for i in range(NUM_PLAYERS):
    lua_idx = i + 1
    additional_sprite_names.append("AvatarJustBadge{}".format(lua_idx))

  avatar["components"].append({
      "component": "AdditionalSprites",
      "kwargs": {
          "renderMode": "ascii_shape",
          "customSpriteNames": [
              additional_sprite_names[i] for i in range(NUM_PLAYERS)
          ],
          "customSpriteShapes": [shapes.JUST_BADGE] * NUM_PLAYERS,
          "customPalettes": [
              dict(**shapes.get_palette(colors.palette[0]),
                   **SHORTS_PALETTE[i])
              for i in range(NUM_PLAYERS)
          ],
          "customNoRotates": [True] * NUM_PLAYERS,
      }
  })
  return avatar


def create_colored_avatar_overlay(player_idx: int) -> Prefab:
  """Create a colored avatar overlay object.

  Args:
    player_idx: Player index, zero-indexed.

  Returns:
    A colored avatar overlay object.
  """
  # Lua is 1-indexed.
  lua_idx = player_idx + 1
  overlay_object = {
      "name": "crown_overlay",
      "components": [
          {
              "component": "StateManager",
              "kwargs": {
                  "initialState": "crownWait",
                  "stateConfigs": [
                      {
                          "state": "crownOff",
                          "layer": "logical",
                          "groups": ["crowns"]
                      },
                      {
                          "state": "crownOn",
                          "layer": "superCrownOverlay",
                          "sprite": "CrownOnInvisible",
                          "groups": ["crowns"]
                      },
                      {
                          "state": "crownOnLanded",
                          "layer": "superCrownOverlay",
                          "sprite": "CrownOnInvisible",
                          "groups": ["crowns"]
                      },
                      {
                          "state": "crownWait",
                          "groups": ["crownWaits"]
                      },
                  ]
              }
          },
          {
              "component": "Transform",
          },
          {
              "component": "Appearance",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "spriteNames": [
                      "CrownOnInvisible",
                  ],
                  "spriteShapes": [shapes.AVATAR_DEFAULT],
                  "palettes": [shapes.INVISIBLE_PALETTE],
                  "noRotates": [True],
              }
          },
          {
              "component": "AvatarConnector",
              "kwargs": {
                  "playerIndex": lua_idx,
                  "aliveState": "crownOff",
                  "waitState": "crownWait"
              }
          },
          {
              "component": "AdditionalSprites",
              "kwargs": {
                  "renderMode": "ascii_shape",
                  "customSpriteNames": [
                      "CrownOnVisible",
                  ],
                  "customSpriteShapes": [
                      shapes.HD_CROWN,
                  ],
                  "customPalettes": [
                      shapes.CROWN_PALETTE,
                  ],
                  "customNoRotates": [True],
              }
          }
      ]
  }
  return overlay_object


def create_avatar_overlays(num_players: int) -> Sequence[Prefab]:
  """Returns list of avatar associated objects.

  Args:
    num_players: Number of players.

  Returns:
    List of avatar overlay objects.
  """
  crown_objects = []
  for python_idx in range(num_players):
    overlay_object = create_colored_avatar_overlay(python_idx)
    crown_objects.append(overlay_object)

  return crown_objects


# Colors for the badge (referred to as SHORTS in the code).
GRAY = (50, 50, 50, 255)
WHITE = (255, 255, 255, 255)
SHORTS_PALETTE = [
    {"a": GRAY, "b": GRAY, "c": GRAY, "d": GRAY},
    {"a": WHITE, "b": GRAY, "c": GRAY, "d": GRAY},
    {"a": WHITE, "b": WHITE, "c": GRAY, "d": GRAY},
    {"a": WHITE, "b": WHITE, "c": WHITE, "d": GRAY},
    {"a": WHITE, "b": WHITE, "c": WHITE, "d": WHITE},
    {"a": GRAY, "b": WHITE, "c": GRAY, "d": WHITE},
]

# PREFABS is a dictionary mapping names to template game objects that can
# be cloned and placed in multiple locations accoring to an ascii map.
PREFABS = {
    "wall": WALL,
    "spawn_point": SPAWN_POINT,
    "water_blocking": get_water("upperPhysical"),
    "water_background": get_water("background"),
    "goal_north": get_goal(bank_side=BankSide.NORTH),
    "goal_south": get_goal(bank_side=BankSide.SOUTH),
    "barrier_north": get_barrier(bank_side=BankSide.NORTH),
    "barrier_south": get_barrier(bank_side=BankSide.SOUTH),
    "single_apple": SINGLE_APPLE,
    "respawning_apple_north": get_respawning_apple(bank_side=BankSide.NORTH),
    "respawning_apple_south": get_respawning_apple(bank_side="SOUTH"),
    "semaphore": SEMAPHORE,
    "boat_FL": get_boat(front=True, left=True),
    "boat_FR": get_boat(front=True, left=False),
    "boat_RL": get_boat(front=False, left=True),
    "boat_RR": get_boat(front=False, left=False),
    "seat_L": get_seat(left=True),
    "seat_R": get_seat(left=False),
    "oar_L": get_oar(left=True),
    "oar_R": get_oar(left=False),
}

# Avatar base colors.
PURPLE = (145, 30, 180)
TEAL = (30, 180, 145)

# PLAYER_COLOR_PALETTES is a list with each entry specifying the color to use
# for the player at the corresponding index.
# These correspond to the persistent agent colors, but are meaningless for the
# human player. They will be overridden by the environment builder if avatar
# configurations are specified.
PLAYER_COLOR_PALETTES = [
    shapes.get_palette(PURPLE),
    shapes.get_palette(TEAL),
    shapes.get_palette(PURPLE),
    shapes.get_palette(TEAL),
    shapes.get_palette(PURPLE),
    shapes.get_palette(TEAL),
]

# Primitive action components.
# pylint: disable=bad-whitespace
# pyformat: disable
NOOP       = {"move": 0, "turn":  0, "row": 0, "flail": 0}
FORWARD    = {"move": 1, "turn":  0, "row": 0, "flail": 0}
STEP_RIGHT = {"move": 2, "turn":  0, "row": 0, "flail": 0}
BACKWARD   = {"move": 3, "turn":  0, "row": 0, "flail": 0}
STEP_LEFT  = {"move": 4, "turn":  0, "row": 0, "flail": 0}
TURN_LEFT  = {"move": 0, "turn": -1, "row": 0, "flail": 0}
TURN_RIGHT = {"move": 0, "turn":  1, "row": 0, "flail": 0}
ROW        = {"move": 0, "turn":  0, "row": 1, "flail": 0}
FLAIL      = {"move": 0, "turn":  0, "row": 0, "flail": 1}
# pyformat: enable
# pylint: enable=bad-whitespace

ACTION_SET = (
    NOOP,
    FORWARD,
    BACKWARD,
    STEP_LEFT,
    STEP_RIGHT,
    TURN_LEFT,
    TURN_RIGHT,
    ROW,
    FLAIL,
)


def get_config(
    player_palettes: Sequence[MutableMapping[str, Any]] = PLAYER_COLOR_PALETTES,
    num_races: int = NUM_RACES,
    crown_alpha: float = CROWN_ALPHA,
    crown_beta: float = CROWN_BETA,
    early_exit_on_any: bool = False) -> configdict.ConfigDict:
  """Default configuration for training on the boat_race level."""
  config = configdict.ConfigDict()

  config.num_races = num_races
  config.early_exit_check_interval = EARLY_EXIT_CHECK_INTERVAL
  game_objects = create_avatar_overlays(NUM_PLAYERS)
  for i, palette in enumerate(player_palettes):
    palette.update(SHORTS_PALETTE[i])

  # Lua script configuration.
  config.lab2d_settings = {
      "levelName": "boat_race",
      "levelDirectory": "meltingpot/lua/levels",
      "numPlayers": NUM_PLAYERS,
      "maxEpisodeLengthFrames":
          config.get_ref("num_races") * (PARTNER_DURATION + RACE_DURATION),
      "spriteSize": 16,
      "simulation": {
          "map": ASCII_MAP,
          "scene": get_scene(
              config.get_oneway_ref("early_exit_check_interval").get(),
              early_exit_on_any=early_exit_on_any),
          "prefabs": PREFABS | {
              "avatar": get_avatar({}, crown_alpha, crown_beta)},
          "gameObjects": game_objects,
          "charPrefabMap": CHAR_PREFAB_MAP,
          "playerPalettes": player_palettes,
      },
  }

  return config


def get_perceptual_intervention_sprite_map(
    crown: bool,
    just_badges: bool = False,
    oars: bool = False,
) -> Mapping[str, Any]:
  """Generate the naive learner sprite overrides."""
  naive_learner_sprite_overrides = {}

  crown_overrides = {'CrownOnInvisible': 'CrownOnVisible'}

  oar_overrides = {'OarUpLRow': 'OarUpL_green',
                   'OarUpRRow': 'OarUpR_green'}

  justbadge_overrides = {
      'Avatar{}'.format(i + 1): 'AvatarJustBadge{}'.format(i + 1)
      for i in range(NUM_PLAYERS)
  }

  if just_badges:
    naive_learner_sprite_overrides.update(justbadge_overrides)
  if oars:
    naive_learner_sprite_overrides.update(oar_overrides)
  if crown:
    naive_learner_sprite_overrides.update(crown_overrides)

  return naive_learner_sprite_overrides


def build_environment(avatar_configs: Sequence[Mapping[str, Any]] = tuple(),
                      num_races: int = NUM_RACES,
                      crown_alpha: float = CROWN_ALPHA,
                      crown_beta: float = CROWN_BETA,
                      early_exit_on_any: bool = False,
                      seed: int | None = None):
  """Builds the environment.

  Avatar configurations are optional. If present, they must be the same length
  as the number of players. Each avatar config must at least have an
  'avatar_color' key with a tuple of (color_R, color_G, color_B) where each
  color channel is an int in the range 0-255.

  Perceptual interventions are activated on a player by having a 'spriteMap'
  key. The value can be any sprite map, but the convenience function
  `get_perceptual_intervention_sprite_map` is provided to simplify this.

  Args:
    avatar_configs: A list of avatar configurations.

  Returns:
    The meltingpot environment.
  """
  player_palettes = PLAYER_COLOR_PALETTES
  if avatar_configs and 'avatar_color' in avatar_configs[0]:
    assert len(avatar_configs) == NUM_PLAYERS
    player_palettes = [
        shapes.get_palette(avatar_config['avatar_color'])
        for avatar_config in avatar_configs
    ]

  config: configdict.ConfigDict = get_config(
      player_palettes=player_palettes,
      num_races=num_races,
      crown_alpha=crown_alpha,
      crown_beta=crown_beta,
      early_exit_on_any=early_exit_on_any,
  )
  avatar_objects = game_object_utils.build_avatar_objects(
      NUM_PLAYERS,
      config.lab2d_settings.simulation.prefabs,
      player_palettes)
  # We don't need the prefab anymore and we don't want automatic avatars built.
  del config.lab2d_settings.simulation.prefabs['avatar']

  if avatar_configs:
    # Apply the ordered (slot-based) overrides if applicable.
    for i, avatar_object in enumerate(avatar_objects):
      # Assume ordered_roles[i]["overrides"] is list of AgentRoleSpecs.
      if 'spriteMap' in avatar_configs[i]:
        game_object_utils.get_first_named_component(
            avatar_object, "Avatar"
        )["kwargs"]["spriteMap"] = avatar_configs[i]["spriteMap"]

  config.lab2d_settings.simulation.gameObjects += avatar_objects
  return builder.builder(config.lab2d_settings, {}, env_seed=seed)


In [None]:
#@title Example of a default configuration (3 players of each color, no perceptual interventions)

env = build_environment(avatar_configs=[])

env.reset().observation['WORLD.RGB']

In [None]:
#@title Example of a custom configuration with first player being purple & with crown $\beta = 0.02$ intervention, rest are teal without intervention

avatar_configs = [{'avatar_color': TEAL} for _ in range(NUM_PLAYERS)]
avatar_configs[0]['avatar_color'] = PURPLE
avatar_configs[0]['spriteMap'] = get_perceptual_intervention_sprite_map(crown=True)

env = build_environment(avatar_configs=avatar_configs, crown_beta=0.02, seed=28)

ts = env.reset()
ts.observation['WORLD.RGB']

In [None]:
#@title Manually move players 1 & 3 into their boat
for _ in range(80):
  ts = env.step({'1.move': 4, '3.move': 2})

In [None]:
#@title Players 1 & 3 should be boarded at the boat
np.concatenate([ts.observation['1.RGB'], ts.observation['3.RGB']], axis=1)

In [None]:
#@title Manually row for both players (should get a crown)
for _ in range(40):
  ts = env.step({'1.row': 1, '3.row': 1})

In [None]:
#@title Players 1 & 3 should have crowns only from player 1's perspective
np.concatenate([ts.observation['1.RGB'], ts.observation['3.RGB']], axis=1)

In [None]:
#@title Do nithing for a while (crown should decay)
for _ in range(150):
  ts = env.step({})

In [None]:
#@title Players 1 & 3 should not have crowns anymore
np.concatenate([ts.observation['1.RGB'], ts.observation['3.RGB']], axis=1)