# MADDPG

### Imports

In [1]:
import numpy as np
from pettingzoo.atari import boxing_v2
import numpy as np
import random 
import torch
import torch.nn as nn
import torch.optim as optim
from supersuit import pad_observations_v0, pad_action_space_v0, resize_v1, normalize_obs_v0, frame_skip_v0, dtype_v0
from pettingzoo.utils import aec_to_parallel
import matplotlib.pyplot as plt
from collections import defaultdict

## Critic and Actor Class

In [2]:
class Critic(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super(Critic, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, obs, actions):
        # Ensure observations are flattened
        obs = obs.flatten(start_dim=1) if len(obs.shape) > 2 else obs
        actions = actions.flatten(start_dim=1) if len(actions.shape) > 2 else actions

        # Debugging shapes
        #print(f"obs shape: {obs.shape}, actions shape: {actions.shape}")

        x = torch.cat([obs, actions], dim=-1)
        return self.fc(x)
    
class Actor(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super(Actor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh(),
        )

    def forward(self, obs):
        return self.fc(obs)

## MADDPG Agent

In [3]:
class MADDPGAgent:
    def __init__(self, obs_dim, action_dim, lr_actor=1e-3, lr_critic=1e-3):
        self.actor = Actor(obs_dim, action_dim)
        self.critic = Critic(obs_dim * 2, action_dim * 2)
        self.actor_target = Actor(obs_dim, action_dim)
        self.critic_target = Critic(obs_dim * 2, action_dim * 2)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)

        self.update_target(1.0)

    def update_target(self, tau):
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

## MADDPG Wrapper

In [4]:
class MADDPG:
    def __init__(self, obs_dim, action_dim, n_agents, gamma=0.95, tau=0.01):
        self.agents = [MADDPGAgent(obs_dim, action_dim) for _ in range(n_agents)]
        self.gamma = gamma
        self.tau = tau

    def update(self, replay_buffer, batch_size):
        policy_loss = 0
        value_loss = 0

        for agent_idx, agent in enumerate(self.agents):
            obs, actions, rewards, next_obs, dones = replay_buffer.sample(batch_size)
            obs = torch.tensor(obs, dtype=torch.float32).flatten(start_dim=2)  # Flatten observations
            actions = torch.tensor(actions, dtype=torch.float32).flatten(start_dim=2)  # Flatten actions

            rewards = torch.tensor(rewards[:, agent_idx], dtype=torch.float32).unsqueeze(-1)
            next_obs = torch.tensor(next_obs, dtype=torch.float32).flatten(start_dim=2)
            dones = torch.tensor(dones[:, agent_idx], dtype=torch.float32).unsqueeze(-1)

            # Critic update
            with torch.no_grad():
                next_actions = torch.cat([
                    ag.actor_target(next_obs[:, i]) for i, ag in enumerate(self.agents)
                ], dim=-1)
                target_q = rewards + self.gamma * (1 - dones) * agent.critic_target(next_obs, next_actions)

            q_value = agent.critic(obs, actions)
            critic_loss = nn.MSELoss()(q_value, target_q)
            agent.critic_optimizer.zero_grad()
            critic_loss.backward()
            agent.critic_optimizer.step()
            value_loss += critic_loss.item()

            # Actor update
            predicted_actions = torch.cat([
                agent.actor(obs[:, i]) if i == agent_idx else actions[:, i] for i, _ in enumerate(self.agents)
            ], dim=-1)
            actor_loss = -agent.critic(obs, predicted_actions).mean()
            agent.actor_optimizer.zero_grad()
            actor_loss.backward()
            agent.actor_optimizer.step()
            policy_loss += actor_loss.item()

            # Update targets
            agent.update_target(self.tau)

        return policy_loss / len(self.agents), value_loss / len(self.agents)

## Buffer

In [5]:
class ReplayBuffer:
    def __init__(self, max_size, obs_dim, action_dim, n_agents):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.obs = np.zeros((max_size, n_agents, obs_dim))
        self.actions = np.zeros((max_size, n_agents, action_dim))
        self.rewards = np.zeros((max_size, n_agents))
        self.next_obs = np.zeros((max_size, n_agents, obs_dim))
        self.dones = np.zeros((max_size, n_agents))

    def store(self, obs, actions, rewards, next_obs, dones):
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = actions
        self.rewards[self.ptr] = rewards
        self.next_obs[self.ptr] = next_obs
        self.dones[self.ptr] = dones
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        idx = np.random.choice(self.size, batch_size, replace=False)
        return (
            self.obs[idx],
            self.actions[idx],
            self.rewards[idx],
            self.next_obs[idx],
            self.dones[idx],
        )

## Environment setup

In [8]:
# Environment setup
env = boxing_v2.env(render_mode="rgb_array")
env.reset(seed=42)
env = pad_observations_v0(env)
env = pad_action_space_v0(env)
env = resize_v1(env, 84, 84)
env = dtype_v0(env, dtype="float32")
env = normalize_obs_v0(env, env_min=0, env_max=1)
parallel_env = aec_to_parallel(env)

obs_dim = 84 * 84
action_dim = parallel_env.action_space("first_0").n
n_agents = 2
maddpg = MADDPG(obs_dim, action_dim, n_agents)
replay_buffer = ReplayBuffer(100000, obs_dim, action_dim, n_agents)

## Training loop

In [9]:
# Training loop
num_episodes = 10
batch_size = 64

max_steps_per_episode = 100  # Set the maximum steps per episode
for episode in range(num_episodes):
    obs = parallel_env.reset()
    if isinstance(obs, tuple):
        obs = obs[0]  # Extract observations if returned as a tuple

    done = {agent: False for agent in parallel_env.agents}
    episode_reward = defaultdict(float)  # Store cumulative reward for each agent

    step_count = 0  # Initialize step counter

    while not all(done.values()) and step_count < max_steps_per_episode:
        actions = {}
        for agent in parallel_env.agents:
            # Preprocess observation
            obs_preprocessed = torch.tensor(obs[agent], dtype=torch.float32)
            if len(obs_preprocessed.shape) > 2:  # Ensure grayscale
                obs_preprocessed = obs_preprocessed.mean(axis=-1)  # Convert RGB to grayscale
            obs_preprocessed = obs_preprocessed.flatten().unsqueeze(0)  # Flatten and add batch dim

            # Get continuous action from Actor
            continuous_action = maddpg.agents[int(agent.split('_')[1])].actor(obs_preprocessed).detach().numpy()

            # Convert continuous action to discrete action
            discrete_action = np.argmax(continuous_action)  # Take the action with the highest probability
            actions[agent] = discrete_action  # Store the discrete action for the agent

        # One-hot encode actions for storage
        actions_one_hot = np.zeros((len(parallel_env.agents), action_dim))
        for idx, agent in enumerate(parallel_env.agents):
            actions_one_hot[idx, actions[agent]] = 1

        # Step the environment
        step_output = parallel_env.step(actions)

        if isinstance(step_output, tuple):  # Handle cases where step returns a tuple
            next_obs, rewards, dones, truncations, infos = step_output
            dones = {agent: dones[agent] or truncations[agent] for agent in dones}
        else:
            next_obs, rewards, dones, infos = step_output

        # Accumulate rewards
        for agent, reward in rewards.items():
            episode_reward[agent] += reward

        # Store data in the replay buffer
        obs_array = []
        next_obs_array = []

        for agent in parallel_env.agents:
            # Preprocess observations for storage
            obs_processed = obs[agent].mean(axis=-1).flatten() if len(obs[agent].shape) > 2 else obs[agent].flatten()
            next_obs_processed = next_obs[agent].mean(axis=-1).flatten() if len(next_obs[agent].shape) > 2 else next_obs[agent].flatten()
            obs_array.append(obs_processed)
            next_obs_array.append(next_obs_processed)

        replay_buffer.store(
            np.array(obs_array),
            actions_one_hot,  # Use one-hot encoded actions
            np.array([rewards[agent] for agent in parallel_env.agents]),
            np.array(next_obs_array),
            np.array([dones[agent] for agent in parallel_env.agents]),
        )

        obs = next_obs

        if replay_buffer.size >= batch_size:
            policy_loss, value_loss = maddpg.update(replay_buffer, batch_size)
            #print(f"Policy Loss: {policy_loss}, Value Loss: {value_loss}")

        step_count += 1  # Increment step count

    # Print cumulative reward for the episode
    total_reward = sum(episode_reward.values())
    print(f"Episode {episode + 1}/{num_episodes} completed with total reward: {total_reward}")


Episode 1/10 completed with total reward: 0.0
Episode 2/10 completed with total reward: 0.0
Episode 3/10 completed with total reward: 0.0
Episode 4/10 completed with total reward: 0.0
Episode 5/10 completed with total reward: 0.0
Episode 6/10 completed with total reward: 0.0
Episode 7/10 completed with total reward: 0.0
Episode 8/10 completed with total reward: 0.0
Episode 9/10 completed with total reward: 0.0
Episode 10/10 completed with total reward: 0.0
