# Predator-Prey Pursuit: A Multi-Agent Reinforcement Learning Approach

## Introduction

In this notebook, we'll implement a classic Multi-Agent Reinforcement Learning (MARL) scenario: Predator-Prey Pursuit. This problem involves both cooperative and competitive elements, making it an intuitive starting point for exploring MARL techniques.

Our scenario will consist of multiple predator agents working together to capture prey agents in a grid world environment. We'll use Independent Q-Learning with parameter sharing as our MARL approach.

## Environment Setup

First, let's set up our environment. We'll use a simple grid world where predators and prey can move in four directions: up, down, left, and right.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple

class PredatorPreyEnv:
    def __init__(self, grid_size: int, n_predators: int, n_prey: int):
        self.grid_size = grid_size
        self.n_predators = n_predators
        self.n_prey = n_prey
        self.initial_n_prey = n_prey  # Store the initial number of prey
        self.reset()

    def reset(self) -> np.ndarray:
        self.predator_positions = np.random.randint(0, self.grid_size, size=(self.n_predators, 2))
        self.prey_positions = np.random.randint(0, self.grid_size, size=(self.initial_n_prey, 2))
        return self._get_observations()

    def step(self, predator_actions: List[int]) -> Tuple[np.ndarray, List[float], bool, dict]:
        # Move predators
        for i, action in enumerate(predator_actions):
            self._move_agent(self.predator_positions[i], action)

        # Move prey (random movement)
        for i in range(len(self.prey_positions)):
            self._move_agent(self.prey_positions[i], np.random.randint(0, 4))

        # Check for captures
        rewards = [0] * self.n_predators
        captured_prey = []
        for i, prey_pos in enumerate(self.prey_positions):
            if any(np.all(prey_pos == pred_pos) for pred_pos in self.predator_positions):
                rewards = [1] * self.n_predators  # All predators get reward for capture
                captured_prey.append(i)

        # Remove captured prey
        self.prey_positions = np.delete(self.prey_positions, captured_prey, axis=0)

        done = len(self.prey_positions) == 0
        return self._get_observations(), rewards, done, {}

    def _move_agent(self, position: np.ndarray, action: int):
        if action == 0:  # Up
            position[0] = max(0, position[0] - 1)
        elif action == 1:  # Down
            position[0] = min(self.grid_size - 1, position[0] + 1)
        elif action == 2:  # Left
            position[1] = max(0, position[1] - 1)
        elif action == 3:  # Right
            position[1] = min(self.grid_size - 1, position[1] + 1)

    def _get_observations(self) -> np.ndarray:
        obs = np.zeros((self.n_predators, 2 * (self.n_predators + self.initial_n_prey)))
        for i in range(self.n_predators):
            obs[i, :2*self.n_predators] = self.predator_positions.flatten()
            prey_obs = self.prey_positions.flatten()
            obs[i, 2*self.n_predators:2*(self.n_predators + len(self.prey_positions))] = prey_obs
            # The remaining elements are already zero-padded
        return obs

    def render(self):
        grid = np.zeros((self.grid_size, self.grid_size, 3))
        for pos in self.predator_positions:
            grid[pos[0], pos[1]] = [1, 0, 0]  # Red for predators
        for pos in self.prey_positions:
            grid[pos[0], pos[1]] = [0, 1, 0]  # Green for prey
        plt.imshow(grid)
        plt.show()

# Test the environment
env = PredatorPreyEnv(grid_size=10, n_predators=3, n_prey=2)
obs = env.reset()
env.render()

## Agent Implementation

Now that we have our environment, let's implement our predator agents using Independent Q-Learning with parameter sharing. We'll use a simple neural network to approximate the Q-function.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class PredatorAgent:
    def __init__(self, state_dim, action_dim):
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters())
        self.memory = deque(maxlen=10000)
        self.batch_size = 32
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.update_target_every = 100
        self.update_counter = 0

    def get_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, 3)
        with torch.no_grad():
            q_values = self.q_network(torch.FloatTensor(state))
            return q_values.argmax().item()

    def update(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_network(next_states).max(1)[0].detach()
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_counter += 1
        if self.update_counter % self.update_target_every == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

# Initialize environment and agents
env = PredatorPreyEnv(grid_size=10, n_predators=3, n_prey=2)
state_dim = 2 * (env.n_predators + env.n_prey)
action_dim = 4
agents = [PredatorAgent(state_dim, action_dim) for _ in range(env.n_predators)]

## Training Loop

Now that we have our environment and agents set up, let's implement the training loop.

In [None]:
n_episodes = 1000
max_steps = 100
episode_rewards = []

for episode in range(n_episodes):
    state = env.reset()
    episode_reward = 0
    
    for step in range(max_steps):
        actions = [agent.get_action(s) for agent, s in zip(agents, state)]
        next_state, rewards, done, _ = env.step(actions)
        episode_reward += sum(rewards)
        
        for i, agent in enumerate(agents):
            agent.update(state[i], actions[i], rewards[i], next_state[i], done)
        
        state = next_state
        
        if done:
            break
    
    episode_rewards.append(episode_reward)
    
    if episode % 100 == 0:
        print(f"Episode {episode}, Avg Reward: {np.mean(episode_rewards[-100:]):.2f}")

# Plot the learning curve
plt.plot(episode_rewards)
plt.title("Learning Curve")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.show()

## Evaluation

Let's evaluate our trained agents by running a few episodes and visualizing their behavior.

In [None]:
def evaluate(env, agents, n_episodes=5):
    for episode in range(n_episodes):
        state = env.reset()
        done = False
        step = 0
        
        while not done and step < 100:
            env.render()
            actions = [agent.get_action(s) for agent, s in zip(agents, state)]
            state, _, done, _ = env.step(actions)
            step += 1
        
        print(f"Episode {episode + 1} finished in {step} steps")

evaluate(env, agents)

## Conclusion

In this notebook, we've implemented a basic Predator-Prey Pursuit scenario using Independent Q-Learning with parameter sharing. We've seen how multiple agents can learn to cooperate to capture prey in a grid world environment.

Some observations and potential improvements:

1. The agents learn to pursue and capture prey, but their strategy might not be optimal.
2. We could improve coordination by implementing more advanced MARL techniques like QMIX or MADDPG.
3. The environment is relatively simple. We could increase complexity by adding obstacles or giving prey more sophisticated evasion strategies.
4. We're using a simple feedforward neural network. Recurrent neural networks (RNNs) might perform better by allowing agents to remember past observations.

In the next notebooks, we'll explore more complex scenarios and advanced MARL techniques to address some of these limitations.

## References

1. Tan, M. (1993). Multi-agent reinforcement learning: Independent vs. cooperative agents. In Proceedings of the tenth international conference on machine learning (pp. 330-337).
2. Tampuu, A., Matiisen, T., Kodelja, D., Kuzovkin, I., Korjus, K., Aru, J., ... & Vicente, R. (2017). Multiagent cooperation and competition with deep reinforcement learning. PloS one, 12(4), e0172395.