In [11]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-bhtctz4k
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-bhtctz4k
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pettingzoo>=1.23.1 (from magent2==0.3.3)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading pe

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from collections import defaultdict
import random
import os
from tqdm import tqdm

SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [13]:
def one_hot_encode(agent_id, num_classes=81, device="cpu"):
  # Ensure agent_id is a tensor on the correct device
  if not isinstance(agent_id, torch.Tensor):
      agent_id = torch.tensor(agent_id, device=device)
  else:
      agent_id = agent_id.to(device)

  return F.one_hot(agent_id.long(), num_classes=num_classes).float()

# Agent Network

In [14]:
# Define the agent network with CNN and agent ID input
class AgentNetwork(nn.Module):
    def __init__(self, observation_shape, action_dim, n_agents):
        super(AgentNetwork, self).__init__()
        # observation_shape is (H, W, C)
        self.conv1 = nn.Conv2d(observation_shape[2], 16, kernel_size=3, stride=1, padding=1) # (16, H, W)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) # (32, H, W)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) # (32, H, W)

        # Calculate the flattened size after convolutions
        flat_size = 32 * observation_shape[0] * observation_shape[1] # 32 * H * W

        # Add a linear layer to process the agent ID
        self.fc_agent_id = nn.Linear(n_agents, 32)

        self.fc1 = nn.Linear(flat_size + 32, 128)  # Concatenate conv output with agent ID embedding
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, obs, agent_id):
        agent_id = one_hot_encode(agent_id, device=agent_id.device)
        # Add a batch dimension
        if len(obs.shape) == 3:
            obs = obs.unsqueeze(0)
        x = torch.fliplr(obs).permute(0,3,1,2) # flip left-right because blue agent observe identically with red agent
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.flatten(start_dim=1)  # Flatten all dimensions except batch

        # Process agent ID
        agent_id_embedding = F.relu(self.fc_agent_id(agent_id))

        # Concatenate the flattened convolutional output with the agent ID embedding
        x = torch.cat((x, agent_id_embedding), dim=1)

        x = F.relu(self.fc1(x))
        q_values = self.fc2(x)
        return q_values

# n_agents = 81  # Example: 5 agents
# agent_id = torch.randint(0, n_agents, (2,)).to(device)  # Example: Agent with ID 2 (0-indexed)
# print(agent_id)
# obs = torch.randn(2, 13, 13, 5).to(device)
# test = AgentNetwork((13,13,5), 21, n_agents).to(device)
# test(obs, agent_id).shape

# Hyper Network

In [15]:
# Define the HyperNetwork with CNN for dynamic weight generation
class HyperNetwork(nn.Module):
    def __init__(self, input_shape, output_dim, hidden_dim):
        super().__init__()
        # CNN layers
        # Input shape (H, W, C)
        self.conv1 = nn.Conv2d(input_shape[2], 32, kernel_size=3, padding=1) #(B, 32, H, W)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1) # (B, 16, H, W)

        # FC layers
        flat_size = 16 * input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(flat_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

        # Initialize weights
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, state):
        # Add batch dimension if not present
        if len(state.shape) == 3:
            state = state.unsqueeze(0)  # (1, H, W, C)

        # Convert (B, H, W, C) to (B, C, H, W) for CNN
        x = state.permute(0, 3, 1, 2)

        # Apply convolutions
        x = F.relu(self.conv1(x))  # (B, 32, H, W)
        x = F.relu(self.conv2(x))  # (B, 16, H, W)

        # Flatten all dimensions except batch
        x = x.flatten(start_dim=1)  # (B, 16*H*W)

        # Apply fully connected layers
        x = F.relu(self.fc1(x))  # (B, hidden_dim)
        weights = self.fc2(x)  # (B, output_dim)

        return weights


# state = torch.randn(45, 45, 5)
# test_hyper = HyperNetwork(state.shape, 1, 64)
# state_batch = torch.randn(5, 45, 45, 5)
# test_hyper(state_batch).shape

# Mixing Network

In [16]:
# Define the Mixing Network
class MixingNetwork(nn.Module):
    def __init__(self, state_dim, num_agents, mixing_dim):
        super(MixingNetwork, self).__init__()
        self.num_agents = num_agents
        self.mixing_dim = mixing_dim

        # Hypernetworks for weights and biases
        self.hyper_w1 = HyperNetwork(state_dim, num_agents * mixing_dim, 64)
        self.hyper_b1 = HyperNetwork(state_dim, mixing_dim, 64)
        self.hyper_w2 = HyperNetwork(state_dim, mixing_dim, 64)
        self.hyper_b2 = HyperNetwork(state_dim, 1, 64)

    def forward(self, agent_qs, states):
        # Add batch dimension if not present
        if len(agent_qs.shape) == 1:
            agent_qs = agent_qs.unsqueeze(0)  # (1, num_agents, action_dim)
        if len(states.shape) == 3:
            states = states.unsqueeze(0)  # (1, H, W, C)

        batch_size = agent_qs.size(0)

        agent_qs = agent_qs.view(batch_size, 1, self.num_agents)  # (batch_size, 1, num_agents)

        # First layer weights and biases
        w1 = torch.abs(self.hyper_w1(states))
        w1 = w1.view(batch_size, self.num_agents, self.mixing_dim)  # (batch_size, num_agents, mixing_dim)
        b1 = self.hyper_b1(states) # (batch_size, mixing_dim)
        b1 = b1.view(batch_size, 1, self.mixing_dim)  # (batch_size, 1, mixing_dim)


        # Compute first layer output
        hidden = F.elu(torch.bmm(agent_qs, w1) + b1)  # (batch_size, 1, mixing_dim)
        # Second layer weights and biases
        w2 = torch.abs(self.hyper_w2(states))
        w2 = w2.view(batch_size, self.mixing_dim, 1)  # (batch_size, mixing_dim, 1)
        b2 = self.hyper_b2(states)
        b2 = b2.view(batch_size, 1, 1)  # (batch_size, 1, 1)

        # Compute final output
        q_tot = torch.bmm(hidden, w2) + b2  # (batch_size, 1, 1)
        # Remove unnecessary dimensions
        q_tot = q_tot.squeeze(-1)  # (batch_size, 1)

        # If input was single sample, remove batch dimension
        if len(agent_qs.shape) == 2:
            q_tot = q_tot.squeeze(0)  # (1)

        return q_tot

# agent_qs = torch.randn(5, 81)
# test_mix_net = MixingNetwork(state.shape, 81, 2)
# test_mix_net(agent_qs, state_batch)

# Prioritized Multi-Agent Replay Buffer

In [93]:
class SumTree:
    """
    A SumTree data structure for storing replay priorities.
    Based on the implementation found in the paper "Prioritized Experience Replay" (Schaul et al., 2015).
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.size = 0

    def add(self, priority, data):
        tree_index = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_index, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.size < self.capacity:
            self.size += 1

    def update(self, tree_index, priority):
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        while tree_index != 0:
            tree_index = (tree_index - 1) // 2
            self.tree[tree_index] += change

    def get_leaf(self, v):
        parent_index = 0
        while True:
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1
            if left_child_index >= len(self.tree):
                leaf_index = parent_index
                break
            else:
                if v <= self.tree[left_child_index]:
                    parent_index = left_child_index
                else:
                    v -= self.tree[left_child_index]
                    parent_index = right_child_index
        data_index = leaf_index - self.capacity + 1
        return leaf_index, self.tree[leaf_index], self.data[data_index]

    def total_priority(self):
        return self.tree[0]

class PrioritizedMultiAgentReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        self.alpha = alpha
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 1  # For beta calculation
        self.beta = beta_start

        self.sum_tree = SumTree(capacity)
        self.max_priority = 1.0
        self.epsilon = 1e-6

        # Temporary storage for the current turn's experience
        self.temp_experience = defaultdict(lambda: {'obs': [], 'action': [], 'reward': [], 'next_obs': [],
                                                    'done': []})
        self.global_state = None
        self.next_global_state = None
        self.turn_ended = False

    def start_of_turn(self, global_state):
        """
        Signal the start of a new turn and store the initial global state.
        """
        self.global_state = global_state
        self.next_global_state = None
        self.turn_ended = False
        self.temp_experience.clear()  # Clear any leftover data from a previous turn

    def store_experience(self, agent_id, obs, action, reward, next_obs, done):
        """
        Temporarily store experiences for a single agent during a turn.
        """
        if self.turn_ended:
            raise ValueError("Cannot store experience after the turn has ended. Call start_of_turn() first.")

        self.temp_experience[agent_id]['obs'].append(obs)
        self.temp_experience[agent_id]['action'].append(action)
        self.temp_experience[agent_id]['reward'].append(reward)
        self.temp_experience[agent_id]['next_obs'].append(next_obs)
        self.temp_experience[agent_id]['done'].append(done)

    def end_of_turn(self, next_global_state):
        """
        Signal the end of the turn, store the final global state, and push the complete experience to the buffer.
        """
        if self.turn_ended:
            raise ValueError("Turn has already ended. Call start_of_turn() to begin a new turn.")

        self.next_global_state = next_global_state
        self.turn_ended = True

        if not self.temp_experience:
            return

        # Combine all agent experiences into a single dictionary
        full_turn_experience = {
            'agents': {
                agent_id: {
                    'obs': np.array(data['obs']),
                    'action': np.array(data['action']),
                    'reward': np.array(data['reward']),
                    'next_obs': np.array(data['next_obs']),
                    'done': np.array(data['done']),
                }
                for agent_id, data in self.temp_experience.items()
            },
            'global_state': self.global_state,
            'next_global_state': self.next_global_state
        }

        # Add full turn experience to the buffer
        self.sum_tree.add(self.max_priority, full_turn_experience)

        # Reset temporary storage (global states are reset in start_of_turn)
        #self.temp_experience.clear()

    def update_last_reward(self, agent_id, new_reward):
        """
        Update the reward of the last experience for the given agent.
        """
        if agent_id not in self.temp_experience or len(self.temp_experience[agent_id]['reward']) == 0:
            return

        # Get the index of the last experience of this agent
        last_idx = len(self.temp_experience[agent_id]['reward']) - 1
        # Update the last reward for the agent
        self.temp_experience[agent_id]['reward'][last_idx] = new_reward

    def sample(self, batch_size):
        """
        Sample a batch of full-turn experiences.
        """
        if self.sum_tree.size < batch_size:
            return None, None, None

        segment = self.sum_tree.total_priority() / batch_size
        priorities = []
        batch = []
        indices = []

        self.beta = min(1.0, self.beta_start + self.frame * (1.0 - self.beta_start) / self.beta_frames)
        self.frame += 1

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            v = random.uniform(a, b)
            index, priority, data = self.sum_tree.get_leaf(v)

            if priority == 0 or data is None:
                continue

            priorities.append(priority)
            batch.append(data)
            indices.append(index)

        if len(batch) < batch_size:
            return None, None, None

        sampling_probabilities = np.array(priorities) / self.sum_tree.total_priority()
        is_weights = np.power(self.sum_tree.size * sampling_probabilities, -self.beta)
        is_weights /= is_weights.max()

        return batch, indices, is_weights

    def update_priorities(self, indices, td_errors):
        """
        Update the priorities of sampled full-turn experiences.
        """
        for index, td_error in zip(indices, td_errors):
            priority = (abs(td_error) + self.epsilon) ** self.alpha
            self.sum_tree.update(index, priority)
            self.max_priority = max(self.max_priority, priority)

    def sample_experience(self, batch_size, num_agents=81, device="cpu"):
        batch, indices, is_weights = self.sample(batch_size)

        # Initialize arrays with correct dimensions
        obs = np.zeros((batch_size, num_agents, 13, 13, 5), dtype=np.float32)
        next_obs = np.zeros((batch_size, num_agents, 13, 13, 5), dtype=np.float32)
        action = np.zeros((batch_size, num_agents), dtype=np.int64)
        reward = np.zeros((batch_size, num_agents), dtype=np.float32)
        done = np.zeros((batch_size, num_agents), dtype=np.float32)
        global_state = np.zeros((batch_size, *batch[0]['global_state'].shape), dtype=np.float32)
        next_global_state = np.zeros((batch_size, *batch[0]['next_global_state'].shape), dtype=np.float32)

        # Default values for dead agents
        default_obs = np.zeros((13, 13, 5), dtype=np.float32)
        default_next_obs = np.zeros((13, 13, 5), dtype=np.float32)
        default_action = 6
        default_reward = 0.0
        default_done = 1.0

        # Populate the arrays
        for b_idx, turn in enumerate(batch):
            global_state[b_idx] = turn['global_state']
            next_global_state[b_idx] = turn['next_global_state']

            for agent_id in range(num_agents):
                if agent_id in turn['agents']:
                    agent_data = turn['agents'][agent_id]
                    obs[b_idx, agent_id] = agent_data['obs']  # Shape matches (13, 13, 5)
                    next_obs[b_idx, agent_id] = agent_data['next_obs']
                    action[b_idx, agent_id] = agent_data['action'].item()
                    reward[b_idx, agent_id] = agent_data['reward'].item()
                    done[b_idx, agent_id] = agent_data['done'].item()
                else:
                    # Assign default values for missing (dead) agents
                    obs[b_idx, agent_id] = default_obs
                    next_obs[b_idx, agent_id] = default_next_obs
                    action[b_idx, agent_id] = default_action
                    reward[b_idx, agent_id] = default_reward
                    done[b_idx, agent_id] = default_done

        # Convert to PyTorch tensors
        obs_batch = torch.from_numpy(obs).float().to(device)
        next_obs_batch = torch.from_numpy(next_obs).float().to(device)
        action_batch = torch.from_numpy(action).long().to(device)
        reward_batch = torch.from_numpy(reward).float().to(device)
        done_batch = torch.from_numpy(done).float().to(device)
        global_state_batch = torch.from_numpy(global_state).float().to(device)
        next_global_state_batch = torch.from_numpy(next_global_state).float().to(device)
        is_weights = torch.from_numpy(is_weights).float().to(device)

        return obs_batch, action_batch, reward_batch, next_obs_batch, global_state_batch, next_global_state_batch, done_batch, is_weights


    def __len__(self):
        return self.sum_tree.size

# Hyperparameters

In [97]:
# Hyperparameters
NUM_AGENTS = 81
MIXING_DIM = 32  # Dimension of the mixing network
NUM_EPISODES = 10
BATCH_SIZE = 64
GAMMA = 0.99
LR = 5e-4
WEIGHT_DECAY = 0.0001
MAX_REPLAY_BUFFER_SIZE = 5000
EPS_START = 1
EPS_END = 0.1
EPS_DECAY = 10
TAU = 0.01
GRADIENT_CLIPPING = 10

# Funtions

## Policy

In [19]:
def linear_epsilon(steps_done):
    return max(EPS_END, EPS_START - (EPS_START - EPS_END) * (steps_done / EPS_DECAY))

def policy(observation, q_network, team, id):
    global steps_done
    id = torch.tensor([id]).to(device)
    sample = random.random()
    if sample < linear_epsilon(steps_done):
        return env.action_space("red_0").sample()
    else:
        observation = (
            torch.Tensor(observation).to(device)
        )
        if team == "red":
            observation = torch.fliplr(observation)
        with torch.no_grad():
            q_values = q_network(observation, id)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

## Save model

In [20]:
def save_model(policy_net, path="models/qmix.pt"):
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    torch.save(policy_net.state_dict(), path)


# Initialize

In [98]:
from magent2.environments import battle_v4

env = battle_v4.env(map_size=45,max_cycles=200, step_reward = 0.01, attack_penalty=0, attack_opponent_reward=1 ,render_mode="rgb_array")
env.reset()

# Define the environment parameters
obs_shape = env.observation_space("red_0").shape  # (Height, Width, C)
state_dim = env.state_space.shape  # Dimension of the global state
action_dim = env.action_space("red_0").n  # Number of discrete actions for each agent


policy_net = AgentNetwork(obs_shape, action_dim, NUM_AGENTS).to(device)
mixing_net = MixingNetwork(state_dim, NUM_AGENTS, MIXING_DIM).to(device)
target_net = AgentNetwork(obs_shape, action_dim, NUM_AGENTS).to(device)
target_net.load_state_dict(policy_net.state_dict())
red_policy_net = AgentNetwork(obs_shape, action_dim, NUM_AGENTS).to(device) # for self-play


agent_optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True, weight_decay=WEIGHT_DECAY)
mixing_optimizer = optim.AdamW(mixing_net.parameters(), lr=LR, amsgrad=True, weight_decay=WEIGHT_DECAY)


buffer = PrioritizedMultiAgentReplayBuffer(MAX_REPLAY_BUFFER_SIZE)

running_loss = 0.0
steps_done = 0

# Optimizer

In [99]:
def optimize_model():
    global running_loss

    # Sample a batch from the prioritized multi-agent replay buffer
    batch, indices, is_weights = buffer.sample(BATCH_SIZE)

    # If the buffer doesn't have enough samples yet, return
    if batch is None:
        return
    id_batch = torch.arange(81).repeat(BATCH_SIZE, 1).to(device)
    obs_batch, action_batch, reward_batch, next_obs_batch, global_state_batch, next_global_state_batch, done_batch, is_weights = buffer.sample_experience(batch_size=BATCH_SIZE, device=device)
    # Compute individual Q-values for current observation
    individual_q_values = torch.stack([policy_net(obs_batch[:, i], id_batch[:, i]) for i in range(NUM_AGENTS)], dim=1)  # [BATCH_SIZE, num_agents, action_dim]
    # Compute Q-values of action that agent has taken
    chosen_q_values = torch.gather(individual_q_values, 2, action_batch.unsqueeze(-1)).squeeze(-1) # [BATCH_SIZE, num_agents]

    # Compute individual Q-values for current observation
    next_individual_q_values = torch.stack([target_net(next_obs_batch[:, i], id_batch[:, i]) for i in range(NUM_AGENTS)], dim=1)  # [BATCH_SIZE, num_agents, action_dim]
    target_q_values = next_individual_q_values .max(dim=-1)[0]  # [BATCH_SIZE, num_agents]

    # Compute Q_tot for the current state
    q_tot = mixing_net(chosen_q_values, global_state_batch).squeeze(1)  # [BATCH_SIZE]

    # Mask for non-terminal states
    # non_final_mask = (1 - done_batch).float()  # [BATCH_SIZE, num_agents]

    # Compute Q_tot for next state using mixing network
    next_q_tot = mixing_net(target_q_values, next_global_state_batch).squeeze(1)  # [BATCH_SIZE]

    # Compute expected Q_tot
    reward_tot = reward_batch.sum(dim=1)  # Sum rewards across agents [BATCH_SIZE]
    expected_q_tot = reward_tot + GAMMA * next_q_tot  # Discounted target Q_tot [BATCH_SIZE]

    # Compute TD errors
    td_errors = (expected_q_tot - q_tot).detach().cpu().numpy()

    # Compute Huber loss weighted by importance-sampling weights
    criterion = nn.SmoothL1Loss(reduction='none')
    loss_per_sample = criterion(q_tot, expected_q_tot)
    loss = (loss_per_sample * is_weights).mean()  # Weighted loss

    # Optimize the model
    agent_optimizer.zero_grad()
    mixing_optimizer.zero_grad()
    loss.backward()

    # Gradient clipping for stability
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), GRADIENT_CLIPPING)
    torch.nn.utils.clip_grad_value_(mixing_net.parameters(), GRADIENT_CLIPPING)

    # Update all optimizers
    agent_optimizer.step()
    mixing_optimizer.step()

    # Update priorities in the replay buffer
    buffer.update_priorities(indices, td_errors)

    # Track running loss
    running_loss += loss.item()

    return loss.item()


# Training loop

In [96]:
# Training loop
for episode in range(NUM_EPISODES):
    print(f"\nStarting Episode {episode + 1}...")
    prev_team = "red"
    cur_team = "red"
    env.reset()
    episode_reward = 0
    running_loss = 0.0
    steps_done += 1
    total_steps = 0  # Total steps for this episode

    # Reset temporary storage at the beginning of each episode
    buffer.temp_experience.clear()

    # Use tqdm to track total steps in this episode
    with tqdm(total=400, desc=f"Steps in Episode {episode + 1}", leave=True) as pbar:
        for agent in env.agent_iter():
            agent_team = agent.split("_")[0]
            agent_id = int(agent.split("_")[1])

            prev_team = cur_team
            cur_team = agent_team
            if cur_team != prev_team:
                total_steps += 1
                pbar.update(1)  # Update progress bar for each step
                state = np.transpose(env.state(), (1, 0, 2))
                if cur_team == "blue":  # Red turn changes to blue turn
                    buffer.start_of_turn(global_state=state)
                else:  # Blue turn changes to red turn
                    buffer.end_of_turn(next_global_state=state)

            observation, reward, termination, truncation, _ = env.last()
            done = termination or truncation

            if done: # This agent is dead or truncated
                action = None
                env.step(action)
            elif agent_team == "blue":
                episode_reward += reward
                buffer.update_last_reward(agent_id, reward)  # Update reward for the last agent's action

                action = policy(observation, policy_net, agent_team, agent_id)
                env.step(action)

                next_global_state = np.transpose(env.state(), (1, 0, 2))
                try:
                    next_observation = env.observe(agent)
                except:
                    next_observation = np.zeros_like(observation)
                    print("I think code never reach this line!")

                # Store the experience
                buffer.store_experience(agent_id, observation, action, reward, next_observation, done)

                # Perform one step of the optimization (on the policy network)
                if len(buffer) % 32 == 0:
                    optimize_model()
                    # Update of the target network's weights
                    with torch.no_grad():
                        target_net_state_dict = target_net.state_dict()
                        policy_net_state_dict = policy_net.state_dict()
                        for key in policy_net_state_dict:
                            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
                        target_net.load_state_dict(target_net_state_dict)
            elif agent_team == "red":
                # Red team actions
                action = policy(observation, red_policy_net, agent_team, agent_id)
                env.step(action)

        pbar.refresh()  # Ensure the progress bar reflects the final step count

    # Synchronize red team's policy network with blue's every 4 episodes
    if episode % 4 == 0:
        red_policy_net.load_state_dict(policy_net.state_dict())

    # Summary output for each episode
    print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Total Steps = {total_steps}, Average Loss = {running_loss:.4f}, Epsilon = {linear_epsilon(steps_done):.4f}")

    # Save the model periodically
    save_model(policy_net, path=f"models/blue_qmix_{episode+1}.pt")

env.close()



Starting Episode 1...


Steps in Episode 1: 100%|██████████| 400/400 [02:34<00:00,  2.60it/s]


Episode 1: Total Reward = -673.90, Total Steps = 400, Average Loss = 10.7728, Epsilon = 0.9800

Starting Episode 2...


Steps in Episode 2: 100%|██████████| 400/400 [03:09<00:00,  2.11it/s]


Episode 2: Total Reward = -661.90, Total Steps = 400, Average Loss = 1.1449, Epsilon = 0.9600

Starting Episode 3...


Steps in Episode 3: 100%|██████████| 400/400 [03:15<00:00,  2.05it/s]


Episode 3: Total Reward = -632.20, Total Steps = 400, Average Loss = 0.7084, Epsilon = 0.9400

Starting Episode 4...


Steps in Episode 4: 100%|██████████| 400/400 [03:13<00:00,  2.06it/s]


Episode 4: Total Reward = -643.15, Total Steps = 400, Average Loss = 0.4979, Epsilon = 0.9200

Starting Episode 5...


Steps in Episode 5: 100%|██████████| 400/400 [03:42<00:00,  1.80it/s]


Episode 5: Total Reward = -631.44, Total Steps = 400, Average Loss = 1.0501, Epsilon = 0.9000

Starting Episode 6...


Steps in Episode 6: 100%|██████████| 400/400 [03:11<00:00,  2.09it/s]


Episode 6: Total Reward = -570.81, Total Steps = 400, Average Loss = 1.0771, Epsilon = 0.8800

Starting Episode 7...


Steps in Episode 7: 100%|██████████| 400/400 [03:15<00:00,  2.05it/s]


Episode 7: Total Reward = -619.37, Total Steps = 400, Average Loss = 1.6092, Epsilon = 0.8600

Starting Episode 8...


Steps in Episode 8: 100%|██████████| 400/400 [03:13<00:00,  2.06it/s]


Episode 8: Total Reward = -556.01, Total Steps = 400, Average Loss = 1.3393, Epsilon = 0.8400

Starting Episode 9...


Steps in Episode 9: 100%|██████████| 400/400 [03:44<00:00,  1.79it/s]


Episode 9: Total Reward = -555.66, Total Steps = 400, Average Loss = 0.9240, Epsilon = 0.8200

Starting Episode 10...


Steps in Episode 10: 100%|██████████| 400/400 [03:10<00:00,  2.10it/s]


Episode 10: Total Reward = -427.87, Total Steps = 400, Average Loss = 1.8110, Epsilon = 0.8000

Starting Episode 11...


Steps in Episode 11: 100%|██████████| 400/400 [03:17<00:00,  2.03it/s]


Episode 11: Total Reward = -445.93, Total Steps = 400, Average Loss = 2.1280, Epsilon = 0.7800

Starting Episode 12...


Steps in Episode 12: 100%|██████████| 400/400 [03:16<00:00,  2.04it/s]


Episode 12: Total Reward = -448.88, Total Steps = 400, Average Loss = 1.8592, Epsilon = 0.7600

Starting Episode 13...


Steps in Episode 13: 100%|██████████| 400/400 [03:38<00:00,  1.83it/s]


Episode 13: Total Reward = -413.50, Total Steps = 400, Average Loss = 2.5072, Epsilon = 0.7400

Starting Episode 14...


Steps in Episode 14: 100%|██████████| 400/400 [03:13<00:00,  2.07it/s]


Episode 14: Total Reward = -393.46, Total Steps = 400, Average Loss = 2.7165, Epsilon = 0.7200

Starting Episode 15...


Steps in Episode 15: 100%|██████████| 400/400 [03:18<00:00,  2.01it/s]


Episode 15: Total Reward = -398.37, Total Steps = 400, Average Loss = 2.4371, Epsilon = 0.7000

Starting Episode 16...


Steps in Episode 16: 100%|██████████| 400/400 [03:22<00:00,  1.98it/s]


Episode 16: Total Reward = -420.10, Total Steps = 400, Average Loss = 2.0405, Epsilon = 0.6800

Starting Episode 17...


Steps in Episode 17: 100%|██████████| 400/400 [03:46<00:00,  1.77it/s]


Episode 17: Total Reward = -315.09, Total Steps = 400, Average Loss = 4.2130, Epsilon = 0.6600

Starting Episode 18...


Steps in Episode 18: 100%|██████████| 400/400 [03:23<00:00,  1.97it/s]


Episode 18: Total Reward = -440.40, Total Steps = 400, Average Loss = 3.8423, Epsilon = 0.6400

Starting Episode 19...


Steps in Episode 19: 100%|██████████| 400/400 [03:22<00:00,  1.97it/s]


Episode 19: Total Reward = -390.06, Total Steps = 400, Average Loss = 2.5021, Epsilon = 0.6200

Starting Episode 20...


Steps in Episode 20: 100%|██████████| 400/400 [03:28<00:00,  1.92it/s]


Episode 20: Total Reward = -359.54, Total Steps = 400, Average Loss = 2.3930, Epsilon = 0.6000

Starting Episode 21...


Steps in Episode 21: 100%|██████████| 400/400 [03:57<00:00,  1.68it/s]


Episode 21: Total Reward = -384.45, Total Steps = 400, Average Loss = 3.3478, Epsilon = 0.5800

Starting Episode 22...


Steps in Episode 22: 100%|██████████| 400/400 [03:19<00:00,  2.01it/s]


Episode 22: Total Reward = -252.57, Total Steps = 400, Average Loss = 2.3445, Epsilon = 0.5600

Starting Episode 23...


Steps in Episode 23: 100%|██████████| 400/400 [03:16<00:00,  2.03it/s]


Episode 23: Total Reward = -297.74, Total Steps = 400, Average Loss = 3.5577, Epsilon = 0.5400

Starting Episode 24...


Steps in Episode 24: 100%|██████████| 400/400 [03:20<00:00,  2.00it/s]


Episode 24: Total Reward = -290.75, Total Steps = 400, Average Loss = 3.2840, Epsilon = 0.5200

Starting Episode 25...


Steps in Episode 25:  96%|█████████▋| 385/400 [03:39<00:08,  1.75it/s]


KeyboardInterrupt: 