In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from collections import deque

In [2]:
def divide_to_batch(tensor, batch_size):
    return tensor.unsqueeze(0).repeat(batch_size, 1, 1, 1)

In [37]:
# Define the agent network with CNN
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the agent network with CNN and agent ID input
class AgentNetwork(nn.Module):
    def __init__(self, observation_shape, action_dim, n_agents):
        super(AgentNetwork, self).__init__()
        # observation_shape is (H, W, C)
        self.conv1 = nn.Conv2d(observation_shape[2], 16, kernel_size=3, stride=1, padding=1) # (16, H, W)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) # (32, H, W)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) # (32, H, W)

        # Calculate the flattened size after convolutions
        flat_size = 32 * observation_shape[0] * observation_shape[1] # 32 * H * W

        # Add a linear layer to process the agent ID
        self.fc_agent_id = nn.Linear(n_agents, 32) 

        self.fc1 = nn.Linear(flat_size + 32, 128)  # Concatenate conv output with agent ID embedding
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, obs, agent_id):
        # Add a batch dimension
        if len(obs.shape) == 3:
            obs = obs.unsqueeze(0) 
        if len(agent_id.shape) == 1:
            agent_id = agent_id.unsqueeze(0)
            
        x = obs.permute(0, 3, 1, 2)  # Convert to (B, C, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.flatten(start_dim=1)  # Flatten all dimensions except batch

        # Process agent ID
        agent_id_embedding = F.relu(self.fc_agent_id(agent_id))

        # Concatenate the flattened convolutional output with the agent ID embedding
        x = torch.cat((x, agent_id_embedding), dim=1)

        x = F.relu(self.fc1(x))
        q_values = self.fc2(x)
        return q_values

n_agents = 5  # Example: 5 agents
agent_id = 2   # Example: Agent with ID 2 (0-indexed)
agent_id_one_hot = F.one_hot(torch.tensor(agent_id), num_classes=n_agents).float()

obs = torch.randn(13, 13, 5)
test = AgentNetwork(obs.shape, 21, n_agents)
test(obs, agent_id_one_hot)

tensor([[ 0.0876,  0.0105, -0.0041,  0.0459,  0.1077, -0.0681,  0.0394, -0.0490,
         -0.0374,  0.0253,  0.0263,  0.0250, -0.0725, -0.0289, -0.0417,  0.0821,
          0.0226,  0.0354,  0.0285,  0.0574, -0.0257]],
       grad_fn=<AddmmBackward0>)

In [31]:
# Define the HyperNetwork with CNN for dynamic weight generation
class HyperNetwork(nn.Module):
    def __init__(self, input_shape, output_dim, hidden_dim):
        super().__init__()
        # CNN layers
        # Input shape (H, W, C)
        self.conv1 = nn.Conv2d(input_shape[2], 32, kernel_size=3, padding=1) #(B, 32, H, W)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1) # (B, 16, H, W)
        
        # FC layers
        flat_size = 16 * input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(flat_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    @staticmethod
    def _init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, state):
        # Add batch dimension if not present
        if len(state.shape) == 3:
            state = state.unsqueeze(0)  # (1, H, W, C)
        
        # Convert (B, H, W, C) to (B, C, H, W) for CNN
        x = state.permute(0, 3, 1, 2)
        
        # Apply convolutions
        x = F.relu(self.conv1(x))  # (B, 32, H, W)
        x = F.relu(self.conv2(x))  # (B, 16, H, W)
        
        # Flatten all dimensions except batch
        x = x.flatten(start_dim=1)  # (B, 16*H*W)
        
        # Apply fully connected layers
        x = F.relu(self.fc1(x))  # (B, hidden_dim)
        weights = self.fc2(x)  # (B, output_dim)
        
        return weights
    

state = torch.randn(45, 45, 5)
test_hyper = HyperNetwork(state.shape, 1, 64)
state_batch = torch.randn(5, 45, 45, 5)
test_hyper(state_batch).shape

torch.Size([5, 1])

In [32]:
# Define the Mixing Network
class MixingNetwork(nn.Module):
    def __init__(self, state_dim, num_agents, mixing_dim):
        super(MixingNetwork, self).__init__()
        self.num_agents = num_agents
        self.mixing_dim = mixing_dim

        # Hypernetworks for weights and biases
        self.hyper_w1 = HyperNetwork(state_dim, num_agents * mixing_dim, 64)
        self.hyper_b1 = HyperNetwork(state_dim, mixing_dim, 64)
        self.hyper_w2 = HyperNetwork(state_dim, mixing_dim, 64)
        self.hyper_b2 = HyperNetwork(state_dim, 1, 64)

    def forward(self, agent_qs, states):
        # Add batch dimension if not present
        if len(agent_qs.shape) == 2:
            agent_qs = agent_qs.unsqueeze(0)  # (1, num_agents, action_dim)
        else:
            batch_size = agent_qs.size(0)
        if len(states.shape) == 3:
            states = states.unsqueeze(0)  # (1, H, W, C)


        # Compute the max value across the entire batch
        qs_max_idx = agent_qs.argmax(dim=-1)  # Argmax of action on each agent 
        agent_qs = agent_qs.gather(1, qs_max_idx.unsqueeze(-1))  # Caculate max value on each agent

        agent_qs = agent_qs.view(batch_size, 1, self.num_agents)  # (batch_size, 1, num_agents)

        # First layer weights and biases
        w1 = torch.abs(self.hyper_w1(states)) 
        w1 = w1.view(batch_size, self.num_agents, self.mixing_dim)  # (batch_size, num_agents, mixing_dim)
        b1 = self.hyper_b1(states) # (batch_size, mixing_dim)
        b1 = b1.view(batch_size, 1, self.mixing_dim)  # (batch_size, 1, mixing_dim)


        # Compute first layer output
        hidden = F.elu(torch.bmm(agent_qs, w1) + b1)  # (batch_size, 1, mixing_dim)
        # Second layer weights and biases
        w2 = torch.abs(self.hyper_w2(states)) 
        w2 = w2.view(batch_size, self.mixing_dim, 1)  # (batch_size, mixing_dim, 1)
        b2 = self.hyper_b2(states)  
        b2 = b2.view(batch_size, 1, 1)  # (batch_size, 1, 1)

        # Compute final output
        q_tot = torch.bmm(hidden, w2) + b2  # (batch_size, 1, 1)
        # Remove unnecessary dimensions
        q_tot = q_tot.squeeze(-1)  # (batch_size, 1)
        
        # If input was single sample, remove batch dimension
        if len(agent_qs.shape) == 2:
            q_tot = q_tot.squeeze(0)  # (1)
            
        return q_tot

agent_qs = torch.randn(5, 81, 21)
test_mix_net = MixingNetwork(state.shape, 81, 2)
test_mix_net(agent_qs, state_batch)

tensor([[-0.1747],
        [-1.3147],
        [ 0.4179],
        [-0.7584],
        [-0.4372]], grad_fn=<SqueezeBackward1>)

In [7]:
# Putting it all together
class QMIX(nn.Module):
    def __init__(self, obs_shape, action_dim, state_dim, num_agents, mixing_dim):
        super(QMIX, self).__init__()
        self.num_agents = num_agents

        # Create agent networks
        self.agent_networks = nn.ModuleList([AgentNetwork(obs_shape, action_dim) for _ in range(num_agents)])

        # Mixing network
        self.mixing_network = MixingNetwork(state_dim, num_agents, mixing_dim)

    def forward(self, total_obs, states):
        if len(total_obs.shape) == 4:
            total_obs = total_obs.unsqueeze(0) # (B, num_agents, H, W, C)
            
        # Forward pass for each agent
        agent_qs = torch.stack([agent(total_obs[:,i]) for i, agent in enumerate(self.agent_networks)], dim=1)  # [batch_size, num_agents, action_dim]

        # Mix Q-values
        q_tot = self.mixing_network(agent_qs, states)
        return q_tot


In [33]:
from magent2.environments import battle_v4

env = battle_v4.env(map_size=45,max_cycles=300, render_mode="human")
env.reset()

In [20]:
env.reset()
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    if not termination and not truncation:
        action = np.random.randint(0,21)
    else:
        action = None
    env.step(action)
    
print(env.metadata)
env.close()

{'render_modes': ['human', 'rgb_array'], 'name': 'battle_v4', 'render_fps': 5, 'is_parallelizable': True}


In [42]:
# Define the environment parameters
NUM_AGENTS = 81
obs_shape = env.observation_space("red_0").shape  # (Height, Width, C)
state_dim = env.state_space.shape  # Dimension of the global state
action_dim = env.action_space("red_0").n  # Number of discrete actions for each agent


# Hyperparameters
MIXING_DIM = 32  # Dimension of the mixing network
NUM_EPISODES = 10
BATCH_SIZE = 16
GAMMA = 0.99
LR = 1e-3
WEIGHT_DECAY = 0.001
MAX_REPLAY_BUFFER_SIZE = 5000

policy_net = AgentNetwork(obs_shape, action_dim, NUM_AGENTS)
target_net = AgentNetwork(obs_shape, action_dim, NUM_AGENTS)
target_net.load_state_dict(policy_net.state_dict())

def one_hot_encode(agent_id):
    return F.one_hot(torch.tensor(agent_id), num_classes=NUM_AGENTS).float()


In [28]:
memory = []
turn = "red"
initial_state = np.transpose(env.state())

for episode in range(NUM_EPISODES):
    
    env.reset()
    
    for agent in env.agent_iter():
        
        agent_team = agent.split("_")[0]
        agent_id = agent.split("_")[1]
        
        observation, reward, termination, truncation, _ = env.last()
        done = termination or truncation
        
        if done:
            action = None 
            env.step(action)
        else:
            if agent_team == "blue":
                

In [12]:
# # Training Loop
# for i_episode in range(episode, NUM_EPISODES):
#     # Reset environment for the new episode
#     env.reset()
#     episode_reward = 0
#     running_loss = 0.0

#     # Collect experiences and interact with the environment
#     for agent in env.agent_iter():
#         agent_handle = agent.split("_")

#         observation, reward, termination, truncation, info = env.last()
#         done = termination or truncation

#         # Perform actions for each agent
#         if done:
#             action = None  # Agent is dead (or finished its part)
#             env.step(action)
#             total_obs[agent_handle[1]] = torch.tensor(0)
#         else:
#             if agent_handle[0] == "blue": 
#                 # Select action using the current policy (based on the QMIX model)
#                 action = policy()
#                 env.step(action)
#                 next_observation, reward, termination, truncation, _ = env.last()
#                 # Store the experience in replay buffer
#                 replay_buffer.append((observation, action, reward, done, state))
                
#                 # Take the action in the environment
#                 env.step(action)

#                 # Calculate the next observation and the total reward for the episode
#                 next_observation, reward, termination, truncation, info = env.last()
                
#                 # Store transition
#                 replay_buffer.append((total_obs[agent], action, reward, done, env.state()))
#                 total_obs[agent] = next_observation

#                 episode_reward += reward

#                 # Training step
#                 if len(replay_buffer) >= BATCH_SIZE:
#                     batch = np.random.choice(len(replay_buffer), BATCH_SIZE, replace=False)
#                     transitions = [replay_buffer[idx] for idx in batch]
#                     obs_batch, action_batch, reward_batch, done_batch, state_batch = zip(*transitions)

#                     obs_batch = torch.tensor(obs_batch, dtype=torch.float32)
#                     action_batch = torch.tensor(action_batch, dtype=torch.long)
#                     reward_batch = torch.tensor(reward_batch, dtype=torch.float32)
#                     done_batch = torch.tensor(done_batch, dtype=torch.float32)
#                     state_batch = torch.tensor(state_batch, dtype=torch.float32)

#                     # Forward pass
#                     q_values, global_q = model(obs_batch, state_batch)

#                     # Compute targets
#                     with torch.no_grad():
#                         target_q_values, _ = model(obs_batch, state_batch)
#                         target_values = reward_batch + (1 - done_batch) * GAMMA * target_q_values.max(dim=-1)[0]

#                     # Loss calculation
#                     predicted_q_values = q_values.gather(1, action_batch.unsqueeze(-1)).squeeze(-1)
#                     loss = nn.MSELoss()(predicted_q_values, target_values)

#                     # Optimize model
#                     optimizer.zero_grad()
#                     loss.backward()
#                     optimizer.step()
#                     running_loss += loss.item()
#                 else:
#                     action = np.random.randint(action_dim)
#                     env.step(action)

#     # Log statistics
#     print(f"Episode {i_episode}/{NUM_EPISODES}, Reward: {episode_reward}, Loss: {running_loss}")
