In [1]:
pip install magent2

Collecting magent2
  Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting pettingzoo>=1.23.1 (from magent2)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pettingzoo-1.24.3-py3-none-any.whl (847 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m847.8/847.8 kB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [

In [2]:
pip install pettingzoo==1.22.0

Collecting pettingzoo==1.22.0
  Downloading PettingZoo-1.22.0-py3-none-any.whl.metadata (5.0 kB)
Downloading PettingZoo-1.22.0-py3-none-any.whl (823 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.4/823.4 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pettingzoo
  Attempting uninstall: pettingzoo
    Found existing installation: pettingzoo 1.24.3
    Uninstalling pettingzoo-1.24.3:
      Successfully uninstalled pettingzoo-1.24.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
magent2 0.3.3 requires pettingzoo>=1.23.1, but you have pettingzoo 1.22.0 which is incompatible.[0m[31m
[0mSuccessfully installed pettingzoo-1.22.0


In [3]:
import torch
import torch.nn as nn
import os
import cv2
import numpy as np
from magent2.environments import battle_v4

# Define the Q-networks (same as training)
class BlueQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1).unsqueeze(0)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]

        self.fc_value = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 1)
        )

        self.fc_advantage = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, action_shape)
        )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.cnn(x)
        x = x.reshape(x.size(0), -1)
        value = self.fc_value(x)
        advantage = self.fc_advantage(x)
        return value + advantage - advantage.mean(dim=1, keepdim=True)


In [4]:

class RedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)


In [6]:
# Initialize environment and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if __name__ == "__main__":
    # Initialize environment
    env = battle_v4.env(map_size=45, render_mode="rgb_array")
    vid_dir = "video"
    os.makedirs(vid_dir, exist_ok=True)
    fps = 35

    env.reset()
    frames = []

    # Load the pretrained model
    model_path = "blue_agent_dueling_ddqn_per_best.pth"  # Adjust path as needed
    sample_observation = env.observation_spaces[env.agents[0]].shape
    state_space = sample_observation  # Dynamic observation shape
    action_space = env.action_spaces[env.agents[0]].n

    blue_q_network = BlueQNetwork(state_space, action_space).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    blue_q_network.load_state_dict(checkpoint)
    blue_q_network.eval()

    red_q_network = RedQNetwork(state_space, action_space).to(device)
    red_q_network.load_state_dict(torch.load("red.pt", weights_only=True, map_location="cpu"))
    red_q_network.eval()

    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()

        if termination or truncation:
            action = None  # This agent has died
        else:
            if agent.startswith("red"):
                action = env.action_space(agent).sample()
            else:
                obs = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    blue_q_values = blue_q_network(obs)
                action = int(torch.argmax(blue_q_values, dim=1).item())
        env.step(action)

        if agent == "blue_12" or agent == "blue_77" or agent == "blue_37" or agent == "blue_7":
            frames.append(env.render())

    # Save the video
    height, width, _ = frames[0].shape
    out = cv2.VideoWriter(
        os.path.join(vid_dir, f"pretrained_agents.mp4"),
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (width, height),
    )
    for frame in frames:
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    out.release()
    print("Done recording pretrained agents")

    env.close()

Using device: cpu


  checkpoint = torch.load(model_path, map_location=device)


Done recording pretrained agents


In [7]:
from magent2.environments import battle_v4
import torch
import numpy as np

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, *args, **kwargs: x  # Fallback: tqdm becomes a no-op


def eval():
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()

    red_q_network = RedQNetwork(state_space, action_space).to(device)
    red_q_network.load_state_dict(torch.load("red.pt", weights_only=True, map_location="cpu"))

    blue_q_network = BlueQNetwork(state_space, action_space).to(device)
    blue_q_network.load_state_dict(torch.load("blue_agent_dueling_ddqn_per_best.pth", map_location=device))


    def pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            red_q_values = red_q_network(observation)
        return torch.argmax(red_q_values, dim=1).cpu().numpy()[0]

    def my_model_pretrain_policy(env, agent, obs):
        observation = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad():
          blue_q_values = blue_q_network(observation)
        return int(torch.argmax(blue_q_values, dim=1).item())
    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  # This assumes default reward settups
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  # this agent has died
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
        }

    print("=" * 20)
    print("Eval with random policy")
    print(
        run_eval(
            env=env, red_policy=random_policy, blue_policy=my_model_pretrain_policy, n_episode=30
        )
    )
    print("=" * 20)

    print("Eval with trained policy")
    print(
        run_eval(
            env=env, red_policy=pretrain_policy, blue_policy=my_model_pretrain_policy, n_episode=30
        )
    )
    print("=" * 20)


if __name__ == "__main__":
    eval()

  blue_q_network.load_state_dict(torch.load("blue_agent_dueling_ddqn_per_best.pth", map_location=device))


Eval with random policy


100%|██████████| 30/30 [02:45<00:00,  5.51s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -2.2230926709843666, 'average_rewards_blue': 4.429499981608702}
Eval with trained policy


100%|██████████| 30/30 [01:42<00:00,  3.41s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 0.9988353812085933, 'average_rewards_blue': 4.834024665303866}



