In [1]:
import json
import time

import flax
import flax.serialization
from luxai_s3.params import EnvParams
from luxai_s3.state import EnvState, serialize_env_actions, serialize_env_states
import jax
import jax.numpy as jnp

from luxai_s3.env import LuxAIS3Env

# from luxai_s3.wrappers import RecordEpisode

# Create the environment
env = LuxAIS3Env(auto_reset=False)
env_params = EnvParams(map_type=0, max_steps_in_match=50)

# Initialize a random key
key = jax.random.key(0)

# Reset the environment
key, reset_key = jax.random.split(key)
obs, state = env.reset(reset_key, params=env_params)
# Take a random action
key, subkey = jax.random.split(key)

env_params

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


EnvParams(max_steps_in_match=50, map_type=0, map_width=24, map_height=24, num_teams=2, match_count_per_episode=5, max_units=16, init_unit_energy=100, min_unit_energy=0, max_unit_energy=400, unit_move_cost=2, spawn_rate=5, unit_sap_cost=10, unit_sap_range=4, unit_sap_dropoff_factor=0.5, unit_energy_void_factor=0.125, max_energy_nodes=6, max_energy_per_tile=20, min_energy_per_tile=-20, max_relic_nodes=6, relic_config_size=5, fog_of_war=True, unit_sensor_range=2, nebula_tile_vision_reduction=1, nebula_tile_energy_reduction=0, nebula_tile_drift_speed=-0.05, energy_node_drift_speed=0.02, energy_node_drift_magnitude=5)

In [2]:
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Dict
import numpy as np
import networkx as nx
import numpy as np
from numpy.typing import ArrayLike, NDArray


def direction_to(src, target):
    ds = target - src
    dx = ds[0]
    dy = ds[1]
    if dx == 0 and dy == 0:
        return 0
    if abs(dx) > abs(dy):
        if dx > 0:
            return 2
        else:
            return 4
    else:
        if dy > 0:
            return 3
        else:
            return 1

In [3]:
def manhattan_distance(start_node: ArrayLike, end_node: ArrayLike) -> int:
    return abs(start_node[0] - end_node[0]) + abs(start_node[0] - end_node[0])

In [4]:
@dataclass(frozen=True)
class Coordinate:
    x: int
    y: int

    def __getitem__(self, index: int):
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        else:
            raise IndexError("Index out of range. Use 0 for 'x' and 1 for 'y'.")


class Graph:
    def __init__(self, map_width: int, map_height: int):
        self.map: nx.DiGraph = nx.grid_2d_graph(
            map_width, map_height, periodic=False, create_using=nx.DiGraph()
        )

In [5]:
G = Graph(24, 24)
obs_player_0 = obs["player_0"]

In [11]:
def update_nodes_from_obs(G: nx.DiGraph, obs: Dict[str, Any]):
    visible_tiles_indices = jax.numpy.indices(obs["sensor_mask"].shape)[
        :, obs["sensor_mask"]
    ]
    visible_energies = obs["map_features"]["energy"][obs["sensor_mask"]]
    visible_type_tiles = obs["map_features"]["tile_type"][obs["sensor_mask"]]
    for i in range(visible_tiles_indices.shape[1]):
        node_location = tuple(visible_tiles_indices[:, i].tolist())
        symmetric_node_location = tuple(np.array([23, 23]) - np.array(node_location))
        cost = visible_energies[i]
        G.add_node(node_location, energy=cost, type=visible_type_tiles[i])
        G.add_node(symmetric_node_location, energy=cost, type=visible_type_tiles[i])


update_nodes_from_obs(G.map, obs_to_dict(obs_player_0))

In [12]:
def remove_edges_from_asteroides(G: nx.DiGraph, obs: Namespace):
    for n in G.nodes:
        if G.nodes[n].get("type", 0) == 2:
            for neighbor in G.neighbors(n):
                G.remove_edge(n, neighbor)


remove_edges_from_asteroides(G.map, obs_player_0)

G.map.nodes.data()

NodeDataView({(0, 0): {}, (0, 1): {}, (0, 2): {}, (0, 3): {}, (0, 4): {}, (0, 5): {}, (0, 6): {}, (0, 7): {}, (0, 8): {}, (0, 9): {}, (0, 10): {}, (0, 11): {}, (0, 12): {}, (0, 13): {}, (0, 14): {}, (0, 15): {}, (0, 16): {}, (0, 17): {}, (0, 18): {}, (0, 19): {}, (0, 20): {}, (0, 21): {}, (0, 22): {}, (0, 23): {}, (1, 0): {}, (1, 1): {}, (1, 2): {}, (1, 3): {}, (1, 4): {}, (1, 5): {}, (1, 6): {}, (1, 7): {}, (1, 8): {}, (1, 9): {}, (1, 10): {}, (1, 11): {}, (1, 12): {}, (1, 13): {}, (1, 14): {}, (1, 15): {}, (1, 16): {}, (1, 17): {}, (1, 18): {}, (1, 19): {}, (1, 20): {}, (1, 21): {}, (1, 22): {}, (1, 23): {}, (2, 0): {}, (2, 1): {}, (2, 2): {}, (2, 3): {}, (2, 4): {}, (2, 5): {}, (2, 6): {}, (2, 7): {}, (2, 8): {}, (2, 9): {}, (2, 10): {}, (2, 11): {}, (2, 12): {}, (2, 13): {}, (2, 14): {}, (2, 15): {}, (2, 16): {}, (2, 17): {}, (2, 18): {}, (2, 19): {}, (2, 20): {}, (2, 21): {}, (2, 22): {}, (2, 23): {}, (3, 0): {}, (3, 1): {}, (3, 2): {}, (3, 3): {}, (3, 4): {}, (3, 5): {}, (3, 6): 

In [13]:
def update_edges_cost_from_energy_nodes(G: nx.DiGraph):
    for node_in, node_out in G.edges:
        # Get energy values of the two nodes
        energy_node_in = G.nodes[node_in].get("energy", 0)
        energy_node_out = G.nodes[node_out].get("energy", 0)

        # Calculate edge cost as |energy_u - energy_v| + 1000
        cost = energy_node_in - energy_node_out + 1000
        G.edges[node_in, node_out]["cost"] = cost


update_edges_cost_from_energy_nodes(G.map)

In [14]:
def get_closest_path(
    G: nx.DiGraph, start_node: ArrayLike, end_node: ArrayLike
) -> ArrayLike:
    return nx.astar_path(
        G, start_node, end_node, heuristic=manhattan_distance, weight="cost"
    )


list_nodes = get_closest_path(G.map, (0, 0), (16, 16))
direction_to(jnp.array(list_nodes[0]), jnp.array(list_nodes[1]))

2

In [33]:
class Agent:
    def __init__(self, player: str, env_cfg: Dict[str, Any]) -> None:
        np.random.seed(0)

        self.player = player
        self.opp_player = "player_1" if self.player == "player_0" else "player_0"
        self.team_id = 0 if self.player == "player_0" else 1
        self.opp_team_id = 1 if self.team_id == 0 else 0
        self.env_cfg = env_cfg

        self.map = Graph(env_cfg["map_width"], env_cfg["map_height"])
        print(self.map)

        self.unit_explore_locations = dict()

    def act(self, step: int, obs: Dict[str, Any], remainingOverageTime: int = 60):
        """implement this function to decide what actions to send to each available unit.

        step is the current timestep number of the game starting from 0 going up to max_steps_in_match * match_count_per_episode - 1.
        """
        update_nodes_from_obs(G.map, obs)

        unit_mask = np.array(obs["units_mask"][self.team_id])
        unit_positions = np.array(
            obs["units"]["position"][self.team_id]
        )  # shape (max_units, 2)
        available_unit_ids = np.where(unit_mask)[0]

        # Act randomly
        actions = np.zeros((self.env_cfg["max_units"], 3), dtype=int)

        for unit_id in available_unit_ids:
            unit_pos = unit_positions[unit_id]

            if unit_id not in self.unit_explore_locations:
                self.unit_explore_locations[unit_id] = (
                    np.random.randint(0, self.env_cfg["map_width"]),
                    np.random.randint(0, self.env_cfg["map_height"]),
                )

            print(self.unit_explore_locations[unit_id])
            list_nodes = get_closest_path(G.map, unit_pos, self.unit_explore_locations[unit_id])
            print(list_nodes)
            actions[unit_id] = [
                direction_to(jnp.array(list_nodes[0]), jnp.array(list_nodes[1])),
                0,
                0,
            ]
        return actions

In [34]:
from luxai_s3.state import EnvObs


def env_params_to_dict(env_params: EnvParams) -> Dict[str, Any]:
    return {
        "map_width": env_params.map_width,
        "map_height": env_params.map_height,
        "max_steps_in_match": env_params.max_steps_in_match,
        "max_units": env_params.max_units,
    }


def obs_to_dict(obs: EnvObs) -> Dict[str, Any]:
    return {
        "units": {
            "position": obs.units.position,
            "energy": obs.units.energy,
        },
        "units_mask": obs.units_mask,
        "sensor_mask": obs.sensor_mask,
        "map_features": {
            "energy": obs.map_features.energy,
            "tile_type": obs.map_features.tile_type,
        },
        "relic_nodes": obs.relic_nodes,
        "relic_nodes_mask": obs.relic_nodes_mask,
        "team_points": obs.team_points,
        "team_wins": obs.team_wins,
        "steps": obs.steps,
        "match_steps": obs.match_steps,
    }


Agent("player_0", env_params_to_dict(env_params)).act(0, obs_to_dict(obs["player_0"]))

<__main__.Graph object at 0x7ff23c5fb010>
(12, 15)


NodeNotFound: Source [9 7] is not in G

In [21]:
from typing import OrderedDict


agent_0 = Agent("player_0", env_params_to_dict(env_params))
agent_1 = Agent("player_1", env_params_to_dict(env_params))

for _ in range(19):
    key, subkey = jax.random.split(key)
    action = OrderedDict(
        {
            "player_0": agent_0.act(0, obs_to_dict(obs["player_0"])),
            "player_1": agent_1.act(0, obs_to_dict(obs["player_1"])),
        }
    )
    obs, state, reward, terminated, truncated, info = env.step(
        subkey, state, action, params=env_params
    )