In [None]:
from Modified_lux3_wrapper.modified_wrappers_20250228_01 import ModifiedLuxAIS3GymEnv
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.optim import AdamW
import os
# import copy
# from GreedyLRScheduler import GreedyLR
# from luxai_s3.wrappers import LuxAIS3GymEnv
import gc
gc.enable()
# from stable_baselines3.common.buffers import DictRolloutBuffer
# from tqdm.notebook import tqdm
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import MultiInputActorCriticPolicy
from gymnasium import spaces

In [None]:
torch.set_float32_matmul_precision('high')
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 128
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.8)
torch.cuda.empty_cache()

In [None]:
learning_rate = 1e-6
env = ModifiedLuxAIS3GymEnv(numpy_output=True)
model = PPO("MultiInputPolicy", env, verbose=1, learning_rate=learning_rate, ent_coef=0.015, vf_coef=0.75, clip_range_vf=0.15, clip_range=0.2, n_steps=505, batch_size=101, max_grad_norm=0.5)

In [None]:
def normalize_observation(obs: dict, obs_space: spaces.Dict) -> dict:
    """
    Normalize continuous features in the observation dict using min-max scaling,
    while leaving discrete or binary features unchanged.
    """
    norm_obs = {}
    for key, space in obs_space.spaces.items():
        value = obs[key]
        # For Box spaces with numeric types (and not MultiBinary)
        if isinstance(space, spaces.Box) and np.issubdtype(space.dtype, np.number):
            # If the range is [0,1] (or binary), assume it's already normalized
            if (space.low == 0).all() and (space.high == 1).all():
                norm_obs[key] = value
            else:
                # Convert to float and apply min-max normalization:
                # norm = (value - low) / (high - low)
                low = torch.tensor(space.low, device=value.device, dtype=torch.float32)
                high = torch.tensor(space.high, device=value.device, dtype=torch.float32)
                # print(value)
                norm_obs[key] = (value.to(dtype=torch.float32) - low) / (high - low + 1e-8)
        else:
            # For discrete or MultiBinary spaces, just copy the values
            norm_obs[key] = value
    return norm_obs

In [None]:
class CustomFeatureExtractor(BaseFeaturesExtractor):
    """
    Custom feature extractor that:
    - Processes 24x24 grid features using CNN.
    - Flattens and concatenates other features.
    """

    def __init__(self, observation_space: spaces.Dict, features_dim: int = 0):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)

        # Identify 24x24 grid features
        self.grid_features = ["map_explored_status", "map_features_energy", "map_features_tile_type", "sensor_mask"]
        self.features = []
        for key in observation_space.keys():
            self.features.append(key)

        # **CNN for 24x24 Grid Features** (Expects input shape [batch, channels, 24, 24])
        self.cnn_extractor = nn.Sequential(
            nn.Conv2d(len(self.grid_features), 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.SiLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.Flatten(),
            nn.Dropout(0.1),
        )

        self.cnn_extractor = torch.compile(self.cnn_extractor)

        # Compute CNN output dimension (using a dummy input)
        dummy_input = torch.zeros((1, len(self.grid_features), 24, 24))
        cnn_output_dim = self.cnn_extractor(dummy_input).shape[1]

        # **Flatten layers for non-grid features**
        self.extractors = nn.ModuleDict()
        flatten_dim = 0

        for key in observation_space.keys():
            # print(key)
            space_shape = observation_space.spaces[key].shape
            self.extractors[key] = nn.Flatten()
            flatten_dim += torch.prod(torch.tensor(space_shape)).item()

        # Compute total feature dimension
        self._features_dim = cnn_output_dim + flatten_dim

    def forward(self, observations):
        """
        Forward pass:
        - Grid features go through CNN
        - Other features are flattened
        - Both are concatenated into a single tensor
        """
        # observations = normalize_observation(observations, model.observation_space)

        grid_stack = torch.stack([observations[key] for key in self.grid_features], dim=1).float()
        grid_features = self.cnn_extractor(grid_stack)

        # Flatten vector features
        features = torch.cat([self.extractors[key](observations[key]) for key in self.features], dim=1)

        combined_features = torch.cat([grid_features, features], dim=1)

        return combined_features

In [None]:
policy_kwargs = dict(
    features_extractor_class=CustomFeatureExtractor,
    features_extractor_kwargs=dict(features_dim=20897),
    activation_fn=nn.SiLU,
    # net_arch=dict(pi=[8192, 4096, 2048, 1024], vf=[8192, 4096, 2048, 1024, 512, 256, 128, 64]),
    net_arch=dict(pi=[4096, 2048, 1024], vf=[4096, 2048, 1024, 512, 256, 128]),
    # net_arch=dict(pi=[128, 64], vf=[128, 64]),
)
env = ModifiedLuxAIS3GymEnv(numpy_output=True)
learning_rate = 6e-4
model = PPO(
    "MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=0, learning_rate=learning_rate, ent_coef=0.04, vf_coef=0.75, clip_range_vf=0.2, clip_range=0.3, n_steps=505, batch_size=505,
    max_grad_norm=0.5, n_epochs=15, save_dir="saved_policies/", tensorboard_log="logs/", gamma=0.99, target_kl=None, gae_lambda=0.95, load_models=True
)

In [None]:
model.policy.parameters()

In [None]:
total_params = sum(p.numel() for p in model.policy.parameters())
print(f"Number of parameters: {total_params}")

In [None]:
import copy
policy1 = copy.deepcopy(model.policy)

In [None]:
for name, param in policy1.named_parameters():
    print(name, param.dtype)

In [None]:
# Get observation space
obs_space = model.policy.observation_space

# Create dummy inputs
dummy_input = {
    key: torch.tensor(np.zeros(space.shape, dtype=np.float32)).to("cuda", dtype=torch.float32).unsqueeze(0) for key, space in obs_space.spaces.items()
}

# If the model needs a single tensor, flatten and concatenate everything
flat_input = torch.cat([v.flatten() for v in dummy_input.values()]).unsqueeze(0)  # Add batch dim

In [None]:
dummy_output = policy1(dummy_input)
print(type(dummy_output))

In [None]:
torch.save(policy1.state_dict(), "policy1_after_quantization.pth")

In [None]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
# Export ONNX
torch.onnx.export(policy1, (dummy_input,), "model.onnx")

In [None]:
dummy_input = torch.randn(1, model.policy.observation_space.shape)

In [None]:
model.policy.observation_space

In [None]:
torch.onnx.export(policy1, )

In [None]:
torch.save(model.policy.state_dict(), f"policy1.pth")
torch.save(model.policy_2.state_dict(), f"policy2.pth")

In [None]:
model.policy

In [None]:
model.policy_class.__init__.__annotations__

In [None]:
learn_results = model.learn(total_timesteps=5050000, progress_bar=True)

In [None]:
import copy
model.policy_2 = copy.deepcopy(model.policy)
model.policy_2

In [None]:
model.policy.load_state_dict(torch.load("saved_policies_20250307_06/policy_200.pth"))
model.policy_2.load_state_dict(torch.load("saved_policies_20250307_06/policy_2_200.pth"))

In [None]:
torch.save(model.policy.state_dict(), "saved_policies/ppo_policy_20250306_01.pth")
torch.save(model.policy_2.state_dict(), "saved_policies/ppo_policy_2_20250306_01.pth")

In [None]:
class CustomActivation(nn.Module):
    def __init__(self, dropout_prob=0.1):
        super(CustomActivation, self).__init__()
        self.silu = nn.SiLU()
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        x = self.silu(x)
        x = self.dropout(x)
        return x

In [None]:
# Number of parallel environments (adjust based on CPU cores)
NUM_ENVS = 2

def make_env():
    return ModifiedLuxAIS3GymEnv(numpy_output=True)  # Use your custom environment

env = SubprocVecEnv([lambda: make_env() for _ in range(NUM_ENVS)])

In [None]:
model.action_space

In [None]:
from stable_baselines3.common.torch_layers import MlpExtractor

class CustomMlpExtractor(MlpExtractor):
    def __init__(self, feature_dim):
        super().__init__(feature_dim, net_arch=[4096, 1024, 256])

        # Redefine policy_net with Dropout & LayerNorm
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, 2048),
            nn.SiLU(),
            nn.LayerNorm(2048),
            nn.Dropout(0.2),
            nn.Linear(2048, 1024),
            nn.SiLU(),
            nn.LayerNorm(1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.SiLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.SiLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.2),
        )

        # Redefine value_net with Dropout & LayerNorm
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, 2048),
            nn.SiLU(),
            nn.LayerNorm(2048),
            nn.Dropout(0.2),
            nn.Linear(2048, 1024),
            nn.SiLU(),
            nn.LayerNorm(1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.SiLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, 64),
            nn.SiLU(),
            nn.LayerNorm(64),
            nn.Dropout(0.2),
        )

    def forward(self, x):
        return self.policy_net(x), self.value_net(x)

In [None]:
model.policy.mlp_extractor.policy_net[2] = nn.Dropout(0.2)
model.policy.mlp_extractor.policy_net[3] = nn.Linear(4096, 1024)
model.policy.mlp_extractor.policy_net[4] = nn.SiLU()
model.policy.mlp_extractor.policy_net[5] = nn.Dropout(0.2)
model.policy.mlp_extractor.policy_net[6] = nn.Linear(1024, 512)
model.policy.mlp_extractor.policy_net[7] = nn.SiLU()
model.policy.mlp_extractor.policy_net.add_module("8", nn.Dropout(0.2))
model.policy.mlp_extractor.policy_net.add_module("9", nn.Linear(512, 256))
model.policy.mlp_extractor.policy_net.add_module("10", nn.SiLU())
model.policy.mlp_extractor.policy_net.add_module("11", nn.Dropout(0.2))
model.policy.mlp_extractor.value_net[2] = nn.Dropout(0.2)
model.policy.mlp_extractor.value_net[3] = nn.Linear(4096, 1024)
model.policy.mlp_extractor.value_net[4] = nn.SiLU()
model.policy.mlp_extractor.value_net[5] = nn.Dropout(0.2)
model.policy.mlp_extractor.value_net[6] = nn.Linear(1024, 512)
model.policy.mlp_extractor.value_net[7] = nn.SiLU()
model.policy.mlp_extractor.value_net.add_module("8", nn.Dropout(0.2))
model.policy.mlp_extractor.value_net.add_module("9", nn.Linear(512, 128))
model.policy.mlp_extractor.value_net.add_module("10", nn.SiLU())
model.policy.mlp_extractor.value_net.add_module("11", nn.Dropout(0.2))
model.policy.mlp_extractor.value_net.add_module("12", nn.Linear(128, 32))
model.policy.mlp_extractor.value_net.add_module("13", nn.SiLU())
model.policy.mlp_extractor.value_net.add_module("14", nn.Dropout(0.2))

In [None]:
model.policy

In [None]:
model.policy_2.mlp_extractor.policy_net[2] = nn.Dropout(0.2)
model.policy_2.mlp_extractor.policy_net[3] = nn.Linear(4096, 1024)
model.policy_2.mlp_extractor.policy_net[4] = nn.SiLU()
model.policy_2.mlp_extractor.policy_net[5] = nn.Dropout(0.2)
model.policy_2.mlp_extractor.policy_net[6] = nn.Linear(1024, 512)
model.policy_2.mlp_extractor.policy_net[7] = nn.SiLU()
model.policy_2.mlp_extractor.policy_net.add_module("8", nn.Dropout(0.2))
model.policy_2.mlp_extractor.policy_net.add_module("9", nn.Linear(512, 256))
model.policy_2.mlp_extractor.policy_net.add_module("10", nn.SiLU())
model.policy_2.mlp_extractor.policy_net.add_module("11", nn.Dropout(0.2))
model.policy_2.mlp_extractor.value_net[2] = nn.Dropout(0.2)
model.policy_2.mlp_extractor.value_net[3] = nn.Linear(4096, 1024)
model.policy_2.mlp_extractor.value_net[4] = nn.SiLU()
model.policy_2.mlp_extractor.value_net[5] = nn.Dropout(0.2)
model.policy_2.mlp_extractor.value_net[6] = nn.Linear(1024, 512)
model.policy_2.mlp_extractor.value_net[7] = nn.SiLU()
model.policy_2.mlp_extractor.value_net.add_module("8", nn.Dropout(0.2))
model.policy_2.mlp_extractor.value_net.add_module("9", nn.Linear(512, 128))
model.policy_2.mlp_extractor.value_net.add_module("10", nn.SiLU())
model.policy_2.mlp_extractor.value_net.add_module("11", nn.Dropout(0.2))
model.policy_2.mlp_extractor.value_net.add_module("12", nn.Linear(128, 32))
model.policy_2.mlp_extractor.value_net.add_module("13", nn.SiLU())
model.policy_2.mlp_extractor.value_net.add_module("14", nn.Dropout(0.2))

In [None]:
model.policy.features_dim

In [None]:
model.policy

In [None]:
class CustomFeatureExtractor(BaseFeaturesExtractor):
    """
    Custom feature extractor that:
    - Processes 24x24 grid features using CNN.
    - Flattens and concatenates other features.
    """

    def __init__(self, observation_space: spaces.Dict, features_dim: int = model.policy.features_dim):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)

        # Identify 24x24 grid features
        self.grid_features = ["map_explored_status", "map_features_energy", "map_features_tile_type", "sensor_mask"]

        # Identify 1D and 2D features (excluding grid)
        # self.scalar_features = []
        self.vector_features = []
        for key, space in observation_space.spaces.items():
            if key in self.grid_features:
                continue  # Grid features are processed separately
            # elif space.shape == ():  # Scalar value (e.g., team_id)
            #     self.scalar_features.append(key)
            elif len(space.shape) == 1:  # 1D vector (e.g., enemy_energies)
                self.vector_features.append(key)
            elif len(space.shape) == 2:  # 2D tensor (e.g., enemy_positions)
                self.vector_features.append(key)  # Flattened separately

        # **CNN for 24x24 Grid Features** (Expects input shape [batch, channels, 24, 24])
        self.cnn_extractor = nn.Sequential(
            nn.Conv2d(len(self.grid_features), 16, kernel_size=3, stride=1, padding=1),
            nn.SiLU(),
            nn.Dropout(0.2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.SiLU(),
            nn.Flatten()
        )

        # Compute CNN output dimension (using a dummy input)
        dummy_input = torch.zeros((1, len(self.grid_features), 24, 24))
        cnn_output_dim = self.cnn_extractor(dummy_input).shape[1]

        # **Flatten layers for non-grid features**
        self.extractors = nn.ModuleDict()
        vector_dim = 0

        for key in self.vector_features:
            # print(key)
            space_shape = observation_space.spaces[key].shape
            self.extractors[key] = nn.Flatten()
            vector_dim += torch.prod(torch.tensor(space_shape)).item()

        # Scalar features are just concatenated directly
        # scalar_dim = len(self.scalar_features)

        # Compute total feature dimension
        self._features_dim = cnn_output_dim + vector_dim # + scalar_dim

    def forward(self, observations):
        """
        Forward pass:
        - Grid features go through CNN
        - Other features are flattened
        - Both are concatenated into a single tensor
        """

        # print("--- Feature Extractor Forward Pass ---") # Separator
        # print("Input Observations (first element of batch):\n", observations) # Print input observations (first batch element)
        observations = normalize_observation(observations, model.observation_space)
        # print(observations)



        grid_stack = torch.stack([observations[key] for key in self.grid_features], dim=1).float()
        grid_features = self.cnn_extractor(grid_stack)

        # Flatten vector features
        vector_features = torch.cat([self.extractors[key](observations[key]) for key in self.vector_features], dim=1)

        combined_features = torch.cat([grid_features, vector_features], dim=1)

        # print("Output Features (first element of batch):\n", combined_features) # Print output features (first batch element)

        return combined_features
    

class CustomMlpExtractor(nn.ModuleDict):
    def __init__(self, feature_dim):
        super().__init__()

        reduced_dim = 1024  # Reduce 92,321 → 1024
        # 18593
        # 8192
        self.feature_reduction = nn.Sequential(
            nn.Linear(feature_dim, 8192),
            nn.SiLU(),
            nn.LayerNorm(8192),
            nn.Dropout(0.2),
            nn.Linear(8192, 4096),
            nn.SiLU(),
            nn.LayerNorm(4096),
            nn.Dropout(0.2),
            nn.Linear(4096, 2048),
            nn.SiLU(),
            nn.LayerNorm(2048),
            nn.Dropout(0.2),
            nn.Linear(2048, reduced_dim),
            nn.SiLU(),
            nn.LayerNorm(reduced_dim),  # **LayerNorm for stability**
            nn.Dropout(0.2)
        )

        self.policy_net = nn.Sequential(
            nn.Linear(reduced_dim, 512),
            nn.SiLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.SiLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.SiLU()
        )

        self.value_net = nn.Sequential(
            nn.Linear(reduced_dim, 512),
            nn.SiLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.SiLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.SiLU()
        )
    
    def forward(self, x):
        x = self.feature_reduction(x)
        latent_pi = self.policy_net(x)
        latent_vf = self.value_net(x)
        return latent_pi, latent_vf

    def forward_actor(self, x):
        x = self.feature_reduction(x)
        return self.policy_net(x)

    def forward_critic(self, x):
        x = self.feature_reduction(x)
        return self.value_net(x)


class CustomMultiInputPolicy(MultiInputActorCriticPolicy):
    """
    Custom MultiInput Policy that:
    - Uses CNN for spatial features.
    - Uses MLP for non-spatial features.
    """

    def __init__(self, *args, **kwargs):
        super(CustomMultiInputPolicy, self).__init__(
            *args, **kwargs,
            features_extractor_class=CustomFeatureExtractor,
        )
        print(self.features_extractor._features_dim)
        self.mlp_extractor = CustomMlpExtractor(self.features_extractor._features_dim)

        # Output layers
        # self.action_net = nn.Linear(128, 16*6 + 16*2*15)  # Adjust for action space
        self.action_net = nn.Sequential(
            nn.Linear(128, 256),
            nn.SiLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.2),
            nn.Linear(256, 16*6 + 16*2*15)
        )
        # self.value_net = nn.Linear(128, 1)
        self.value_net = nn.Sequential(
            nn.Linear(128, 64),
            nn.SiLU(),
            nn.LayerNorm(64),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )

    def forward(self, obs, *args, **kwargs):
        """
        Forward pass through the policy.
        """
        # obs = normalize_observation(obs, model.observation_space)
        features = self.features_extractor(obs)
        policy_features = self.mlp_extractor.forward_actor(features)
        value_features = self.mlp_extractor.forward_critic(features)

        # Get logits for discrete action space
        logits = self.action_net(policy_features) #/ 10.0  # Divide by 10 for stability
        # logits = torch.tanh(logits) / 20.0
        logits = logits / 20.0

        action_logits = logits[:, :16*6].view(-1, 16, 6)
        dxdy_logits = logits[:, 16*6:].view(-1, 16, 2, 15)
        dx_logits = dxdy_logits[:, :, 0, :]
        dy_logits = dxdy_logits[:, :, 1, :]

        # print("--- Policy Forward Pass ---") # Add a separator for clarity
        # print("Action Logits (before softmax):\n", action_logits) # Print raw logits
        # print("DX Logits (before softmax):\n", dx_logits) # Print raw logits
        # print("DY Logits (before softmax):\n", dy_logits) # Print raw logits

        action_probs = F.softmax(action_logits, dim=-1)
        dx_probs = F.softmax(dx_logits, dim=-1)
        dy_probs = F.softmax(dy_logits, dim=-1)

        # print("Action Probs (after softmax):\n", action_probs) # Print probabilities
        # print("DX Probs (after softmax):\n", dx_probs) # Print probabilities
        # print("DY Probs (after softmax):\n", dy_probs) # Print probabilities

        actions_dist = torch.distributions.Categorical(probs=action_probs)
        dx_dist = torch.distributions.Categorical(probs=dx_probs)
        dy_dist = torch.distributions.Categorical(probs=dy_probs)

        # print("Action Entropy:", actions_dist.entropy().mean()) # Print Entropy
        # print("DX Entropy:", dx_dist.entropy().mean()) # Print Entropy
        # print("DY Entropy:", dy_dist.entropy().mean()) # Print Entropy

        actions = actions_dist.sample()
        dx = dx_dist.sample()
        dy = dy_dist.sample()

        zeros = torch.zeros((actions.shape[0], 16, 3), dtype=actions.dtype, device=actions.device)
        zeros[:, :, 0] = actions
        sap_mask = zeros == 5
        sap_mask_dxdy = sap_mask[:, :, 0]
        batch_idx, unit_idx = sap_mask_dxdy.nonzero(as_tuple=True)

        zeros[batch_idx, unit_idx, 1] = dx[batch_idx, unit_idx]
        zeros[batch_idx, unit_idx, 2] = dy[batch_idx, unit_idx]

        # ---- Compute log_probs ----
        actions_log_probs = actions_dist.log_prob(actions)  # (batch_size, 16)
        dx_log_probs = dx_dist.log_prob(dx)        # (batch_size, 16)
        dy_log_probs = dy_dist.log_prob(dy)        # (batch_size, 16)

        # Apply SAP mask to sum only dx/dy log_probs where action == 5
        dxdy_log_probs = torch.zeros_like(actions_log_probs)  # Initialize to zeros
        dxdy_log_probs[batch_idx, unit_idx] = dx_log_probs[batch_idx, unit_idx] + dy_log_probs[batch_idx, unit_idx]

        total_log_probs = actions_log_probs + dxdy_log_probs  # Final log probability per unit

        return zeros, self.value_net(value_features), total_log_probs.sum(dim=-1)

In [None]:
custom_policy = CustomMultiInputPolicy(model.observation_space, model.action_space, model.lr_schedule).to("cuda")
custom_policy_2 = CustomMultiInputPolicy(model.observation_space, model.action_space, model.lr_schedule).to("cuda")

In [None]:
from torch.optim.lr_scheduler import _LRScheduler

class GreedyLR(_LRScheduler):
    def __init__(self, optimizer, factor=0.1, patience=10, cooldown=0, warmup=0, 
                 min_lr=0, max_lr=10, smooth=False, window=5, reset=None):
        self.factor = factor
        self.patience = patience
        self.cooldown = cooldown
        self.warmup = warmup
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.smooth = smooth
        self.window = window
        self.reset = reset
        
        self.best_loss = float('inf')
        self.warmup_counter = 0
        self.cooldown_counter = 0
        self.num_good_epochs = 0
        self.num_bad_epochs = 0
        self.loss_window = []
        
        super().__init__(optimizer)

    def get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

    def step(self, metrics=None):
        if metrics is not None:
            current_lr = self.get_lr()[0]
            
            if self.smooth:
                self.loss_window.append(metrics)
                if len(self.loss_window) > self.window:
                    self.loss_window.pop(0)
                metrics = sum(self.loss_window) / len(self.loss_window)
            
            if metrics < self.best_loss:
                self.best_loss = metrics
                self.num_good_epochs += 1
                self.num_bad_epochs = 0
            else:
                self.num_good_epochs = 0
                self.num_bad_epochs += 1
            
            if self.warmup_counter < self.warmup:
                self.warmup_counter += 1
                new_lr = min(current_lr / self.factor, self.max_lr)
            elif self.cooldown_counter < self.cooldown:
                self.cooldown_counter += 1
                new_lr = max(current_lr * self.factor, self.min_lr)
            elif self.num_good_epochs >= self.patience:
                new_lr = min(current_lr / self.factor, self.max_lr)
                self.cooldown_counter = 0
            elif self.num_bad_epochs >= self.patience:
                new_lr = max(current_lr * self.factor, self.min_lr)
                self.warmup_counter = 0
            else:
                new_lr = current_lr
            
            new_lr = max(self.min_lr, min(new_lr, self.max_lr))
            
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = new_lr
            
            if self.reset and self.last_epoch % self.reset == 0:
                self.best_loss = float('inf')
                self.warmup_counter = 0
                self.cooldown_counter = 0
                self.num_good_epochs = 0
                self.num_bad_epochs = 0
                self.loss_window = []
            
            self.last_epoch += 1
            return [new_lr]
        else:
            return self.get_lr()

In [None]:
max_lr = 5e-3
min_lr = 1e-12

# custom_policy.optimizer = torch.optim.AdamW(custom_policy.parameters(), lr=learning_rate, fused=True)  # type: ignore[call-arg]
custom_policy.optimizer = torch.optim.Adam(custom_policy.parameters(), lr=learning_rate, fused=True)  # type: ignore[call-arg]
custom_policy.scheduler = GreedyLR(custom_policy.optimizer, factor=1, patience=10, cooldown=3, warmup=0, min_lr=min_lr, max_lr=max_lr, smooth=False, window=5, reset=None)

# custom_policy_2.optimizer = torch.optim.AdamW(custom_policy_2.parameters(), lr=learning_rate, fused=True)  # type: ignore[call-arg]
custom_policy_2.optimizer = torch.optim.Adam(custom_policy_2.parameters(), lr=learning_rate, fused=True)  # type: ignore[call-arg]
custom_policy_2.scheduler = GreedyLR(custom_policy_2.optimizer, factor=1, patience=10, cooldown=3, warmup=0, min_lr=min_lr, max_lr=max_lr, smooth=False, window=5, reset=None)

In [None]:
custom_policy.train()
custom_policy_2.train()

for module in custom_policy.children():
    module.train(True)

for module in custom_policy_2.children():
    module.train(True)

In [None]:
model.policy = custom_policy
model.policy_2 = custom_policy_2

In [None]:
model.policy

In [None]:
model.policy.optimizer

In [None]:
learn_results = model.learn(total_timesteps=10000, progress_bar=True)

In [None]:
model.policy.optimizer

In [None]:
model.policy.optimizer.param_groups[0]["lr"]

In [None]:
for k, v in model.observation_space.items():
    print(k, v)
    print(v.shape)

In [None]:
len(model.observation_space["team_id"].shape)

In [None]:
len(model.observation_space["team_id"].shape)

In [None]:
torch.as_tensor(np.array(model.observation_space["map_explored_status"]))

In [None]:
dir(model.observation_space["map_explored_status"])

In [None]:
model.observation_space["map_explored_status"]

In [None]:
obs = {
    "enemy_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "enemy_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "enemy_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "enemy_visible_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "map_explored_status": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "map_features_energy": np.random.randint(-7, 10, size=(1, 24, 24), dtype=np.int32),
    "map_features_tile_type": np.random.randint(-1, 3, size=(1, 24, 24), dtype=np.int32),
    "match_steps": np.random.randint(0, 101, size=(1, 1,), dtype=np.int32),
    "my_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "relic_nodes": np.random.randint(-1, 24, size=(1, 6, 2), dtype=np.int32),
    "relic_nodes_mask": np.random.randint(0, 2, size=(1, 6,), dtype=np.int32),
    "sensor_mask": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "steps": np.random.randint(0, 506, size=(1, 1,), dtype=np.int32),
    "team_id": np.random.randint(0, 2, size=(1, 1,), dtype=np.int32),
    "team_points": np.random.randint(0, 2501, size=(1, 2,), dtype=np.int32),
    "team_wins": np.random.randint(0, 4, size=(1, 2,), dtype=np.int32),
    "unit_active_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "unit_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "unit_move_cost": np.random.randint(1, 6, size=(1, 1, ), dtype=np.int32),
    "unit_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "unit_sap_cost": np.random.randint(30, 51, size=(1, 1, ), dtype=np.int32),
    "unit_sap_range": np.random.randint(3, 8, size=(1, 1, ), dtype=np.int32),
    "unit_sensor_range": np.random.randint(2, 5, size=(1, 1, ), dtype=np.int32),
}

obs = {k: torch.tensor(v, dtype=torch.float32, device="cuda") for k, v in obs.items()}

In [None]:
test_output = test_model2(obs)

In [None]:
test_output

In [None]:
test_output[2].sum(dim=-1)

In [None]:
len(test_output)

In [None]:
test_output[0]

In [None]:
test_output[1]

In [None]:
test_output[2]

In [None]:
test_output

In [None]:
test_output

In [None]:
len(test_output)

In [None]:
test_output[0].shape

In [None]:
test_output[1].shape

In [None]:
test_output[2].shape

In [None]:
test_output2 = model.policy(obs)
test_output2

In [None]:
test_model2.action_space

In [None]:
model.policy.action_space

In [None]:
prac_module_dict = nn.ModuleDict(
    {
        "policy_net": nn.Sequential(
            nn.Linear(2466, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        ),
        "value_net": nn.Sequential(
            nn.Linear(2466, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        ),
        "something_else": nn.Linear(55, 6969)
    }
)

In [None]:
prac_module_dict

In [None]:
model.policy.features_extractor.extractors.enemy_energies = prac_module_dict

In [None]:
model.policy

In [None]:
prac_module_dict

In [None]:
new_mlp = nn.Sequential(
    nn.Linear(2466, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU()
)

In [None]:
model.policy.mlp_extractor.policy_net = new_mlp
model.policy.mlp_extractor.value_net = new_mlp

In [None]:
model.policy

In [None]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch as th
from gymnasium import spaces

class CustomExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        super().__init__(observation_space, features_dim=512)

        # Define CNN for grid-based inputs
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

    def forward(self, observations):
        map_features = self.cnn(observations["map_features_tile_type"].unsqueeze(1))
        return th.cat([map_features, observations["unit_positions"].flatten(1)], dim=1)


# Replace the feature extractor
model.policy.features_extractor = CustomExtractor(model.policy.observation_space)


In [None]:
model.policy

In [None]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from gymnasium import spaces
import numpy as np


class CustomCNNExtractor(BaseFeaturesExtractor):
    """
    CNN Feature Extractor for spatial inputs (map-based features).
    """
    def __init__(self, observation_space: spaces.Dict, features_dim=128):
        super().__init__(observation_space, features_dim)
        
        # CNN for 2D map-like inputs (assuming 24x24 grid)
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        # Compute CNN output size dynamically
        with th.no_grad():
            dummy_input = th.zeros(1, 3, 24, 24)
            cnn_out_size = self.cnn(dummy_input).shape[1]

        # MLP for non-spatial inputs
        self.mlp = nn.Sequential(
            nn.Linear(50, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        
        # Multi-Head Self-Attention for units
        self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=4)
        
        # Final feature size
        self.final_linear = nn.Linear(cnn_out_size + 64, features_dim)

    def forward(self, observations):
        map_input = observations["map_features_tile_type"].view(-1, 3, 24, 24)  # Reshape as (batch, channels, H, W)
        non_spatial_input = observations["team_points"]  # Example non-spatial input
        unit_features = observations["unit_positions"].view(-1, 16, 3)  # Reshape for attention
        
        map_features = self.cnn(map_input)
        non_spatial_features = self.mlp(non_spatial_input)
        attn_out, _ = self.attention(unit_features, unit_features, unit_features)
        attn_out = attn_out.mean(dim=1)  # Pool across units
        
        combined = th.cat([map_features, attn_out], dim=1)
        return self.final_linear(combined)


class CustomActorCriticPolicy(ActorCriticPolicy):
    """
    Custom PPO Policy with optimized architecture.
    """
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            features_extractor_class=CustomCNNExtractor,
            features_extractor_kwargs={"features_dim": 128},
            **kwargs
        )


# Example Model Usage
# env = YourEnvironment()
# model = PPO(CustomActorCriticPolicy, env, verbose=1, ent_coef=0.015, vf_coef=0.75, clip_range_vf=0.15, n_steps=505, batch_size=505)


In [None]:
model.policy.pi_features_extractor.

In [None]:
learn_results = model.learn(total_timesteps=2000000, progress_bar=True)

In [None]:
# Number of parallel environments (adjust based on CPU cores)
NUM_ENVS = 8

def make_env():
    return ModifiedLuxAIS3GymEnv(numpy_output=True)  # Use your custom environment

env = SubprocVecEnv([lambda: make_env() for _ in range(NUM_ENVS)])
model = PPO("MultiInputPolicy", env, verbose=1, n_steps=2048 * NUM_ENVS)

In [None]:
learn_results = model.learn(total_timesteps=2000000, progress_bar=True)

In [None]:
# Number of parallel environments (adjust based on CPU cores)
NUM_ENVS = 8

def make_env():
    return ModifiedLuxAIS3GymEnv(numpy_output=True)  # Use your custom environment

env = SubprocVecEnv([lambda: make_env() for _ in range(NUM_ENVS)])
model = PPO("MultiInputPolicy", env, verbose=1, n_steps=2048 * NUM_ENVS)

In [None]:
learn_results = model.learn(total_timesteps=2000000, progress_bar=True)

In [None]:
temp_zeros = np.zeros(2, dtype=np.int32)
temp_zeros

In [None]:
temp_zeros[0]

In [None]:
import jax

In [None]:
jax

In [None]:
import jax.numpy as jnp

In [None]:
jnp.bool(False) == False

In [None]:
team_points=jnp.zeros(shape=(2), dtype=jnp.int32),
team_points

In [None]:
type(jnp.where(True, 3, -1) != -1)

In [None]:
team_points.at[0]

In [None]:
obs_all, info = env.reset()

In [None]:
action0 = np.zeros((16, 3), dtype=np.int8)
action1 = np.zeros((16, 3), dtype=np.int8)

In [None]:
env.step({
    "player_0": action0,
    "player_1": action1
})

In [None]:
env.step("player_0")

In [None]:
from Modified_lux3_wrapper.modified_wrappers_20250228_01 import ModifiedLuxAIS3GymEnv
import numpy as np
from Modified_stablebaseline3_PPO.modified_ppo_20250228_01 import PPO
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import os
import copy
from GreedyLRScheduler import GreedyLR
from luxai_s3.wrappers import LuxAIS3GymEnv
import gc
gc.enable()
from stable_baselines3.common.buffers import DictRolloutBuffer

In [None]:
torch.set_float32_matmul_precision('medium')
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 128
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.8)
torch.cuda.empty_cache()

In [None]:
init_env = ModifiedLuxAIS3GymEnv(numpy_output=True)
init_ppo = PPO("MultiInputPolicy", init_env, verbose=1)

In [None]:
model_0 = init_ppo.policy
model_1 = copy.deepcopy(model_0)

In [None]:
model_0

In [None]:
model_0.device

In [None]:
temp_buffer = DictRolloutBuffer(1000, model_0.observation_space, model_0.action_space, model_0.device)
temp_buffer

In [None]:
temp_buffer.reset()

In [None]:
for rollout_data in temp_buffer.get(1):
    print(rollout_data)

In [None]:
def point_gain_reward_func(reward_score) -> float:

    return reward_score * 20 if reward_score > 0.0 else -1

def match_won_reward_func(match_won) -> float:

    return 5000.0 if match_won else 0.0

def match_lost_reward_func(match_lost) -> float:

    return -3000.0 if match_lost else 0.0

def game_won_reward_func(game_won) -> float:

    return 1000000000.0 if game_won else 0.0

def game_lost_reward_func(game_lost) -> float:

    return -1000000000.0 if game_lost else 0.0

def map_reveal_reward_func(map_reveal_score):

    return map_reveal_score * 10

def attack_reward_func(actions, sap_range, enemy_unit_mask) -> float:

    attack_score = 0.0
    
    for i, action in enumerate(actions):
        action_num, dx, dy = action[0], action[1], action[2]
        if action_num >= 5:
            if enemy_unit_mask.sum() != 0:
                sap_action_range = max(abs(dx), abs(dy))
                if sap_action_range > sap_range:
                    attack_score -= 0.5
            else:
                attack_score -= 5.0
    
    return attack_score

def next_position_calculator(action_num, unit_positions):
    # 0: stay, 1: up, 2: right, 3: down, 4: left

    if action_num == 1:
        next_position = (unit_positions[0], unit_positions[1] - 1)
    elif action_num == 2:
        next_position = (unit_positions[0] + 1, unit_positions[1])
    elif action_num == 3:
        next_position = (unit_positions[0], unit_positions[1] + 1)
    elif action_num == 4:
        next_position = (unit_positions[0] - 1, unit_positions[1])
    else:
        next_position = unit_positions
    
    return next_position

def movement_reward_func(actions, obs, team_id) -> float:

    movement_score = 0.0

    for i, action in enumerate(actions):
        action_num, dx, dy = action[0], action[1], action[2]
        unit_positions = obs["units"]["position"][team_id][i]
        unit_energy = obs["units"]["energy"][team_id][i]

        # give penalty if try to move unit that doesn't exist
        if (unit_positions == (-1, -1)).sum() == 2 and action_num != 0:
            movement_score -= 0.25
        
        # give penalty if dx or dy is not 0 when not attacking
        if action_num != 5:
            if dx != 0 or dy != 0:
                movement_score -= 0.25

        
        if unit_positions[0] >= 0 and unit_positions[1] >= 0:
            # give penalty if try to move unit that has no energy
            if unit_energy <= 0 and action_num != 0:
                movement_score -= 0.25
        
        # give penalty if try to move unit out of map
        next_position = next_position_calculator(action_num, unit_positions)
        if next_position[0] < 0 or next_position[1] < 0 or next_position[0] > 23 or next_position[1] > 23:
            movement_score -= 0.5
        else:
            movement_score += 2.0
    

    return movement_score

def relic_discovery_reward_func(relic_discovery_reward) -> float:

    return relic_discovery_reward * 100

In [None]:
class TrainPPO:
    def __init__(
        self,
        model_0,
        model_1,
        num_games=1000,
        learning_rate=5e-4,
        weight_decay=0.01,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range=0.2,
        clip_range_vf=None,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
    ):
        self.model_0 = model_0
        self.model_1 = model_1
        self.num_games = num_games
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm

        self.optimizer_0 = AdamW(self.model_0.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, fused=True)
        self.optimizer_1 = AdamW(self.model_1.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, fused=True)

        self.scheduler_0 = GreedyLR(self.optimizer_0, cooldown=3, min_lr=1e-7, max_lr=5e-4)
        self.scheduler_1 = GreedyLR(self.optimizer_1, cooldown=3, min_lr=1e-7, max_lr=5e-4)

        self.step_rollout_buffer_0 = DictRolloutBuffer(10, self.model_0.observation_space, self.model_0.action_space, device="cuda")
        self.step_rollout_buffer_1 = DictRolloutBuffer(10, self.model_1.observation_space, self.model_1.action_space, device="cuda")

        self.match_rollout_buffer_0 = DictRolloutBuffer(101, self.model_0.observation_space, self.model_0.action_space, device="cuda")
        self.match_rollout_buffer_1 = DictRolloutBuffer(101, self.model_1.observation_space, self.model_1.action_space, device="cuda")

        self.env = LuxAIS3GymEnv(numpy_output=True)

        self.model_0.mlp_extractor = torch.compile(self.model_0.mlp_extractor)
        self.model_1.mlp_extractor = torch.compile(self.model_1.mlp_extractor)

    def train(self):

        for game in range(1, self.num_games + 1):
            print("="*15 + f" Game {game} Started " + "="*15)

            obs_all, info = self.env.reset()
            self.env_cfg = info['params']

            game_ended = False

            player_0_previous_score = 0.0
            player_1_previous_score = 0.0

            first_spawn = False

            self.spawn_location = np.array([[-1, -1], [-1, -1]], dtype=np.int32)

            self.map_explored_status = np.zeros((2, 24, 24), dtype=bool)

            player_0_previous_map_explored_status_score = self.map_explored_status[0].sum()
            player_1_previous_map_explored_status_score = self.map_explored_status[1].sum()

            player_0_match_won_num = 0
            player_1_match_won_num = 0

            player_0_previous_relic_discovery_points = 0
            player_1_previous_relic_discovery_points = 0

            victor = None

            game_start = True

            match_number = 1

            while game_ended is not True:

                player_0_match_won = False
                player_0_match_lost = False
                player_1_match_won = False
                player_1_match_lost = False

                player_0_game_won = False
                player_0_game_lost = False
                player_1_game_won = False
                player_1_game_lost = False

                player_0_current_score = obs_all['player_0']['team_points'][0]
                player_1_current_score = obs_all['player_1']['team_points'][1]

                player_0_reward_score = player_0_current_score - player_0_previous_score
                player_1_reward_score = player_1_current_score - player_1_previous_score

                player_0_previous_score = player_0_current_score
                player_1_previous_score = player_1_current_score

                current_match_step = obs_all["player_0"]["match_steps"]

                if current_match_step == 100:
                    if player_0_current_score > player_1_current_score:
                        player_0_match_won = True
                        player_1_match_lost = True
                        player_0_match_won_num += 1
                    elif player_0_current_score < player_1_current_score:
                        player_0_match_lost = True
                        player_1_match_won = True
                        player_1_match_won_num += 1

                if player_0_match_won_num >= 3:
                    game_ended = True
                    print("Player 0 won the game.")
                    victor = "player_0"
                    player_0_game_won = True
                    player_1_game_lost = True

                if player_1_match_won_num >= 3:
                    game_ended = True
                    print("Player 1 won the game.")
                    victor = "player_1"
                    player_0_game_lost = True
                    player_1_game_won = True

                player_0_unit_positions = np.array(obs_all['player_0']["units"]["position"][0])
                player_1_unit_positions = np.array(obs_all['player_1']["units"]["position"][1])

                player_0_unit_mask = np.array(obs_all['player_0']["units_mask"][0])
                player_1_unit_mask = np.array(obs_all['player_1']["units_mask"][1])

                player_0_available_unit_ids = np.where(player_0_unit_mask)[0]
                player_1_available_unit_ids = np.where(player_1_unit_mask)[0]

                if player_0_available_unit_ids.shape[0] == 0:
                    pass
                else:
                    if first_spawn == False:
                        player_0_first_unit_id = player_0_available_unit_ids[0]
                        player_0_first_unit_pos = player_0_unit_positions[player_0_first_unit_id]
                        self.spawn_location[0] = (player_0_first_unit_pos[0], player_0_first_unit_pos[1])
                        player_1_first_unit_id = player_1_available_unit_ids[0]
                        player_1_first_unit_pos = player_1_unit_positions[player_1_first_unit_id]
                        self.spawn_location[1] = (player_1_first_unit_pos[0], player_1_first_unit_pos[1])
                        first_spawn = True

                player_0_map_features = obs_all['player_0']['map_features']
                player_1_map_features = obs_all['player_1']['map_features']

                player_0_current_map_tile_type = player_0_map_features['tile_type'].T
                player_1_current_map_tile_type = player_1_map_features['tile_type'].T

                self.map_explored_status[0][player_0_current_map_tile_type != -1] = True
                self.map_explored_status[1][player_1_current_map_tile_type != -1] = True

                player_0_current_map_explored_status_score = self.map_explored_status[0].sum()
                player_1_current_map_explored_status_score = self.map_explored_status[1].sum()

                player_0_map_explored_status_reward = player_0_current_map_explored_status_score - player_0_previous_map_explored_status_score
                player_1_map_explored_status_reward = player_1_current_map_explored_status_score - player_1_previous_map_explored_status_score

                player_0_previous_map_explored_status_score = player_0_current_map_explored_status_score
                player_1_previous_map_explored_status_score = player_1_current_map_explored_status_score

                ### Reward caclulation
                player_0_relic_point_reward = point_gain_reward_func(player_0_reward_score)
                player_1_relic_point_reward = point_gain_reward_func(player_1_reward_score)

                player_0_match_won_reward = match_won_reward_func(player_0_match_won)
                player_0_match_lost_reward = match_lost_reward_func(player_0_match_lost)
                player_1_match_won_reward = match_won_reward_func(player_1_match_won)
                player_1_match_lost_reward = match_lost_reward_func(player_1_match_lost)

                player_0_game_won_reward = game_won_reward_func(player_0_game_won)
                player_0_game_lost_reward = game_lost_reward_func(player_0_game_lost)
                player_1_game_won_reward = game_won_reward_func(player_1_game_won)
                player_1_game_lost_reward = game_lost_reward_func(player_1_game_lost)

                player_0_map_reveal_reward = map_reveal_reward_func(player_0_map_explored_status_reward)
                player_1_map_reveal_reward = map_reveal_reward_func(player_1_map_explored_status_reward)

                ### model input
                if game_start == True:
                    player_0_model_input = self.prepare_model_input(obs_all["player_0"], 0)
                    player_1_model_input = self.prepare_model_input(obs_all["player_1"], 1)
                    game_start = False

                with torch.no_grad():
                    player_0_action_distribution, _, _ = self.model_0(player_0_model_input)
                    player_1_action_distribution, _, _ = self.model_1(player_1_model_input)

                player_0_action = copy.deepcopy(player_0_action_distribution.reshape(-1, 16, 3)).squeeze()
                player_0_action[:, 1] = player_0_action[:, 1] - 7
                player_0_action[:, 2] = player_0_action[:, 2] - 7
                player_1_action = copy.deepcopy(player_1_action_distribution.reshape(-1, 16, 3)).squeeze()
                player_1_action[:, 1] = player_1_action[:, 1] - 7
                player_1_action[:, 2] = player_1_action[:, 2] - 7

                print(player_0_action)
                print(obs_all["player_0"]["map_features"]["tile_type"].T)

                player_0_attack_reward = attack_reward_func(player_0_action, self.env_cfg["unit_sap_range"], player_1_unit_mask)
                player_1_attack_reward = attack_reward_func(player_1_action, self.env_cfg["unit_sap_range"], player_0_unit_mask)

                player_0_movement_reward = movement_reward_func(player_0_action, obs_all["player_0"], 0)
                player_1_movement_reward = movement_reward_func(player_1_action, obs_all["player_1"], 1)

                player_0_reward = player_0_relic_point_reward + player_0_match_won_reward + player_0_match_lost_reward + player_0_game_won_reward + player_0_game_lost_reward + player_0_map_reveal_reward + player_0_attack_reward + player_0_movement_reward
                player_1_reward = player_1_relic_point_reward + player_1_match_won_reward + player_1_match_lost_reward + player_1_game_won_reward + player_1_game_lost_reward + player_1_map_reveal_reward + player_1_attack_reward + player_1_movement_reward
                # player_0_reward = torch.tensor(player_0_reward, dtype=torch.float32, device="cuda")
                # player_1_reward = torch.tensor(player_1_reward, dtype=torch.float32, device="cuda")

                player_0_features = self.model_0.extract_features(player_0_model_input)
                player_1_features = self.model_1.extract_features(player_1_model_input)

                player_0_latent_pi, player_0_latent_vf = self.model_0.mlp_extractor(player_0_features)
                player_1_latent_pi, player_1_latent_vf = self.model_1.mlp_extractor(player_1_features)

                player_0_distribution = self.model_0._get_action_dist_from_latent(player_0_latent_pi)
                player_1_distribution = self.model_1._get_action_dist_from_latent(player_1_latent_pi)

                player_0_log_prob = player_0_distribution.log_prob(player_0_action_distribution)
                player_1_log_prob = player_1_distribution.log_prob(player_1_action_distribution)

                player_0_value = self.model_0.value_net(player_0_latent_vf)
                player_1_value = self.model_1.value_net(player_1_latent_vf)

                player_0_entropy = player_0_distribution.entropy()
                player_1_entropy = player_1_distribution.entropy()

                obs_all, _, _, _, _ = self.env.step({
                    "player_0": player_0_action.detach(),
                    "player_1": player_1_action.detach()
                })

                player_0_model_input = self.prepare_model_input(obs_all["player_0"], 0)
                player_1_model_input = self.prepare_model_input(obs_all["player_1"], 1)

                with torch.no_grad():
                    # Compute value for the last timestep
                    player_0_new_value = self.model_0.predict_values(player_0_model_input)  # type: ignore[arg-type]
                    player_1_new_value = self.model_1.predict_values(player_1_model_input)

                # player_0_delta = player_0_reward + self.gamma * player_0_new_value - player_0_value
                # player_0_advantage = player_0_delta + self.gamma * self.gae_lambda
                player_0_advantage = player_0_reward + self.gamma * player_0_new_value - player_0_value
                player_0_advantage = player_0_advantage.detach()
                # player_0_advantage = torch.tensor(player_0_advantage, dtype=torch.float32, device="cuda")
                player_0_return = player_0_advantage + player_0_value

                # player_1_delta = player_1_reward + self.gamma * player_1_new_value - player_1_value
                # player_1_advantage = player_1_delta + self.gamma * self.gae_lambda
                player_1_advantage = player_1_reward + self.gamma * player_1_new_value - player_1_value
                player_1_advantage = player_1_advantage.detach()
                # player_1_advantage = torch.tensor(player_1_advantage, dtype=torch.float32, device="cuda")
                player_1_return = player_1_advantage + player_1_value

                player_0_policy_loss_1 = player_0_advantage
                player_0_policy_loss_2 = player_0_advantage * torch.clamp(torch.tensor(1), 1 - self.clip_range, 1 + self.clip_range)
                player_0_policy_loss = -torch.min(player_0_policy_loss_1, player_0_policy_loss_2).mean()

                player_1_policy_loss_1 = player_1_advantage
                player_1_policy_loss_2 = player_1_advantage * torch.clamp(torch.tensor(1), 1 - self.clip_range, 1 + self.clip_range)
                player_1_policy_loss = -torch.min(player_1_policy_loss_1, player_1_policy_loss_2).mean()

                if self.clip_range_vf is None:
                    player_0_values_pred = player_0_new_value
                    player_1_values_pred = player_1_new_value
                else:
                    player_0_values_pred = player_0_value + torch.clamp(player_0_new_value - player_0_value, -self.clip_range_vf, self.clip_range_vf)
                    player_1_values_pred = player_1_value + torch.clamp(player_1_new_value - player_1_value, -self.clip_range_vf, self.clip_range_vf)

                player_0_value_loss = F.mse_loss(player_0_return, player_0_values_pred)
                player_1_value_loss = F.mse_loss(player_1_return, player_1_values_pred)

                player_0_entropy_loss = -torch.mean(-player_0_entropy)
                player_1_entropy_loss = -torch.mean(-player_1_entropy)

                player_0_loss = player_0_policy_loss + self.ent_coef * player_0_entropy_loss + self.vf_coef * player_0_value_loss
                player_1_loss = player_1_policy_loss + self.ent_coef * player_1_entropy_loss + self.vf_coef * player_1_value_loss

                self.optimizer_0.zero_grad()
                player_0_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model_0.parameters(), self.max_grad_norm)
                self.optimizer_0.step()
                self.scheduler_0.step(player_0_loss.item())

                self.optimizer_1.zero_grad()
                player_1_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model_1.parameters(), self.max_grad_norm)
                self.optimizer_1.step()
                self.scheduler_1.step(player_1_loss.item())

                if match_number >= 5 and current_match_step == 100:
                    game_ended = True
                    print("Game ended.")

                if current_match_step == 100:
                    match_number += 1

            if victor == "player_0":
                self.synchronize_models(self.model_0, self.model_1)
            elif victor == "player_1":
                self.synchronize_models(self.model_1, self.model_0)

            torch.cuda.empty_cache()
            gc.collect()

                

        return


    def prepare_model_input(self, obs, my_team_id):
        enemy_team_id = 1 - my_team_id

        self.spawn_location = np.array([[-1, -1], [-1, -1]], dtype=np.int32)

        

        model_input = {
            "enemy_energies": obs["units"]["energy"][enemy_team_id],
            "enemy_positions": obs["units"]["position"][enemy_team_id],
            "enemy_spawn_location": self.spawn_location[enemy_team_id],
            "enemy_visible_mask": obs["units_mask"][enemy_team_id],
            "map_explored_status": self.map_explored_status[my_team_id],
            "map_features_energy": obs["map_features"]["energy"],
            "map_features_tile_type": obs["map_features"]["tile_type"],
            "match_steps": np.array([obs["match_steps"]]),
            "my_spawn_location": self.spawn_location[my_team_id],
            "relic_nodes": obs["relic_nodes"],
            "relic_nodes_mask": obs["relic_nodes_mask"],
            "sensor_mask": obs["sensor_mask"],
            "steps": np.array([obs["steps"]]),
            "team_id": np.array([my_team_id]),
            "team_points": obs["team_points"],
            "team_wins": obs["team_wins"],
            "unit_active_mask": obs["units_mask"][my_team_id],
            "unit_energies": obs["units"]["energy"][my_team_id],
            "unit_move_cost": np.array([self.env_cfg["unit_move_cost"]]),
            "unit_positions": obs["units"]["position"][my_team_id],
            "unit_sap_cost": np.array([self.env_cfg["unit_sap_cost"]]),
            "unit_sap_range": np.array([self.env_cfg["unit_sap_range"]]),
            "unit_sensor_range": np.array([self.env_cfg["unit_sensor_range"]]),
        }

        model_input = {k: torch.tensor(np.expand_dims(v, axis=0), dtype=torch.int32, device="cuda") for k, v in model_input.items()}

        return model_input
    
    def synchronize_models(self, winner_model, loser_model):
        with torch.no_grad():
            for p1, p2 in zip(winner_model.parameters(), loser_model.parameters()):
                p2.data.copy_(p1.data)



        


In [None]:
trainer = TrainPPO(model_0, model_1)

In [None]:
trainer.train()

In [None]:
model.policy.parameters()

In [None]:
obs_tensor = model.policy.extract_features(obs)
obs_tensor

In [None]:
obs_tensor.shape

In [None]:
obs = {
    "enemy_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "enemy_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "enemy_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "enemy_visible_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "map_explored_status": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "map_features_energy": np.random.randint(-7, 10, size=(1, 24, 24), dtype=np.int32),
    "map_features_tile_type": np.random.randint(-1, 3, size=(1, 24, 24), dtype=np.int32),
    "match_steps": np.random.randint(0, 101, size=(1, 1,), dtype=np.int32),
    "my_spawn_location": np.random.randint(-1, 24, size=(1, 2,), dtype=np.int32),
    "relic_nodes": np.random.randint(-1, 24, size=(1, 6, 2), dtype=np.int32),
    "relic_nodes_mask": np.random.randint(0, 2, size=(1, 6,), dtype=np.int32),
    "sensor_mask": np.random.randint(0, 2, size=(1, 24, 24), dtype=np.int32),
    "steps": np.random.randint(0, 506, size=(1, 1,), dtype=np.int32),
    "team_id": np.random.randint(0, 2, size=(1, 1,), dtype=np.int32),
    "team_points": np.random.randint(0, 2501, size=(1, 2,), dtype=np.int32),
    "team_wins": np.random.randint(0, 4, size=(1, 2,), dtype=np.int32),
    "unit_active_mask": np.random.randint(0, 2, size=(1, 16,), dtype=np.int32),
    "unit_energies": np.random.randint(-800, 401, size=(1, 16,), dtype=np.int32),
    "unit_move_cost": np.random.randint(1, 6, size=(1, 1, ), dtype=np.int32),
    "unit_positions": np.random.randint(-1, 24, size=(1, 16, 2), dtype=np.int32),
    "unit_sap_cost": np.random.randint(30, 51, size=(1, 1, ), dtype=np.int32),
    "unit_sap_range": np.random.randint(3, 8, size=(1, 1, ), dtype=np.int32),
    "unit_sensor_range": np.random.randint(2, 5, size=(1, 1, ), dtype=np.int32),
}

In [None]:
obs = {k: torch.tensor(v, dtype=torch.float32, device="cuda") for k, v in obs.items()}

# Convert observation to tensor and check shape
obs_tensor = model.policy.extract_features(obs)
print(f"Extracted Feature Shape: {obs_tensor.shape}")  # Expected: (batch_size, 2464)

In [None]:
with torch.no_grad():
    action_distribution, value, log = model.policy.forward(obs)

In [None]:
action_distribution.shape

In [None]:
action_distribution

In [None]:
value.shape

In [None]:
value

In [None]:
log.shape

In [None]:
log

In [None]:
actions = action_distribution.reshape(-1, 16, 3)
actions

In [None]:
actions.shape

In [None]:
model.policy.mlp_extractor = torch.compile(model.policy.mlp_extractor)

In [None]:
with torch.no_grad():
    action_distribution, value, log = model.policy.forward(obs)

In [None]:
actions = action_distribution.reshape(-1, 16, 3)
actions

In [None]:
%pip install --upgrade luxai-s3