In [1]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

In [2]:
mse = nn.MSELoss()
sf = nn.Softmax(dim=1)

### Initialize agents and mediator

In [3]:
num_of_actions = 2
MEDIATOR_ACTION = 2
hidden_size = 16
lr_actor = 1e-3
lr_critic = 1e-2
lr_lambda = 1e-2

NUM_OF_SEEDS = 10
num_of_agents = 3
PG_MULTIPLIER = 5 if num_of_agents > 10 else 2
NORM_CONST = 1 / (PG_MULTIPLIER - 1)

P_CONSTRAINT = True
P_COEF = 0.99

## Training

In [4]:
batch_size = 128
n_episodes = 20_000

In [5]:
def train(n:int, π_m, R, π_a, π_total):

  class NormalizingLayer(nn.Module):
    def forward(self, x):
        return x / num_of_agents * 2 - 1

  torch.manual_seed(n)

  # Agents nets
  agents_actors = [nn.Linear(1, num_of_actions + 1) for _ in range(num_of_agents)]
  agents_critic = [nn.Linear(1, 1) for _ in range(num_of_agents)]
  
  # Optimizers (agents)
  opts_actor = [torch.optim.Adam(actor.parameters(), lr_actor, weight_decay=1e-3) for actor in agents_actors]
  opts_critic = [torch.optim.Adam(critic.parameters(), lr_critic, weight_decay=1e-3) for critic in agents_critic]

  # Mediator nets
  mediator_actor = nn.Sequential(NormalizingLayer(), nn.Linear(1, hidden_size), nn.ReLU(), nn.Linear(hidden_size, num_of_actions))
  mediator_critic = nn.Sequential(NormalizingLayer(), nn.Linear(1, 2 * hidden_size), nn.ReLU(), nn.Linear(2 * hidden_size, 1))

  # Optimizers (mediator)
  opt_mediator_actor = torch.optim.Adam(mediator_actor.parameters(), lr_actor, weight_decay=1e-3)
  opt_mediator_critic = torch.optim.Adam(mediator_critic.parameters(), lr_critic, weight_decay=1e-3)

  # Hyperparameters
  entropy_coef_start = 0.1
  entropy_coef = entropy_coef_start
  entropy_coef_finish = 0.001
  entropy_coef_step = 1 / ((entropy_coef_start / entropy_coef_finish) ** (1 / n_episodes))

  log_lambda_P = torch.zeros(1)

  # Training 
  dummy_state = torch.zeros((batch_size, 1)) 

  for episode in range(n_episodes):
    act = torch.empty((batch_size, num_of_agents)) 

    # Agents act
    for i, actor in enumerate(agents_actors):
      with torch.no_grad():
        l = actor(dummy_state)
      C = Categorical(logits=l)
      act[:, i] = C.sample()

    # Who is in coalition?
    coalition = torch.zeros((batch_size, num_of_agents))
    coal_1, coal_2 = torch.where(act == MEDIATOR_ACTION)
    coalition[coal_1, coal_2] = 1

    coal_sum = torch.sum(coalition, dim=1, keepdim=True)
    coal_state = coal_sum[coal_1].view(-1, 1)

    # Take action for coalition
    with torch.no_grad():
      mediator_logit = mediator_actor(coal_state)
    mediator_C = Categorical(logits=mediator_logit)

    # Swap "mediator" action on picked action
    act_final = torch.clone(act)
    act_final[coal_1, coal_2] = mediator_C.sample().float()

    # Compute reward
    reward = (- act_final + torch.sum(act_final, dim=1, keepdim=True) * (PG_MULTIPLIER / num_of_agents)) * NORM_CONST  # maximum reward is 1

    # Actor update
    for critic, actor, opt_critic, opt_actor, r, action in zip(agents_critic, agents_actors, opts_critic, opts_actor, reward.T, act.T):
      V = critic(dummy_state).squeeze(1)
      adv_pg = r - V.detach()
      adv = mse(r, V)
      logits = actor(dummy_state)
      C = Categorical(logits=logits)

      pg_loss = - adv_pg * C.log_prob(action)
      pg_loss = pg_loss.mean() - entropy_coef * C.entropy().mean()

      opt_critic.zero_grad()
      opt_actor.zero_grad()

      adv.backward()
      pg_loss.backward()

      opt_actor.step()
      opt_critic.step()

    # Compute sum of rewards of coalition
    total_reward = reward * coalition
    total_reward = total_reward.sum(dim=1)
    uniq_id_1 = torch.unique(coal_1)

    # Critic loss
    R_total = total_reward[uniq_id_1].flatten()
    V_total = mediator_critic(coal_sum[uniq_id_1]).flatten()
    adv_total = mse(V_total, R_total)


    # Log probabilities for policy gradients
    mediator_logit = mediator_actor(coal_state)
    mediator_C = Categorical(logits=mediator_logit)

    tot_rew = total_reward[coal_1].view(-1, 1)


    with torch.no_grad():
      tot_adv = tot_rew - mediator_critic(coal_state)

      if P_CONSTRAINT:
        mask = (1 < coal_state) & (coal_state < num_of_agents)

        p_reward = reward * (1 - coalition)
        pun_adv = p_reward.sum(dim=1)[coal_1].view(-1, 1)[mask] / (num_of_agents - coal_sum)[coal_1].view(-1, 1)[mask]
        pun_adv -= mediator_critic(coal_state + 1)[mask] / (coal_state[mask] + 1)

        tot_adv[mask] -= log_lambda_P.exp() * pun_adv

        # update lambda
        if pun_adv.shape[0] > 1:
          pun_adv[pun_adv < 0] *= P_COEF
          log_lambda_P += lr_lambda * pun_adv.mean()
        log_lambda_P = torch.clip(log_lambda_P, min=-4, max=4)

    # Actor's loss
    pg_total = - tot_adv.flatten() * mediator_C.log_prob(act_final[coal_1, coal_2])
    pg_total = pg_total.mean() - entropy_coef * mediator_C.entropy().mean()


    # Backward
    opt_mediator_actor.zero_grad()
    opt_mediator_critic.zero_grad()

    adv_total.backward()
    pg_total.backward()

    torch.nn.utils.clip_grad_value_(mediator_actor.parameters(), 0.1)
    torch.nn.utils.clip_grad_value_(mediator_critic.parameters(), 0.1)

    opt_mediator_actor.step()
    opt_mediator_critic.step()

    # schedule entropy
    entropy_coef = max(entropy_coef * entropy_coef_step, entropy_coef_finish)

  # Kind of logging
  R.append(reward.mean().item())

  with torch.no_grad():
    π_s = torch.empty(num_of_agents, num_of_actions + 1)
    for i, actor in enumerate(agents_actors):
      π_s[i] = sf(actor(torch.zeros(1, 1)))
    π_a.append(π_s.mean(dim=0)[2])

  π_m.append(act_final[coal_1, coal_2].mean())

  test_state = torch.arange(1, num_of_agents + 1).unsqueeze(1).float()

  with torch.no_grad():
    π_total.append(sf(mediator_actor(test_state.float())))

## Saving

In [6]:
def save_dict(pi_med, mean_rew, pi_agents, pi_total, filename):
  result_dict = {}
  result_dict["mediator_c"] = pi_med
  result_dict["mean_reward"] = mean_rew
  result_dict["prob_m"] = pi_agents
  result_dict['policies'] = pi_total

  torch.save(result_dict, filename)

## Run for all seeds

In [7]:
experiment_configs = [('3_naive', 3, False), ('3_p_constraint', 3, True), ('10_p_constraint', 10, True), ('25_p_constraint', 25, True)]

In [None]:
for conf in experiment_configs:
  name, num_of_agents, P_CONSTRAINT = conf
  PG_MULTIPLIER = 5 if num_of_agents > 10 else 2
  NORM_CONST = 1 / (PG_MULTIPLIER - 1)

  pi_med = []
  mean_rew = []
  pi_agents = []
  pi_total = []

  for i in range(NUM_OF_SEEDS):
    train(i, pi_med, mean_rew, pi_agents, pi_total)
    print(f"run {i + 1}")

  save_dict(pi_med, mean_rew, pi_agents, pi_total, f'./probs_{name}.pt')