### QMIX Training

In [2]:

from mpe2 import simple_spread_v3
from pettingzoo.utils import aec_to_parallel
from failure_api.communication_models import ProbabilisticModel
from failure_api.wrappers import CommunicationWrapper

def make_env(num_agents=3, use_failure_api=False, failure_prob=0.5, seed=42, max_cycles=25):
    env = simple_spread_v3.env(N=num_agents, max_cycles=max_cycles)
    env.reset(seed=seed)

    if use_failure_api:
        model = ProbabilisticModel(agent_ids=env.possible_agents, failure_prob=failure_prob)
        env = CommunicationWrapper(env, failure_models=[model])
        
        
    return aec_to_parallel(env)


## Models

| Component       | Description                                                      |
| --------------- | ---------------------------------------------------------------- |
| `AgentQNetwork` | Shared network for all agents to estimate individual Q-values    |
| `MixerNetwork`  | Combines all agent Qs into a global Q value conditioned on state |
| `QMIXPolicy`    | Combines agents + mixer + action selection (ε-greedy)            |


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from tqdm.notebook import trange


class AgentQNetwork(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, act_dim)
        )

    def forward(self, obs):
        return self.net(obs)  
    
class MixerNetwork(nn.Module):
    def __init__(self, num_agents, state_dim, embed_dim=32):
        super().__init__()
        self.state_dim = state_dim
        self.num_agents = num_agents

        self.hyper_w1 = nn.Linear(state_dim, num_agents * embed_dim)
        self.hyper_w2 = nn.Linear(state_dim, embed_dim)
        self.hyper_b1 = nn.Linear(state_dim, embed_dim)
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1)
        )
        
    def forward(self, agent_qs, state):
        """
        agent_qs: (batch, num_agents)
        state: (batch, state_dim)
        """
        bs = agent_qs.size(0)
        
        w1 = self.hyper_w1(state).view(bs, self.num_agents, -1)
        b1 = self.hyper_b1(state).view(bs, 1, -1)   
        hidden = F.elu(torch.bmm(agent_qs.unsqueeze(1), w1) + b1)
        
        w2 = self.hyper_w2(state).view(bs, -1, 1)  
        b2 = self.hyper_b2(state).view(bs, 1, 1) 
        
        q_total = torch.bmm(hidden, w2) + b2
        return q_total.squeeze(-1).squeeze(-1) 

## QMIX Policy

Shared AgentQNetwork

Global MixerNetwork

Epsilon-greedy exploration

Centralized training + decentralized execution

In [4]:


class QMIXPolicy:
    def __init__(self, obs_dim, act_dim, state_dim, num_agents, device="cpu",
                 gamma=0.99, lr=5e-4, epsilon_start=1.0, epsilon_end=0.05, epsilon_decay=5000):
        self.num_agents = num_agents
        self.act_dim = act_dim
        self.device = device

        # Shared agent network
        self.agent_net = AgentQNetwork(obs_dim, act_dim).to(device)
        self.target_agent_net = AgentQNetwork(obs_dim, act_dim).to(device)
        self.target_agent_net.load_state_dict(self.agent_net.state_dict())

        # Mixer
        self.mixer_net = MixerNetwork(num_agents, state_dim).to(device)
        self.target_mixer_net = MixerNetwork(num_agents, state_dim).to(device)
        self.target_mixer_net.load_state_dict(self.mixer_net.state_dict())
        self.agent_q_net = AgentQNetwork(obs_dim, act_dim)  # example
        self.mixing_net = MixerNetwork(num_agents, state_dim)  # example

        # Optimizer
        self.optimizer = torch.optim.Adam(
            list(self.agent_net.parameters()) + list(self.mixer_net.parameters()),
            lr=lr
        )

        # Epsilon-greedy
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.total_steps = 0

        self.gamma = gamma
        
    def state_dict(self):
        return {
            "agent_q_net": self.agent_q_net.state_dict(),
            "mixing_net": self.mixing_net.state_dict(),
        }
    
    def load_state_dict(self, state_dict):
        self.agent_q_net.load_state_dict(state_dict["agent_q_net"])
        self.mixing_net.load_state_dict(state_dict["mixing_net"])

    def select_actions(self, obs_batch, explore=True):
        """
        obs_batch: Dict[str, np.array], each obs shape = (obs_dim,)
        """
        self.total_steps += 1
        epsilon = self._epsilon()

        actions = {}
        for agent_id, obs in obs_batch.items():
            obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device).unsqueeze(0)
            q_values = self.agent_net(obs_tensor)
            if explore and np.random.rand() < epsilon:
                action = np.random.randint(self.act_dim)
            else:
                action = torch.argmax(q_values).item()
            actions[agent_id] = action
        return actions

    def _epsilon(self):
        return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
               np.exp(-1.0 * self.total_steps / self.epsilon_decay)
    
    
    def update(self, batch):
        """
        batch: dict with keys: obs, actions, rewards, next_obs, states, next_states, dones
        Shapes:
          obs:      (B, num_agents, obs_dim)
          actions:  (B, num_agents)
          rewards:  (B,)
          next_obs: (B, num_agents, obs_dim)
          states:   (B, state_dim)
          dones:    (B,)
        """
        obs = torch.tensor(batch["obs"], dtype=torch.float32).to(self.device)
        actions = torch.tensor(batch["actions"], dtype=torch.long).to(self.device)
        rewards = torch.tensor(batch["rewards"], dtype=torch.float32).to(self.device)
        next_obs = torch.tensor(batch["next_obs"], dtype=torch.float32).to(self.device)
        dones = torch.tensor(batch["dones"], dtype=torch.float32).to(self.device)
        states = torch.tensor(batch["states"], dtype=torch.float32).to(self.device)
        next_states = torch.tensor(batch["next_states"], dtype=torch.float32).to(self.device)

        B = obs.shape[0]
        agent_qs = self.agent_net(obs.view(-1, obs.shape[-1]))  # (B*num_agents, act_dim)
        agent_qs = agent_qs.view(B, self.num_agents, self.act_dim)
        chosen_qs = torch.gather(agent_qs, dim=2, index=actions.unsqueeze(-1)).squeeze(-1)

        # Target Q
        with torch.no_grad():
            next_agent_qs = self.target_agent_net(next_obs.view(-1, next_obs.shape[-1]))
            next_agent_qs = next_agent_qs.view(B, self.num_agents, self.act_dim)
            max_next_qs = next_agent_qs.max(dim=2)[0]  # (B, num_agents)

        # Mixer
        q_total = self.mixer_net(chosen_qs, states)
        next_q_total = self.target_mixer_net(max_next_qs, next_states)

        targets = rewards + self.gamma * (1 - dones) * next_q_total

        loss = nn.MSELoss()(q_total, targets.detach())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def update_target(self, tau=0.01):
        for target_param, param in zip(self.target_agent_net.parameters(), self.agent_net.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

        for target_param, param in zip(self.target_mixer_net.parameters(), self.mixer_net.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

## Replay Buffer

In [5]:
# replay_buffer.py


class ReplayBuffer:
    def __init__(self, capacity, num_agents, obs_dim, state_dim):
        self.capacity = capacity
        self.buffer = []
        self.num_agents = num_agents
        self.obs_dim = obs_dim
        self.state_dim = state_dim

    def add(self, obs, actions, reward, next_obs, state, next_state, done):
        """
        obs, next_obs: dict of obs arrays per agent
        actions: dict of int
        """
        obs_arr = np.array([obs[a] for a in sorted(obs)])
        next_obs_arr = np.array([next_obs[a] for a in sorted(next_obs)])
        action_arr = np.array([actions[a] for a in sorted(actions)])

        transition = {
            "obs": obs_arr,
            "actions": action_arr,
            "rewards": reward,
            "next_obs": next_obs_arr,
            "states": state,
            "next_states": next_state,
            "dones": done
        }

        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return {
            "obs": np.stack([x["obs"] for x in batch]),
            "actions": np.stack([x["actions"] for x in batch]),
            "rewards": np.array([x["rewards"] for x in batch]),
            "next_obs": np.stack([x["next_obs"] for x in batch]),
            "states": np.stack([x["states"] for x in batch]),
            "next_states": np.stack([x["next_states"] for x in batch]),
            "dones": np.array([x["dones"] for x in batch])
        }

    def __len__(self):
        return len(self.buffer)

## Train

In [6]:



def get_state(obs_dict, agent_ids, obs_dim):
    """
    Returns concatenated observation vector across all agents.
    Missing agents are padded with zeros (if terminated).
    """
    return np.concatenate([
        obs_dict.get(a, np.zeros(obs_dim)) for a in sorted(agent_ids)
    ])

def train_qmix(use_failure_api=False, episodes=1000, batch_size=32, buffer_capacity=10000):
    env = make_env(num_agents=3, use_failure_api=use_failure_api)
    agent_ids = env.possible_agents
    obs_sample = env.reset()[0]
    obs_dim = obs_sample[agent_ids[0]].shape[0]
    act_dim = env.action_space(agent_ids[0]).n
    state_dim = obs_dim * len(agent_ids)

    policy = QMIXPolicy(obs_dim, act_dim, state_dim, len(agent_ids))
    buffer = ReplayBuffer(buffer_capacity, len(agent_ids), obs_dim, state_dim)

    rewards_log = []

    for ep in trange(episodes, desc="Training"):
        obs_dict, _ = env.reset()
        total_reward = 0
        done = False
        steps = 0

        while not done:
            actions = policy.select_actions(obs_dict, explore=True)
            next_obs_dict, reward_dict, terminations, truncations, infos = env.step(actions)

            reward = sum(reward_dict.values())
            done = all(list(terminations.values()) + list(truncations.values()))
            if all(a in obs_dict and a in next_obs_dict for a in agent_ids):
                state = get_state(obs_dict, agent_ids, obs_dim)
                next_state = get_state(next_obs_dict, agent_ids, obs_dim)
            
                buffer.add(obs_dict, actions, reward, next_obs_dict, state, next_state, done)
                obs_dict = next_obs_dict
                total_reward += reward
                steps += 1

            if len(buffer) > batch_size:
                batch = buffer.sample(batch_size)
                loss = policy.update(batch)
                policy.update_target()

        rewards_log.append((ep, total_reward, steps))
        for i, x in enumerate(buffer.buffer):
            try:
                np.stack([x["next_obs"] for x in buffer.buffer])
            except Exception as e:
                print(f"⚠️ Problem at index {i}: {e}")
                break

        if ep % 100 == 0:
            avg_reward = np.mean([r[1] for r in rewards_log[-100:]])
            torch.save(policy.state_dict(), r"C:\Users\koste\venv\Bachelor_Thesis\QMIX_Training\qmix_ep{ep}.pt")
            print(f"[Episode {ep}] Avg Reward: {avg_reward:.2f} Steps: {steps}")

    return rewards_log



In [7]:
log_baseline = train_qmix(use_failure_api=False, episodes=500)
log_failure = train_qmix(use_failure_api=True, episodes=500)


Training:   0%|          | 0/500 [00:00<?, ?it/s]

[Episode 0] Avg Reward: -72.51 Steps: 25
[Episode 100] Avg Reward: -100.20 Steps: 25
[Episode 200] Avg Reward: -143.57 Steps: 25
[Episode 300] Avg Reward: -174.81 Steps: 25
[Episode 400] Avg Reward: -185.78 Steps: 25


KeyboardInterrupt: 