In [None]:
from Modified_lux3_wrapper.modified_wrappers_20250228_01 import ModifiedLuxAIS3GymEnv
import numpy as np
from Modified_stablebaseline3_PPO.modified_ppo_20250228_01 import PPO
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import os
import copy
from GreedyLRScheduler import GreedyLR
from luxai_s3.wrappers import LuxAIS3GymEnv
import gc
gc.enable()

In [None]:
torch.set_float32_matmul_precision('medium')
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 128
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.8)
torch.cuda.empty_cache()

In [None]:
init_env = ModifiedLuxAIS3GymEnv(numpy_output=True)
init_ppo = PPO("MultiInputPolicy", init_env, verbose=1)

In [None]:
model_0 = init_ppo.policy
model_1 = copy.deepcopy(model_0)

In [None]:
model_0

In [None]:
model_0.device

In [None]:
def point_gain_reward_func(reward_score) -> float:

    return reward_score * 20 if reward_score > 0.0 else -1

def match_won_reward_func(match_won) -> float:

    return 5000.0 if match_won else 0.0

def match_lost_reward_func(match_lost) -> float:

    return -3000.0 if match_lost else 0.0

def game_won_reward_func(game_won) -> float:

    return 1000000000.0 if game_won else 0.0

def game_lost_reward_func(game_lost) -> float:

    return -1000000000.0 if game_lost else 0.0

def map_reveal_reward_func(map_reveal_score):

    return map_reveal_score * 10

def attack_reward_func(actions, sap_range, enemy_unit_mask) -> float:

    attack_score = 0.0
    
    for i, action in enumerate(actions):
        action_num, dx, dy = action[0], action[1], action[2]
        if action_num >= 5:
            if enemy_unit_mask.sum() != 0:
                sap_action_range = max(abs(dx), abs(dy))
                if sap_action_range > sap_range:
                    attack_score -= 0.5
            else:
                attack_score -= 5.0
    
    return attack_score

def next_position_calculator(action_num, unit_positions):
    # 0: stay, 1: up, 2: right, 3: down, 4: left

    if action_num == 1:
        next_position = (unit_positions[0], unit_positions[1] - 1)
    elif action_num == 2:
        next_position = (unit_positions[0] + 1, unit_positions[1])
    elif action_num == 3:
        next_position = (unit_positions[0], unit_positions[1] + 1)
    elif action_num == 4:
        next_position = (unit_positions[0] - 1, unit_positions[1])
    else:
        next_position = unit_positions
    
    return next_position

def movement_reward_func(actions, obs, team_id) -> float:

    movement_score = 0.0

    for i, action in enumerate(actions):
        action_num, dx, dy = action[0], action[1], action[2]
        unit_positions = obs["units"]["position"][team_id][i]
        unit_energy = obs["units"]["energy"][team_id][i]

        # give penalty if try to move unit that doesn't exist
        if (unit_positions == (-1, -1)).sum() == 2 and action_num != 0:
            movement_score -= 0.25
        
        # give penalty if dx or dy is not 0 when not attacking
        if action_num != 5:
            if dx != 0 or dy != 0:
                movement_score -= 0.25

        
        if unit_positions[0] >= 0 and unit_positions[1] >= 0:
            # give penalty if try to move unit that has no energy
            if unit_energy <= 0 and action_num != 0:
                movement_score -= 0.25
        
        # give penalty if try to move unit out of map
        next_position = next_position_calculator(action_num, unit_positions)
        if next_position[0] < 0 or next_position[1] < 0 or next_position[0] > 23 or next_position[1] > 23:
            movement_score -= 0.5
        else:
            movement_score += 2.0
    

    return movement_score

def relic_discovery_reward_func(relic_discovery_reward) -> float:

    return relic_discovery_reward * 100

In [None]:
class TrainPPO:
    def __init__(
        self,
        model_0,
        model_1,
        num_games=1000,
        learning_rate=5e-4,
        weight_decay=0.01,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range=0.2,
        clip_range_vf=None,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
    ):
        self.model_0 = model_0
        self.model_1 = model_1
        self.num_games = num_games
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm

        self.optimizer_0 = AdamW(self.model_0.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, fused=True)
        self.optimizer_1 = AdamW(self.model_1.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, fused=True)

        self.scheduler_0 = GreedyLR(self.optimizer_0, cooldown=3, min_lr=1e-7, max_lr=5e-4)
        self.scheduler_1 = GreedyLR(self.optimizer_1, cooldown=3, min_lr=1e-7, max_lr=5e-4)

        self.env = LuxAIS3GymEnv(numpy_output=True)

        self.model_0.mlp_extractor = torch.compile(self.model_0.mlp_extractor)
        self.model_1.mlp_extractor = torch.compile(self.model_1.mlp_extractor)

    def train(self):

        for game in range(1, self.num_games + 1):
            print("="*15 + f" Game {game} Started " + "="*15)

            obs_all, info = self.env.reset()
            self.env_cfg = info['params']

            game_ended = False

            player_0_previous_score = 0.0
            player_1_previous_score = 0.0

            first_spawn = False

            self.spawn_location = np.array([[-1, -1], [-1, -1]], dtype=np.int32)

            self.map_explored_status = np.zeros((2, 24, 24), dtype=bool)

            player_0_previous_map_explored_status_score = self.map_explored_status[0].sum()
            player_1_previous_map_explored_status_score = self.map_explored_status[1].sum()

            player_0_match_won_num = 0
            player_1_match_won_num = 0

            player_0_previous_relic_discovery_points = 0
            player_1_previous_relic_discovery_points = 0

            victor = None

            game_start = True

            match_number = 1

            while game_ended is not True:

                player_0_match_won = False
                player_0_match_lost = False
                player_1_match_won = False
                player_1_match_lost = False

                player_0_game_won = False
                player_0_game_lost = False
                player_1_game_won = False
                player_1_game_lost = False

                player_0_current_score = obs_all['player_0']['team_points'][0]
                player_1_current_score = obs_all['player_1']['team_points'][1]

                player_0_reward_score = player_0_current_score - player_0_previous_score
                player_1_reward_score = player_1_current_score - player_1_previous_score

                player_0_previous_score = player_0_current_score
                player_1_previous_score = player_1_current_score

                current_match_step = obs_all["player_0"]["match_steps"]

                if current_match_step == 100:
                    if player_0_current_score > player_1_current_score:
                        player_0_match_won = True
                        player_1_match_lost = True
                        player_0_match_won_num += 1
                    elif player_0_current_score < player_1_current_score:
                        player_0_match_lost = True
                        player_1_match_won = True
                        player_1_match_won_num += 1

                if player_0_match_won_num >= 3:
                    game_ended = True
                    print("Player 0 won the game.")
                    victor = "player_0"
                    player_0_game_won = True
                    player_1_game_lost = True

                if player_1_match_won_num >= 3:
                    game_ended = True
                    print("Player 1 won the game.")
                    victor = "player_1"
                    player_0_game_lost = True
                    player_1_game_won = True

                player_0_unit_positions = np.array(obs_all['player_0']["units"]["position"][0])
                player_1_unit_positions = np.array(obs_all['player_1']["units"]["position"][1])

                player_0_unit_mask = np.array(obs_all['player_0']["units_mask"][0])
                player_1_unit_mask = np.array(obs_all['player_1']["units_mask"][1])

                player_0_available_unit_ids = np.where(player_0_unit_mask)[0]
                player_1_available_unit_ids = np.where(player_1_unit_mask)[0]

                if player_0_available_unit_ids.shape[0] == 0:
                    pass
                else:
                    if first_spawn == False:
                        player_0_first_unit_id = player_0_available_unit_ids[0]
                        player_0_first_unit_pos = player_0_unit_positions[player_0_first_unit_id]
                        self.spawn_location[0] = (player_0_first_unit_pos[0], player_0_first_unit_pos[1])
                        player_1_first_unit_id = player_1_available_unit_ids[0]
                        player_1_first_unit_pos = player_1_unit_positions[player_1_first_unit_id]
                        self.spawn_location[1] = (player_1_first_unit_pos[0], player_1_first_unit_pos[1])
                        first_spawn = True

                player_0_map_features = obs_all['player_0']['map_features']
                player_1_map_features = obs_all['player_1']['map_features']

                player_0_current_map_tile_type = player_0_map_features['tile_type'].T
                player_1_current_map_tile_type = player_1_map_features['tile_type'].T

                self.map_explored_status[0][player_0_current_map_tile_type != -1] = True
                self.map_explored_status[1][player_1_current_map_tile_type != -1] = True

                player_0_current_map_explored_status_score = self.map_explored_status[0].sum()
                player_1_current_map_explored_status_score = self.map_explored_status[1].sum()

                player_0_map_explored_status_reward = player_0_current_map_explored_status_score - player_0_previous_map_explored_status_score
                player_1_map_explored_status_reward = player_1_current_map_explored_status_score - player_1_previous_map_explored_status_score

                player_0_previous_map_explored_status_score = player_0_current_map_explored_status_score
                player_1_previous_map_explored_status_score = player_1_current_map_explored_status_score

                ### Reward caclulation
                player_0_relic_point_reward = point_gain_reward_func(player_0_reward_score)
                player_1_relic_point_reward = point_gain_reward_func(player_1_reward_score)

                player_0_match_won_reward = match_won_reward_func(player_0_match_won)
                player_0_match_lost_reward = match_lost_reward_func(player_0_match_lost)
                player_1_match_won_reward = match_won_reward_func(player_1_match_won)
                player_1_match_lost_reward = match_lost_reward_func(player_1_match_lost)

                player_0_game_won_reward = game_won_reward_func(player_0_game_won)
                player_0_game_lost_reward = game_lost_reward_func(player_0_game_lost)
                player_1_game_won_reward = game_won_reward_func(player_1_game_won)
                player_1_game_lost_reward = game_lost_reward_func(player_1_game_lost)

                player_0_map_reveal_reward = map_reveal_reward_func(player_0_map_explored_status_reward)
                player_1_map_reveal_reward = map_reveal_reward_func(player_1_map_explored_status_reward)

                ### model input
                if game_start == True:
                    player_0_model_input = self.prepare_model_input(obs_all["player_0"], 0)
                    player_1_model_input = self.prepare_model_input(obs_all["player_1"], 1)
                    game_start = False

                with torch.no_grad():
                    player_0_action_distribution, _, _ = self.model_0(player_0_model_input)
                    player_1_action_distribution, _, _ = self.model_1(player_1_model_input)

                player_0_action = copy.deepcopy(player_0_action_distribution.reshape(-1, 16, 3)).squeeze()
                player_0_action[:, 1] = player_0_action[:, 1] - 7
                player_0_action[:, 2] = player_0_action[:, 2] - 7
                player_1_action = copy.deepcopy(player_1_action_distribution.reshape(-1, 16, 3)).squeeze()
                player_1_action[:, 1] = player_1_action[:, 1] - 7
                player_1_action[:, 2] = player_1_action[:, 2] - 7

                print(player_0_action)
                print(obs_all["player_0"]["map_features"]["tile_type"].T)

                player_0_attack_reward = attack_reward_func(player_0_action, self.env_cfg["unit_sap_range"], player_1_unit_mask)
                player_1_attack_reward = attack_reward_func(player_1_action, self.env_cfg["unit_sap_range"], player_0_unit_mask)

                player_0_movement_reward = movement_reward_func(player_0_action, obs_all["player_0"], 0)
                player_1_movement_reward = movement_reward_func(player_1_action, obs_all["player_1"], 1)

                player_0_reward = player_0_relic_point_reward + player_0_match_won_reward + player_0_match_lost_reward + player_0_game_won_reward + player_0_game_lost_reward + player_0_map_reveal_reward + player_0_attack_reward + player_0_movement_reward
                player_1_reward = player_1_relic_point_reward + player_1_match_won_reward + player_1_match_lost_reward + player_1_game_won_reward + player_1_game_lost_reward + player_1_map_reveal_reward + player_1_attack_reward + player_1_movement_reward
                # player_0_reward = torch.tensor(player_0_reward, dtype=torch.float32, device="cuda")
                # player_1_reward = torch.tensor(player_1_reward, dtype=torch.float32, device="cuda")

                player_0_features = self.model_0.extract_features(player_0_model_input)
                player_1_features = self.model_1.extract_features(player_1_model_input)

                player_0_latent_pi, player_0_latent_vf = self.model_0.mlp_extractor(player_0_features)
                player_1_latent_pi, player_1_latent_vf = self.model_1.mlp_extractor(player_1_features)

                player_0_distribution = self.model_0._get_action_dist_from_latent(player_0_latent_pi)
                player_1_distribution = self.model_1._get_action_dist_from_latent(player_1_latent_pi)

                player_0_log_prob = player_0_distribution.log_prob(player_0_action_distribution)
                player_1_log_prob = player_1_distribution.log_prob(player_1_action_distribution)

                player_0_value = self.model_0.value_net(player_0_latent_vf)
                player_1_value = self.model_1.value_net(player_1_latent_vf)

                player_0_entropy = player_0_distribution.entropy()
                player_1_entropy = player_1_distribution.entropy()

                obs_all, _, _, _, _ = self.env.step({
                    "player_0": player_0_action.detach(),
                    "player_1": player_1_action.detach()
                })

                player_0_model_input = self.prepare_model_input(obs_all["player_0"], 0)
                player_1_model_input = self.prepare_model_input(obs_all["player_1"], 1)

                with torch.no_grad():
                    # Compute value for the last timestep
                    player_0_new_value = self.model_0.predict_values(player_0_model_input)  # type: ignore[arg-type]
                    player_1_new_value = self.model_1.predict_values(player_1_model_input)

                # player_0_delta = player_0_reward + self.gamma * player_0_new_value - player_0_value
                # player_0_advantage = player_0_delta + self.gamma * self.gae_lambda
                player_0_advantage = player_0_reward + self.gamma * player_0_new_value - player_0_value
                player_0_advantage = player_0_advantage.detach()
                # player_0_advantage = torch.tensor(player_0_advantage, dtype=torch.float32, device="cuda")
                player_0_return = player_0_advantage + player_0_value

                # player_1_delta = player_1_reward + self.gamma * player_1_new_value - player_1_value
                # player_1_advantage = player_1_delta + self.gamma * self.gae_lambda
                player_1_advantage = player_1_reward + self.gamma * player_1_new_value - player_1_value
                player_1_advantage = player_1_advantage.detach()
                # player_1_advantage = torch.tensor(player_1_advantage, dtype=torch.float32, device="cuda")
                player_1_return = player_1_advantage + player_1_value

                player_0_policy_loss_1 = player_0_advantage
                player_0_policy_loss_2 = player_0_advantage * torch.clamp(torch.tensor(1), 1 - self.clip_range, 1 + self.clip_range)
                player_0_policy_loss = -torch.min(player_0_policy_loss_1, player_0_policy_loss_2).mean()

                player_1_policy_loss_1 = player_1_advantage
                player_1_policy_loss_2 = player_1_advantage * torch.clamp(torch.tensor(1), 1 - self.clip_range, 1 + self.clip_range)
                player_1_policy_loss = -torch.min(player_1_policy_loss_1, player_1_policy_loss_2).mean()

                if self.clip_range_vf is None:
                    player_0_values_pred = player_0_new_value
                    player_1_values_pred = player_1_new_value
                else:
                    player_0_values_pred = player_0_value + torch.clamp(player_0_new_value - player_0_value, -self.clip_range_vf, self.clip_range_vf)
                    player_1_values_pred = player_1_value + torch.clamp(player_1_new_value - player_1_value, -self.clip_range_vf, self.clip_range_vf)

                player_0_value_loss = F.mse_loss(player_0_return, player_0_values_pred)
                player_1_value_loss = F.mse_loss(player_1_return, player_1_values_pred)

                player_0_entropy_loss = -torch.mean(-player_0_entropy)
                player_1_entropy_loss = -torch.mean(-player_1_entropy)

                player_0_loss = player_0_policy_loss + self.ent_coef * player_0_entropy_loss + self.vf_coef * player_0_value_loss
                player_1_loss = player_1_policy_loss + self.ent_coef * player_1_entropy_loss + self.vf_coef * player_1_value_loss

                self.optimizer_0.zero_grad()
                player_0_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model_0.parameters(), self.max_grad_norm)
                self.optimizer_0.step()
                self.scheduler_0.step(player_0_loss.item())

                self.optimizer_1.zero_grad()
                player_1_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model_1.parameters(), self.max_grad_norm)
                self.optimizer_1.step()
                self.scheduler_1.step(player_1_loss.item())

                if match_number >= 5 and current_match_step == 100:
                    game_ended = True
                    print("Game ended.")

                if current_match_step == 100:
                    match_number += 1

            if victor == "player_0":
                self.synchronize_models(self.model_0, self.model_1)
            elif victor == "player_1":
                self.synchronize_models(self.model_1, self.model_0)

            torch.cuda.empty_cache()
            gc.collect()

                

        return


    def prepare_model_input(self, obs, my_team_id):
        enemy_team_id = 1 - my_team_id

        model_input = {
            "enemy_energies": obs["units"]["energy"][enemy_team_id],
            "enemy_positions": obs["units"]["position"][enemy_team_id],
            "enemy_spawn_location": self.spawn_location[enemy_team_id],
            "enemy_visible_mask": obs["units_mask"][enemy_team_id],
            "map_explored_status": self.map_explored_status[my_team_id],
            "map_features_energy": obs["map_features"]["energy"],
            "map_features_tile_type": obs["map_features"]["tile_type"],
            "match_steps": np.array([obs["match_steps"]]),
            "my_spawn_location": self.spawn_location[my_team_id],
            "relic_nodes": obs["relic_nodes"],
            "relic_nodes_mask": obs["relic_nodes_mask"],
            "sensor_mask": obs["sensor_mask"],
            "steps": np.array([obs["steps"]]),
            "team_id": np.array([my_team_id]),
            "team_points": obs["team_points"],
            "team_wins": obs["team_wins"],
            "unit_active_mask": obs["units_mask"][my_team_id],
            "unit_energies": obs["units"]["energy"][my_team_id],
            "unit_move_cost": np.array([self.env_cfg["unit_move_cost"]]),
            "unit_positions": obs["units"]["position"][my_team_id],
            "unit_sap_cost": np.array([self.env_cfg["unit_sap_cost"]]),
            "unit_sap_range": np.array([self.env_cfg["unit_sap_range"]]),
            "unit_sensor_range": np.array([self.env_cfg["unit_sensor_range"]]),
        }

        model_input = {k: torch.tensor(np.expand_dims(v, axis=0), dtype=torch.int32, device="cuda") for k, v in model_input.items()}

        return model_input
    
    def synchronize_models(self, winner_model, loser_model):
        with torch.no_grad():
            for p1, p2 in zip(winner_model.parameters(), loser_model.parameters()):
                p2.data.copy_(p1.data)



        


In [None]:
trainer = TrainPPO(model_0, model_1)

In [None]:
trainer.train()

In [None]:
model.policy.parameters()

In [None]:
obs_tensor = model.policy.extract_features(obs)
obs_tensor

In [None]:
obs_tensor.shape

In [None]:
obs = {
    "enemy_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "enemy_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "enemy_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "enemy_visible_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "map_explored_status": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "map_features_energy": np.random.randint(-7, 10, size=(1, 24, 24), dtype=np.int32),
    "map_features_tile_type": np.random.randint(-1, 3, size=(1, 24, 24), dtype=np.int32),
    "match_steps": np.random.randint(0, 101, size=(1, 1,), dtype=np.int32),
    "my_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "relic_nodes": np.random.randint(-1, 24, size=(1, 6, 2), dtype=np.int32),
    "relic_nodes_mask": np.random.randint(0, 2, size=(1, 6,), dtype=np.int32),
    "sensor_mask": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "steps": np.random.randint(0, 506, size=(1, 1,), dtype=np.int32),
    "team_id": np.random.randint(0, 2, size=(1, 1,), dtype=np.int32),
    "team_points": np.random.randint(0, 2501, size=(1, 2,), dtype=np.int32),
    "team_wins": np.random.randint(0, 4, size=(1, 2,), dtype=np.int32),
    "unit_active_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "unit_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "unit_move_cost": np.random.randint(1, 6, size=(1, 1, ), dtype=np.int32),
    "unit_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "unit_sap_cost": np.random.randint(30, 51, size=(1, 1, ), dtype=np.int32),
    "unit_sap_range": np.random.randint(3, 8, size=(1, 1, ), dtype=np.int32),
    "unit_sensor_range": np.random.randint(2, 5, size=(1, 1, ), dtype=np.int32),
}

In [None]:
obs = {k: torch.tensor(v, dtype=torch.float32, device="cuda") for k, v in obs.items()}

# Convert observation to tensor and check shape
obs_tensor = model.policy.extract_features(obs)
print(f"Extracted Feature Shape: {obs_tensor.shape}")  # Expected: (batch_size, 2464)

In [None]:
with torch.no_grad():
    action_distribution, value, log = model.policy.forward(obs)

In [None]:
action_distribution.shape

In [None]:
action_distribution

In [None]:
value.shape

In [None]:
value

In [None]:
log.shape

In [None]:
log

In [None]:
actions = action_distribution.reshape(-1, 16, 3)
actions

In [None]:
actions.shape

In [None]:
model.policy.mlp_extractor = torch.compile(model.policy.mlp_extractor)

In [None]:
with torch.no_grad():
    action_distribution, value, log = model.policy.forward(obs)

In [None]:
actions = action_distribution.reshape(-1, 16, 3)
actions

In [None]:
%pip install --upgrade luxai-s3