In [None]:
from magent2.environments import battle_v4
import os
import cv2
import torch
# Load both pretrained models
from torch_model import QNetwork as RedQNetwork
from torch_model_modified import QNetwork as BlueQNetwork

In [None]:
# Initialize environment and video settings
env = battle_v4.env(map_size=45, render_mode="rgb_array")
vid_dir = "video"
os.makedirs(vid_dir, exist_ok=True)
fps = 35

red_network = RedQNetwork(
    env.observation_space("red_0").shape, 
    env.action_space("red_0").n
)
blue_network = BlueQNetwork(
    env.observation_space("blue_0").shape, 
    env.action_space("blue_0").n
)

In [None]:
# Load model weights
red_network.load_state_dict(
    torch.load("red.pt", weights_only=True, map_location="cpu")
)
blue_network.load_state_dict(
    torch.load("blue.pth", weights_only=True, map_location="cpu")
)

<All keys matched successfully>

In [8]:
# Battle simulation
frames = []
env.reset()

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None  # agent has died
    else:
        agent_handle = agent.split("_")[0]
        observation = torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0)
        
        with torch.no_grad():
            if agent_handle == "red":
                q_values = red_network(observation)
                action = torch.argmax(q_values, dim=1).numpy()[0]
            else:  # blue team
                q_values = blue_network(observation)
                action = torch.argmax(q_values, dim=1).numpy()[0]

    env.step(action)

    if agent == "red_0":  # Record frames from red_0's perspective
        frames.append(env.render())

# Save video
height, width, _ = frames[0].shape
out = cv2.VideoWriter(
    os.path.join(vid_dir, "red_vs_blue_battle.mp4"),
    cv2.VideoWriter_fourcc(*"mp4v"),
    fps,
    (width, height),
)

for frame in frames:
    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    out.write(frame_bgr)
out.release()
print("Done recording battle between pretrained agents")

env.close()

Done recording battle between pretrained agents
