# MARL with Communication Protocols

## Introduction

In this notebook, we'll explore how to implement communication protocols in Multi-Agent Reinforcement Learning (MARL) systems. Effective communication between agents can significantly improve their ability to collaborate and solve complex tasks. In the context of app modernization, this could represent different components of the modernization process sharing information and coordinating their actions.

## Setup

First, let's import the necessary libraries and set up our environment.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
torch.autograd.set_detect_anomaly(True)

## Implementing Agents with Communication

We'll create a simple environment where agents need to communicate to solve a task. Our agents will use a differentiable communication channel to share information.

In [92]:
# Hyperparameters
num_agents = 3
state_dim = 10
hidden_dim = 64
action_dim = 5
message_dim = 8

class CommunicatingAgent(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, message_dim, num_agents):
        super(CommunicatingAgent, self).__init__()
        input_dim = state_dim + (num_agents - 1) * message_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.action_head = nn.Linear(hidden_dim, action_dim)
        self.message_head = nn.Linear(hidden_dim, message_dim)
        
    def forward(self, state, messages):
        # Flatten received messages and concatenate with state
        x = torch.cat([state, messages.flatten(1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action = self.action_head(x)
        new_message = self.message_head(x)
        return action, new_message

class CommunicationChannel:
    def __init__(self, num_agents, message_dim):
        self.num_agents = num_agents
        self.message_dim = message_dim
        self.reset()
        
    def reset(self):
        self.messages = torch.zeros(self.num_agents, self.message_dim).detach()

    def send(self, agent_id, message):
        # Clone the message to avoid in-place modifications
        self.messages[agent_id] = message.clone().detach()

    def receive(self, agent_id):
        # Return messages from other agents, avoiding in-place modifications
        return torch.cat([self.messages[:agent_id], self.messages[agent_id+1:]], dim=0).detach()

class Environment:
    def __init__(self, num_agents, state_dim, action_dim):
        self.num_agents = num_agents
        self.state_dim = state_dim
        self.action_dim = action_dim

    def reset(self):
        return torch.randn(self.num_agents, self.state_dim)

    def step(self, actions):
        # Simplified environment dynamics
        next_states = torch.randn(self.num_agents, self.state_dim)
        rewards = torch.sum(actions, dim=-1)  # Reward is sum of action values
        done = False
        return next_states, rewards, done

# Initialize agents and environment
agents = [CommunicatingAgent(state_dim, hidden_dim, action_dim, message_dim, num_agents) for _ in range(num_agents)]
comm_channel = CommunicationChannel(num_agents, message_dim)
env = Environment(num_agents, state_dim, action_dim)

## Training Loop

Now, let's implement a training loop where agents learn to communicate and collaborate.

In [None]:
def train_agents(num_episodes, max_steps):
    optimizers = [optim.Adam(agent.parameters(), lr=0.001) for agent in agents]
    episode_rewards = []

    for episode in range(num_episodes):
        states = env.reset()
        comm_channel.reset()
        episode_reward = 0

        for step in range(max_steps):
            actions = []
            messages = []

            for i, agent in enumerate(agents):
                received_messages = comm_channel.receive(i).clone()  # Ensure no in-place modification
                action, message = agent(states[i].unsqueeze(0), received_messages.unsqueeze(0))
                actions.append(action.squeeze(0).clone())  # Clone to avoid in-place modifications
                messages.append(message.squeeze(0).clone())  # Clone to avoid in-place modifications
                comm_channel.send(i, message.clone())  # Clone message before sending

            actions = torch.stack(actions)  # Ensure no in-place modification here
            next_states, rewards, done = env.step(actions)
            episode_reward += rewards.sum().item()

            # Compute losses based on actions
            for i, agent in enumerate(agents):
                loss = (actions[i] ** 2).sum()  # Loss based on actions
                optimizers[i].zero_grad()
                
                retain_graph_flag = i < len(agents) - 1
                loss.backward(retain_graph=retain_graph_flag)
                optimizers[i].step()

            states = next_states.detach()  # Detach next_states to avoid graph retention

            if done:
                break

        episode_rewards.append(episode_reward)
        if episode % 100 == 0:
            print(f"Episode {episode}, Avg Reward: {np.mean(episode_rewards[-100:]):.2f}")

    return episode_rewards

# Train the agents
num_episodes = 1000
max_steps = 50
rewards = train_agents(num_episodes, max_steps)

# Plot the learning curve
plt.plot(rewards)
plt.title("Learning Curve")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.show()

## Analyzing Communication

Let's analyze the learned communication protocols by visualizing the messages sent between agents.

In [None]:
def visualize_messages(num_steps=10):
    states = env.reset()
    comm_channel.reset()
    message_history = []

    for _ in range(num_steps):
        step_messages = []
        for i, agent in enumerate(agents):
            received_messages = comm_channel.receive(i)
            _, message = agent(states[i], received_messages)
            step_messages.append(message.detach().numpy())
            comm_channel.send(i, message)
        message_history.append(step_messages)

    message_history = np.array(message_history)

    fig, axes = plt.subplots(num_agents, 1, figsize=(10, 5*num_agents))
    for i in range(num_agents):
        im = axes[i].imshow(message_history[:, i, :].T, aspect='auto', cmap='viridis')
        axes[i].set_title(f"Agent {i+1} Messages")
        axes[i].set_xlabel("Time Step")
        axes[i].set_ylabel("Message Dimension")
        plt.colorbar(im, ax=axes[i])
    plt.tight_layout()
    plt.show()

visualize_messages()

## Conclusion

In this notebook, we implemented a MARL system with a communication protocol. We saw how agents can learn to communicate and collaborate to solve a task. This approach has several potential applications in app modernization:

1. Coordinating different aspects of the modernization process (e.g., code refactoring, database migration, and infrastructure updates)
2. Sharing information about dependencies and potential conflicts during the modernization process
3. Collaborative decision-making for architectural choices in the modernized application

Future work could involve:
- Implementing more sophisticated communication protocols (e.g., attention mechanisms)
- Applying this approach to specific app modernization tasks
- Analyzing the emergent communication patterns in the context of software engineering practices

## References

1. Foerster, J., et al. (2016). Learning to communicate with deep multi-agent reinforcement learning. NeurIPS.
2. Sukhbaatar, S., Fergus, R., et al. (2016). Learning multiagent communication with backpropagation. NeurIPS.