In [1]:
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from trl import PPOTrainer, PPOConfig
import luxai_s3
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
from luxai_s3.params import EnvParams
import numpy as np
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model
import os
from accelerate import infer_auto_device_map
import gc
import copy
gc.enable()

#from stable_baselines3 import PPO
#import gymnasium as gym
#import gym

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ["FLASH_ATTENTION"] = "1"
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 64
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
# Configure CUDA memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8"
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False

# Enable gradient checkpointing
os.environ["PYTORCH_ATTENTION_USE_MEMORY_EFFICIENT_ATTENTION"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

In [3]:
data = load_dataset('openai/gsm8k', 'main')

In [4]:
train_data = data['train']
train_data

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [5]:
train_data.num_rows

7473

In [6]:
train_data.features

{'question': Value(dtype='string', id=None),
 'answer': Value(dtype='string', id=None)}

In [7]:
train_data[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [8]:
# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

In [9]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

In [10]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# ✅ Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# ✅ Ensure pad token is set correctly
tokenizer.pad_token = tokenizer.eos_token

# ✅ Optimized quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,  # ✅ Add nested quantization for better memory usage
    bnb_4bit_quant_storage="bfloat16"  # Enable quantized storage
)

autoconfig = AutoConfig.from_pretrained(model_name)

In [11]:
autoconfig.max_position_embeddings = 1000

In [12]:
#from lux.utils import direction_to
#import sys
import numpy as np

answer_format = {'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n', 'role': 'system'}

game_rules = {'content': 
"""
!!!GAME RULES!!!
Environment
Two teams compete against each other on a 2D map in a best of 5 match sequence (called a game) with each match lasting 100 time steps. Both teams have a pool of units they can control to gain points around the map while also trying to prevent the other team from doing the same.
A core objective of this game is a balanced strategy of exploration and exploitation. It is recommended to explore more in the first match or two before leveraging gained knowledge about the map and opponent behavior to win the latter matches.
Map
The map is a randomly generated 2D grid of size 24x24. There are several core features that make up the map: Unknown Tiles, Empty Tiles, Asteroid Tiles, Nebula Tiles, Energy Tiles, Relic Nodes, and Relic Fragments. Notably, in a game, the map is never regenerated completely between matches. Whatever is the state of the map at the end of one match is what is used for the next match.
Unknown Tiles
These are tiles that are not visible. They can be any type of tile but are not visible to you until a unit is within sensor range of that tile.
Empty Tiles
These are empty tiles in space without anything special about them. Units and tiles can be placed/move onto these tiles.
Asteroid Tiles
Asteroid tiles are impassable tiles that block anything from moving/spawning onto them. These tiles might move around over time during the map in a symmetric fashion. Sometimes asteroid tiles might move on top of existing units. In the game the unit is not removed as a result of this and can still take actions and move around provided there is an non asteroid tile adjacent to it.
Nebula Tiles
Nebula tiles are passable tiles with a number of features. These tiles might move around over time during the map in a symmetric fashion.
Vision Reduction: Nebula tiles can reduce/block vision of units. Because of vision reduction it is even possible for a unit to be unable to see itself while still being able to move! See Vision section below for more details on how team vision is determined. All nebula tiles have the same vision reduction value called params.nebula_tile_vision_reduction which is randomized from 0 to 3.
Energy Reduction: Nebula tiles can reduce the energy of units that end their turn on them. All nebula tiles have the same energy reduction value called params.nebula_tile_energy_reduction.
In Map Tile Types, empty tiles are represented by 0, asteroid tiles by 1, nebula tiles by 2, and unknown tiles by -1.
Energy Tiles
Energy tiles are mysterious objects that emit energy fields which can be harvested by units. These tiles might move around over time during the map in a symmetric fashion. In code, what actually occurs in each game is energy tiles are randomly generated on the map symmetrically and a random function is generated for each tile. Each energy tile's function is a function of distance. The energy value of a tile on a map is determined to be the sum of the energy tile functions applied to the distance between tile and each tile.
Relic Nodes
Relic nodes are objects in space that enable ships to go near it to gain team points. These relic nodes however are ancient and thus fragmented. As a result, only certain tiles near the relic nodes when a friendly ship is on it will gain points. The tiles that yield points are always hidden and can only be discovered by trial and error by moving around the relic nodes. Relic node positions themselves can be observed if withins sensor range. The tiles around relic nodes can overlap with tiles of other relic nodes but will not yield extra points if that occurs and is treated as one tile.
In code, a random 5x5 configuration / mask centered on the relic node is generated indicating which tiles yield points and which don't. Multiple ships can stack on one tile but will only yield at most one point per tile. Note that ship stacking can be risky due to the sapping action.
Units
Units in the game are ships that can move one tile in 5 directions (center, up, right, down, left) and perform a ranged energy sapping action. Units can overlap with other friendly units if they move onto the same tile. Units have a energy property which determines whether they can perform actions and start with 100 energy and can have a max of 400 energy. Energy is recharged via the energy field of the map. They always spawn on one of the two corners of the map depending on which team they are on.
Note that nebula tiles and energy fields can modify the energy of a unit when it is on that tile. However they can never reduce the energy of a unit below 0, only opposing units can do that which will then remove the unit from the game to be respawned at a later timestep. Unit IDs range from 0 to params.max_units - 1 for each team, and are recycled when units are spawned in if a previous one was removed.
Move Actions
All move actions except moving center cost params.unit_move_cost energy to perform. Moving center is always free (a zero action). Attempting to move off the edge of the map results in no movement occuring but energy is still consumed. Units cannot move onto tiles with an impassible feature like an asteroid tile.
Sap Actions
The sap action lets a unit target a specific tile on the map within a range called params.unit_sap_range and reduces the energy of each opposition unit on the target tile by params.unit_sap_cost while also costing unit_sap_cost energy to use. Moreover, any opposition units on the 8 adjacent tiles to the target tile are also sapped and their energy is reduced by params.unit_sap_cost * params.unit_sap_dropoff_factor.
Sap actions are submitted to the game engine / environment as a delta x and delta y value relative to the unit's current position. The delta x and delta y value magnitudes must both be <= params.unit_sap_range, so the sap range is a square around the unit.
Generally sap actions are risky since a single miss means your ships lose energy while the opponent does not. The area of effect can mitigate this risk somewhat depending on game parameters. Sap actions can however prove very valuable when opposition ships are heavily stacked and get hit as sapping the stacked tile hits every ship on the tile.
Vision
A team's vision is the combined vision of all units on that team. Team vision is essentially a boolean mask / matrix over the 2D map indicating whether that tile's information is visible to the team. In this game, you can think of each unit having an "eye in the sky" sattelite that is capturing information about the units surroundings, but this sattelite has reduced accuracy the farther away the tile is from the unit.
To determine which map tiles are visible to a team, we compute a vision power value for each tile on the map. For each unit on a team, we check each tile within the unit's sensor range and add 1 + params.unit_sensor_range - min(dx, dy) to the vision power map at tile (x+dx, y+dy) where (x,y) is the unit's position and (dx, dy) is the offset from the unit's position and abs(dx) <= params.unit_sensor_range and abs(dy) <= params.unit_sensor_range.
Nebula tiles have a vision reduction value of params.nebula_tile_vision_reduction. This number is reduced from every tile's vision power if that tile is a nebula tile.
When a unit is near a nebula tile, it can't see details about some nebula tiles, but it can see tiles beyond nebula tiles.
When a unit is inside a nebula tile, if the nebula vision reduction is powerful enough, the unit cannot even see itself or any other nebula tiles.
Unit vision can overlap and increase the vision power linearly, which can help handle the situations like above when you cannot see anything.
Collisions / Energy Void Fields
In close quarters, units can impact each other in two ways, via direct collisions or by being adjacent to each other and sapping energy via their energy void fields.
In the event of two or more units from opposing teams occupy the same tile at the end of a turn, the team with the highest aggregate energy among its units on that tile survive, while the units of the opposing teams are removed from the game. If it is a tie, all units are removed from the game.
Furthermore, each unit generates an "energy void" field around itself that affects all cardinally (up, right, down left) adjacent opposition units. To determine how exactly each unit is affected by these energy void fields, we compute a 2D map for each team indicating the energy void strength at each tile. A unit contributes to tiles adjacent to itself a energy void strength equal to the total amount of energy the unit has at the start of the turn multiplied by params.unit_energy_void_factor rounded down. After a energy void map is computed for each team, a unit's energy is reduced by the energy void strength of the tile it is on divided by the total number of units on that tile. Note that units removed due to collisions do not contribute to the energy void field.
The energy void fields generally encourage stacking units to better spread out energy sapped by energy void fields of opposition units.
Win Conditions
To win the game, the team must have won the most matches out of the 5 match sequence.
To win a match, the team must have gained more relic points than the other team at the end of the match. If the relic points scores are tied, then the match winner is decided by who has more total unit energy. If that is also tied then the winner is chosen at random.
Match Resolution Order
At each time step of a match, we run the following steps in order:
1. Move all units that have enough energy to move
2. Execute the sap actions of all units that have enough energy to do so
3. Resolve collisions and apply energy void fields
4. Update the energy of all units based on their position (energy fields and nebula tiles)
5. Spawn units for all teams. Remove units that have less than 0 energy.
6. Determine the team vision / sensor masks for all teams and mask out observations accordingly
7. Environment objects like asteroids/nebula tiles/energy tiles move around in space
8. Compute new team points
Note that each match runs for params.max_steps_in_match steps and you take that many actions that affect the game. However, you will actually receive params.max_steps_in_match + 1 frames of observations since the very first frame will either be empty or the previous match's final observation (actions on these observations will not do anything).
Game Parameters
The full set of game parameters can be found here in the codebase.
Randomized Game Parameters / Map Generation
There are a number of randomized game paramteres which can modify and even disable/enable certain game mechanics. None of these game parameters are changed between matches in a game. The majority of these parameters are also not given to the teams themselves and must be discovered through exploration.
env_params_ranges = dict(
    map_type=[1],
    unit_move_cost=list(range(1, 6)), # list(range(x, y)) = [x, x+1, x+2, ... , y-1]
    unit_sensor_range=list(range(2, 5)),
    nebula_tile_vision_reduction=list(range(0,4)),
    nebula_tile_energy_reduction=[0, 0, 10, 25],
    unit_sap_cost=list(range(30, 51)),
    unit_sap_range=list(range(3, 8)),
    unit_sap_dropoff_factor=[0.25, 0.5, 1],
    unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
    # map randomizations
    nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05],
    energy_tile_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
    energy_tile_drift_magnitude=list(range(3, 6))
)
These parameter ranges (and other parameters) are subject to change in the beta phase of this competition as we gather feedback and data.
There are 6 actions that can be taken by a unit in this game: 0 = center, 1 = up, 2 = right, 3 = down, 4 = left, 5 = sap.
So your answer should be in this format:
Unit 0: action(from 0 to 5)
Unit 1: action(from 0 to 5)
Unit 2: action(from 0 to 5)
Unit 3: action(from 0 to 5)
Unit 4: action(from 0 to 5)
Unit 5: action(from 0 to 5)
Unit 6: action(from 0 to 5)
Unit 7: action(from 0 to 5)
Unit 8: action(from 0 to 5)
Unit 9: action(from 0 to 5)
Unit 10: action(from 0 to 5)
Unit 11: action(from 0 to 5)
Unit 12: action(from 0 to 5)
Unit 13: action(from 0 to 5)
Unit 14: action(from 0 to 5)
Unit 15: action(from 0 to 5)
However, if you choose to sap(5), you should provide the direction of the sap, which is a pair of integers (dx, dy) where dx and dy are the relative coordinates of the target tile from the unit's current position. The magnitudes of dx and dy must be less than or equal to the unit's sap range. For example, if unit 3 is at (5, 5) and you want to sap the tile at (7, 7), your answer for unit 3 should be 5, 2, 2.
Also, you can only take actions for the units that are available to you in the current timestep. If you take an action for a unit that is not available to you, the game engine will ignore that action.
Additionally, you can only take one action per unit per timestep. If you take multiple actions for a single unit in a timestep, the game engine will ignore all but the first action.
So, below is an example of a valid answer:
Unit 0: 1 # move up
Unit 1: 2 # move right
Unit 2: 5, 2, 2 # sap at (2, 2) relative to unit 2's current position
Unit 3: 0 # center
Unit 4: 5, 1, 1 # sap at (1, 1) relative to unit 4's current position
Unit 5: 5, -1, -2 # sap at (-1, -2) relative to unit 5's current position
Unit 6: 5, -2, 2 # sap at (-2, 2) relative to unit 6's current position
Unit 7: 5, 0, 0 # sap at (0, 0) relative to unit 7's current position
Unit 8: 4 # move left
Unit 9: 0 # center
Unit 10: 3 # move down
Unit 11: 2 # move right
Unit 12: 1 # move up
Unit 13: 0 # center
Unit 14: 5, -4, 5 # sap at (-4, 5) relative to unit 14's current position
Unit 15: 5, 3, -3 # sap at (3, -3) relative to unit 15's current position
"""
, 'role': 'system'
}


# Helper functions
def manhattan_distance(pos1, pos2):
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])

def absolute_distance(pos1, pos2):
    return max(abs(pos1[0] - pos2[0]), abs(pos1[1] - pos2[1]))

def find_opposite_corner_coords(array, row, col):
    """
    Given a 2D array and a coordinate (row, col), this function returns the opposite corner coordinates.

    :param array: 2D list or NumPy array
    :param row: Row index of the given point
    :param col: Column index of the given point
    :return: (row', col') - Opposite corner coordinates
    """
    num_rows = len(array)
    num_cols = len(array[0]) if num_rows > 0 else 0

    # Opposite coordinates
    opp_row = num_rows - 1 - row
    opp_col = num_cols - 1 - col

    return (opp_row, opp_col)


class Agent():
    def __init__(self, player: str, env_cfg) -> None:
        self.player = player
        self.enemy_player = "player_1" if self.player == "player_0" else "player_0"
        self.team_id = 0 if self.player == "player_0" else 1
        self.enemy_team_id = 1 if self.team_id == 0 else 0
        #np.random.seed(0)
        self.env_cfg = env_cfg
        #self.min_unit_sap_dropoff_factor = 1
        #self.min_sap_power = self.unit_sap_cost * self.min_unit_sap_dropoff_factor
        self.map_height = env_cfg["map_height"]
        self.map_width = env_cfg["map_width"]
        self.my_spawn_location = None
        self.enemy_spawn_location = None
        self.first_spawn = False
        self.llm_input = None

        self.map_explored_status = np.zeros((self.map_height, self.map_width), dtype=int)

        

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            trust_remote_code=True,
            device_map="auto",  # Let Accelerate handle device placement
            #device_map={"0": "14GiB", "cpu": "64GiB"},  # Let Accelerate handle device placement
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            config=autoconfig,
            attn_implementation="flash_attention_2",
            #use_cache=False,  # Disable KV cache during training
            low_cpu_mem_usage=True
        )

        # Enable memory efficient features
        self.model.gradient_checkpointing_enable()
        self.model.enable_input_require_grads()

        #self.model = torch.compile(self.model)

        # device_map = infer_auto_device_map(model, max_memory={0: "14GiB", "cpu": "64GiB"})

        # model = AutoModelForCausalLM.from_pretrained(
        #     model_name,
        #     trust_remote_code=True,
        #     #device_map="auto",  # Let Accelerate handle device placement
        #     device_map=device_map,  # Let Accelerate handle device placement
        #     quantization_config=bnb_config,
        #     torch_dtype=torch.bfloat16,
        #     config=autoconfig,
        #     attn_implementation="flash_attention_2"
        # )

        # peft_config = LoraConfig(
        #     r=8,
        #     lora_alpha=16,
        #     lora_dropout=0.1,
        #     target_modules=["q_proj", "v_proj"],
        #     bias="none",
        #     task_type="CAUSAL_LM"
        # )
        
        
        # self.model = get_peft_model(model, peft_config)#.to('cpu')
        # self.model.gradient_checkpointing_enable()

        # self.relic_node_positions = []
        # self.discovered_relic_nodes_ids = set()
        # self.unit_explore_locations = dict()

    def prep_llm_input(self, env_cfg, obs):

        game_state_info = "\n!!!GAME STATE INFORMATION!!!"

        ### env_cfg information
        env_cfg_info = "\nENVIRONMENT CONFIGURATION:"
        max_units = f"\nMaximum possible number of units for each team: {env_cfg['max_units']}."
        match_count_per_episode = f"\nNumber of matches per game: {env_cfg['match_count_per_episode']}."
        max_steps_in_match = f"\nNumber of steps per match: {env_cfg['max_steps_in_match']}."
        map_height = f"\nMap height: {env_cfg['map_height']}."
        map_width = f"\nMap width: {env_cfg['map_width']}."
        num_teams = f"\nNumber of teams: {env_cfg['num_teams']}."
        unit_move_cost = f"\nUnit move energy cost: {env_cfg['unit_move_cost']}."
        unit_sap_cost = f"\nUnit sap energy cost: {env_cfg['unit_sap_cost']}."
        unit_sap_range = f"\nUnit sap range: {env_cfg['unit_sap_range']}."
        unit_sensor_range = f"\nUnit sensor range: {env_cfg['unit_sensor_range']}."

        ### obs information
        obs_info = "\nOBSERVATION:"
        unit_position_warning = "\nUnit position: -1, -1 means the unit is not spawned yet or not visible."

        # unit positions
        unit_position_info = "\nUnit Positions:"
        obs_my_unit_positions = obs['units']['position'][self.team_id]
        my_unit_positions_list = []
        for i in range(obs_my_unit_positions.shape[0]):
            pos = obs_my_unit_positions[i]
            my_unit_positions_list.append(f"\nMy unit {i} position: {pos[0]}, {pos[1]}.")
        my_unit_positions = "".join(my_unit_positions_list)

        obs_enemy_unit_positions = obs['units']['position'][self.enemy_team_id]
        enemy_unit_positions_list = []
        for i in range(obs_enemy_unit_positions.shape[0]):
            pos = obs_enemy_unit_positions[i]
            enemy_unit_positions_list.append(f"\nEnemy unit {i} position: {pos[0]}, {pos[1]}.")
        enemy_unit_positions = "".join(enemy_unit_positions_list)

        # unit energys
        unit_energy_info = "\nUnit Energys:"
        obs_my_unit_energys = obs['units']['energy'][self.team_id]
        my_unit_energys_list = []
        for i in range(obs_my_unit_energys.shape[0]):
            energy = obs_my_unit_energys[i]
            my_unit_energys_list.append(f"\nMy unit {i} energy: {energy}.")
        my_unit_energys = "".join(my_unit_energys_list)

        obs_enemy_unit_energys = obs['units']['energy'][self.enemy_team_id]
        enemy_unit_energys_list = []
        for i in range(obs_enemy_unit_energys.shape[0]):
            energy = obs_enemy_unit_energys[i]
            enemy_unit_energys_list.append(f"\nEnemy unit {i} energy: {energy}.")
        enemy_unit_energys = "".join(enemy_unit_energys_list)

        # unit masks
        unit_mask_info = "\nUnit Visibility:"
        obs_my_units_mask = obs['units_mask'][self.team_id]
        my_units_mask_list = []
        for i in range(obs_my_units_mask.shape[0]):
            mask = obs_my_units_mask[i]
            my_units_mask_list.append(f"\nMy unit {i} visibility: {mask}.")
        my_units_mask = "".join(my_units_mask_list)

        obs_enemy_units_mask = obs['units_mask'][self.enemy_team_id]
        enemy_units_mask_list = []
        for i in range(obs_enemy_units_mask.shape[0]):
            mask = obs_enemy_units_mask[i]
            enemy_units_mask_list.append(f"\nEnemy unit {i} visibility: {mask}.")
        enemy_units_mask = "".join(enemy_units_mask_list)

        # sensor mask
        sensor_mask_info = "\nSensor Mask:"
        obs_sensor_mask = obs['sensor_mask']
        sensor_mask_list = []
        for i in range(obs_sensor_mask.shape[0]):
            sensor_mask_list.append(f"\nSensor mask row {i}: {str(obs_sensor_mask[i]).replace("[", "").replace("]", "")}.")
        sensor_mask = "".join(sensor_mask_list)

        # map features - energy
        map_features_energy_info = "\nMap Energys:"
        obs_map_features_energy = obs['map_features']['energy']
        map_features_energy_list = []
        for i in range(obs_map_features_energy.shape[0]):
            map_features_energy_list.append(f"\nMap energy row {i}: {str(obs_map_features_energy[i]).replace("[", "").replace("]", "")}.")
        map_features_energy = "".join(map_features_energy_list)

        # map features - tile_type
        map_features_tile_type_info = "\nMap Tile Types:"
        obs_map_features_tile_type = obs['map_features']['tile_type']
        map_features_tile_type_list = []
        for i in range(obs_map_features_tile_type.shape[0]):
            map_features_tile_type_list.append(f"\nMap tile type row {i}: {str(obs_map_features_tile_type[i]).replace("[", "").replace("]", "")}.")
        map_features_tile_type = "".join(map_features_tile_type_list)

        # relic nodes
        relic_node_info = "\nRelic Node positions:"
        relic_node_warning = "\nRelic node position: -1, -1 means the relic node is not yet discoverd."
        obs_relic_nodes = obs['relic_nodes']
        relic_nodes_list = []
        for i in range(obs_relic_nodes.shape[0]):
            relic_nodes_list.append(f"\nRelic node {i} position: {obs_relic_nodes[i][0]}, {obs_relic_nodes[i][1]}.")
        relic_nodes = "".join(relic_nodes_list)

        # relic nodes mask
        relic_node_mask_info = "\nRelic Node Visibility:"
        obs_relic_nodes_mask = obs['relic_nodes_mask']
        relic_nodes_mask_list = []
        for i in range(obs_relic_nodes_mask.shape[0]):
            relic_nodes_mask_list.append(f"\nRelic node {i} visibility: {obs_relic_nodes_mask[i]}.")
        relic_nodes_mask = "".join(relic_nodes_mask_list)

        # team points
        my_team_points = f"\nMy current point for this match is: {obs['team_points'][self.team_id]}."
        enemy_team_points = f"\nEnemy current point for this match is: {obs['team_points'][self.enemy_team_id]}."

        # team wins
        my_team_wins = f"\nI have won {obs['team_wins'][self.team_id]} matches."
        enemy_team_wins = f"\nEnemy has won {obs['team_wins'][self.enemy_team_id]} matches."

        # steps
        steps = f"\nThis is step {obs['steps']} of the game."

        # match_steps
        match_steps = f"\nThis is step {obs['match_steps']} of the match."

        if self.enemy_spawn_location is None:
            enemy_spawn_location_warning = "\nEnemy spawn location: not yet discovered."
        else:
            enemy_spawn_location_warning = f"\nEnemy spawn location: {self.enemy_spawn_location[0]}, {self.enemy_spawn_location[1]}."
        
        all_variables = "".join([
            game_state_info, env_cfg_info, max_units, match_count_per_episode, max_steps_in_match, map_height, map_width, num_teams, unit_move_cost, unit_sap_cost, unit_sap_range, unit_sensor_range,
            obs_info,
            unit_position_warning, unit_position_info, my_unit_positions, enemy_unit_positions,
            unit_energy_info, my_unit_energys, enemy_unit_energys,
            unit_mask_info, my_units_mask, enemy_units_mask,
            sensor_mask_info, sensor_mask,
            map_features_energy_info, map_features_energy,
            map_features_tile_type_info, map_features_tile_type,
            relic_node_info, relic_node_warning, relic_nodes,
            relic_node_mask_info, relic_nodes_mask,
            my_team_points, enemy_team_points, my_team_wins, enemy_team_wins, steps, match_steps, enemy_spawn_location_warning
        ])

        return {'content':all_variables, 'role':'user'}

    def act(self, step: int, obs, remainingOverageTime: int = 60):
        """implement this function to decide what actions to send to each available unit. 
        
        step is the current timestep number of the game starting from 0 going up to max_steps_in_match * match_count_per_episode - 1.
        """

        # units
        unit_positions = np.array(obs["units"]["position"][self.team_id]) # shape (max_units, 2)
        # enemy_unit_positions = np.array(obs["units"]["position"][self.enemy_team_id]) # shape (max_units, 2)

        # unit_energys = np.array(obs["units"]["energy"][self.team_id]) # shape (max_units, 1)
        # enemy_unit_energys = np.array(obs["units"]["energy"][self.enemy_team_id]) # shape (max_units, 1)

        # units_mask
        unit_mask = np.array(obs["units_mask"][self.team_id]) # shape (max_units, )
        # enemy_unit_mask = np.array(obs["units_mask"][self.enemy_team_id]) # shape (max_units, )

        # sensor_mask
        # sensor_mask = obs['sensor_mask']

        # map_features
        map_features = obs['map_features']
        # current_map_energy = map_features['energy']
        current_map_tile_type = map_features['tile_type']

        # update map explored status
        self.map_explored_status[current_map_tile_type != -1] = 1
        
        # observed_relic_node_positions = np.array(obs["relic_nodes"]) # shape (max_relic_nodes, 2)
        # observed_relic_nodes_mask = np.array(obs["relic_nodes_mask"]) # shape (max_relic_nodes, )
        # team_points = np.array(obs["team_points"]) # points of each team, team_points[self.team_id] is the points of the your team
        
        # ids of units you can control at this timestep
        available_unit_ids = np.where(unit_mask)[0]
        # enemy_available_unit_ids = np.where(enemy_unit_mask)[0]

        if available_unit_ids.shape[0] == 0:
            pass
        else:
            if self.first_spawn == False:
                first_unit_id = available_unit_ids[0]
                first_unit_pos = unit_positions[first_unit_id]
                self.my_spawn_location = (first_unit_pos[0], first_unit_pos[1])
                self.enemy_spawn_location = find_opposite_corner_coords(self.map_explored_status, first_unit_pos[0], first_unit_pos[1])
                self.first_spawn = True
        
        # visible relic nodes
        # visible_relic_node_ids = set(np.where(observed_relic_nodes_mask)[0])

        self.llm_input = self.prep_llm_input(self.env_cfg, obs)

        actions = np.zeros((self.env_cfg["max_units"], 3), dtype=int)
                
        return actions

In [13]:
env = RecordEpisode(
    LuxAIS3GymEnv(numpy_output=True)
)



In [14]:
obs_all, info = env.reset()

In [15]:
agent0 = Agent("player_0", info['params'])
#agent1 = Agent("player_1", info['params'])

In [16]:
actions0 = agent0.act(obs_all['player_0']['steps'], obs_all['player_0'])
#actions1 = agent1.act(obs_all['player_1']['steps'], obs_all['player_1'])

In [17]:
temp_dataset = Dataset.from_list([{'prompt': [answer_format, game_rules, agent0.llm_input]}])
temp_dataset[0]

{'prompt': [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n',
   'role': 'system'},
  {'content': '\n!!!GAME RULES!!!\nEnvironment\nTwo teams compete against each other on a 2D map in a best of 5 match sequence (called a game) with each match lasting 100 time steps. Both teams have a pool of units they can control to gain points around the map while also trying to prevent the other team from doing the same.\nA core objective of this game is a balanced strategy of exploration and exploitation. It is recommended to explore more in the first match or two before leveraging gained knowledge about the map and opponent behavior to win the latter matches.\nMap\nThe map is a randomly generated 2D grid of size 24x24. There are several core features that make up the map: Unknown Tiles, Empty Tiles, Asteroid Tiles, Nebula Tiles, Energy Tiles, Relic Nodes, and Relic Fragments. Notably, in a game, the map is never regenerated completely bet

In [18]:
agent0.llm_input

{'content': '\n!!!GAME STATE INFORMATION!!!\nENVIRONMENT CONFIGURATION:\nMaximum possible number of units for each team: 16.\nNumber of matches per game: 5.\nNumber of steps per match: 100.\nMap height: 24.\nMap width: 24.\nNumber of teams: 2.\nUnit move energy cost: 4.\nUnit sap energy cost: 45.\nUnit sap range: 3.\nUnit sensor range: 2.\nOBSERVATION:\nUnit position: -1, -1 means the unit is not spawned yet or not visible.\nUnit Positions:\nMy unit 0 position: -1, -1.\nMy unit 1 position: -1, -1.\nMy unit 2 position: -1, -1.\nMy unit 3 position: -1, -1.\nMy unit 4 position: -1, -1.\nMy unit 5 position: -1, -1.\nMy unit 6 position: -1, -1.\nMy unit 7 position: -1, -1.\nMy unit 8 position: -1, -1.\nMy unit 9 position: -1, -1.\nMy unit 10 position: -1, -1.\nMy unit 11 position: -1, -1.\nMy unit 12 position: -1, -1.\nMy unit 13 position: -1, -1.\nMy unit 14 position: -1, -1.\nMy unit 15 position: -1, -1.\nEnemy unit 0 position: -1, -1.\nEnemy unit 1 position: -1, -1.\nEnemy unit 2 positio

In [19]:
previous_score = 0.0

In [20]:
current_score = obs_all['player_0']['team_points'][0]
current_score

0

In [21]:
reward_score = current_score - previous_score
reward_score

0.0

In [22]:
# Reward functions
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]

    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]

    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
        count -= len(text.split("<reasoning>\n")[0])*0.001
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001

    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]

    return [count_xml(c) for c in contents]

def answer_format_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    answers = [extract_xml_answer(r) for r in responses]
    answer_pattern = re.compile(r"^Unit \d+: (0|1|2|3|4|5(, -?\d+, -?\d+)?) # (move up|move right|move down|move left|center|sap at \(-?\d+, -?\d+\) relative to unit \d+'s current position)$")

    scores = []
    for answer in answers:
        answer_score = 0.0
        for action in answer.split("\n"):
            if answer_pattern.match(action):
                answer_score += 0.5 / 16
                unit_number = int(action.split(":")[0].split(" ")[1])
                if unit_number < 0 or unit_number > 15:
                    answer_score -= 0.1 / 16
            if len(action) != 16:
                answer_score -= 0.1
        scores.append(answer_score)

    return scores

def point_gain_reward_func(completions, **kwargs) -> list[float]:

    return [reward_score for completion in completions]

In [23]:
dataset

Dataset({
    features: ['question', 'answer', 'prompt'],
    num_rows: 7473
})

In [24]:
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': '72',
 'prompt': [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n',
   'role': 'system'},
  {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
   'role': 'user'}]}

In [25]:
temp_dataset = Dataset.from_list([{'prompt': [answer_format, game_rules, agent0.llm_input]}])
temp_dataset

Dataset({
    features: ['prompt'],
    num_rows: 1
})

In [26]:
temp_dataset[0]

{'prompt': [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n',
   'role': 'system'},
  {'content': '\n!!!GAME RULES!!!\nEnvironment\nTwo teams compete against each other on a 2D map in a best of 5 match sequence (called a game) with each match lasting 100 time steps. Both teams have a pool of units they can control to gain points around the map while also trying to prevent the other team from doing the same.\nA core objective of this game is a balanced strategy of exploration and exploitation. It is recommended to explore more in the first match or two before leveraging gained knowledge about the map and opponent behavior to win the latter matches.\nMap\nThe map is a randomly generated 2D grid of size 24x24. There are several core features that make up the map: Unknown Tiles, Empty Tiles, Asteroid Tiles, Nebula Tiles, Energy Tiles, Relic Nodes, and Relic Fragments. Notably, in a game, the map is never regenerated completely bet

In [27]:
output_dir="outputs/DeepSeek-R1-Distill-Qwen-1.5B-PPO"
run_name="DeepSeek-R1-Distill-Qwen-1.5B-PPO-20250211_02"

training_args = PPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    batch_size=1,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=1,
    bf16=True,
    gradient_accumulation_steps=8,
    num_sample_generations=0,
    max_grad_norm=0.1,
    num_train_epochs=1,
    save_steps=100,
    log_on_each_node=False,
    report_to="none",
    num_ppo_epochs=1,
    cliprange=0.2,
    vf_coef=1.0,
    kl_coef=0.01,
    prediction_loss_only=True,
    gradient_checkpointing=True,
    #reward_model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    optim="adamw_torch_fused",
    #use_cpu=True,
    max_steps=1,
    #eval_steps=1,
    #eval_accumulation_steps=8,
    #accelerator_config={"num_processes": 8},
    per_device_train_batch_size=1,
    #per_device_eval_batch_size=1,
    torch_empty_cache_steps=1,
    #torch_compile=True,
    #torch_compile_mode="default"
)

In [28]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [29]:
device_map = infer_auto_device_map(
    agent0.model,
    max_memory={0: "13GiB", 'cpu': "64GiB"},  # Adjust based on your GPU memory
    no_split_module_classes=["DeepSeekBlock"]  # Adjust based on model architecture
)
device_map

OrderedDict([('', 0)])

In [30]:
agent0.model.config.use_cache = False

In [31]:
def tokenize_fn(examples):
    return tokenizer(
        examples["question"],  # Ensure these keys exist in your dataset
        #examples["answer"],
        padding="max_length",  # Ensure uniform length
        truncation=True,  # Prevent excessive token length issues
        max_length=512,  # Adjust based on your model's requirements
        return_tensors="pt"
    )#.to('cuda')

In [32]:
tokenized_dataset = train_data.map(tokenize_fn, batched=True)

In [33]:
tokenized_dataset = tokenized_dataset.remove_columns(['question', 'answer'])

In [34]:
tokenized_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 7473
})

In [35]:
unknown_input = None

In [36]:
from transformers import AutoModelForSequenceClassification

reward_model_name = "gpt2"
reward_model = AutoModelForSequenceClassification.from_pretrained(
    reward_model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_cache=False,
    low_cpu_mem_usage=True,
)
#reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
reward_model.gradient_checkpointing_enable()
reward_model.enable_input_require_grads()

#reward_model = torch.compile(reward_model)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [37]:
def score_fn(something):
    #print(something)
    global unknown_input
    #unknown_input = copy.deepcopy(something)  # Save for debugging
    unknown_input = something  # Save for debugging

    #print("Original shape:", something.shape)  # Debugging

    # Reduce hidden dimensions → shape: (batch_size, 1)
    something = something.mean(dim=2)#.unsqueeze(-1)

    #print("Reduced shape:", something.shape)  # Debugging

    if torch.isnan(something).any() or torch.isinf(something).any():
        raise ValueError("score_fn produced NaN or Inf values!")
    
    #return torch.zeros(something.shape[0]).to('cuda', dtype=torch.bfloat16)

    return something#.to('cuda', dtype=torch.bfloat16)  # Ensure it's on GPU

In [38]:
agent0.model.score = score_fn

In [39]:
from Modified_PPO_Trainer.ppo_trainer64 import PPOTrainer

In [40]:
trainer = PPOTrainer(
    model=agent0.model,
    value_model=agent0.model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset,
    ref_model=None,
    reward_model=agent0.model,
)

In [41]:
trainer.train()

===training policy===


  0%|          | 0/935 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


torch.Size([8, 512])




{'eps': 0, 'objective/kl': 0.12809604406356812, 'objective/entropy': 36.1016845703125, 'objective/non_score_reward': -0.0012809602776542306, 'objective/rlhf_reward': 0.011170211248099804, 'objective/scores': 0.012451171875, 'policy/approxkl_avg': 0.0013699254486709833, 'policy/clipfrac_avg': 0.004716981202363968, 'loss/policy_avg': -0.000984787940979004, 'loss/value_avg': 0.0006945948116481304, 'val/clipfrac_avg': 0.0, 'policy/entropy_avg': 0.6881699562072754, 'val/ratio': 0.9974040389060974, 'val/ratio_var': 2.181750096497126e-05, 'val/num_eos_tokens': 0, 'lr': 0.0, 'episode': 8, 'epoch': 0.0}
outputs/DeepSeek-R1-Distill-Qwen-1.5B-PPO
outputs/DeepSeek-R1-Distill-Qwen-1.5B-PPO_2
torch.Size([8, 512])
{'eps': 0, 'objective/kl': -0.10009439289569855, 'objective/entropy': 34.07035827636719, 'objective/non_score_reward': 0.0010009438265115023, 'objective/rlhf_reward': 0.017968717962503433, 'objective/scores': 0.0169677734375, 'policy/approxkl_avg': 0.0009927114006131887, 'policy/clipfrac_av

KeyboardInterrupt: 

In [None]:
unknown_input.shape

In [None]:
unknown_input

In [None]:
test_test = unknown_input.mean(dim=2)
test_test.shape

In [None]:
#reward_model.score = score_fn