In [1]:
pip install magent2

Collecting magent2
  Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting pettingzoo>=1.23.1 (from magent2)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pettingzoo-1.24.3-py3-none-any.whl (847 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m847.8/847.8 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [

In [2]:
pip install pettingzoo==1.22.0

Collecting pettingzoo==1.22.0
  Downloading PettingZoo-1.22.0-py3-none-any.whl.metadata (5.0 kB)
Downloading PettingZoo-1.22.0-py3-none-any.whl (823 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.4/823.4 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pettingzoo
  Attempting uninstall: pettingzoo
    Found existing installation: pettingzoo 1.24.3
    Uninstalling pettingzoo-1.24.3:
      Successfully uninstalled pettingzoo-1.24.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
magent2 0.3.3 requires pettingzoo>=1.23.1, but you have pettingzoo 1.22.0 which is incompatible.[0m[31m
[0mSuccessfully installed pettingzoo-1.22.0


In [6]:
import torch
import torch.nn as nn
import os
import cv2
import numpy as np
from magent2.environments import battle_v4
from torch_model import Qnetwork
# Define the Q-networks (same as training)
class RedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1).unsqueeze(0)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]

        self.fc_value = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 1)
        )

        self.fc_advantage = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, action_shape)
        )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.cnn(x)
        x = x.reshape(x.size(0), -1)
        value = self.fc_value(x)
        advantage = self.fc_advantage(x)
        return value + advantage - advantage.mean(dim=1, keepdim=True)
# Initialize environment and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if __name__ == "__main__":
    # Initialize environment
    env = battle_v4.env(map_size=45, render_mode="rgb_array")
    vid_dir = "video"
    os.makedirs(vid_dir, exist_ok=True)
    fps = 35

    env.reset()
    frames = []

    # Load the pretrained model
    model_path = "red_agent_dueling_noisy_ddqn_per_final.pth"  # Adjust path as needed
    sample_observation = env.observation_spaces[env.agents[0]].shape
    state_space = sample_observation  # Dynamic observation shape
    action_space = env.action_spaces[env.agents[0]].n

    red_q_network = RedQNetwork(state_space, action_space).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    red_q_network.load_state_dict(checkpoint)
    red_q_network.eval()

    blue_q_network = BlueQNetwork(state_space, action_space).to(device)
    blue_q_network.load_state_dict(torch.load("red.pt", weights_only=True, map_location="cpu"))
    blue_q_network.eval()

    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()

        if termination or truncation:
            action = None  # This agent has died
        else:
            if agent.startswith("red"):
                obs = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    red_q_values = red_q_network(obs)
                action = int(torch.argmax(red_q_values, dim=1).item())
            else:
                observation = (
                    torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
                )
                with torch.no_grad():
                   blue_q_values = blue_q_network(observation)
                action = torch.argmax(blue_q_values, dim=1).numpy()[0]
        env.step(action)

        if agent == "red_12" or agent == "red_77" or agent == "red_37" or agent == "red_7":
            frames.append(env.render())

    # Save the video
    height, width, _ = frames[0].shape
    out = cv2.VideoWriter(
        os.path.join(vid_dir, f"pretrained_agents.mp4"),
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (width, height),
    )
    for frame in frames:
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    out.release()
    print("Done recording pretrained agents")

    env.close()

Using device: cpu


  checkpoint = torch.load(model_path, map_location=device)


Done recording pretrained agents
