# Counterfactual Multi-agent Policy Gradients (COMA) for the ma-gym env
Author: Christina Kouridi

We will implement [Counterfactual Multi-Agent Policy Gradients (COMA)](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17193/16614) to play in the `Combat` environment from ma-gym. Please read [the wiki of ma-gym](https://github.com/koulanurag/ma-gym/wiki/) for additional information on this environment.

## Environment overview

### Environment Description
![Combat](https://github.com/koulanurag/ma-gym/raw/master/static/gif/Combat-v0.gif)


Combat environment simulates a battle involving two opposing teams in a `20×20` grid. Each team consists of `m = 10` agents and their initial positions are sampled uniformly in a `5×5` square around the team center, which is picked uniformly in the grid.

#### Action Space
At each time step, an agent can perform one of the following actions:
*   move one cell in one of four directions;
*   attack another agent by specifying its ID `j` (there are `m` attack actions, each corresponding to one enemy agent);
*   do nothing.

#### Transition Dynamics
If agent A attacks agent B, then B’s health point will reduce by 1, but only if B is inside the firing range of A (its surrounding `3×3` area). Agents need one-time step of cooling down after an attack, during which they cannot attack. All agents start with three health points and die when their health reaches 0. A team will win if all agents in the other team die. The simulation ends when one team wins, or neither of the teams win within 40 time steps (a draw).

#### Observation Space
When the input to a model, each agent is represented by a set of one-hot binary vectors `{i, t, l, h, c}` encoding its unique ID, team ID, location, health points, and cooldown. A model controlling an agent also sees other agents in its visual range (`3×3` surrounding area).

#### Reward Settings
The model gets a reward of -1 if the team loses or draws at the end of the game. In addition, it also get a reward of −0.1 times the total health points of the enemy team, which encourages it to attack enemy bots.

#### Enemy Settings
The model controls one team during training, and the other team consists of bots that follow a hardcoded policy. The bot policy is to attack the nearest enemy agent if it is within its firing range. If not, it approaches the nearest visible enemy agent within the visual range. An agent is visible to all bots if it is inside the visual range of any individual bot. This shared vision gives an advantage to the bot team.

## Environment Set-up
The following command will download the required scripts and set up the environment. 

Execute the following block until you see a WARNING and a button asking to **RESTART RUNTIME**. Click the button and wait for the runtime to restart, **then proceed to the next block of code.**

In [None]:
!rm -rf /content/ma-gym  
!git clone https://github.com/koulanurag/ma-gym.git
%cd /content/ma-gym 
!pip install -q -e .

Run the following block of code to continue environment set-up AFTER you have restarted the RUNTIME from the previous block.**

In [2]:
!apt-get install -y xvfb python-opengl x11-utils > /dev/null 2>&1
!pip install pyvirtualdisplay > /dev/null 2>&1
!apt-get install x11-utils
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install -U gym[atari] > /dev/null 2>&1

Reading package lists... Done
Building dependency tree       
Reading state information... Done
x11-utils is already the newest version (7.7+3build1).
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 33 not upgraded.
Collecting setuptools
[?25l  Downloading https://files.pythonhosted.org/packages/41/fa/60888a1d591db07bc9c17dce2bcfb9f00ac507c0a23ecb827e76feb8f816/setuptools-49.1.0-py3-none-any.whl (789kB)
[K     |████████████████████████████████| 798kB 5.1MB/s 
[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.[0m
[?25hInstalling collected packages: setuptools
  Found existing installation: setuptools 47.3.1
    Uninstalling setuptools-47.3.1:
      Successfully uninstalled setuptools-47.3.1
Successfully installed setuptools-49.1.0


In [3]:
%matplotlib inline
import gym
import ma_gym
from ma_gym.wrappers import Monitor
from ma_gym.envs.combat.combat import Combat
import matplotlib.pyplot as plt
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

#### Example of playing Combat Using Random Policy

In [4]:
#################################################################
#####Here we have changed n_agents=20, n_opponents=20 to 10######
##################Updated on 19/20/2020##########################
env = wrap_env(Combat(grid_shape=(20, 20), n_agents=10, n_opponents=10))
done_n = [False for _ in range(env.n_agents)]
ep_reward = 0

obs_n = env.reset()
while not all(done_n):
    obs_n, reward_n, done_n, info = env.step(env.action_space.sample())
    ep_reward += sum(reward_n)
    env.render()
env.close()
# To improve the training efficiency, render() is not necessary during the training.
# We provide the render and video code here just want to demonstrate how to debugging and analysis.
show_video()



## Code


### Imports

In [7]:
from collections import namedtuple
import datetime as dt

import gym
import ma_gym

import matplotlib.pyplot as plt

import numpy as np

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

### Model Arguments

In [8]:
Episode = namedtuple('Episode', ['obs', 'state', 'actions', 'actions_onehot', 'rewards', 'obs_next', 
                                 'state_next','terminated'])

# most values align with those given in the original paper
class Args:
    def __init__(self, n_actions=None, n_agents=None, state_shape=None,
                 obs_shape=None, seed=123, rnn_hidden_dim=64, critic_dim=128,
                 lr_actor=.001, lr_critic=.001, epsilon=0.6, anneal_epsilon=.0003, min_epsilon=.02,
                 td_lambda=0.9, grad_norm_clip=5.0, gamma=0.99, target_update_cycle=10,
                 log_every=50, n_episodes=2_000, evaluate=False):

        self.n_actions = n_actions
        self.n_agents = n_agents
        self.state_shape = state_shape
        self.obs_shape = obs_shape
        self.gamma = gamma
        self.evaluate = evaluate

        self.grad_norm_clip = grad_norm_clip

        self.cuda = torch.cuda.is_available()
        self.device = 'cuda' if self.cuda else 'cpu'

        self.seed = seed

        self.rnn_hidden_dim = rnn_hidden_dim
        self.critic_dim = critic_dim
        self.lr_actor = lr_actor
        self.lr_critic = lr_critic

        self.epsilon = epsilon
        self.anneal_epsilon = anneal_epsilon
        self.min_epsilon = min_epsilon

        self.td_lambda = td_lambda

        self.n_episodes = n_episodes
        self.log_every = log_every
        self.target_update_cycle = target_update_cycle

### Critic and Actor network definitions 

In [9]:
class ComaCritic(nn.Module):
    def __init__(self, input_shape, args):
        super(ComaCritic, self).__init__()

        self.fc1 = nn.Linear(input_shape, args.critic_dim)
        self.fc2 = nn.Linear(args.critic_dim, args.critic_dim)
        self.fc3 = nn.Linear(args.critic_dim, args.n_actions)

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        # Note: using leaky RELU to mitigate vanishing gradients
        x = F.leaky_relu(self.fc1(inputs))
        x = F.leaky_relu(self.fc2(x))
        q = self.fc3(x)
        return q

class PolicyRnn(nn.Module):
    def __init__(self, input_shape: int, args):
        super(PolicyRnn, self).__init__()
        self.rnn_hidden_dim = args.rnn_hidden_dim

        self.fc1 = nn.Linear(input_shape, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, args.n_actions)

    def init_hidden(self) -> torch.tensor:
        # Note zero initialisation here
        return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()

    def forward(self, obs: torch.tensor, hidden_state: torch.tensor) -> (torch.tensor, torch.tensor):
        h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
        x = F.leaky_relu(self.fc1(obs))
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        return q, h

### COMA learner

In [10]:
class ComaAgent:
    def __init__(self, args):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape

        self.policy = Coma(args)

        self.args = args

    def act(self, obs: np.ndarray, last_action: np.ndarray, agent_num: int, epsilon: float,
            evaluate: bool = False) -> int:

        # Generate actor inputs >> concatenation of observation, previous action, agent id (one hot)
        agent_ids = np.zeros(self.n_agents)
        agent_ids[agent_num] = 1.

        inputs = np.hstack((obs, last_action, agent_ids))
        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0).to(self.args.device)

        # Select hidden state for relevant agent
        hidden_state = self.policy.eval_hidden[agent_num, :].to(self.args.device)

        # Get Q values and update hidden state for agent
        q_value, self.policy.eval_hidden[agent_num, :] = self.policy.actor.forward(inputs, hidden_state)
        return self._sample_action(q_value.detach(), epsilon, evaluate, self.n_actions)

    def train(self, episode: Episode, train_step: int, epsilon: float = None) -> float:
        return self.policy.learn(episode=episode, train_step=train_step, epsilon=epsilon)

    @staticmethod
    def _sample_action(q_values, epsilon, evaluate, n_actions):
        prob = torch.nn.functional.softmax(q_values, dim=-1)  # generate a probability distribution over q values

        if evaluate:  # if in evaluate mode simply return the max probability
            return torch.argmax(prob).cpu().item()
        else:  # otherwise re-weight probabilities by mixing in a uniform distribution over n_actions equal to epsilon
            prob = ((1 - epsilon) * prob + torch.ones_like(prob) * epsilon / n_actions)
            return Categorical(prob).sample().cpu().item()

class Coma:
    def __init__(self, args):
        self.args = args
        actor_input_shape, critic_input_shape = self._get_input_shapes(self.args)

        self.actor = PolicyRnn(actor_input_shape, args).to(self.args.device)

        self.online_critic = ComaCritic(critic_input_shape, self.args).to(self.args.device)
        self.target_critic = ComaCritic(critic_input_shape, self.args).to(self.args.device)
        self.target_critic.load_state_dict(self.online_critic.state_dict())

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.lr_actor)
        self.critic_optimizer = torch.optim.Adam(self.online_critic.parameters(), lr=args.lr_critic)

        self.eval_hidden = None

    @staticmethod
    def _get_input_shapes(args) -> (int, int):
        # Actor input >> observation + previous action + agent id
        actor_input_shape = args.obs_shape
        actor_input_shape += args.n_actions
        actor_input_shape += args.n_agents

        # Critic input >> state + agent observation + agent id + other agents' actions + all agents' previous actions
        critic_input_shape = args.state_shape
        critic_input_shape += args.obs_shape
        critic_input_shape += args.n_agents
        critic_input_shape += args.n_actions * args.n_agents * 2

        return actor_input_shape, critic_input_shape

    def init_hidden(self):
        self.eval_hidden = self.actor.init_hidden().expand(self.args.n_agents, -1)

    def learn(self, episode: Episode, train_step: int, epsilon: float) -> float:
        self.init_hidden()

        q_values = self._train_critic(episode, train_step)
        action_prob = self._get_action_prob(episode, epsilon)
    
        q_taken = torch.gather(q_values, dim=2, index=episode.actions).squeeze(2)
        # get probabilities of actual actions taken
        pi_taken = torch.gather(action_prob, dim=2, index=episode.actions).squeeze(2)
        log_pi_taken = torch.log(pi_taken)

        # counterfactual baseline
        baseline = (q_values * action_prob).sum(dim=2, keepdim=True).squeeze(2).detach()
        advantage = (q_taken - baseline).detach()

        # policy loss using reinforce;  negative sign as we want to ascend
        loss = - (advantage * log_pi_taken).sum()

        self.actor_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.grad_norm_clip)  # clip gradients
        self.actor_optimizer.step()

        return loss.item()

    def _train_critic(self, episode: Episode, train_step: int):
        # Create an actions_next tensor by concatenating the action history, offset by 1, with a zero action
        # for the episode final timestep
        actions_next_offset = episode.actions[1:]
        padded_next_action = torch.zeros(*actions_next_offset[-1].shape, dtype=torch.long,
                                         device=self.args.device).unsqueeze(0)
        episode_actions_next = torch.cat((actions_next_offset, padded_next_action), dim=0)

        q_evals, q_next_target = self._get_q_values(episode)
        q_values = q_evals.clone()

        q_evals = torch.gather(q_evals, dim=2, index=episode.actions).squeeze(2)
        q_next_target = torch.gather(q_next_target, dim=2, index=episode_actions_next).squeeze(2)
        targets = td_lambda_target(episode, q_next_target.cpu(), self.args).to(self.args.device)

        td_error = targets.detach() - q_evals

        # the value loss is a simple L2 loss
        loss = (td_error ** 2).sum()

        self.critic_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.online_critic.parameters(), self.args.grad_norm_clip)
        self.critic_optimizer.step()

        if train_step and not train_step % self.args.target_update_cycle:
            self.target_critic.load_state_dict(self.online_critic.state_dict())

        return q_values

    def _get_critic_inputs(self, episode: Episode, transition_idx: int):
        # Replicate the episode action histories for each agent
        episode_actions_onehot_repeated = episode.actions_onehot[transition_idx].view((1, -1)).repeat(
            self.args.n_agents, 1)

        # If the first transition in the episode, create a zero action vector for the previous action
        if transition_idx == 0:
            episode_actions_onehot_last_repeated = torch.zeros_like(episode_actions_onehot_repeated).to(
                self.args.device)
        else:
            episode_actions_onehot_last_repeated = episode.actions_onehot[transition_idx - 1].view((1, -1)).repeat(
                self.args.n_agents, 1)

        # If the last transition in the episode, then create a zero action vector
        if transition_idx != episode.obs.shape[0] - 1:
            episode_actions_onehot_next = episode.actions_onehot[transition_idx + 1]
        else:
            episode_actions_onehot_next = torch.zeros(*episode.actions_onehot[0].shape).to(self.args.device)

        episode_actions_onehot_next_repeated = episode_actions_onehot_next.view((1, -1)).repeat(self.args.n_agents, 1)

        episode_state_expanded = episode.state[transition_idx].unsqueeze(0).expand(self.args.n_agents, -1)
        episode_state_next_expanded = episode.state_next[transition_idx].unsqueeze(0).expand(self.args.n_agents, -1)

        # Invert identity matrix to marginalise out the current agent's actions
        action_mask = torch.tensor(1) - torch.eye(self.args.n_agents)
        # Note: .repeat_interleave() mirrors numpy's .repeat() behaviour (repeats elements of the tensor)
        action_mask = action_mask.view(-1).repeat_interleave(self.args.n_actions).view(
            self.args.n_agents, -1).to(self.args.device)

        inputs, inputs_next = [], []

        inputs.append(episode.obs[transition_idx])
        inputs.append(episode_state_expanded)
        inputs.append(episode_actions_onehot_last_repeated)
        inputs.append(episode_actions_onehot_repeated * action_mask.unsqueeze(0))
        inputs.append(torch.eye(self.args.n_agents).to(self.args.device))
        inputs = torch.cat([X.reshape(self.args.n_agents, -1) for X in inputs], dim=1)

        inputs_next.append(episode.obs_next[transition_idx])
        inputs_next.append(episode_state_next_expanded)
        inputs_next.append(episode_actions_onehot_next_repeated)
        inputs_next.append(episode_actions_onehot_next_repeated * action_mask.unsqueeze(0))
        inputs_next.append(torch.eye(self.args.n_agents).to(self.args.device))
        inputs_next = torch.cat([X.reshape(self.args.n_agents, -1) for X in inputs_next], dim=1)

        return inputs.to(self.args.device), inputs_next.to(self.args.device)

    def _get_q_values(self, episode: Episode) -> torch.tensor:
        q_evals, q_targets = [], []

        for transition_idx in range(episode.obs.shape[0]):
            inputs, inputs_next = self._get_critic_inputs(episode, transition_idx)

            # Online net generates q values against the current state, target net against next state
            q_eval = self.online_critic(inputs).view(self.args.n_agents, -1)
            q_target = self.target_critic(inputs_next).view(self.args.n_agents, -1)

            q_evals.append(q_eval)
            q_targets.append(q_target)

        q_evals_s = torch.stack(q_evals, dim=0)
        q_targets_s = torch.stack(q_targets, dim=0)

        return q_evals_s.to(self.args.device), q_targets_s.to(self.args.device)

    def _get_actor_inputs(self, episode: Episode, transition_idx: int):
        inputs = list()

        inputs.append(episode.obs[transition_idx])
        if transition_idx == 0:
            inputs.append(torch.zeros_like(episode.actions_onehot[transition_idx]).to(self.args.device))
        else:
            inputs.append(episode.actions_onehot[transition_idx - 1])
        inputs.append(torch.eye(self.args.n_agents).to(self.args.device))
        return torch.cat([X.reshape(self.args.n_agents, -1) for X in inputs], dim=1)

    def _get_action_prob(self, episode: Episode, epsilon: float) -> torch.tensor:
        transitions_action_prob = []

        for transition_idx in range(episode.obs.shape[0]):
            inputs = self._get_actor_inputs(episode, transition_idx)
            outputs, self.eval_hidden = self.actor(inputs, self.eval_hidden)

            outputs = outputs.view(self.args.n_agents, -1)
            transition_action_prob = torch.nn.functional.softmax(outputs, dim=-1)
            transitions_action_prob.append(transition_action_prob)

        transitions_action_prob_s = torch.stack(transitions_action_prob, dim=0).cpu()
        action_probs = ((1 - epsilon) * transitions_action_prob_s) + \
                       torch.ones_like(transitions_action_prob_s) * epsilon / self.args.n_actions
        return action_probs.to(self.args.device)


def td_lambda_target(episode: Episode, q_targets: torch.tensor, args) -> torch.tensor:
    episode_len = episode.obs.shape[0]

    terminated = ~episode.terminated.repeat(1, args.n_agents).cpu()
    reward_repeated = episode.rewards.repeat((1, args.n_agents)).cpu()  # expand central episode reward for each agent

    n_step_return = torch.zeros((episode_len, args.n_agents, episode_len))
    for transition_idx in range(episode_len-1, -1, -1):  # stepping backwards through episode
        # First n_step_return update initialised with the q_target estimate at that timestep
        n_step_return[transition_idx, :, 0] = reward_repeated[transition_idx] + \
                                              args.gamma * q_targets[transition_idx] * terminated[transition_idx]
        for n in range(1, episode_len - transition_idx):  # and then discounts this by gamma at each preceding timestep
            n_step_return[transition_idx, :, n] = reward_repeated[transition_idx] + \
                                                  args.gamma * n_step_return[transition_idx+1, :, n-1]

    lambda_return = torch.zeros((episode_len, args.n_agents))
    for transition_idx in range(episode_len):
        returns = torch.zeros(args.n_agents)
        for n in range(1, episode_len - transition_idx):
            returns += args.td_lambda**(n-1) * n_step_return[transition_idx, :, n-1]
        lambda_return[transition_idx] = (1-args.td_lambda)*returns + args.td_lambda**(episode_len-transition_idx-1) *\
            n_step_return[transition_idx, :, episode_len-transition_idx-1]
    return lambda_return

### Training

In [None]:
ENV_NAME = 'Combat-v0'
COMBAT_AGENTS = 10

env = gym.make(ENV_NAME, grid_shape=(20, 20), n_agents=COMBAT_AGENTS, n_opponents=COMBAT_AGENTS)

n_obs = env.observation_space[0].shape[0]
n_actions = env.action_space[0].n
n_agents = env.n_agents

ARGS = Args(
    n_agents=n_agents,
    n_actions=n_actions,
    state_shape=n_obs * n_agents,  # could also incorporate action history
    obs_shape=n_obs
)
agents = ComaAgent(ARGS)

print('\n')
print(f'Starting env {ENV_NAME} | Action space: {env.action_space} | Obs space: {env.observation_space}')
print(f'Using device {"CUDA" if ARGS.cuda else "CPU"}')
print('\n')

episode_rewards = []
epsilon = 0 if ARGS.evaluate else ARGS.epsilon

for episode_idx in range(1, ARGS.n_episodes + 1):
    agents.policy.init_hidden()

    if epsilon > ARGS.min_epsilon:
        epsilon -= ARGS.anneal_epsilon

    current_obs_n = env.reset()

    done_n = [False for a in range(env.n_agents)]
    done = all(done_n)

    ep_reward = 0
    ep_step = 0

    last_actions_c = np.zeros((env.n_agents, ARGS.n_actions))

    obs_h, state_h, actions_h, actions_onehot_h, rewards_h, obs_next_h, state_next_h, \
        terminated_h = [], [], [], [], [], [], [], []

    while not done:
        actions_c, actions_onehot_c = [], []

        for agent_id in range(env.n_agents):
            action_c = agents.act(obs=current_obs_n[agent_id], last_action=last_actions_c[agent_id],
                                  agent_num=agent_id, epsilon=epsilon, evaluate=False)

            action_onehot_c = np.zeros(env.action_space[0].n)
            action_onehot_c[action_c] = 1
            last_actions_c[agent_id] = action_onehot_c

            actions_c.append(action_c)
            actions_onehot_c.append(action_onehot_c)

        next_obs_n, reward_n, done_n, _ = env.step(actions_c)

        # if not episode_idx % ARGS.log_every:
        #     env.render()

        done = all(done_n)

        state = []
        for obs in current_obs_n:
            state.extend(obs)

        next_state = []
        for next_obs in next_obs_n:
            next_state.extend(next_obs)

        obs_h.append(current_obs_n)
        obs_next_h.append(next_obs_n)

        state_h.append(state)
        state_next_h.append(next_state)

        actions_h.append(np.reshape(actions_c, [n_agents, 1]))
        actions_onehot_h.append(actions_onehot_c)

        rewards_h.append([sum(reward_n)])
        terminated_h.append([done])

        current_obs_n = next_obs_n

        ep_reward += sum(reward_n)
        ep_step += 1

    episode = Episode(
        obs=torch.tensor(obs_h, dtype=torch.float, device=ARGS.device),
        state=torch.tensor(state_h, dtype=torch.float, device=ARGS.device),
        actions=torch.tensor(actions_h, dtype=torch.long, device=ARGS.device),
        actions_onehot=torch.tensor(actions_onehot_h, dtype=torch.float, device=ARGS.device),
        rewards=torch.tensor(rewards_h, dtype=torch.float, device=ARGS.device),
        obs_next=torch.tensor(obs_next_h, dtype=torch.float, device=ARGS.device),
        state_next=torch.tensor(state_next_h, dtype=torch.float, device=ARGS.device),
        terminated=torch.tensor(terminated_h, dtype=torch.bool, device=ARGS.device)
    )

    loss = agents.train(episode, episode_idx, epsilon=epsilon)

    episode_rewards.append(ep_reward)

    if not episode_idx % ARGS.log_every:
        time.sleep(0.1)  # pause to show env final state
        print(f'On episode {episode_idx:,d} // '
              f'Epsilon: {epsilon:.2f} // '
              f'Mean reward: {np.mean(episode_rewards[-ARGS.log_every:]):.1f} // '
              f'Min reward {np.min(episode_rewards[-ARGS.log_every:]):.1f} // '
              f'Max reward {np.max(episode_rewards[-ARGS.log_every:]):.1f}')

## Analysis on Performance

Counterfactual Multi-agent Policy Gradients (COMA) is a multi-agent actor-critic method for cooperative tasks. It uses a centralised critic to estimate the Q-function and decentralised actors to optimise the agents' policies, both of which are explicitly modelled with deep neural networks. To address the challenge of multi-agent credit assignment, it uses a novel counterfactual baseline that marginalises out a single agent's action, while keeping the other agents' actions fixed. 

It is expected that COMA would work well on this task due to the following reasons: 

*   To achieve good performance on the Combat task, cooperation among agents is required. Each agent can only attack one enemy at a time, but an agent can be attacked by more than one enemies, thus agents in the team that we are controlling need to learn to surrender and attack enemies while ideally keeping some distance from other enemies that can potentially attack them. COMA is theoretically suited for that, as it facilitates cooperation.

*   COMA's counterfactual baseline results in appropriate assignment of credit to agents' actions, which facilitates learning. In typical cooperative settings, joint actions generate only global rewards, making it difficult for each agent to deduce its own contribution to the team’s success. The counterfactual baseline aims to deduce the contribution of each agent's action, and therefore encourages individual agents to sacrifice for the greater good (i.e. not running away from the enemy if there is prospect for another agent in the team to kill that enemy when surrounded). 


The following reasons may explain the suboptimal performance of our implementation:

*   The COMA model works best when the global environment observation is passed to the actor, which enables agents' policies to have full observability to the environment when choosing actions. Unfortunately the Combat environment does not return the environment observation, but only a local observation for each agent which consists of a 5x5 field of view. The environment observation is  approximated by aggregating all agent observations, however this approach is very naive as the field of views overlap. This results in more myopic and suboptimal behaviour than having the full environment observation.

*   Due to parameter sharing, all agents could have been processed in parallel, with each agent for each episode and time step occupying one batch entry. However, batch-learning was not used in our code, neither was a buffer from which to sample experience.

*   The exploitation - exploration trade-off is challenging in environments like Combat environment that consist of many agents. A more tailed exploration strategy for multi-agent settings would result in better performance.

*   As many enemies can attack one agent, but one agent can only attack on enemy, most often than not rewards are negative. Positive rewards are therefore scarce, which poses a challenge for RL whose central premise is reward maximisation. 

