<a href="https://colab.research.google.com/github/longislander/coursera/blob/main/chinese_english_ppo_sb3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PPO (Stable-Baselines3)

## Modifications in V8 Version

**1. Simulate the Competition Data Transmission Format in the Training Environment:**

```json
// T is the number of teams (default is 2)
// N is the max number of units per team
// W, H are the width and height of the map
// R is the max number of relic nodes
{
  "obs": {
    "units": {
      "position": Array(T, N, 2),
      "energy": Array(T, N, 1)
    },
    // Indicates whether the unit exists and is visible to you. units_mask[t][i] shows if team t's unit i can be seen and exists.
    "units_mask": Array(T, N),
    // Indicates whether the tile is visible to the unit for that team
    "sensor_mask": Array(W, H),
    "map_features": {
        // Amount of energy on the tile
        "energy": Array(W, H),
        // Type of the tile. 0 is empty, 1 is a nebula tile, 2 is asteroid
        "tile_type": Array(W, H)
    },
    // Indicates whether the relic node exists and is visible to you.
    "relic_nodes_mask": Array(R),
    // Position of the relic nodes.
    "relic_nodes": Array(R, 2),
    // Points scored by each team in the current match
    "team_points": Array(T),
    // Number of wins each team has in the current game/episode
    "team_wins": Array(T),
    // Number of steps taken in the current game/episode
    "steps": int,
    // Number of steps taken in the current match
    "match_steps": int
  },
  // Number of steps taken in the current game/episode
  "remainingOverageTime": int, // Total amount of time your bot can use whenever it exceeds 2s in a turn
  "player": str, // Your player id
  "info": {
    "env_cfg": dict // Some of the game's visible parameters
  }
}
```
**2. Modify the Gradual Increase in the Number of Agents (According to Competition Rules)**
**3. Calculate the Field of View (Based on Different Agent IDs)**

In the competition, the JSON returns the field of view observation for an entire team, so in order to calculate different fields of view for each unit and generate different predictions, the architecture is as follows:

![ref](https://my-typora-p1.oss-cn-beijing.aliyuncs.com/typoraImgs/image-20250211233804718.png)

**4. Reward Function Optimization**

```markdown
1. Each unit calculates its `unit_reward` independently.

2. If a movement action causes the unit to move out of bounds or onto a target tile that is an Asteroid, the action is deemed invalid, and `unit_reward` is reduced by -0.2.

3. **Sap action:**

   - Check whether the `relic_nodes_mask` in the unit’s local observation contains a relic.
   - If a relic is present, count the number of enemy units within the unit's 8-neighbor region:
     - If the count is **≥2**, the sap reward is **+1.0 × the number of enemy units**; otherwise, a **penalty of -2.0** is applied.
   - If no relic is visible, a **penalty of -2.0** is also applied.

4. **Non-sap actions:**

   - After a successful movement, check if the unit is located at a

     potential point

      configured for relic placement:

     - If it is the **first visit** to this potential point, `unit_reward` increases by **+2.0**, and it is marked as `visited`.
     - If the potential point has **not yet contributed to the team’s score**, increase `self.score` by 1, add **+5.0** to `unit_reward`, and mark it as `team_points_space`.
     - If the unit is **already on a `team_points_space`**, it receives a **+5.0** reward each turn.

   - If the unit is on an **energy node** (`energy == Global.MAX_ENERGY_PER_TILE`), `unit_reward` increases by **+0.2**.

   - If the unit is on a **Nebula** (`tile_type == 1`), `unit_reward` decreases by **-0.2**.

   - If the unit moves and overlaps with an enemy unit **while having higher energy than the enemy**, each enemy unit that meets this condition grants a **+1.0** reward.

5. **Global exploration reward:** Each newly discovered tile within the **combined vision** of all friendly units grants **+0.1** reward per tile.

6. At the **end of each step**, the final reward is calculated as **(point reward × 0.3) + (rule-based reward × 0.7)**.
```

**5. Load the Original Model into the Steps Process for Adversarial Training (To Be Completed)**

## Existing Issues
The final training results do not align with the reward function design.

This should be the final version of the notebook. It serves as a good starting framework (although its results still need optimization).

In [None]:
! pip install stable-baselines3



In [None]:
! mkdir agent
! cp -r /kaggle/input/lux-ai-season-3/lux agent

## base.py

game constants and some useful functions

In [None]:
#%%writefile agent/base.py

from enum import IntEnum


class Global:

    # Game related constants:

    SPACE_SIZE = 24
    MAX_UNITS = 16
    RELIC_REWARD_RANGE = 2
    MAX_STEPS_IN_MATCH = 100
    MAX_ENERGY_PER_TILE = 20
    MAX_RELIC_NODES = 6
    LAST_MATCH_STEP_WHEN_RELIC_CAN_APPEAR = 50
    LAST_MATCH_WHEN_RELIC_CAN_APPEAR = 2

    # We will find the exact value of these constants during the game
    UNIT_MOVE_COST = 1  # OPTIONS: list(range(1, 6))
    UNIT_SAP_COST = 30  # OPTIONS: list(range(30, 51))
    UNIT_SAP_RANGE = 3  # OPTIONS: list(range(3, 8))
    UNIT_SENSOR_RANGE = 2  # OPTIONS: [1, 2, 3, 4]
    OBSTACLE_MOVEMENT_PERIOD = 20  # OPTIONS: 6.67, 10, 20, 40
    OBSTACLE_MOVEMENT_DIRECTION = (0, 0)  # OPTIONS: [(1, -1), (-1, 1)]

    # We will NOT find the exact value of these constants during the game
    NEBULA_ENERGY_REDUCTION = 5  # OPTIONS: [0, 1, 2, 3, 5, 25]

    # Exploration flags:

    ALL_RELICS_FOUND = False
    ALL_REWARDS_FOUND = False
    OBSTACLE_MOVEMENT_PERIOD_FOUND = False
    OBSTACLE_MOVEMENT_DIRECTION_FOUND = False

    # Game logs:

    # REWARD_RESULTS: [{"nodes": Set[Node], "points": int}, ...]
    # A history of reward events, where each entry contains:
    # - "nodes": A set of nodes where our ships were located.
    # - "points": The number of points scored at that location.
    # This data will help identify which nodes yield points.
    REWARD_RESULTS = []

    # obstacles_movement_status: list of bool
    # A history log of obstacle (asteroids and nebulae) movement events.
    # - `True`: The ships' sensors detected a change in the obstacles' positions at this step.
    # - `False`: The sensors did not detect any changes.
    # This information will be used to determine the speed and direction of obstacle movement.
    OBSTACLES_MOVEMENT_STATUS = []

    # Others:

    # The energy on the unknown tiles will be used in the pathfinding
    HIDDEN_NODE_ENERGY = 0


SPACE_SIZE = Global.SPACE_SIZE


class NodeType(IntEnum):
    unknown = -1
    empty = 0
    nebula = 1
    asteroid = 2

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.name


_DIRECTIONS = [
    (0, 0),  # center
    (0, -1),  # up
    (1, 0),  # right
    (0, 1),  #  down
    (-1, 0),  # left
    (0, 0),  # sap
]


class ActionType(IntEnum):
    center = 0
    up = 1
    right = 2
    down = 3
    left = 4
    sap = 5

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.name

    @classmethod
    def from_coordinates(cls, current_position, next_position):
        dx = next_position[0] - current_position[0]
        dy = next_position[1] - current_position[1]

        if dx < 0:
            return ActionType.left
        elif dx > 0:
            return ActionType.right
        elif dy < 0:
            return ActionType.up
        elif dy > 0:
            return ActionType.down
        else:
            return ActionType.center

    def to_direction(self):
        return _DIRECTIONS[self]


def get_match_step(step: int) -> int:
    return step % (Global.MAX_STEPS_IN_MATCH + 1)


def get_match_number(step: int) -> int:
    return step // (Global.MAX_STEPS_IN_MATCH + 1)


# def warp_int(x):
#     if x >= SPACE_SIZE:
#         x -= SPACE_SIZE
#     elif x < 0:
#         x += SPACE_SIZE
#     return x


# def warp_point(x, y) -> tuple:
#     return warp_int(x), warp_int(y)


def get_opposite(x, y) -> tuple:
    # Returns the mirrored point across the diagonal
    return SPACE_SIZE - y - 1, SPACE_SIZE - x - 1


def is_upper_sector(x, y) -> bool:
    return SPACE_SIZE - x - 1 >= y


def is_lower_sector(x, y) -> bool:
    return SPACE_SIZE - x - 1 <= y


def is_team_sector(team_id, x, y) -> bool:
    return is_upper_sector(x, y) if team_id == 0 else is_lower_sector(x, y)


## ppo_game_env.py

In [None]:
#%%writefile agent/ppo_game_env.py
import sys
import gym
from gym import spaces
import numpy as np

# 导入 base 中的全局常量和辅助函数 // Import global constants and helper functions from base
#from base import Global, ActionType, SPACE_SIZE, get_opposite

# 定义常量：队伍数、最大单位数、最大遗迹节点数 // Define constants: number of teams, maximum number of units, maximum number of ruins nodes
NUM_TEAMS = 2
MAX_UNITS = Global.MAX_UNITS
MAX_RELIC_NODES = Global.MAX_RELIC_NODES

class PPOGameEnv(gym.Env):
    """
    PPOGameEnv 模拟环境尽可能还原真实比赛环境，并满足以下要求： // The simulation environment should restore the real competition environment as much as possible and meet the following requirements

    1. 观察数据计算修改：
       - 每个己方单位均有自己独立的 sensor mask（由 compute_unit_vision(unit) 计算），
         并由 get_unit_obs(unit) 构造出符合固定格式的局部观察（字典形式）。
       - 返回给代理的全局观察则采用所有己方单位 sensor mask 的联合（逻辑“或”），
         保持比赛返回 obs 的固定格式。

//  1. Changes to observation data calculation:
        - Each friendly unit has its own independent sensor mask (calculated by compute_unit_vision(unit)),
        and get_unit_obs(unit) constructs a local observation (in dictionary form) that conforms to a fixed format.
        - The global observation returned to the agent is the union (logical "or") of all friendly unit sensor masks,
        keeping the fixed format of the game return obs.

    2. 奖励函数优化：
       根据动作更新环境状态，并返回 (observation, reward, done, info)。
        修改后的奖励逻辑：
          1. 每个 unit 单独计算 unit_reward。
          2. 若移动动作导致超出地图或目标 tile 为 Asteroid，则判定为无效，unit_reward -0.2。
          3. Sap 动作：
             - 检查 unit 局部 obs 中 relic_nodes_mask 是否存在 relic；
             - 如果存在，统计 unit 8 邻域内敌方单位数，若数目>=2，则 sap 奖励 = +1.0×敌方单位数，否则扣 -2.0；
             - 若无 relic 可见，则同样扣 -2.0。
          4. 非 sap 动作：
             - 成功移动后，检查该 unit 是否位于任一 relic 配置内的潜力点：
                  * 若首次访问该潜力点，unit_reward +2.0，并标记 visited；
                  * 如果该潜力点尚未兑现 team point，则增加 self.score 1，同时 unit_reward +5.0 并标记为 team_points_space；
                  * 如果已在 team_points_space 上，则每回合奖励 +5.0；
             - 若 unit 位于能量节点（energy == Global.MAX_ENERGY_PER_TILE），unit_reward +0.2；
             - 若 unit 位于 Nebula（tile_type==1），unit_reward -0.2；
             - 如果 unit 移动后与敌方 unit 重合，且对方能量低于己方，则对每个满足条件的敌方 unit 奖励 +1.0。
          5. 全局探索奖励：所有己方单位联合视野中新发现 tile，每个奖励 +0.1。
          6. 每一step结束，奖励 point*0.3的奖励 + 规则*0.7的奖励

//   2. Reward function optimization:
        Update the environment state according to the action and return (observation, reward, done, info).
        Modified reward logic:
        1. Calculate unit_reward for each unit separately.
        2. If the movement action causes the unit to go beyond the map or the target tile is an Asteroid, it is considered invalid and unit_reward is -0.2.
        3. Sap action:
        - Check if there is a relic in the relic_nodes_mask in the unit local obs;
        - If it exists, count the number of enemy units in the unit 8 neighborhood. If the number is >= 2, then the sap reward = +1.0×the number of enemy units, otherwise deduct -2.0;
        - If no relic is visible, deduct -2.0 as well.
        4. Non-sap actions:
        - After successful movement, check whether the unit is located at a potential point in any relic configuration:
        * If the potential point is visited for the first time, unit_reward +2.0 and mark visited;
        * If the potential point has not yet fulfilled the team point, increase self.score by 1, and unit_reward +5.0 and mark it as team_points_space;
        * If it is already on team_points_space, reward +5.0 per round;
        - If the unit is located at an energy node (energy == Global.MAX_ENERGY_PER_TILE), unit_reward +0.2;
        - If the unit is located in Nebula (tile_type==1), unit_reward -0.2;
        - If the unit overlaps with an enemy unit after moving, and the enemy's energy is lower than its own, reward +1.0 for each enemy unit that meets the conditions.
        5. Global exploration reward: +0.1 for each newly discovered tile in the joint field of view of all friendly units.
        6. At the end of each step, you will be rewarded with a point*0.3 reward + a rule*0.7 reward

    3. 敌方单位策略说明：
       - 敌方单位在出生后不主动行动，其位置仅由环境每 20 步整体滚动（右移 1 格）改变，
         属于被动对手。这样设计主要用于初期调试，后续可引入更主动的对抗策略。

//   3. Enemy unit strategy description:
        - Enemy units do not take the initiative to move after birth, and their positions are only changed by the overall scrolling of the environment every 20 steps (1 square to the right),
        They are passive opponents. This design is mainly used for initial debugging, and more active confrontation strategies can be introduced later.
    """

    def __init__(self):
        super(PPOGameEnv, self).__init__()

        # 修改动作空间：每个单位独立决策（动作取值范围为 0~5）// Modify the action space: each unit makes independent decisions (action value range is 0~5)
        self.action_space = spaces.MultiDiscrete([len(ActionType)] * MAX_UNITS)

        # 观察空间保持不变 // The observation space remains unchanged
        self.observation_space = spaces.Dict({
            "units_position": spaces.Box(
                low=0,
                high=SPACE_SIZE - 1,
                shape=(NUM_TEAMS, MAX_UNITS, 2),
                dtype=np.int32
            ),
            "units_energy": spaces.Box(
                low=0,
                high=400,  # 单位能量上限 400 // Unit Energy Limit
                shape=(NUM_TEAMS, MAX_UNITS, 1),
                dtype=np.int32
            ),
            "units_mask": spaces.Box(
                low=0,
                high=1,
                shape=(NUM_TEAMS, MAX_UNITS),
                dtype=np.int8
            ),
            "sensor_mask": spaces.Box(
                low=0,
                high=1,
                shape=(SPACE_SIZE, SPACE_SIZE),
                dtype=np.int8
            ),
            "map_features_tile_type": spaces.Box(
                low=-1,
                high=2,
                shape=(SPACE_SIZE, SPACE_SIZE),
                dtype=np.int8
            ),
            "map_features_energy": spaces.Box(
                low=-1,
                high=Global.MAX_ENERGY_PER_TILE,
                shape=(SPACE_SIZE, SPACE_SIZE),
                dtype=np.int8
            ),
            "relic_nodes_mask": spaces.Box(
                low=0,
                high=1,
                shape=(MAX_RELIC_NODES,),
                dtype=np.int8
            ),
            "relic_nodes": spaces.Box(
                low=-1,
                high=SPACE_SIZE - 1,
                shape=(MAX_RELIC_NODES, 2),
                dtype=np.int32
            ),
            "team_points": spaces.Box(
                low=0,
                high=1000,
                shape=(NUM_TEAMS,),
                dtype=np.int32
            ),
            "team_wins": spaces.Box(
                low=0,
                high=1000,
                shape=(NUM_TEAMS,),
                dtype=np.int32
            ),
            "steps": spaces.Box(
                low=0, high=Global.MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32
            ),
            "match_steps": spaces.Box(
                low=0, high=Global.MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32
            ),
            "remainingOverageTime": spaces.Box(
                low=0, high=1000, shape=(1,), dtype=np.int32
            ),
            "env_cfg_map_width": spaces.Box(
                low=0, high=SPACE_SIZE, shape=(1,), dtype=np.int32
            ),
            "env_cfg_map_height": spaces.Box(
                low=0, high=SPACE_SIZE, shape=(1,), dtype=np.int32
            ),
            "env_cfg_max_steps_in_match": spaces.Box(
                low=0, high=Global.MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32
            ),
            "env_cfg_unit_move_cost": spaces.Box(
                low=0, high=100, shape=(1,), dtype=np.int32
            ),
            "env_cfg_unit_sap_cost": spaces.Box(
                low=0, high=100, shape=(1,), dtype=np.int32
            ),
            "env_cfg_unit_sap_range": spaces.Box(
                low=0, high=100, shape=(1,), dtype=np.int32
            )
        })

        self.max_steps = Global.MAX_STEPS_IN_MATCH
        self.current_step = 0

        # 全图状态：地图瓦片、遗迹标记、能量地图  // Full map status: map tiles, ruins markers, energy map
        self.tile_map = None     # -1未知、0空地、1星云、2小行星 // -1 unknown, 0 empty space, 1 nebula, 2 asteroids
        self.relic_map = None    # relic 存在标记，1 表示存在 // relic existence flag, 1 means existence
        self.energy_map = None   # 每个 tile 的能量值 // Energy value of each tile

        # 单位状态：己方和敌方单位列表，每个单位以字典表示 {"x": int, "y": int, "energy": int}
        # Unit status: list of friendly and enemy units, each unit represented by a dictionary
        self.team_units = []    # 己方 // Our side
        self.enemy_units = []   # 敌方 // Enemy

        # 出生点：己方出生于左上角，敌方出生于右下角
        # Birth point: Your team is born in the upper left corner, and the enemy is born in the lower right corner
        self.team_spawn = (0, 0)
        self.enemy_spawn = (SPACE_SIZE - 1, SPACE_SIZE - 1)

        # 探索记录：全图布尔数组，记录己方联合视野中已见过的 tile（全局只记录一次）
        # Exploration record: a Boolean array of the entire map, recording the tiles that have been seen in the joint field of view of the team (recorded only once globally)
        self.visited = None

        # 团队得分（己方得分）// Team score (own score)
        self.score = 0

        # 模拟环境的部分参数（env_cfg）// Some parameters of the simulation environment (env_cfg)
        self.env_cfg = {
            "map_width": SPACE_SIZE,
            "map_height": SPACE_SIZE,
            "max_steps_in_match": Global.MAX_STEPS_IN_MATCH,
            "unit_move_cost": Global.UNIT_MOVE_COST,
            "unit_sap_cost": Global.UNIT_SAP_COST if hasattr(Global, "UNIT_SAP_COST") else 30,
            "unit_sap_range": Global.UNIT_SAP_RANGE,
        }

        # 新增：用于 relic 配置相关奖励 // New: Used for relic configuration related rewards
        self.relic_configurations = []   # list of (center_x, center_y, mask(5x5 bool))
        self.potential_visited = None      # 全图记录，shape (SPACE_SIZE, SPACE_SIZE) // Full image record, shape (SPACE_SIZE, SPACE_SIZE)
        self.team_points_space = None      # 全图记录，哪些格子已经贡献过 team point // Full map record, which grids have contributed team points

        self._init_state()

    def _init_state(self):
        """初始化全图状态、单位和记录"""
        # Initialize the whole map status, units and records
        num_tiles = SPACE_SIZE * SPACE_SIZE

        # 初始化 tile_map：随机部分设为 Nebula (1) 或 Asteroid (2)
        # Initialize tile_map: random part is set to Nebula (1) or Asteroid (2)
        self.tile_map = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=np.int8)
        num_nebula = int(num_tiles * 0.1)
        num_asteroid = int(num_tiles * 0.1)
        indices = np.random.choice(num_tiles, num_nebula + num_asteroid, replace=False)
        flat_tiles = self.tile_map.flatten()
        flat_tiles[indices[:num_nebula]] = 1
        flat_tiles[indices[num_nebula:]] = 2
        self.tile_map = flat_tiles.reshape((SPACE_SIZE, SPACE_SIZE))

        # 初始化 relic_map：随机选取 3 个位置设置为 1（表示存在 relic
        # Initialize relic_map: randomly select 3 positions and set them to 1 (indicating the presence of relic)
        self.relic_map = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=np.int8)
        relic_indices = np.random.choice(num_tiles, 3, replace=False)
        flat_relic = self.relic_map.flatten()
        flat_relic[relic_indices] = 1
        self.relic_map = flat_relic.reshape((SPACE_SIZE, SPACE_SIZE))

        # 初始化 energy_map：随机生成 2 个能量节点，值设为 MAX_ENERGY_PER_TILE，其余为 0
        # Initialize energy_map: randomly generate 2 energy nodes, set the value to MAX_ENERGY_PER_TILE, and the rest to 0
        self.energy_map = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=np.int8)
        num_energy_nodes = 2
        indices_energy = np.random.choice(num_tiles, num_energy_nodes, replace=False)
        flat_energy = self.energy_map.flatten()
        for idx in indices_energy:
            flat_energy[idx] = Global.MAX_ENERGY_PER_TILE
        self.energy_map = flat_energy.reshape((SPACE_SIZE, SPACE_SIZE))

        # 初始化己方单位：初始生成 1 个单位，出生于 team_spawn
        # Initialize your own units: Initially generate 1 unit, born in team_spawn
        self.team_units = []
        spawn_x, spawn_y = self.team_spawn
        self.team_units.append({"x": spawn_x, "y": spawn_y, "energy": 100})

        # 初始化敌方单位：初始生成 1 个单位，出生于 enemy_spawn
        # Initialize enemy units: Initially generate 1 unit, born in enemy_spawn
        self.enemy_units = []
        spawn_x_e, spawn_y_e = self.enemy_spawn
        self.enemy_units.append({"x": spawn_x_e, "y": spawn_y_e, "energy": 100})

        # 初始化探索记录：全图大小，取各己方单位联合视野后标记已见区域
        # Initialize exploration record: full map size, take the combined vision of each unit and mark the seen area
        self.visited = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=bool)
        union_mask = self.get_global_sensor_mask()
        self.visited = union_mask.copy()

        # 初始化 team score
        self.score = 0

        # 新增：初始化 relic 配置，及潜力点记录 // New: Initialize relic configuration and record potential points
        self.relic_configurations = []
        relic_coords = np.argwhere(self.relic_map == 1)
        for (y, x) in relic_coords:
            # 生成一个 5x5 mask，随机选择 5 个格子为 True // Generate a 5x5 mask and randomly select 5 cells as True
            mask = np.zeros((5,5), dtype=bool)
            indices = np.random.choice(25, 5, replace=False)
            mask_flat = mask.flatten()
            mask_flat[indices] = True
            mask = mask_flat.reshape((5,5))
            self.relic_configurations.append((x, y, mask))
        self.potential_visited = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=bool)
        self.team_points_space = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=bool)

        self.current_step = 0

    def compute_unit_vision(self, unit):
        """
        根据传入 unit 的位置计算其独立的 sensor mask，
        计算范围为单位传感器范围（切比雪夫距离），并对 Nebula tile 减少贡献。
        取消环绕，只有在地图内的 tile 才计算。
        返回布尔矩阵 shape (SPACE_SIZE, SPACE_SIZE)。

        Computes a unique sensor mask for the passed unit based on its position,
        Calculated to the unit sensor range (Chebyshev distance), and reduced contribution to Nebula tiles.
        Disable wraparound, only calculate for tiles within the map.
        Returns a boolean matrix of shape (SPACE_SIZE, SPACE_SIZE).
        """
        sensor_range = Global.UNIT_SENSOR_RANGE
        nebula_reduction = 2
        vision = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=np.float32)
        x, y = unit["x"], unit["y"]
        for dy in range(-sensor_range, sensor_range + 1):
            for dx in range(-sensor_range, sensor_range + 1):
                new_x = x + dx
                new_y = y + dy
                if not (0 <= new_x < SPACE_SIZE and 0 <= new_y < SPACE_SIZE):
                    continue
                contrib = sensor_range + 1 - max(abs(dx), abs(dy))
                if self.tile_map[new_y, new_x] == 1:
                    contrib -= nebula_reduction
                vision[new_y, new_x] += contrib
        return vision > 0

    def get_global_sensor_mask(self):
        """
        返回己方所有单位 sensor mask 的联合（逻辑 OR）。// Returns the union (logical OR) of all friendly units' sensor masks.
        """
        mask = np.zeros((SPACE_SIZE, SPACE_SIZE), dtype=bool)
        for unit in self.team_units:
            mask |= self.compute_unit_vision(unit)
        return mask

    def get_unit_obs(self, unit):
        """
        根据传入 unit 的独立 sensor mask 构造局部观察字典，
        格式与比赛返回固定 JSON 格式相同。
        仅使用该 unit 自己能看到的区域进行过滤。
        Constructs a local observation dictionary based on the independent sensor mask passed to the unit.
        The format is the same as the fixed JSON format returned by the competition.
        Only the area that the unit itself can see is used for filtering.
        """
        sensor_mask = self.compute_unit_vision(unit)
        map_tile_type = np.where(sensor_mask, self.tile_map, -1)
        map_energy = np.where(sensor_mask, self.energy_map, -1)
        map_features = {"tile_type": map_tile_type, "energy": map_energy}
        sensor_mask_int = sensor_mask.astype(np.int8)

        # 构造单位信息，分别对己方与敌方单位过滤（使用该 unit 的 sensor mask）
        # Construct unit information and filter friendly and enemy units separately (using the sensor mask of the unit)
        units_position = np.full((NUM_TEAMS, MAX_UNITS, 2), -1, dtype=np.int32)
        units_energy = np.full((NUM_TEAMS, MAX_UNITS, 1), -1, dtype=np.int32)
        units_mask = np.zeros((NUM_TEAMS, MAX_UNITS), dtype=np.int8)
        for i, u in enumerate(self.team_units):
            ux, uy = u["x"], u["y"]
            if sensor_mask[uy, ux]:
                units_position[0, i] = np.array([ux, uy])
                units_energy[0, i] = u["energy"]
                units_mask[0, i] = 1
        for i, u in enumerate(self.enemy_units):
            ux, uy = u["x"], u["y"]
            if sensor_mask[uy, ux]:
                units_position[1, i] = np.array([ux, uy])
                units_energy[1, i] = u["energy"]
                units_mask[1, i] = 1
        units = {"position": units_position, "energy": units_energy}

        # 构造 relic_nodes 信息：仅显示在 sensor_mask 内的 relic 坐标
        # Construct relic_nodes information: only display relic coordinates within sensor_mask
        relic_coords = np.argwhere(self.relic_map == 1)
        relic_nodes = np.full((MAX_RELIC_NODES, 2), -1, dtype=np.int32)
        relic_nodes_mask = np.zeros(MAX_RELIC_NODES, dtype=np.int8)
        idx = 0
        for (ry, rx) in relic_coords:
            if idx >= MAX_RELIC_NODES:
                break
            if sensor_mask[ry, rx]:
                relic_nodes[idx] = np.array([rx, ry])
                relic_nodes_mask[idx] = 1
            else:
                relic_nodes[idx] = np.array([-1, -1])
                relic_nodes_mask[idx] = 0
            idx += 1

        team_points = np.array([self.score, 0], dtype=np.int32)
        team_wins = np.array([0, 0], dtype=np.int32)
        steps = self.current_step
        match_steps = self.current_step

        obs = {
            "units": units,
            "units_mask": units_mask,
            "sensor_mask": sensor_mask_int,
            "map_features": map_features,
            "relic_nodes_mask": relic_nodes_mask,
            "relic_nodes": relic_nodes,
            "team_points": team_points,
            "team_wins": team_wins,
            "steps": steps,
            "match_steps": match_steps
        }
        observation = {
            "obs": obs,
            "remainingOverageTime": 60,
            "player": "player_0",
            "info": {"env_cfg": self.env_cfg}
        }
        return observation

    def get_obs(self):
        """
        返回平铺后的全局观测字典，确保所有键与 observation_space 完全一致。
        Returns the flattened global observation dictionary, ensuring that all keys are exactly the same as observation_space.
        """
        sensor_mask = self.get_global_sensor_mask()
        sensor_mask_int = sensor_mask.astype(np.int8)

        map_features_tile_type = np.where(sensor_mask, self.tile_map, -1)
        map_features_energy = np.where(sensor_mask, self.energy_map, -1)

        units_position = np.full((NUM_TEAMS, MAX_UNITS, 2), -1, dtype=np.int32)
        units_energy = np.full((NUM_TEAMS, MAX_UNITS, 1), -1, dtype=np.int32)
        units_mask = np.zeros((NUM_TEAMS, MAX_UNITS), dtype=np.int8)

        for i, unit in enumerate(self.team_units):
            ux, uy = unit["x"], unit["y"]
            if sensor_mask[uy, ux]:
                units_position[0, i] = np.array([ux, uy])
                units_energy[0, i] = unit["energy"]
                units_mask[0, i] = 1
        for i, unit in enumerate(self.enemy_units):
            ux, uy = unit["x"], unit["y"]
            if sensor_mask[uy, ux]:
                units_position[1, i] = np.array([ux, uy])
                units_energy[1, i] = unit["energy"]
                units_mask[1, i] = 1

        relic_coords = np.argwhere(self.relic_map == 1)
        relic_nodes = np.full((MAX_RELIC_NODES, 2), -1, dtype=np.int32)
        relic_nodes_mask = np.zeros((MAX_RELIC_NODES,), dtype=np.int8)
        idx = 0
        for (ry, rx) in relic_coords:
            if idx >= MAX_RELIC_NODES:
                break
            if sensor_mask[ry, rx]:
                relic_nodes[idx] = np.array([rx, ry])
                relic_nodes_mask[idx] = 1
            else:
                relic_nodes[idx] = np.array([-1, -1])
                relic_nodes_mask[idx] = 0
            idx += 1

        team_points = np.array([self.score, 0], dtype=np.int32)
        team_wins = np.array([0, 0], dtype=np.int32)
        steps = np.array([self.current_step], dtype=np.int32)
        match_steps = np.array([self.current_step], dtype=np.int32)
        remainingOverageTime = np.array([60], dtype=np.int32)

        env_cfg_map_width = np.array([self.env_cfg["map_width"]], dtype=np.int32)
        env_cfg_map_height = np.array([self.env_cfg["map_height"]], dtype=np.int32)
        env_cfg_max_steps_in_match = np.array([self.env_cfg["max_steps_in_match"]], dtype=np.int32)
        env_cfg_unit_move_cost = np.array([self.env_cfg["unit_move_cost"]], dtype=np.int32)
        env_cfg_unit_sap_cost = np.array([self.env_cfg["unit_sap_cost"]], dtype=np.int32)
        env_cfg_unit_sap_range = np.array([self.env_cfg["unit_sap_range"]], dtype=np.int32)

        flat_obs = {
            "units_position": units_position,
            "units_energy": units_energy,
            "units_mask": units_mask,
            "sensor_mask": sensor_mask_int,
            "map_features_tile_type": map_features_tile_type,
            "map_features_energy": map_features_energy,
            "relic_nodes_mask": relic_nodes_mask,
            "relic_nodes": relic_nodes,
            "team_points": team_points,
            "team_wins": team_wins,
            "steps": steps,
            "match_steps": match_steps,
            "remainingOverageTime": remainingOverageTime,
            "env_cfg_map_width": env_cfg_map_width,
            "env_cfg_map_height": env_cfg_map_height,
            "env_cfg_max_steps_in_match": env_cfg_max_steps_in_match,
            "env_cfg_unit_move_cost": env_cfg_unit_move_cost,
            "env_cfg_unit_sap_cost": env_cfg_unit_sap_cost,
            "env_cfg_unit_sap_range": env_cfg_unit_sap_range
        }

        return flat_obs

    def reset(self):
        """
        重置环境状态，并返回初始的平铺观测数据。// Resets the environment state and returns the initial tiled observation data.
        """
        self._init_state()
        return self.get_obs()

    def _spawn_unit(self, team):
        """生成新单位：己方或敌方，初始能量 100，出生于各自出生点
        Generate new units: friendly or enemy, initial energy 100, born at their respective birth points
        """
        if team == 0:
            spawn_x, spawn_y = self.team_spawn
            self.team_units.append({"x": spawn_x, "y": spawn_y, "energy": 100})
        elif team == 1:
            spawn_x, spawn_y = self.enemy_spawn
            self.enemy_units.append({"x": spawn_x, "y": spawn_y, "energy": 100})

    def step(self, actions):
        """
        根据动作更新环境状态，并返回 (observation, reward, done, info)。
        修改后的奖励逻辑：
          1. 每个 unit 单独计算 unit_reward。
          2. 若移动动作导致超出地图或目标 tile 为 Asteroid，则判定为无效，unit_reward -0.2。
          3. Sap 动作：
             - 检查 unit 局部 obs 中 relic_nodes_mask 是否存在 relic；
             - 如果存在，统计 unit 8 邻域内敌方单位数，若数目>=2，则 sap 奖励 = +1.0×敌方单位数，否则扣 -2.0；
             - 若无 relic 可见，则同样扣 -2.0。
          4. 非 sap 动作：
             - 成功移动后，检查该 unit 是否位于任一 relic 配置内的潜力点：
                  * 若首次访问该潜力点，unit_reward +2.0，并标记 visited；
                  * 如果该潜力点尚未兑现 team point，则增加 self.score 1，同时 unit_reward +5.0 并标记为 team_points_space；
                  * 如果已在 team_points_space 上，则每回合奖励 +5.0；
             - 若 unit 位于能量节点（energy == Global.MAX_ENERGY_PER_TILE），unit_reward +0.2；
             - 若 unit 位于 Nebula（tile_type==1），unit_reward -0.2；
             - 如果 unit 移动后与敌方 unit 重合，且对方能量低于己方，则对每个满足条件的敌方 unit 奖励 +1.0。
          5. 全局探索奖励：所有己方单位联合视野中新发现 tile，每个奖励 +0.1。
          6. 每一step结束，奖励 point*0.3的奖励 + 规则*0.7的奖励
          7. 每 3 步生成新单位；每 20 步整体滚动地图和敌方单位位置（滚动时对敌方单位使用边界检查）。

          Update the environment state according to the action and return (observation, reward, done, info).
            Modified reward logic:
            1. Calculate unit_reward for each unit separately.
            2. If the movement action causes the unit to go beyond the map or the target tile is an Asteroid, it is considered invalid and unit_reward is -0.2.
            3. Sap action:
            - Check if there is a relic in the relic_nodes_mask in the unit local obs;
            - If it exists, count the number of enemy units in the unit 8 neighborhood. If the number is >= 2, then the sap reward = +1.0×the number of enemy units, otherwise deduct -2.0;
            - If no relic is visible, deduct -2.0 as well.
            4. Non-sap actions:
            - After successful movement, check whether the unit is located at a potential point in any relic configuration:
            * If the potential point is visited for the first time, unit_reward +2.0 and mark visited;
            * If the potential point has not yet fulfilled the team point, increase self.score by 1, and unit_reward +5.0 and mark it as team_points_space;
            * If it is already on team_points_space, reward +5.0 per round;
            - If the unit is located at an energy node (energy == Global.MAX_ENERGY_PER_TILE), unit_reward +0.2;
            - If the unit is located in Nebula (tile_type==1), unit_reward -0.2;
            - If the unit overlaps with an enemy unit after moving, and the enemy's energy is lower than its own, reward +1.0 for each enemy unit that meets the conditions.
            5. Global exploration reward: +0.1 for each newly discovered tile in the joint field of view of all friendly units.
            6. At the end of each step, reward point*0.3 + rule*0.7
            7. Generate new units every 3 steps; scroll the map and enemy unit positions as a whole every 20 steps (use boundary check for enemy units when scrolling).

            """
        prev_score = self.score

        self.current_step += 1
        total_reward = 0.0

        # 处理每个己方单位 // Deal with each friendly unit
        for idx, unit in enumerate(self.team_units):
            unit_reward = 0.0
            act = actions[idx]
            action_enum = ActionType(act)
            # print(f"Unit {idx} action: {action_enum}",file=sys.stderr)

            # 获取该 unit 的局部 obs // Get the local obs of the unit
            unit_obs = self.get_unit_obs(unit)

            # 如果动作为 sap // If the action is sap
            if action_enum == ActionType.sap:
                # 检查局部 obs 中是否有 relic 可见 // Check if relic is visible in the local obs
                if np.any(unit_obs["obs"]["relic_nodes_mask"] == 1):
                    # 统计 unit 周围 8 邻域内敌方单位数 // Counts the number of enemy units in the 8 neighborhoods around unit
                    enemy_count = 0
                    for dy in [-1, 0, 1]:
                        for dx in [-1, 0, 1]:
                            if dx == 0 and dy == 0:
                                continue
                            nx = unit["x"] + dx
                            ny = unit["y"] + dy
                            if not (0 <= nx < SPACE_SIZE and 0 <= ny < SPACE_SIZE):
                                continue
                            for enemy in self.enemy_units:
                                if enemy["x"] == nx and enemy["y"] == ny:
                                    enemy_count += 1
                    if enemy_count >= 2:
                        unit_reward += 1.0 * enemy_count
                    else:
                        unit_reward -= 1.0
                else:
                    unit_reward -= 1.0
                # Sap 动作不改变位置 // Sap action does not change position
            else:
                # 计算移动方向 // Calculate the direction of movement
                if action_enum in [ActionType.up, ActionType.right, ActionType.down, ActionType.left]:
                    dx, dy = action_enum.to_direction()
                else:
                    dx, dy = (0, 0)
                new_x = unit["x"] + dx
                new_y = unit["y"] + dy
                # 检查边界和障碍 // Check boundaries and barriers
                if not (0 <= new_x < SPACE_SIZE and 0 <= new_y < SPACE_SIZE):
                    unit_reward -= 0.2  # 超出边界 // Beyond Boundaries
                    new_x, new_y = unit["x"], unit["y"]
                elif self.tile_map[new_y, new_x] == 2:
                    unit_reward -= 0.2  # 遇到 Asteroid // Encountering Asteroid
                    new_x, new_y = unit["x"], unit["y"]
                else:
                    # 移动成功 // Mobile success
                    unit["x"], unit["y"] = new_x, new_y

                # 重新获取移动后的局部 obs // Re-obtain the local obs after moving
                unit_obs = self.get_unit_obs(unit)

                # 检查 relic 配置奖励：遍历所有 relic 配置，判断该 unit 是否位于配置中（计算时考虑边界）
                # Check relic configuration bonus: traverse all relic configurations and determine whether the unit is in the configuration (consider boundaries when calculating)
                for (rx, ry, mask) in self.relic_configurations:
                    # relic 配置区域范围：中心 (rx, ry) ±2 // relic configuration area range: center (rx, ry) ±2
                    # 如果 unit 在 [rx-2, rx+2] 和 [ry-2, ry+2] 范围内 // If unit is in the range [rx-2, rx+2] and [ry-2, ry+2]
                    if rx - 2 <= unit["x"] <= rx + 2 and ry - 2 <= unit["y"] <= ry + 2:
                        # 计算在配置 mask 中的索引 // Calculate the index in the configuration mask
                        ix = unit["x"] - rx + 2
                        iy = unit["y"] - ry + 2
                        # 检查索引是否在 mask 范围内（考虑边界）// Check if index is within mask (considering boundaries)
                        if 0 <= ix < 5 and 0 <= iy < 5:
                            if mask[iy, ix]:
                                # 如果该潜力点未被访问，则奖励 +2.0 // If the potential point has not been visited, the reward is +2.0
                                if not self.potential_visited[unit["y"], unit["x"]]:
                                    unit_reward += 2.0
                                    self.potential_visited[unit["y"], unit["x"]] = True
                                # 如果该点尚未产生 team point，则增加 team point并奖励 +5.0
                                # If the point does not have a team point yet, increase the team point and award +5.0
                                if not self.team_points_space[unit["y"], unit["x"]]:
                                    self.score += 1
                                    unit_reward += 5.0
                                    self.team_points_space[unit["y"], unit["x"]] = True
                                else:
                                    # 已在 team_points_space 上，每回合奖励 +5.0 // Already on team_points_space, bonus +5.0 per turn
                                    self.score += 1
                                    unit_reward += 5.0
                # 能量节点奖励 // Energy Node Rewards
                if unit_obs["obs"]["map_features"]["energy"][unit["y"], unit["x"]] == Global.MAX_ENERGY_PER_TILE:
                    unit_reward += 0.2
                # Nebula 惩罚 // Nebula Penalty
                if unit_obs["obs"]["map_features"]["tile_type"][unit["y"], unit["x"]] == 1:
                    unit_reward -= 0.2
                # 攻击行为：若与敌方单位重合且对方能量低于己方，则对每个敌人奖励 +1.0
                # Attack Behavior: If you overlap with an enemy unit and the enemy's energy is lower than yours, you will be rewarded +1.0 for each enemy.
                for enemy in self.enemy_units:
                    if enemy["x"] == unit["x"] and enemy["y"] == unit["y"]:
                        if enemy["energy"] < unit["energy"]:
                            unit_reward += 1.0
            total_reward += unit_reward
            # print("################################",file=sys.stderr)
            # print("step:",self.current_step)
            # print("")
            # print(total_reward,file=sys.stderr)

        # 全局探索奖励：利用所有己方单位联合视野中新发现的 tile
        # Global Exploration Reward: Utilize newly discovered tiles in the combined field of view of all friendly units
        union_mask = self.get_global_sensor_mask()
        new_tiles = union_mask & (~self.visited)
        num_new = np.sum(new_tiles)
        if num_new > 0:
            total_reward += 0.1 * num_new
        self.visited[new_tiles] = True

        # 每 3 步生成新单位（若未达到 MAX_UNITS）// Generate new units every 3 steps (if MAX_UNITS is not reached)
        if self.current_step % 3 == 0:
            if len(self.team_units) < MAX_UNITS:
                self._spawn_unit(team=0)
            if len(self.enemy_units) < MAX_UNITS:
                self._spawn_unit(team=1)

        # 每 20 步整体滚动地图、遗迹和能量图，以及敌方单位位置（右移 1 格，边界检查）
        # Scroll the map, ruins and energy maps as a whole, as well as enemy unit positions every 20 steps (move right 1 square, border check)
        if self.current_step % 20 == 0:
            # 这里采用 np.roll 保持地图内部数据不变，但对于敌方单位，我们检查边界
            # Here we use np.roll to keep the internal data of the map unchanged, but for enemy units, we check the border
            self.tile_map = np.roll(self.tile_map, shift=1, axis=1)
            self.relic_map = np.roll(self.relic_map, shift=1, axis=1)
            self.energy_map = np.roll(self.energy_map, shift=1, axis=1)
            for enemy in self.enemy_units:
                new_ex = enemy["x"] + 1
                if new_ex >= SPACE_SIZE:
                    new_ex = enemy["x"]  # 保持不变 // Keep it the same
                enemy["x"] = new_ex

        # 在 step 结束时计算 self.score 的增加量 // Calculate the increment of self.score at the end of the step
        score_increase = self.score - prev_score

        # 将总奖励合并：total_reward * 0.5 + score_increase * 0.5 // Combine the total rewards: total_reward * 0.5 + score_increase * 0.5
        final_reward = total_reward * 0.25 + score_increase * 0.75

        # done = self.current_step >= self.max_steps
        done = self.current_step >= 500
        info = {"score": self.score, "step": self.current_step}
        return self.get_obs(), final_reward, done, info

    def render(self, mode='human'):
        display = self.tile_map.astype(str).copy()
        for unit in self.team_units:
            display[unit["y"], unit["x"]] = 'A'
        print("Step:", self.current_step)
        print(display)


## train.py

In [None]:
#%%writefile agent/train.py

from stable_baselines3 import PPO
#from ppo_game_env import PPOGameEnv

# 创建环境实例 // Creating an environment instance
env = PPOGameEnv()

# 使用多层感知机策略初始化 PPO 模型 // Initialize the PPO model using the multilayer perceptron strategy
model = PPO("MultiInputPolicy", env, verbose=1)

# # 训练 10000 个时间步（可根据需要调整）// Train for 10,000 time steps (adjustable as needed)
# model.learn(total_timesteps=500000)

# # 保存训练好的模型 // Save the trained model
# model.save("/kaggle/working/agent/ppo_game_env_model")

# # 测试：加载模型并进行一次模拟 // Test: Load the model and run a simulation
# # obs = env.reset()
# # done = False
# # while not done:
# #     action, _ = model.predict(obs)
# #     obs, reward, done, info = env.step(action)
# #     env.render()


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




In [None]:
print(model.policy)

MultiInputActorCriticPolicy(
  (features_extractor): CombinedExtractor(
    (extractors): ModuleDict(
      (env_cfg_map_height): Flatten(start_dim=1, end_dim=-1)
      (env_cfg_map_width): Flatten(start_dim=1, end_dim=-1)
      (env_cfg_max_steps_in_match): Flatten(start_dim=1, end_dim=-1)
      (env_cfg_unit_move_cost): Flatten(start_dim=1, end_dim=-1)
      (env_cfg_unit_sap_cost): Flatten(start_dim=1, end_dim=-1)
      (env_cfg_unit_sap_range): Flatten(start_dim=1, end_dim=-1)
      (map_features_energy): Flatten(start_dim=1, end_dim=-1)
      (map_features_tile_type): Flatten(start_dim=1, end_dim=-1)
      (match_steps): Flatten(start_dim=1, end_dim=-1)
      (relic_nodes): Flatten(start_dim=1, end_dim=-1)
      (relic_nodes_mask): Flatten(start_dim=1, end_dim=-1)
      (remainingOverageTime): Flatten(start_dim=1, end_dim=-1)
      (sensor_mask): Flatten(start_dim=1, end_dim=-1)
      (steps): Flatten(start_dim=1, end_dim=-1)
      (team_points): Flatten(start_dim=1, end_dim=-1)
 

## agent.py

In [None]:
%%writefile agent/agent.py
import sys
import numpy as np
from stable_baselines3 import PPO

def transform_obs(comp_obs, env_cfg=None):
    """
    将比赛引擎返回的 JSON 观测转换为模型训练时使用的平铺观测格式
    比赛环境的观测格式（comp_obs）结构如下：

    Convert the JSON observations returned by the competition engine to the flat observation format used when training the model.
    The observation format (comp_obs) of the competition environment has the following structure:

      {
        "obs": {
            "units": {"position": Array(T, N, 2), "energy": Array(T, N, 1)},
            "units_mask": Array(T, N),
            "sensor_mask": Array(W, H),
            "map_features": {"energy": Array(W, H), "tile_type": Array(W, H)},
            "relic_nodes_mask": Array(R),
            "relic_nodes": Array(R, 2),
            "team_points": Array(T),
            "team_wins": Array(T),
            "steps": int,
            "match_steps": int
        },
        "remainingOverageTime": int,
        "player": str,
        "info": {"env_cfg": dict}
      }

    我们需要构造如下平铺字典（与 PPOGameEnv.get_obs() 返回的格式一致）：
    We need to construct the following tile dictionary (in the same format as returned by PPOGameEnv.get_obs()):
      {
        "units_position": (T, N, 2),
        "units_energy": (T, N, 1),
        "units_mask": (T, N),
        "sensor_mask": (W, H),
        "map_features_tile_type": (W, H),
        "map_features_energy": (W, H),
        "relic_nodes_mask": (R,),
        "relic_nodes": (R, 2),
        "team_points": (T,),
        "team_wins": (T,),
        "steps": (1,),
        "match_steps": (1,),
        "remainingOverageTime": (1,),
        "env_cfg_map_width": (1,),
        "env_cfg_map_height": (1,),
        "env_cfg_max_steps_in_match": (1,),
        "env_cfg_unit_move_cost": (1,),
        "env_cfg_unit_sap_cost": (1,),
        "env_cfg_unit_sap_range": (1,)
      }
    """
    # 如果存在 "obs" 键，则取其内部数据，否则直接使用 comp_obs // If the "obs" key exists, use its internal data, otherwise use comp_obs directly
    if "obs" in comp_obs:
        base_obs = comp_obs["obs"]
    else:
        base_obs = comp_obs


    flat_obs = {}

    # 处理 units 数据 // Processing units data
    if "units" in base_obs:
        flat_obs["units_position"] = np.array(base_obs["units"]["position"], dtype=np.int32)
        flat_obs["units_energy"] = np.array(base_obs["units"]["energy"], dtype=np.int32)
        # 如果 units_energy 的 shape 为 (NUM_TEAMS, MAX_UNITS) 则扩展一个维度
        # If units_energy has shape (NUM_TEAMS, MAX_UNITS) then expand one dimension
        if flat_obs["units_energy"].ndim == 2:
            flat_obs["units_energy"] = np.expand_dims(flat_obs["units_energy"], axis=-1)
    else:
        flat_obs["units_position"] = np.array(base_obs["units_position"], dtype=np.int32)
        flat_obs["units_energy"] = np.array(base_obs["units_energy"], dtype=np.int32)
        if flat_obs["units_energy"].ndim == 2:
            flat_obs["units_energy"] = np.expand_dims(flat_obs["units_energy"], axis=-1)

    # 处理 units_mask
    if "units_mask" in base_obs:
        flat_obs["units_mask"] = np.array(base_obs["units_mask"], dtype=np.int8)
    else:
        flat_obs["units_mask"] = np.zeros(flat_obs["units_position"].shape[:2], dtype=np.int8)

    # 处理 sensor_mask：若返回的是 3D 数组，则取逻辑 or 得到全局 mask
    # Process sensor_mask: If the returned value is a 3D array, perform a logical or operation to get the global mask.
    sensor_mask_arr = np.array(base_obs["sensor_mask"], dtype=np.int8)
    if sensor_mask_arr.ndim == 3:
        sensor_mask = np.any(sensor_mask_arr, axis=0).astype(np.int8)
    else:
        sensor_mask = sensor_mask_arr
    flat_obs["sensor_mask"] = sensor_mask

    # 处理 map_features（tile_type 与 energy）
    if "map_features" in base_obs:
        mf = base_obs["map_features"]
        flat_obs["map_features_tile_type"] = np.array(mf["tile_type"], dtype=np.int8)
        flat_obs["map_features_energy"] = np.array(mf["energy"], dtype=np.int8)
    else:
        flat_obs["map_features_tile_type"] = np.array(base_obs["map_features_tile_type"], dtype=np.int8)
        flat_obs["map_features_energy"] = np.array(base_obs["map_features_energy"], dtype=np.int8)

    # 处理 relic 节点信息 // Process relic node information
    if "relic_nodes_mask" in base_obs:
        flat_obs["relic_nodes_mask"] = np.array(base_obs["relic_nodes_mask"], dtype=np.int8)
    else:
        max_relic = env_cfg.get("max_relic_nodes", 6) if env_cfg is not None else 6
        flat_obs["relic_nodes_mask"] = np.zeros((max_relic,), dtype=np.int8)
    if "relic_nodes" in base_obs:
        flat_obs["relic_nodes"] = np.array(base_obs["relic_nodes"], dtype=np.int32)
    else:
        max_relic = env_cfg.get("max_relic_nodes", 6) if env_cfg is not None else 6
        flat_obs["relic_nodes"] = np.full((max_relic, 2), -1, dtype=np.int32)

    # 处理团队得分与胜局 // Handling team scores and wins
    if "team_points" in base_obs:
        flat_obs["team_points"] = np.array(base_obs["team_points"], dtype=np.int32)
    else:
        flat_obs["team_points"] = np.zeros(2, dtype=np.int32)
    if "team_wins" in base_obs:
        flat_obs["team_wins"] = np.array(base_obs["team_wins"], dtype=np.int32)
    else:
        flat_obs["team_wins"] = np.zeros(2, dtype=np.int32)

    # 处理步数信息 // Processing step information
    if "steps" in base_obs:
        flat_obs["steps"] = np.array([base_obs["steps"]], dtype=np.int32)
    else:
        flat_obs["steps"] = np.array([0], dtype=np.int32)
    if "match_steps" in base_obs:
        flat_obs["match_steps"] = np.array([base_obs["match_steps"]], dtype=np.int32)
    else:
        flat_obs["match_steps"] = np.array([0], dtype=np.int32)

    # 注意：不在此处处理 remainingOverageTime，
    # 将在 Agent.act 中利用传入的参数添加
    # Note: remainingOverageTime is not handled here, will be added in Agent.act using the passed parameters

    # 补全环境配置信息 // Complete environment configuration information
    if env_cfg is not None:
        flat_obs["env_cfg_map_width"] = np.array([env_cfg["map_width"]], dtype=np.int32)
        flat_obs["env_cfg_map_height"] = np.array([env_cfg["map_height"]], dtype=np.int32)
        flat_obs["env_cfg_max_steps_in_match"] = np.array([env_cfg["max_steps_in_match"]], dtype=np.int32)
        flat_obs["env_cfg_unit_move_cost"] = np.array([env_cfg["unit_move_cost"]], dtype=np.int32)
        flat_obs["env_cfg_unit_sap_cost"] = np.array([env_cfg["unit_sap_cost"]], dtype=np.int32)
        flat_obs["env_cfg_unit_sap_range"] = np.array([env_cfg["unit_sap_range"]], dtype=np.int32)
    else:
        flat_obs["env_cfg_map_width"] = np.array([0], dtype=np.int32)
        flat_obs["env_cfg_map_height"] = np.array([0], dtype=np.int32)
        flat_obs["env_cfg_max_steps_in_match"] = np.array([0], dtype=np.int32)
        flat_obs["env_cfg_unit_move_cost"] = np.array([0], dtype=np.int32)
        flat_obs["env_cfg_unit_sap_cost"] = np.array([0], dtype=np.int32)
        flat_obs["env_cfg_unit_sap_range"] = np.array([0], dtype=np.int32)

    return flat_obs

class Agent():
    def __init__(self, player: str, env_cfg) -> None:
        self.player = player
        self.opp_player = "player_1" if self.player == "player_0" else "player_0"
        self.team_id = 0 if self.player == "player_0" else 1
        self.opp_team_id = 1 if self.team_id == 0 else 0
        np.random.seed(0)
        self.env_cfg = env_cfg

        # 如果 env_cfg 中没有 "max_units"，则补上默认值 16 // If there is no "max_units" in env_cfg, the default value 16 is added
        if "max_units" not in self.env_cfg:
            self.env_cfg["max_units"] = 16

        # 加载训练好的 PPO 模型（请确保模型文件路径正确）// Load the trained PPO model (make sure the model file path is correct)
        self.model = PPO.load("ppo_game_env_model")

    def act(self, step: int, obs, remainingOverageTime: int = 60):
        """
        根据比赛观测与当前步数决定各单位动作。
        输出为形状 (max_units, 3) 的 numpy 数组，每行格式为 [动作类型, delta_x, delta_y]，
        其中非汲取动作时 delta_x 和 delta_y 固定为 0。
        Determine the actions of each unit based on the game observation and the current number of steps.
        The output is a numpy array of shape (max_units, 3), with each row formatted as [action type, delta_x, delta_y],
        where delta_x and delta_y are fixed to 0 for non-drawing actions.
        """
        flat_obs = transform_obs(obs, self.env_cfg)
        # 手动添加 remainingOverageTime（取自传入参数）// Manually add remainingOverageTime (taken from the passed parameter)
        flat_obs["remainingOverageTime"] = np.array([remainingOverageTime], dtype=np.int32)

        # 使用模型预测动作（deterministic 模式）// Using the model to predict actions (deterministic mode)
        action, _ = self.model.predict(flat_obs, deterministic=True)
        # 确保 action 为 numpy 数组，并显式设置为 np.int32 类型 // Make sure action is a numpy array and explicitly set it to type np.int32
        action = np.array(action, dtype=np.int32)

        max_units = self.env_cfg["max_units"]
        actions = np.zeros((max_units, 3), dtype=np.int32)
        for i, a in enumerate(action):
            actions[i, 0] = int(a)
            actions[i, 1] = 0  # 若为 sap 动作，可在此扩展目标偏移 // If it is a sap action, you can extend the target offset here
            actions[i, 2] = 0
        return actions



Writing agent/agent.py


## main.py

In [None]:
%%writefile agent/main.py

import json
from typing import Dict
import sys
from argparse import Namespace

import numpy as np

from agent import Agent
# from lux.config import EnvConfig
from lux.kit import from_json
### DO NOT REMOVE THE FOLLOWING CODE ###
agent_dict = dict() # store potentially multiple dictionaries as kaggle imports code directly
agent_prev_obs = dict()
def agent_fn(observation, configurations):
    """
    agent definition for kaggle submission.
    """
    global agent_dict
    obs = observation.obs
    if type(obs) == str:
        obs = json.loads(obs)
    step = observation.step
    player = observation.player
    remainingOverageTime = observation.remainingOverageTime
    if step == 0:
        agent_dict[player] = Agent(player, configurations["env_cfg"])
    agent = agent_dict[player]
    actions = agent.act(step, from_json(obs), remainingOverageTime)
    return dict(action=actions.tolist())
if __name__ == "__main__":

    def read_input():
        """
        Reads input from stdin
        """
        try:
            return input()
        except EOFError as eof:
            raise SystemExit(eof)
    step = 0
    player_id = 0
    env_cfg = None
    i = 0
    while True:
        inputs = read_input()
        raw_input = json.loads(inputs)
        observation = Namespace(**dict(step=raw_input["step"], obs=raw_input["obs"], remainingOverageTime=raw_input["remainingOverageTime"], player=raw_input["player"], info=raw_input["info"]))
        if i == 0:
            env_cfg = raw_input["info"]["env_cfg"]
            player_id = raw_input["player"]
        i += 1
        actions = agent_fn(observation, dict(env_cfg=env_cfg))
        # send actions to engine
        print(json.dumps(actions))


Writing agent/main.py


In [None]:
print(model.policy)

NameError: name 'model' is not defined

# Test run

In [None]:
!pip install --upgrade luxai-s3

Collecting luxai-s3
  Downloading luxai_s3-0.2.1-py3-none-any.whl.metadata (253 bytes)
Collecting gymnax==0.0.8 (from luxai-s3)
  Downloading gymnax-0.0.8-py3-none-any.whl.metadata (19 kB)
Collecting tyro (from luxai-s3)
  Downloading tyro-0.9.16-py3-none-any.whl.metadata (9.4 kB)
Collecting gym>=0.26 (from gymnax==0.0.8->luxai-s3)
  Downloading gym-0.26.2.tar.gz (721 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m721.7/721.7 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting shtab>=1.5.6 (from tyro->luxai-s3)
  Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)
Downloading luxai_s3-0.2.1-py3-none-any.whl (35 kB)
Downloading gymnax-0.0.8-py3-none-any.whl (96 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.3/96.3 kB[0m [31m5.8 MB/s[0m et

In [None]:
!python agent/train.py

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 500      |
|    ep_rew_mean     | 72.5     |
| time/              |          |
|    fps             | 59       |
|    iterations      | 1        |
|    time_elapsed    | 34       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 500         |
|    ep_rew_mean          | -11         |
| time/                   |             |
|    fps                  | 56          |
|    iterations           | 2           |
|    time_elapsed         | 72          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.019829348 |
|    clip_fraction        | 0.27        |
|    clip_range           | 0.2      

In [None]:
!luxai-s3 agent/main.py agent/main.py --seed 21 --output=replay.html

[0m[31mplayer_1, agent/main.py:   th_object = th.load(file_content, map_location=device)
[0m[31mplayer_0, agent/main.py:   th_object = th.load(file_content, map_location=device)
[0mTime Elapsed:  24.59803533554077
Rewards:  {'player_0': array(1, dtype=int32), 'player_1': array(4, dtype=int32)}
[0m

In [None]:
import IPython # load the HTML replay
IPython.display.HTML(filename='replay.html')

# Create a submission

In [None]:
!cd agent && tar -czf submission.tar.gz *
!mv agent/submission.tar.gz .

In [None]:
## game, time step, turn
 - 대전 한판은 5개의 매치로 이루어짐. (1 game = 5 matches)
 - 1개의 매치에는 100개의 time-step이 있음. (time step = turn)
 - (뇌피셜: 턴을 주고 받는 방식이 아니라, 양쪽팀이 각자 행동을 결정하고 게임엔진이 계산을 돌려서 결과 처리하고 다음 턴으로 넘어감.
   즉, 상대방 행동결과를 보고 내 행동을 정하는 게 아니라, 상대방의 1턴 직전의 행동결과를 보고 내 행동을 정하는 것임.
  예컨대, 공격행동의 경우, 타겟을 정하고 때리려는데 상대 유닛이 우연히 다른 칸으로 이동하면 공격이 빗나갈 수 있는 것임. (왜냐하면 아래 진행순서에서 이동이 공격보다 먼저 이루어지기 때문임. 따라서, 상대방의 행동을 예상하여 내 행동을 정할 수도 있을 것임.)

## 게임의 목적
 - exploration과 exploitation간 적절한 균형 찾는 게 목적
 - 첫 1-2매치 동안 탐색을 더 많이 하고, 나머지 3-5매치 동안 최적 행동을 해야할 것임.

## 지도
 - 타일이 가로 세로 각 24개씩 놓은 2차원 평면
 - 5번의 매치에 적용되는 지도는 불변(첫 매치때 무작위로 결정된 지형이 5번째 매치까지 그대로 감.)

# Empty Tiles
 - Units와 Nodes가 여기로 이동 가능

# Asteroid Tiles
 - 횡단 불가, 점유 불가
 - 대칭을 이루며 지도상에 여기저기 떠다닐 수 있음.
 - 유닛 머리 위로 올라탈 수도 있으나, 유닛이 제거되는 것은 아님. 다른 Asteroid Tiles끼리는 못 겹침.

# Nebula Tiles
 - 횡단 가능
 - 대칭을 이루며 지도상에 여기저기 떠다닐 수 있음.
 - 시계 차감/차단 : 심지어는 유닛이 자기 자신도 못볼 수 있으나, 이동이 마비되는 건 아님.
 - 시계 차감치는 0-3중 무작위로 결정됨.
 - 에너지 소모 : 유닛 에너지를 감소시키며, 행동 턴 1번 중단
 - 에너지 소모치는 0-3중 무작위로 결정됨.

# Energy Nodes
 - 대칭을 이루며 지도상에 여기저기 떠다닐 수 있음.
 - 대칭을 이루며 무작위로 생성되며, 각 노드별로 적용되는 랜덤 함수가 달리 적용됨.
 - 각 노드의 함수는 거리에 대한 함수인데, 어떤 타일의 에너지값은 여러 에너지 노드와 그 타일간 거리를 변수로 하는 각 노드별 함수값을 다 더한 값임.

# Relic Nodes
 - Relic nodes 근처에 있는 특정 타일만 팀포인트 생성 (팀포인트 생성 못하는 타일도 존재)
 - 다른 relic nodes와 겹치는 타일이라고 해서 포인트가 더 나오는 건 아님. 타일 1개로 취급.
 - 어떤 타일에서 포인트가 나오는지는, relic node를 중심으로 하는 5*5 범위 중에서 무작위로 결정.
 - 타일 한개에 여러 유닛이 올라탈 수는 있으나, 타일 1개당 포인트 1점만 인정하므로 중복 탑승은 무의미함.(오히려 적의 공격에 취약하므로 비추)

## 유닛
 - 유닛은 한쪽편당 최대 16개이며, 고유 식별번호(ID) 0-15 부여
 - 유닛은 5개 방향으로 움직이는 함선(가운데, 상, 우, 하, 좌)
 - 유효거리 내 공격(sapping) 행동 가능
 - 아군 유닛끼리는 1개 타일에 여러개 같이 존재할 수 있음.
 - 유닛별로 에너지값(체력?)이 있으며, 기본 100에서 시작하고 최대 400까지 가능
 - 에너지 필드에 들어가면 에너지 재충전
 - Nebula tile 영향을 받으면 에너지가 감소하나 0 밑으로 가지는 않으며, 오직 적군 공격으로만 0미만(파괴) 가능
 - 지도 양쪽끝에서만 생성됨. (아군, 적군)
 - 제자리(center) 이동 외에는 이동시 에너지 2단위 소모
 - 지도 가장자리 밖으로 이동하려고 하면 이동은 안되면서 에너지만 소모

# 공격 행동(sap actions)
 - sap_range (4칸) 안에 있는 타일 중 하나를 공격목표로 골라서 공격
 - 공격목표에 있는 적군 유닛의 에너지를 sap_cost (10단위, 단 게임 시작시 변경 가능)만큼 차감시킴
 - 동시에 자기 에너지도 sap_cost만큼 소모 (즉, 한번에 적군 2마리 이상이 공격목표 타일에 여러마리 올라타고 있을 때 공격해야 이득)
 - 공격목표 타일은 아니지만 사방 8칸 이내에 있는 적군들도 한마리당 sap_cost * sap_dropoff_factor (10*0.5)만큼 에너지 감소
 - 목표 타일에 있던 적군이 다른 곳으로 이동해버리면 공격 실패하는 경우도 있음.

## 관측가능범위(시계)
 - 각 유닛의 시계범위(unit_sensor_range)는 유닛 위치에서 상하좌우 및 대각선으로 2~4칸까지임.
 - 지도상 모든 타일은 vision power 값이 매겨짐. vision power map을 따로 계산함.
 - 유닛이 있는 자리의 vision power값은 3이며, 한칸씩 멀어질 때마다 1씩 감소하고, 3칸째에서 0이 됨.
 - nebula tile이 있으면 vision power값 감소
 - nebula tile이 점한 자리에 어떤 유닛이나 노드가 있는지 볼 수 없으나, 그 너머는 볼 수 있는 경우가 있음.
   (예컨대 nebula tile의 시계차감치가 2라고 하면, 어떤 유닛의 1칸 이내에 nebula tile이 있으면 그 타일에 놓인 유닛이나 노드는 못보나,
   유닛에서 2칸 떨어진 타일이 nebula tile에 가려져 있지 않으면 그건 볼 수가 있는 것임.)
 - 1개 타일위에 여러 아군의 vision power가 겹치면 그 타일의 vision power값은 합산 가능함.

## 충돌 / 에너지 무효화 필드
 - 한 턴이 끝날 때, 1개 타일 위에 여러 아군, 적군이 올라타는 양상이 되면, 그 타일에서는 에너지 총합이 큰 쪽이 살아남고 다른 쪽은 전부 사망함.
 - 양쪽이 똑같으면(무승부이면), 양쪽 다 사망해서 사라짐.
 - 1개 유닛 상하좌우에는 energy void 필드가 형성되며, 거기에 적군이 있으면 에너지가 소모됨.
 - 매턴마다 각 팀별로 energy void field map을 작성함.
 - 각 유닛의 에너지에서 차감되는 에너지 양은 (energy void field map의 해당 타일값) / (그 타일위에 있는 유닛 갯수 총합). 즉, 여러개 몰려있으면 에너지 차감 희석됨.
 - 충돌로 사망하는 유닛들은 energy void field 작성에 영향을 주지 않음.
 - (결국, 한 타일에 몰려있는 것보다 분산되도록 하는 장치임.)

## 승리 조건
 - 5판 3승제 (1 game = 5 matches)
 - 각 매치 승리조건 : relic points가 상대팀보다 더 많아야 함. 두개가 동점이면, 에너지 총량이 큰 쪽이 이김. 그것도 동점이면 무작위로 선정

## 매치별 매 time step마다 아래 순서로 진행
 1. 충분한 에너지를 가진 유닛의 이동 실행
 2. 충분한 에너지를 가진 유닛의 공격 실행
 3. 충돌 결과 계산하고 energy void fields 적용
 4. 각 유닛의 에너지 보유량 업데이트 (에너지 필드와 nebula tile의 영향 고려)
 5. 유닛 생성 및 에너지 고갈된 유닛 제거
 6. 팀별 시계범위와 sensor mask값 계산 후 시계범위에서 그 값을 차감 처리
 7. 환경 물체(asteroids/nebula 타일 및 에너지 노드 등) 이동 처리
 8. 팀별 총점(team point) 업데이트



## 무작위 변수값의 범위
 - 무작위 변수는 게임마다 달라지며 한번 정해지면 5번의 매치 동안에는 불변

env_params_ranges = dict(
    map_type=[1],
    unit_move_cost=list(range(1, 6)), # list(range(x, y)) = [x, x+1, x+2, ... , y-1]
    unit_sensor_range=list(range(2, 5)),
    nebula_tile_vision_reduction=list(range(0,4)),
    nebula_tile_energy_reduction=[0, 0, 10, 25],
    unit_sap_cost=list(range(30, 51)),
    unit_sap_range=list(range(3, 8)),
    unit_sap_dropoff_factor=[0.25, 0.5, 1],
    unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
    # map randomizations
    nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05],
    energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
    energy_node_drift_magnitude=list(range(3, 6))
)


## 상기로 볼 때 매 게임마다 알아내야 하는 무작위 변수
 - 유닛 이동 비용(에너지 차감치)
 - 유닛 센서 범위
 - 네뷸라 타일 시계 차감치
 - 네뷸라 에너지 차감치
 - 유닛 공격 비용
 - 유닛 공격 유효범위
 - 유닛 공격 유탄범위
 - 에너지 저감필드 차감치
 - 네뷸라 타일 표류 속도
 - 에너지 노드 표류 속도
 - 에너지 노드 표류 거리

## 그외 알아내야 하는 사실 (게임엔진 코드분석 필요)
 - 3개의 타일, 2개의 노드, 이들의 갯수가 몇개인가?
 - 에너지 노드 및 렐릭 노드의 위치는 무엇인가? 특히 렐릭 노드 위치
 - 에너지 노드의 재충전치는?
 -



## 어디까지를 자율학습에 맡기고 어디까지를 규칙으로 제어할 것인지?
 -
 -
 -

## 적군의 위치 추적 함수
 - 적군이 한개 타일에 몇마리 몰려있는지 어떻게 알아내지?
 -

## 재충전여부 판단 함수
 - 에너지 잔고와 에너지필드간 거리, 거기까지의 이동시 소모되는 에너지를 비교해서 재충전하러 갈지, 자살할지 결정.
 - 자살도 스카웃으로 쓰거나, 적 인근에서 적 공격으로 남은 에너지 소모하거나, 렐릭노드 인근에 가서 죽을 때까지 팀포인트 계속 올리는 방식으로 자살

## 역할 교대 전략
 - 자기 에너지 잔고를 보다가 일정 수준에 이르면 자살공격하러 가고, 다른 유닛이

## 적 생성위치에서 공격하고 우리 생성위치 방어하는 전략
 - 적이 에너지 100을 갖고 태어나므로, 최소한 이것보다 높은 유닛들을 보내야 하는데, 적군이 알아차리면 적군도 몰려올 것이므로 이걸 판단해야함.