# Environment

In [37]:
# !git clone https://github.com/cuongtv312/marl-delivery.git marl_delivery

In [38]:
try:
    from env import Environment
except:
    from marl_delivery.env import Environment
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.distributions import Categorical
import numpy as np
import random
import os
from sklearn.calibration import LabelEncoder # For action conversion
import matplotlib.pyplot as plt

SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [39]:
def convert_observation(state, persistent_packages, current_robot_idx):
    """
    Convert state to a 2D multi-channel tensor for a specific robot.
    - 6 channels for robot-specific observation:
        0. Map
        1. Urgency of 'waiting' packages (if robot is not carrying)
        2. Start positions of 'waiting' packages (if robot is not carrying)
        3. Other robots' positions
        4. Current robot's position
        5. Current robot's carried package target (if robot is carrying)

    Args:
        state (dict): Raw state from the environment.
                      Expected keys: "map", "robots", "time_step".
                      state["robots"] is a list of tuples: (pos_x+1, pos_y+1, carrying_package_id)
        persistent_packages (dict): Dictionary tracking all active packages.
                                    Positions are 0-indexed.
        current_robot_idx (int): Index of the current robot for which to generate the observation.

    Returns:
        np.ndarray of shape (6, n_rows, n_cols)
    """
    grid = np.array(state["map"])
    n_rows, n_cols = grid.shape
    n_channels = 6
    tensor = np.zeros((n_channels, n_rows, n_cols), dtype=np.float32)

    # --- Channel 0: Map ---
    tensor[0] = grid

    current_time_step = state["time_step"]
    if isinstance(current_time_step, np.ndarray): # Handle case where time_step might be an array
        current_time_step = current_time_step[0]

    # Get current robot's data and determine if it's carrying a package
    # Ensure current_robot_idx is valid
    if current_robot_idx < 0 or current_robot_idx >= len(state["robots"]):
        # This case should ideally be handled by the caller or indicate an error
        # print(f"Warning: Invalid current_robot_idx {current_robot_idx}")
        return tensor # Return empty tensor or handle error appropriately

    current_robot_data = state["robots"][current_robot_idx]
    carried_pkg_id_by_current_robot = current_robot_data[2] # 1-indexed ID, 0 if not carrying

    # --- Channel 1: Urgency of 'waiting' packages (if robot is not carrying) ---
    # --- Channel 2: Start positions of 'waiting' packages (if robot is not carrying) ---
    if carried_pkg_id_by_current_robot == 0: # Robot is NOT carrying a package
        for pkg_id, pkg_data in persistent_packages.items():
            if pkg_data['status'] == 'waiting':
                sr, sc = pkg_data['start_pos']  # 0-indexed
                st = pkg_data['start_time']
                dl = pkg_data['deadline']

                # Check if package is active (start_time has passed)
                if current_time_step >= st:
                    # Channel 1: Urgency
                    urgency = 0
                    if dl > st: # Avoid division by zero or negative duration
                        # Normalize urgency: 0 (just appeared) to 1 (deadline reached)
                        # Cap at 1 if current_time_step exceeds deadline
                        urgency = min(1.0, max(0.0, (current_time_step - st) / (dl - st)))
                    elif dl == st: # Deadline is the start time
                         urgency = 1.0 if current_time_step >= st else 0.0
                    # else: dl < st, invalid, urgency remains 0

                    if 0 <= sr < n_rows and 0 <= sc < n_cols: # Boundary check
                        tensor[1, sr, sc] = max(tensor[1, sr, sc], urgency) # Use max if multiple pkgs at same spot

                    # Channel 2: Start position
                    if 0 <= sr < n_rows and 0 <= sc < n_cols: # Boundary check
                        tensor[2, sr, sc] = 1.0 # Mark presence
    # If robot is carrying, channels 1 and 2 remain all zeros.

    # --- Channel 3: Other robots' positions ---
    for i, rob_data in enumerate(state["robots"]):
        if i == current_robot_idx:
            continue # Skip the current robot
        rr, rc, _ = rob_data # Positions are 1-indexed from env
        rr_idx, rc_idx = int(rr) - 1, int(rc) - 1 # Convert to 0-indexed
        if 0 <= rr_idx < n_rows and 0 <= rc_idx < n_cols: # Boundary check
            tensor[3, rr_idx, rc_idx] = 1.0

    # --- Channel 4: Current robot's position ---
    # current_robot_data was fetched earlier
    crr, crc, _ = current_robot_data # Positions are 1-indexed
    crr_idx, crc_idx = int(crr) - 1, int(crc) - 1 # Convert to 0-indexed
    if 0 <= crr_idx < n_rows and 0 <= crc_idx < n_cols: # Boundary check
        tensor[4, crr_idx, crc_idx] = 1.0

    # --- Channel 5: Current robot's carried package target (if robot is carrying) ---
    if carried_pkg_id_by_current_robot != 0:
        # Ensure the package ID from state['robots'] is valid and exists in persistent_packages
        if carried_pkg_id_by_current_robot in persistent_packages:
            pkg_data_carried = persistent_packages[carried_pkg_id_by_current_robot]
            # Double check status, though if robot carries it, it should be 'in_transit'
            # or just became 'in_transit' in the persistent_packages update logic.
            # For this observation, we primarily care about its target.
            tr_carried, tc_carried = pkg_data_carried['target_pos'] # 0-indexed
            if 0 <= tr_carried < n_rows and 0 <= tc_carried < n_cols: # Boundary check
                tensor[5, tr_carried, tc_carried] = 1.0
        # else:
            # This case might indicate an inconsistency.
            # print(f"Warning: Robot {current_robot_idx} carrying pkg {carried_pkg_id_by_current_robot} not in persistent_packages.")
    # If robot is not carrying, channel 5 remains all zeros.

    return tensor

In [40]:
def convert_state(state_dict, persistent_packages, state_tensor_shape):
    """
    Converts the global state dictionary to a tensor for QMIX.
    Relies on `persistent_packages` for all package information.
    The `packages` key in `state_dict` (if present) is ignored for package data.

    Args:
        state_dict (dict): The raw environment state dictionary.
                           Expected keys: "map", "robots", "time_step".
        persistent_packages (dict): Dictionary tracking all active packages.
                                    Positions are 0-indexed.
                                    Example entry:
                                    { pkg_id: {'start_pos': (r,c), 'target_pos': (r,c),
                                                'status': 'waiting'/'in_transit',
                                                'start_time': ts, 'deadline': dl, 'id': pkg_id} }
        state_tensor_shape (tuple): Tuple (num_channels, n_rows, n_cols) for the output state tensor.
        max_time_steps (int): Maximum time steps in an episode for normalization.

    Returns:
        np.ndarray: The global state tensor with shape specified by state_tensor_shape.
        float: Normalized current time step (scalar feature).
    """
    num_channels_out, n_rows, n_cols = state_tensor_shape
    
    spatial_tensor = np.zeros((num_channels_out, n_rows, n_cols), dtype=np.float32)

    CH_IDX_MAP_OBSTACLES = 0
    CH_IDX_ROBOT_POSITIONS = 1
    CH_IDX_ROBOT_CARRYING_STATUS = 2
    CH_IDX_PKG_WAITING_START_POS = 3
    CH_IDX_PKG_WAITING_TARGET_POS = 4
    CH_IDX_PKG_IN_TRANSIT_TARGET_POS = 5
    CH_IDX_PKG_WAITING_URGENCY = 6

    # --- Channel: Map Obstacles (Centering/Cropping Logic) ---
    if CH_IDX_MAP_OBSTACLES < num_channels_out:
        game_map_from_state = np.array(state_dict["map"])
        map_rows_src, map_cols_src = game_map_from_state.shape

        src_r_start = (map_rows_src - n_rows) // 2 if map_rows_src > n_rows else 0
        src_c_start = (map_cols_src - n_cols) // 2 if map_cols_src > n_cols else 0
        
        rows_to_copy_from_src = min(map_rows_src, n_rows)
        cols_to_copy_from_src = min(map_cols_src, n_cols)

        map_section_to_copy = game_map_from_state[
            src_r_start : src_r_start + rows_to_copy_from_src,
            src_c_start : src_c_start + cols_to_copy_from_src
        ]
        
        target_r_offset = (n_rows - map_section_to_copy.shape[0]) // 2
        target_c_offset = (n_cols - map_section_to_copy.shape[1]) // 2
            
        spatial_tensor[
            CH_IDX_MAP_OBSTACLES,
            target_r_offset : target_r_offset + map_section_to_copy.shape[0],
            target_c_offset : target_c_offset + map_section_to_copy.shape[1]
        ] = map_section_to_copy

    # --- Current Time (Scalar Feature) ---
    current_time = state_dict["time_step"]

    # --- Channels: Robot Positions and Carrying Status (from state_dict['robots']) ---
    if 'robots' in state_dict and state_dict['robots'] is not None:
        for r_data in state_dict['robots']:
            # r_data: (pos_r_1idx, pos_c_1idx, carrying_package_id)
            r_idx, c_idx = int(r_data[0]) - 1, int(r_data[1]) - 1 # Convert to 0-indexed
            carried_pkg_id = r_data[2]

            if 0 <= r_idx < n_rows and 0 <= c_idx < n_cols: # Boundary check
                if CH_IDX_ROBOT_POSITIONS < num_channels_out:
                    spatial_tensor[CH_IDX_ROBOT_POSITIONS, r_idx, c_idx] = 1.0
                
                if carried_pkg_id != 0 and CH_IDX_ROBOT_CARRYING_STATUS < num_channels_out:
                    spatial_tensor[CH_IDX_ROBOT_CARRYING_STATUS, r_idx, c_idx] = 1.0

    # --- Process persistent_packages for ALL package-related channels ---
    # Note: state_dict['packages'] is NOT used here.
    for pkg_id, pkg_data in persistent_packages.items():
        start_pos = pkg_data['start_pos']   # Expected (r, c) 0-indexed
        target_pos = pkg_data['target_pos'] # Expected (r, c) 0-indexed
        status = pkg_data['status']
        pkg_start_time = pkg_data['start_time']
        pkg_deadline = pkg_data['deadline']
        
        # Process only if package is active based on its start_time
        if current_time >= pkg_start_time:
            if status == 'waiting':
                # Channel: Waiting Packages' Start Positions
                if CH_IDX_PKG_WAITING_START_POS < num_channels_out:
                    if 0 <= start_pos[0] < n_rows and 0 <= start_pos[1] < n_cols: # Boundary check
                        spatial_tensor[CH_IDX_PKG_WAITING_START_POS, start_pos[0], start_pos[1]] = 1.0

                # Channel: Urgency of Waiting Packages
                if CH_IDX_PKG_WAITING_URGENCY < num_channels_out:
                    if 0 <= start_pos[0] < n_rows and 0 <= start_pos[1] < n_cols: # Boundary check
                        urgency = 0.0
                        if pkg_deadline > pkg_start_time: 
                            urgency = min(1.0, max(0.0, (current_time - pkg_start_time) / (pkg_deadline - pkg_start_time)))
                        elif pkg_deadline == pkg_start_time: 
                            urgency = 1.0 # Deadline is now or passed if current_time >= pkg_start_time
                        # Use max in case multiple packages share the same start_pos
                        spatial_tensor[CH_IDX_PKG_WAITING_URGENCY, start_pos[0], start_pos[1]] = \
                            max(spatial_tensor[CH_IDX_PKG_WAITING_URGENCY, start_pos[0], start_pos[1]], urgency)

                # Channel: Waiting Packages' Target Positions
                if CH_IDX_PKG_WAITING_TARGET_POS < num_channels_out:
                    if 0 <= target_pos[0] < n_rows and 0 <= target_pos[1] < n_cols: # Boundary check
                        spatial_tensor[CH_IDX_PKG_WAITING_TARGET_POS, target_pos[0], target_pos[1]] = \
                            max(spatial_tensor[CH_IDX_PKG_WAITING_TARGET_POS, target_pos[0], target_pos[1]], 1.0)
            
            elif status == 'in_transit':
                # Channel: In-Transit Packages' Target Positions
                if CH_IDX_PKG_IN_TRANSIT_TARGET_POS < num_channels_out:
                    if 0 <= target_pos[0] < n_rows and 0 <= target_pos[1] < n_cols: # Boundary check
                        spatial_tensor[CH_IDX_PKG_IN_TRANSIT_TARGET_POS, target_pos[0], target_pos[1]] = \
                            max(spatial_tensor[CH_IDX_PKG_IN_TRANSIT_TARGET_POS, target_pos[0], target_pos[1]], 1.0)
                
    return spatial_tensor

# Hyperparameters

In [41]:
# --- MAPPO Hyperparameters ---
ACTION_DIM = 15  # Total discrete actions for an agent
NUM_AGENTS = 5
MAP_FILE = "map1.txt"
N_PACKAGES = 20
MOVE_COST = -0.01 # Adjusted for PPO, rewards should be reasonably scaled
DELIVERY_REWARD = 10
DELAY_REWARD = 1 # Or 0, depending on reward shaping strategy
MAX_TIME_STEPS_PER_EPISODE = 1000 # Max steps for one episode in one env

NUM_ENVS = 1  # Number of parallel environments
ROLLOUT_STEPS = 1024 # Number of steps to collect data for before an update
TOTAL_TIMESTEPS = 1_000_000 # Total timesteps for training

# PPO specific
LR_ACTOR = 5e-5
LR_CRITIC = 5e-5
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPS = 0.2
NUM_EPOCHS = 10 # Number of epochs to train on collected data
MINIBATCH_SIZE = 64 # Minibatch size for PPO updates
ENTROPY_COEF = 0.01
VALUE_LOSS_COEF = 0.5
MAX_GRAD_NORM = 0.5
WEIGHT_DECAY = 1e-3

# Actor Network for MAPPO

In [42]:
class ActorNetwork(nn.Module):
    def __init__(self, obs_shape, action_dim):
        super(ActorNetwork, self).__init__()
        # obs_shape is (C, H, W) e.g. (6, map_height, map_width)
        self.conv1 = nn.Conv2d(obs_shape[0], 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        # Calculate flattened size after conv layers
        def conv_out_size(h_in, w_in):
            # Assuming kernel=3, stride=1, padding=1 keeps H, W same
            return h_in * w_in * 64

        self.flattened_size = conv_out_size(obs_shape[1], obs_shape[2])
        
        self.fc1 = nn.Linear(self.flattened_size, 256)
        self.actor_head = nn.Linear(256, action_dim)

    def forward(self, obs):
        # obs: (batch_size, C, H, W)
        x = F.relu(self.bn1(self.conv1(obs)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.reshape(x.size(0), -1) # Flatten
        x = F.relu(self.fc1(x))
        action_logits = self.actor_head(x)
        return action_logits

# Critic Network for MAPPO
class CriticNetwork(nn.Module):
    def __init__(self, global_state_shape):
        super(CriticNetwork, self).__init__()
        # global_state_shape is (C_global, H, W) e.g. (7, map_height, map_width)
        self.conv1 = nn.Conv2d(global_state_shape[0], 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        def conv_out_size(h_in, w_in):
            return h_in * w_in * 64
            
        self.flattened_size = conv_out_size(global_state_shape[1], global_state_shape[2])

        self.fc1 = nn.Linear(self.flattened_size, 256)
        self.critic_head = nn.Linear(256, 1) # Outputs a single value

    def forward(self, global_state):
        # global_state: (batch_size, C_global, H, W)
        x = F.relu(self.bn1(self.conv1(global_state)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.reshape(x.size(0), -1) # Flatten
        x = F.relu(self.fc1(x))
        value = self.critic_head(x)
        return value

In [43]:
def save_mappo_model(actor, critic, path_prefix="models/mappo"):
    if not os.path.exists(os.path.dirname(path_prefix)):
        os.makedirs(os.path.dirname(path_prefix))
    torch.save(actor.state_dict(), f"{path_prefix}_actor.pt")
    torch.save(critic.state_dict(), f"{path_prefix}_critic.pt")
    print(f"MAPPO models saved with prefix {path_prefix}")


In [44]:
def load_mappo_model(actor, critic, path_prefix="models/mappo"):
    actor_path = f"{path_prefix}_actor.pt"
    critic_path = f"{path_prefix}_critic.pt"
    if os.path.exists(actor_path) and os.path.exists(critic_path):
        actor.load_state_dict(torch.load(actor_path, map_location=device))
        critic.load_state_dict(torch.load(critic_path, map_location=device))
        print(f"MAPPO models loaded from prefix {path_prefix}")
        return True
    print(f"Could not find MAPPO models at prefix {path_prefix}")
    return False

# Reward Shaping

In [45]:
def reward_shaping(
    prev_env_state,
    current_env_state,
    actions_taken,
    persistent_packages_before_action,
    num_agents
):
    """
    Compute shaped rewards for each agent using only carrying_id transitions.
    """
    # --- Constants ---
    SHAPING_SUCCESSFUL_PICKUP_BONUS = 50
    SHAPING_SUCCESSFUL_DELIVERY_BONUS = 1000
    SHAPING_LATE_DELIVERY_PENALTY = -900
    SHAPING_WRONG_DROP_LOCATION_PENALTY = -10 # never happens
    SHAPING_WASTED_PICKUP_PENALTY = -0.01
    SHAPING_WASTED_DROP_PENALTY = 0
    SHAPING_FAILED_INTENDED_PICKUP_PENALTY = -0.01
    SHAPING_FAILED_INTENDED_DROP_PENALTY = -0.01
    SHAPING_STAY_PENALTY = -0.01
    MOVE_COST = -0.01

    individual_rewards = np.array([MOVE_COST] * num_agents)

    current_time_from_env = current_env_state['time_step']
    if isinstance(current_time_from_env, np.ndarray):
        current_time_from_env = current_time_from_env[0]

    time_at_prev_state = prev_env_state.get('time_step', current_time_from_env - 1)
    if isinstance(time_at_prev_state, np.ndarray):
        time_at_prev_state = time_at_prev_state[0]

    for i in range(num_agents):
        agent_action = actions_taken[i]
        package_op = int(agent_action[1])  # 0: None, 1: Pick, 2: Drop

        prev_robot_info = prev_env_state['robots'][i]
        current_robot_info = current_env_state['robots'][i]

        robot_prev_pos_0idx = (prev_robot_info[0] - 1, prev_robot_info[1] - 1)
        robot_current_pos_0idx = (current_robot_info[0] - 1, current_robot_info[1] - 1)

        prev_carrying_id = prev_robot_info[2]
        current_carrying_id = current_robot_info[2]

        # Penalty for staying in place if there are waiting packages
        waiting_packages_exist = any(
            pkg['status'] == 'waiting' and pkg['start_time'] <= time_at_prev_state
            for pkg in persistent_packages_before_action.values()
        )
        if robot_prev_pos_0idx == robot_current_pos_0idx and waiting_packages_exist:
            individual_rewards[i] += SHAPING_STAY_PENALTY
            # print(f"Agent {i}: Stay penalty. Reward: {SHAPING_STAY_PENALTY}")
        elif robot_prev_pos_0idx == robot_current_pos_0idx and not waiting_packages_exist:
            individual_rewards[i] -= SHAPING_STAY_PENALTY  # reward for not moving if no waiting packages
            # print(f"Agent {i}: Cancel move cost. Reward: {SHAPING_STAY_PENALTY}")

        # --- Xử lý dựa trên thay đổi trạng thái mang hàng (carrying status) ---
        if prev_carrying_id == 0 and current_carrying_id != 0:
            # Sự kiện: Robot đã NHẶT được một gói hàng
            individual_rewards[i] += SHAPING_SUCCESSFUL_PICKUP_BONUS
            # print(f"Agent {i}: Successful pickup. Reward: {SHAPING_SUCCESSFUL_PICKUP_BONUS}")

        elif prev_carrying_id != 0 and current_carrying_id == 0:
            # Sự kiện: Robot đã THẢ một gói hàng
            dropped_pkg_id = prev_carrying_id
            if dropped_pkg_id in persistent_packages_before_action:
                pkg_info = persistent_packages_before_action[dropped_pkg_id]
                pkg_target_pos_0idx = pkg_info['target_pos']
                pkg_deadline = pkg_info['deadline']

                if robot_current_pos_0idx == pkg_target_pos_0idx:
                    # Thả hàng ĐÚNG vị trí đích
                    individual_rewards[i] += SHAPING_SUCCESSFUL_DELIVERY_BONUS
                    # print(f"Agent {i}: Successful delivery. Reward: {SHAPING_SUCCESSFUL_DELIVERY_BONUS}")
                    if current_time_from_env > pkg_deadline:
                        individual_rewards[i] += SHAPING_LATE_DELIVERY_PENALTY
                        # print(f"Agent {i}: Late delivery penalty. Reward: {SHAPING_LATE_DELIVERY_PENALTY}")
                else:
                    # Thả hàng SAI vị trí đích
                    individual_rewards[i] += SHAPING_WRONG_DROP_LOCATION_PENALTY
                    # print(f"Agent {i}: Wrong drop location. Reward: {SHAPING_WRONG_DROP_LOCATION_PENALTY}")
            # else: dropped a package not in persistent_packages_before_action (should not happen)

        # --- Xử lý các ý định (package_op) không dẫn đến thay đổi trạng thái mong muốn hoặc không hợp lệ ---
        else:
            if package_op == 1:  # Robot CỐ GẮNG nhặt hàng
                if prev_carrying_id != 0:
                    # Cố nhặt khi đang mang hàng
                    individual_rewards[i] += SHAPING_WASTED_PICKUP_PENALTY
                    # print(f"Agent {i}: Wasted pickup attempt (already carrying). Reward: {SHAPING_WASTED_PICKUP_PENALTY}")
                elif prev_carrying_id == 0 and current_carrying_id == 0:
                    # Cố nhặt khi không mang gì, và sau đó vẫn không mang gì (nhặt thất bại)
                    individual_rewards[i] += SHAPING_FAILED_INTENDED_PICKUP_PENALTY
                    # print(f"Agent {i}: Failed intended pickup. Reward: {SHAPING_FAILED_INTENDED_PICKUP_PENALTY}")

            elif package_op == 2:  # Robot CỐ GẮNG thả hàng
                if prev_carrying_id == 0:
                    # Cố thả khi không mang gì
                    individual_rewards[i] += SHAPING_WASTED_DROP_PENALTY
                    # print(f"Agent {i}: Wasted drop attempt (not carrying). Reward: {SHAPING_WASTED_DROP_PENALTY}")
                elif prev_carrying_id != 0 and current_carrying_id != 0:
                    # Cố thả khi đang mang hàng, và sau đó vẫn đang mang hàng (thả thất bại)
                    individual_rewards[i] += SHAPING_FAILED_INTENDED_DROP_PENALTY
                    # print(f"Agent {i}: Failed intended drop. Reward: {SHAPING_FAILED_INTENDED_DROP_PENALTY}")
    
    # print(f"Total rewards: {individual_rewards.sum()}")
    return individual_rewards.sum()

# Initialize

In [46]:
class VectorizedEnv:
    def __init__(self, env_cls, num_envs, **env_kwargs):
        # Assign a unique seed to each environment if 'seed' is in env_kwargs
        base_seed = env_kwargs.get('seed', None)
        self.envs = []
        for idx in range(num_envs):
            env_args = env_kwargs.copy()
            if base_seed is not None:
                env_args['seed'] = base_seed + idx
            self.envs.append(env_cls(**env_args))
        self.num_envs = num_envs

    def reset(self):
        return [env.reset() for env in self.envs]

    def step(self, actions):
        # actions: list of actions for each env
        results = [env.step(action) for env, action in zip(self.envs, actions)]
        next_states, rewards, dones, infos = zip(*results)
        return list(next_states), list(rewards), list(dones), list(infos)

    def render(self):
        for env in self.envs:
            env.render_pygame() 

In [47]:
class MAPPOTrainer:
    def __init__(self, vec_env, num_agents, action_dim, obs_shape, global_state_shape):
        self.vec_env = vec_env
        self.num_envs = vec_env.num_envs
        self.num_agents = num_agents
        self.action_dim = action_dim
        self.obs_shape = obs_shape # (C, H, W) for local obs
        self.global_state_shape = global_state_shape # (C_global, H, W) for global state

        self.actor = ActorNetwork(obs_shape, action_dim).to(device)
        self.critic = CriticNetwork(global_state_shape).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR_ACTOR, weight_decay=WEIGHT_DECAY)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY)

        # For converting integer actions to environment actions
        self.le_move = LabelEncoder()
        self.le_move.fit(['S', 'L', 'R', 'U', 'D'])
        self.le_pkg_op = LabelEncoder()
        self.le_pkg_op.fit(['0', '1', '2']) # 0: None, 1: Pickup, 2: Drop
        self.NUM_MOVE_ACTIONS = len(self.le_move.classes_)
        self.NUM_PKG_OPS = len(self.le_pkg_op.classes_)
        
        # Persistent packages trackers for each environment (for state conversion)
        self.persistent_packages_list = [{} for _ in range(self.num_envs)]


    def _update_persistent_packages_for_env(self, env_idx, current_env_state_dict):
        # This is a simplified version of the DQNTrainer's method, adapted for one env
        current_persistent_packages = self.persistent_packages_list[env_idx]
        
        if 'packages' in current_env_state_dict and current_env_state_dict['packages'] is not None:
            for pkg_tuple in current_env_state_dict['packages']:
                pkg_id = pkg_tuple[0]
                if pkg_id not in current_persistent_packages:
                    current_persistent_packages[pkg_id] = {
                        'id': pkg_id,
                        'start_pos': (pkg_tuple[1] - 1, pkg_tuple[2] - 1),
                        'target_pos': (pkg_tuple[3] - 1, pkg_tuple[4] - 1),
                        'start_time': pkg_tuple[5],
                        'deadline': pkg_tuple[6],
                        'status': 'waiting'
                    }

        current_carried_pkg_ids_set = set()
        if 'robots' in current_env_state_dict and current_env_state_dict['robots'] is not None:
            for r_data in current_env_state_dict['robots']:
                carried_id = r_data[2]
                if carried_id != 0:
                    current_carried_pkg_ids_set.add(carried_id)

        packages_to_remove = []
        for pkg_id, pkg_data in list(current_persistent_packages.items()):
            if pkg_id in current_carried_pkg_ids_set:
                current_persistent_packages[pkg_id]['status'] = 'in_transit'
            else:
                if pkg_data['status'] == 'in_transit':
                    packages_to_remove.append(pkg_id)
        
        for pkg_id_to_remove in packages_to_remove:
            if pkg_id_to_remove in current_persistent_packages:
                del current_persistent_packages[pkg_id_to_remove]
        self.persistent_packages_list[env_idx] = current_persistent_packages


    def _get_actions_and_values(self, current_local_obs_b_a_c_h_w, current_global_states_b_c_h_w):
        # Ensure input tensors are on the correct device
        current_local_obs_b_a_c_h_w = current_local_obs_b_a_c_h_w.to(device)
        current_global_states_b_c_h_w = current_global_states_b_c_h_w.to(device)
        
        actor_input_obs = current_local_obs_b_a_c_h_w.reshape(self.num_envs * self.num_agents, self.obs_shape[0], self.obs_shape[1], self.obs_shape[2])
        action_logits = self.actor(actor_input_obs) # (NUM_ENVS * NUM_AGENTS, ACTION_DIM)
        dist = Categorical(logits=action_logits)
        actions_int = dist.sample() # (NUM_ENVS * NUM_AGENTS)
        log_probs = dist.log_prob(actions_int) # (NUM_ENVS * NUM_AGENTS)

        actions_int_reshaped = actions_int.reshape(self.num_envs, self.num_agents)
        log_probs_reshaped = log_probs.reshape(self.num_envs, self.num_agents)

        values = self.critic(current_global_states_b_c_h_w) # (NUM_ENVS, 1)

        return actions_int_reshaped, log_probs_reshaped, values.squeeze(-1) # values squeezed to (NUM_ENVS)

    def collect_rollouts(self, current_env_states_list, current_local_obs_list, current_global_states_list):
        # Buffers to store trajectory data
        mb_obs = torch.zeros((ROLLOUT_STEPS, self.num_envs, self.num_agents, *self.obs_shape), device=device)
        mb_global_states = torch.zeros((ROLLOUT_STEPS, self.num_envs, *self.global_state_shape), device=device)
        mb_actions = torch.zeros((ROLLOUT_STEPS, self.num_envs, self.num_agents), dtype=torch.long, device=device)
        mb_log_probs = torch.zeros((ROLLOUT_STEPS, self.num_envs, self.num_agents), device=device)
        mb_rewards = torch.zeros((ROLLOUT_STEPS, self.num_envs), device=device)
        mb_dones = torch.zeros((ROLLOUT_STEPS, self.num_envs), dtype=torch.bool, device=device)
        mb_values = torch.zeros((ROLLOUT_STEPS, self.num_envs), device=device)

        # Move initial obs/states to device
        current_local_obs_list = current_local_obs_list.to(device)
        current_global_states_list = current_global_states_list.to(device)

        for step in range(ROLLOUT_STEPS):
            # Render the environment
            self.vec_env.render()
            
            mb_obs[step] = current_local_obs_list
            mb_global_states[step] = current_global_states_list

            with torch.no_grad():
                actions_int_ne_na, log_probs_ne_na, values_ne = self._get_actions_and_values(
                    current_local_obs_list, 
                    current_global_states_list
                )
            
            mb_actions[step] = actions_int_ne_na
            mb_log_probs[step] = log_probs_ne_na
            mb_values[step] = values_ne

            # Convert integer actions to environment compatible actions
            env_actions_batch = []
            for env_idx in range(self.num_envs):
                env_agent_actions = []
                for agent_idx in range(self.num_agents):
                    int_act = actions_int_ne_na[env_idx, agent_idx].item()
                    move_idx = int_act % self.NUM_MOVE_ACTIONS
                    pkg_op_idx = int_act // self.NUM_MOVE_ACTIONS
                    if pkg_op_idx >= self.NUM_PKG_OPS: pkg_op_idx = 0 # Safety clamp
                    
                    move_str = self.le_move.inverse_transform([move_idx])[0]
                    pkg_op_str = self.le_pkg_op.inverse_transform([pkg_op_idx])[0]
                    env_agent_actions.append((move_str, pkg_op_str))
                env_actions_batch.append(env_agent_actions)

            next_env_states_list, global_rewards_ne, dones_ne, _ = self.vec_env.step(env_actions_batch)
            
            # use reward shaping here
            reshaped_global_rewards_ne = [reward_shaping(current_env_states_list[env_idx], 
                                                        next_env_states_list[env_idx], 
                                                        env_actions_batch[env_idx], 
                                                        self.persistent_packages_list[env_idx],
                                                        self.num_agents) for env_idx in range(self.num_envs)]
            
            mb_rewards[step] = torch.tensor(reshaped_global_rewards_ne, dtype=torch.float32, device=device)
            mb_dones[step] = torch.tensor(dones_ne, dtype=torch.bool, device=device)

            # Prepare next observations and states
            next_local_obs_list = torch.zeros_like(current_local_obs_list, device=device)
            next_global_states_list = torch.zeros_like(current_global_states_list, device=device)

            for env_idx in range(self.num_envs):
                if dones_ne[env_idx]:
                    # --- Reset environment if done ---
                    reset_state = self.vec_env.envs[env_idx].reset()
                    self._update_persistent_packages_for_env(env_idx, reset_state)
                    current_persistent_packages = self.persistent_packages_list[env_idx]
                    next_env_states_list[env_idx] = reset_state
                    # Update global state and local obs after reset
                    next_global_states_list[env_idx] = torch.from_numpy(convert_state(
                        reset_state, 
                        current_persistent_packages, 
                        self.global_state_shape
                    )).to(device)
                    for agent_idx in range(self.num_agents):
                        next_local_obs_list[env_idx, agent_idx] = torch.from_numpy(
                            convert_observation(reset_state, current_persistent_packages, agent_idx)
                        ).float().to(device)
                else:
                    self._update_persistent_packages_for_env(env_idx, next_env_states_list[env_idx])
                    current_persistent_packages = self.persistent_packages_list[env_idx]
                    next_global_states_list[env_idx] = torch.from_numpy(convert_state(
                        next_env_states_list[env_idx], 
                        current_persistent_packages, 
                        self.global_state_shape
                    )).to(device)
                    for agent_idx in range(self.num_agents):
                        next_local_obs_list[env_idx, agent_idx] = torch.from_numpy(
                            convert_observation(next_env_states_list[env_idx], current_persistent_packages, agent_idx)
                        ).float().to(device)
            
            current_env_states_list = next_env_states_list
            current_local_obs_list = next_local_obs_list
            current_global_states_list = next_global_states_list
        
        # Calculate advantages using GAE
        advantages = torch.zeros_like(mb_rewards, device=device)
        last_gae_lambda = 0
        with torch.no_grad():
            # Get value of the last state in the rollout
            next_value_ne = self.critic(current_global_states_list).squeeze(-1) # (NUM_ENVS)

        for t in reversed(range(ROLLOUT_STEPS)):
            next_non_terminal = 1.0 - mb_dones[t].float()
            next_values_step = next_value_ne if t == ROLLOUT_STEPS - 1 else mb_values[t+1]
            
            delta = mb_rewards[t] + GAMMA * next_values_step * next_non_terminal - mb_values[t]
            advantages[t] = last_gae_lambda = delta + GAMMA * GAE_LAMBDA * next_non_terminal * last_gae_lambda
        
        returns = advantages + mb_values

        # Flatten the batch for training
        b_obs = mb_obs.reshape(-1, *self.obs_shape)
        b_global_states = mb_global_states.reshape(ROLLOUT_STEPS * self.num_envs, *self.global_state_shape)
        b_actions = mb_actions.reshape(-1)
        b_log_probs = mb_log_probs.reshape(-1)
        
        b_advantages = advantages.reshape(ROLLOUT_STEPS * self.num_envs, 1).repeat(1, self.num_agents).reshape(-1)
        b_returns_critic = returns.reshape(-1)

        return (b_obs, b_global_states, b_actions, 
           b_log_probs, b_advantages, b_returns_critic,
           current_env_states_list, current_local_obs_list, current_global_states_list,
           mb_rewards)

    def update_ppo(self, b_obs, b_global_states, b_actions, b_log_probs_old, b_advantages, b_returns_critic):
        # Ensure all tensors are on the correct device
        b_obs = b_obs.to(device)
        b_global_states = b_global_states.to(device)
        b_actions = b_actions.to(device)
        b_log_probs_old = b_log_probs_old.to(device)
        b_advantages = b_advantages.to(device)
        b_returns_critic = b_returns_critic.to(device)

        num_samples_actor = ROLLOUT_STEPS * self.num_envs * self.num_agents
        num_samples_critic = ROLLOUT_STEPS * self.num_envs
        
        actor_batch_indices = np.arange(num_samples_actor)
        critic_batch_indices = np.arange(num_samples_critic)

        for epoch in range(NUM_EPOCHS):
            np.random.shuffle(actor_batch_indices)
            np.random.shuffle(critic_batch_indices)

            # Actor update
            for start in range(0, num_samples_actor, MINIBATCH_SIZE):
                end = start + MINIBATCH_SIZE
                mb_indices = actor_batch_indices[start:end]

                mb_obs_slice = b_obs[mb_indices]
                mb_actions_slice = b_actions[mb_indices]
                mb_log_probs_old_slice = b_log_probs_old[mb_indices]
                mb_advantages_slice = b_advantages[mb_indices]
                
                # Normalize advantages (optional but often helpful)
                mb_advantages_slice = (mb_advantages_slice - mb_advantages_slice.mean()) / (mb_advantages_slice.std() + 1e-8)

                action_logits = self.actor(mb_obs_slice)
                dist = Categorical(logits=action_logits)
                new_log_probs = dist.log_prob(mb_actions_slice)
                entropy = dist.entropy().mean()

                log_ratio = new_log_probs - mb_log_probs_old_slice
                ratio = torch.exp(log_ratio)

                pg_loss1 = -mb_advantages_slice * ratio
                pg_loss2 = -mb_advantages_slice * torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS)
                actor_loss = torch.max(pg_loss1, pg_loss2).mean()

                total_actor_loss = actor_loss - ENTROPY_COEF * entropy
                
                self.actor_optimizer.zero_grad()
                total_actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), MAX_GRAD_NORM)
                self.actor_optimizer.step()

            # Critic update
            for start in range(0, num_samples_critic, MINIBATCH_SIZE // self.num_agents if self.num_agents > 0 else MINIBATCH_SIZE):
                end = start + (MINIBATCH_SIZE // self.num_agents if self.num_agents > 0 else MINIBATCH_SIZE)
                mb_indices = critic_batch_indices[start:end]
                
                mb_global_states_slice = b_global_states[mb_indices]
                mb_returns_critic_slice = b_returns_critic[mb_indices]

                new_values = self.critic(mb_global_states_slice).squeeze(-1)
                critic_loss = F.mse_loss(new_values, mb_returns_critic_slice)
                
                total_critic_loss = VALUE_LOSS_COEF * critic_loss

                self.critic_optimizer.zero_grad()
                total_critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), MAX_GRAD_NORM)
                self.critic_optimizer.step()
        
        return actor_loss.item(), critic_loss.item(), entropy.item()

In [None]:
import pygame


print(f"Using device: {device}")
vec_env = VectorizedEnv(
    Environment, num_envs=NUM_ENVS,
    map_file=MAP_FILE,
    n_robots=NUM_AGENTS,
    n_packages=N_PACKAGES,
    move_cost=MOVE_COST,
    delivery_reward=DELIVERY_REWARD,
    delay_reward=DELAY_REWARD,
    seed=SEED, # Seed for each sub-environment will be SEED, SEED+1, ...
    max_time_steps=MAX_TIME_STEPS_PER_EPISODE
)

# Determine observation and global state shapes from one env instance
_temp_env = Environment(map_file=MAP_FILE, n_robots=NUM_AGENTS, n_packages=N_PACKAGES, move_cost=MOVE_COST, delivery_reward=DELIVERY_REWARD, delay_reward=DELAY_REWARD, seed=SEED, max_time_steps=MAX_TIME_STEPS_PER_EPISODE)

        
OBS_SHAPE = (6, _temp_env.n_rows, _temp_env.n_cols)
GLOBAL_STATE_SHAPE = (7, _temp_env.n_rows, _temp_env.n_cols)

print(f"Obs shape: {OBS_SHAPE}, Global state shape: {GLOBAL_STATE_SHAPE}")

trainer = MAPPOTrainer(vec_env, NUM_AGENTS, ACTION_DIM, OBS_SHAPE, GLOBAL_STATE_SHAPE)

# Load existing model if available
# load_mappo_model(trainer.actor, trainer.critic) # Uncomment to load

episode_rewards_history = []
actor_loss_history = []
critic_loss_history = []
entropy_history = []

print("Starting MAPPO training...")

# Initial reset and state preparation
current_env_states_list = vec_env.reset() # List of state dicts
current_local_obs_list = torch.zeros((NUM_ENVS, NUM_AGENTS, *OBS_SHAPE), device="cpu")
current_global_states_list = torch.zeros((NUM_ENVS, *GLOBAL_STATE_SHAPE), device="cpu")

for env_idx in range(NUM_ENVS):
    
    trainer._update_persistent_packages_for_env(env_idx, current_env_states_list[env_idx])
    current_persistent_packages = trainer.persistent_packages_list[env_idx]
    current_global_states_list[env_idx] = torch.from_numpy(convert_state(
                                                        current_env_states_list[env_idx], 
                                                        current_persistent_packages, 
                                                        GLOBAL_STATE_SHAPE
                                                        ))
    for agent_idx in range(NUM_AGENTS):
        current_local_obs_list[env_idx, agent_idx] = torch.from_numpy(
            convert_observation(current_env_states_list[env_idx], current_persistent_packages, agent_idx)
        ).float()

num_updates = TOTAL_TIMESTEPS // (ROLLOUT_STEPS * NUM_ENVS)
total_steps_done = 0

try:
    for update_num in range(1, num_updates + 1):
        
        (b_obs, b_global_states, b_actions, b_log_probs_old, b_advantages, b_returns_critic,
        current_env_states_list, current_local_obs_list, current_global_states_list,mb_rewards
        ) = trainer.collect_rollouts(current_env_states_list, current_local_obs_list, current_global_states_list)


        actor_loss, critic_loss, entropy = trainer.update_ppo(
            b_obs, b_global_states, b_actions, b_log_probs_old, b_advantages, b_returns_critic
        )
        
        total_steps_done += ROLLOUT_STEPS * NUM_ENVS
        
        # For logging, we might need to track rewards from rollouts
        # Print mean reward for this rollout
        mean_reward = mb_rewards.mean().item()
        print(f"Mean reward this rollout: {mean_reward:.4f}")
        episode_rewards_history.append(mean_reward)

        actor_loss_history.append(actor_loss)
        critic_loss_history.append(critic_loss)
        entropy_history.append(entropy)

        if update_num % 10 == 0: # Log every 10 updates
            print(f"Update {update_num}/{num_updates} | Timesteps: {total_steps_done}/{TOTAL_TIMESTEPS}")
            print(f"  Actor Loss: {actor_loss:.4f} | Critic Loss: {critic_loss:.4f} | Entropy: {entropy:.4f}")
        
        if update_num % 100 == 0: # Save model periodically
            print(f"Saving checkpoint at update {update_num}...")
            save_mappo_model(trainer.actor, trainer.critic, path_prefix=f"models/mappo_update{update_num}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"\nAn error occurred during training: {e}")
    import traceback
    traceback.print_exc()
finally:
    print("Saving final model...")
    pygame.quit()
    save_mappo_model(trainer.actor, trainer.critic, path_prefix="models/mappo_final")
    print("\nTraining loop finished or was interrupted.")

    # Plotting
    plt.figure(figsize=(24, 5))
    plt.subplot(1, 4, 1)
    plt.plot(episode_rewards_history)
    plt.title('Mean Reward per Rollout')
    plt.xlabel('Update Number')
    plt.ylabel('Mean Reward')
    plt.grid(True)

    plt.subplot(1, 4, 2)
    plt.plot(actor_loss_history)
    plt.title('Actor Loss per Update')
    plt.xlabel('Update Number')
    plt.ylabel('Actor Loss')
    plt.grid(True)

    plt.subplot(1, 4, 3)
    plt.plot(critic_loss_history)
    plt.title('Critic Loss per Update')
    plt.xlabel('Update Number')
    plt.ylabel('Critic Loss')
    plt.grid(True)

    plt.subplot(1, 4, 4)
    plt.plot(entropy_history)
    plt.title('Policy Entropy per Update')
    plt.xlabel('Update Number')
    plt.ylabel('Entropy')
    plt.grid(True)

    plt.tight_layout()
    plt.show()


Using device: cuda
Obs shape: (6, 10, 10), Global state shape: (7, 10, 10)
Starting MAPPO training...
Mean reward this rollout: 1.0246
Mean reward this rollout: 1.2832
Mean reward this rollout: 2.6415
Mean reward this rollout: 1.2137
Mean reward this rollout: 1.3955
Mean reward this rollout: 1.9008
Mean reward this rollout: 1.8547
Mean reward this rollout: 2.5677
Mean reward this rollout: 1.9353
Mean reward this rollout: 1.8513
Update 10/976 | Timesteps: 10240/1000000
  Actor Loss: -0.1250 | Critic Loss: 163.0067 | Entropy: 2.5576
Mean reward this rollout: 0.8160
Mean reward this rollout: 1.3811
Mean reward this rollout: 1.2000
Mean reward this rollout: 1.8940
Mean reward this rollout: 1.6370
Mean reward this rollout: 0.8175
Mean reward this rollout: 1.4640
Mean reward this rollout: 2.2389
Mean reward this rollout: 1.2959
Mean reward this rollout: 1.3548
Update 20/976 | Timesteps: 20480/1000000
  Actor Loss: -0.1439 | Critic Loss: 432.3106 | Entropy: 2.5523
Mean reward this rollout: 0.