In [1]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-wn48zx32
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-wn48zx32
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pettingzoo>=1.23.1 (from magent2==0.3.3)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading pe

# QNetwork

In [2]:
import torch
import torch.nn as nn

class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        if len(x.shape) == 3:
            batchsize = 1
            x = x.unsqueeze(0)
        else:
            batchsize = x.shape[0]
        x = torch.fliplr(x).permute(0,3,1,2) # flip left-right because blue agent observe identically with red agent
        x = self.cnn(x)
        x = x.reshape(batchsize, -1)
        return self.network(x)

test = QNetwork((13,13,5), 21)
test_obs = torch.rand((13,13,5))
test(test_obs)

tensor([[ 0.0414,  0.0612, -0.0647, -0.1421,  0.0311, -0.0413,  0.0350,  0.0335,
          0.0120,  0.0380,  0.1036, -0.0682,  0.0391, -0.0420, -0.1298, -0.0986,
          0.0551, -0.0406,  0.1169,  0.0486, -0.0643]],
       grad_fn=<AddmmBackward0>)

# Import libs

In [3]:
from collections import defaultdict, deque
import random
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from magent2.environments import battle_v4

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

torch.manual_seed(42)
np.random.seed(42)

# Plot

In [4]:
def plot_metrics(episode_rewards, episode_losses, show_result=False):
    plt.figure(1)
    plt.clf()
    if show_result:
        plt.title('Result')
    else:
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Value')

    rewards_t = torch.tensor(episode_rewards, dtype=torch.float)
    losses_t = torch.tensor(episode_losses, dtype=torch.float)

    plt.plot(rewards_t.numpy(), label='Reward')
    plt.plot(losses_t.numpy(), label='Loss')

    if len(rewards_t) >= 5:
        rewards_means = rewards_t.unfold(0, 5, 1).mean(1).view(-1)
        rewards_means = torch.cat((torch.zeros(4), rewards_means))
        plt.plot(rewards_means.numpy(), label='Reward (mean)')

    if len(losses_t) >= 5:
        losses_means = losses_t.unfold(0, 5, 1).mean(1).view(-1)
        losses_means = torch.cat((torch.zeros(4), losses_means))
        plt.plot(losses_means.numpy(), label='Loss (mean)')

    plt.legend()
    plt.pause(0.001)
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

# Replay Buffer

In [5]:
class MultiAgentReplayBuffer:
    def __init__(self, capacity, observation_shape, action_shape):
        self.capacity = capacity
        self.observation_shape = observation_shape
        self.action_shape = action_shape

        # Use a defaultdict to automatically create deques for new agents
        self.buffers = defaultdict(lambda: {
            'obs': deque(maxlen=capacity),
            'action': deque(maxlen=capacity),
            'reward': deque(maxlen=capacity),
            'next_obs': deque(maxlen=capacity),
            'done': deque(maxlen=capacity),
        })

    def push(self, agent_id, obs, action, reward, next_obs, done):
        self.buffers[agent_id]['obs'].append(obs)
        self.buffers[agent_id]['action'].append(action)
        self.buffers[agent_id]['reward'].append(reward)
        self.buffers[agent_id]['next_obs'].append(next_obs)
        self.buffers[agent_id]['done'].append(done)

    def sample(self, batch_size):
        all_agent_ids = list(self.buffers.keys())
        if not all_agent_ids:
            return None  # No agents in the buffer

        # Check if we have enough data to sample
        total_transitions = sum(len(self.buffers[agent_id]['obs']) for agent_id in all_agent_ids)
        if total_transitions < batch_size:
            return None

        # Collect transitions from all agents into a single list
        all_transitions = []
        for agent_id in all_agent_ids:
            agent_buffer = self.buffers[agent_id]
            for i in range(len(agent_buffer['obs'])):
                all_transitions.append({
                    'obs': agent_buffer['obs'][i],
                    'action': agent_buffer['action'][i],
                    'reward': agent_buffer['reward'][i],
                    'next_obs': agent_buffer['next_obs'][i],
                    'done': agent_buffer['done'][i]
                })

        # Sample indices from the combined transitions
        indices = np.random.choice(len(all_transitions), batch_size, replace=False)

        # Extract the sampled transitions
        obs_batch = np.array([all_transitions[i]['obs'] for i in indices])
        action_batch = np.array([all_transitions[i]['action'] for i in indices])
        reward_batch = np.array([all_transitions[i]['reward'] for i in indices])
        next_obs_batch = np.array([all_transitions[i]['next_obs'] for i in indices])
        done_batch = np.array([all_transitions[i]['done'] for i in indices])

        return {
            'obs': obs_batch,
            'action': action_batch,
            'reward': reward_batch,
            'next_obs': next_obs_batch,
            'done': done_batch
        }

    def update_last_reward(self, agent_id, new_reward):
        if agent_id not in self.buffers:
            return
        self.buffers[agent_id]['reward'][-1] = new_reward

    def __len__(self):
        return sum(len(self.buffers[agent_id]['obs']) for agent_id in self.buffers)

    def clear(self, agent_id=None):
        if agent_id:
            self.buffers[agent_id]['obs'].clear()
            self.buffers[agent_id]['action'].clear()
            self.buffers[agent_id]['reward'].clear()
            self.buffers[agent_id]['next_obs'].clear()
            self.buffers[agent_id]['done'].clear()
        else:
            for agent_id in self.buffers:
                self.clear(agent_id)

# Initialize

In [7]:
env = battle_v4.env(map_size=45, minimap_mode=False, step_reward=0.01,
                        dead_penalty=-2, attack_penalty=-0.1, attack_opponent_reward=2,
                        max_cycles=300, extra_features=False, render_mode="rgb_array")

BATCH_SIZE = 128
GAMMA = 0.9
EPS_START = 1
EPS_END = 0.1
EPS_DECAY = 50
TAU = 0.005
LR = 1e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

observation_shape = env.observation_space("blue_0").shape
action_shape = env.action_space("blue_0").n

# Initialize networks
policy_net = QNetwork(observation_shape, action_shape).to(device)
red_policy_net = QNetwork(observation_shape, action_shape).to(device) # for self-play
target_net = QNetwork(observation_shape, action_shape).to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)

# Load pretrained
pretrained_net = QNetwork(observation_shape, action_shape).to(device)
pretrained_net.load_state_dict(torch.load("models/red.pt", map_location=device, weights_only=True))

try:
    checkpoint = torch.load("models/blue.pt", map_location=device, weights_only=True)
    policy_net.load_state_dict(checkpoint["policy_net_state_dict"])
    target_net.load_state_dict(checkpoint["target_net_state_dict"])
    red_policy_net.load_state_dict(checkpoint["policy_net_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    episode = checkpoint["episode"]
    print(f"Start with episode: {episode}")
except Exception as e:
    print(f"No model found!")
    episode = 0

buffer = MultiAgentReplayBuffer(10000, observation_shape, action_shape)
steps_done = episode
episode_rewards = []
episode_losses = []
running_loss = 0.0
num_episodes = 60

cuda
No model found!


## Save model

In [8]:
def save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path):
    torch.save({
        'episode': i_episode,
        'policy_net_state_dict': policy_net.state_dict(),
        'target_net_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'episode_rewards': episode_rewards,
        'episode_losses': episode_losses,
    }, path)

## Greedy Policy

In [9]:
def linear_epsilon(steps_done):
    return max(EPS_END, EPS_START - (EPS_START - EPS_END) * (steps_done / EPS_DECAY))

def policy(observation, q_network):
    global steps_done
    sample = random.random()
    if sample < linear_epsilon(steps_done):
        return env.action_space("red_0").sample()
    else:
        observation = (
            torch.Tensor(observation).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

# Optimize

In [10]:
def optimize_model():
    global running_loss

    batch = buffer.sample(BATCH_SIZE)

    # Handle cases where the buffer doesn't have enough samples yet
    if batch is None:
        return

    # Unpack the batch
    state_batch = torch.from_numpy(batch['obs']).float().to(device)
    action_batch = torch.from_numpy(batch['action']).long().to(device)
    reward_batch = torch.from_numpy(batch['reward']).float().to(device)
    next_state_batch = torch.from_numpy(batch['next_obs']).float().to(device)
    done_batch = torch.from_numpy(batch['done']).float().to(device)

    # Reshape action_batch to (BATCH_SIZE, 1) for gather()
    action_batch = action_batch.unsqueeze(1)
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    non_final_mask = (done_batch == 0).squeeze()  # Create a mask for non-terminal states

    # Only compute for non-terminal states
    if non_final_mask.any():
        next_state_values[non_final_mask] = target_net(next_state_batch[non_final_mask]).max(1).values.detach()

    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    running_loss += loss.item()

    return loss.item()

# Training loop

In [None]:
for i_episode in range(episode, num_episodes):
    env.reset()
    episode_reward = 0
    running_loss = 0.0
    steps_done += 1

    for agent in env.agent_iter():

        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation
        episode_reward += reward

        if done:
            action = None  # Agent is dead
            env.step(action)
        else:
            agent_handle = agent.split("_")
            agent_id = agent_handle[1]
            agent_team = agent_handle[0]
            if agent_team == "blue":

                buffer.update_last_reward(agent_id, reward) # update reward of last agent's action (bad environment!)

                action = policy(observation, policy_net)
                env.step(action)

                try:
                    next_observation = env.observe(agent)
                    agent_done = False
                except:
                    next_observation = None
                    agent_done = True

                reward = 0 # Wait for next time to be selected to get reward

                # Store the transition in buffer
                buffer.push(agent_id, observation, action, reward, next_observation, agent_done)

                # Perform one step of the optimization (on the policy network)
                optimize_model()

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                target_net.load_state_dict(target_net_state_dict)

            else:
                # red agent
                action = policy(observation, red_policy_net)
                env.step(action)

        # Periodically update the red agent's policy with the blue agent's learned policy
        if i_episode % 4 == 0 and i_episode < 24:
            # Copy all weights and biases from the blue agent's policy network to the red agent's
            red_policy_net.load_state_dict(policy_net.state_dict())
        elif i_episode == 24: # more complex (pretrained) opponent
            red_policy_net.load_state_dict(pretrained_net.state_dict())



    # Add these lines at the end of each episode
    episode_rewards.append(episode_reward)
    episode_losses.append(running_loss)

    print(f'Episode {i_episode + 1}/{num_episodes}')
    print(f'Total Reward of previous episode: {episode_reward:.2f}')
    print(f'Average Loss: {running_loss:.4f}')
    print(f'Epsilon: {linear_epsilon(steps_done)}')
    print('-' * 40)
    save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path=f"models/blue_{i_episode}.pt")

plot_metrics(episode_rewards, episode_losses, show_result=True)
plt.ioff()
plt.show()

Episode 1/60
Total Reward of previous episode: -306.10
Average Loss: 15.0050
Epsilon: 0.982
----------------------------------------
Episode 2/60
Total Reward of previous episode: -389.30
Average Loss: 29.2992
Epsilon: 0.964
----------------------------------------
Episode 3/60
Total Reward of previous episode: -256.00
Average Loss: 43.3634
Epsilon: 0.946
----------------------------------------
Episode 4/60
Total Reward of previous episode: -334.86
Average Loss: 54.0438
Epsilon: 0.9279999999999999
----------------------------------------
Episode 5/60
Total Reward of previous episode: -278.60
Average Loss: 56.2159
Epsilon: 0.91
----------------------------------------
Episode 6/60
Total Reward of previous episode: -263.70
Average Loss: 56.8774
Epsilon: 0.892
----------------------------------------
Episode 7/60
Total Reward of previous episode: -254.90
Average Loss: 57.6357
Epsilon: 0.874
----------------------------------------
Episode 8/60
Total Reward of previous episode: -259.40
Av