In [1]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
from typing import Dict

In [2]:
mse = nn.MSELoss()
sf = nn.Softmax(dim=1)

## Hyperparameters

In [3]:
num_of_agents = 2
NUM_OF_ACTIONS = 2  # 3 if with mediator, 2 if solo 
MEDIATOR_ACTION = 3

hidden_size = 16

lr_actor = 1e-3
lr_critic = 1e-3
lr_lambda = 1e-3

NUM_OF_SEEDS = 50
IC_CONSTRAINT = True # False if Naive mediator, True if IC constrined mediator
IC_COEF = 2

## Logging function

In [4]:
def make_empty_dict():
  """
  Make dictionary for further logging.
  """
  prefix = ["med_agent", "med_agent", "agent"]
  suffix = ["solo", "full", "solo"]
  size = [NUM_OF_ACTIONS, NUM_OF_ACTIONS, NUM_OF_ACTIONS + 1]
  
  prob_dist = {}
  for p, s, sz in zip(prefix, suffix, size):
    for i in range(num_of_agents):
      prob_dist[p + f"_{i}_" + s] = torch.empty((NUM_OF_SEEDS, sz))

  return prob_dist

In [5]:
def log_pd(dictionary, n: int, med_actor, ag_actors):
  """
  Saves policies of agents and mediator to dictionary.
  """
  test_0 = torch.tensor([[1., 0., 0.]])
  test_1 = torch.tensor([[0., 1., 1.]])
  testf_0 = torch.tensor([[1., 1., 0.]])
  testf_1 = torch.tensor([[1., 1., 1.]])

  with torch.no_grad():
    for i, actor in enumerate(ag_actors):
      dictionary[f"agent_{i}_solo"][n] = sf(actor(torch.zeros((1, 1)))).squeeze(0)

    if MEDIATOR_ACTION == NUM_OF_ACTIONS:
      dictionary["med_agent_0_solo"][n] = sf(med_actor(test_0)).squeeze(0)
      dictionary["med_agent_1_solo"][n] = sf(med_actor(test_1)).squeeze(0)
      dictionary["med_agent_0_full"][n] = sf(med_actor(testf_0)).squeeze(0)
      dictionary["med_agent_1_full"][n] = sf(med_actor(testf_1)).squeeze(0)

## Sacrifice table

In [7]:
reward_table = torch.tensor([[1, 1], [3, 0], [5, 0],
                             [0, 3], [2, 2], [5, 0],
                             [0, 0], [0, 0], [0, 0]]).float()
table_width = 3

## Train function

In [11]:
def train(n: int, prob_dist: Dict[str, torch.Tensor]):
  entropy_coef_start = 5e-1
  entropy_coef = entropy_coef_start
  entropy_coef_finish = 1e-2
  entropy_coef_step = 1 / ((entropy_coef_start / entropy_coef_finish) ** (1 / n_episodes))

  log_lambda_IC = torch.zeros(num_of_agents)

  torch.manual_seed(n)

  # Init agents
  agents_actors = [nn.Linear(1, NUM_OF_ACTIONS + 1) for _ in range(num_of_agents)]
  agents_critic = [nn.Linear(1, 1) for _ in range(num_of_agents)]

  # Init mediator
  mediator_actor = nn.Sequential(nn.Linear(num_of_agents + 1, hidden_size), nn.ReLU(), nn.Linear(hidden_size, NUM_OF_ACTIONS))
  mediator_critic = nn.Sequential(nn.Linear(num_of_agents, 2 * hidden_size), nn.ReLU(), nn.Linear(2 * hidden_size, num_of_agents))

  # Agents optimizers
  opts_actor = [torch.optim.Adam(actor.parameters(), lr_actor, weight_decay=1e-3) for actor in agents_actors]
  opts_critic = [torch.optim.Adam(critic.parameters(), lr_critic, weight_decay=1e-3) for critic in agents_critic]
  
  # Mediator optimizers
  opt_mediator_actor = torch.optim.Adam(mediator_actor.parameters(), lr_actor, weight_decay=1e-3)
  opt_mediator_critic = torch.optim.Adam(mediator_critic.parameters(), lr_critic, weight_decay=1e-3)
  
  # Training
  dummy_state = torch.zeros((batch_size, 1)) 

  for episode in range(n_episodes):
    act = torch.zeros(batch_size, num_of_agents).long() 

    # Agents act
    with torch.no_grad():
      for i, actor in enumerate(agents_actors):
        l = actor(dummy_state)
        C = Categorical(logits=l)
        act[:, i] = C.sample()

    # Who is in coalition?
    coalition = torch.zeros_like(act, dtype=torch.float)
    coal_1, coal_2 = torch.where(act == MEDIATOR_ACTION)
    coalition[coal_1, coal_2] = 1

    coal_state = torch.cat([coalition[coal_1], coal_2.unsqueeze(1)], dim=1)

    act_final = torch.clone(act)
    
    if coal_1.numel():
      # Take action for coalition
      with torch.no_grad():
        mediator_logit = mediator_actor(coal_state)
      mediator_C = Categorical(logits=mediator_logit)

      # Swap "mediator" action on picked action
      act_final[coal_1, coal_2] = mediator_C.sample()

    # Compute reward
    reward_ids = act_final[:, 1] + act_final[:, 0] * table_width
    reward_ids = reward_ids.view(-1).long()
    reward = reward_table[reward_ids]

    # Actor update
    for critic, actor, opt_critic, opt_actor, r, action in zip(agents_critic, agents_actors, opts_critic, opts_actor, reward.T, act.T):
      V = critic(dummy_state).squeeze(1)
      adv_pg = r - V.detach()
      adv = mse(r, V)
      logits = actor(dummy_state)
      C = Categorical(logits=logits)

      pg_loss = - adv_pg * C.log_prob(action)
      pg_loss = pg_loss.mean() - entropy_coef * C.entropy().mean()

      opt_critic.zero_grad()
      opt_actor.zero_grad()

      adv.backward()
      pg_loss.backward()

      opt_actor.step()
      opt_critic.step()
    
    # Critic loss
    adv_total = mse(reward, mediator_critic(coalition))
    if coal_1.numel():

      # Sum of rewards for each coalition
      total_reward = reward * coalition
      total_reward = total_reward.sum(dim=1)

      # Computing policy gradients
      mediator_logit = mediator_actor(coal_state)
      mediator_C = Categorical(logits=mediator_logit)

      tot_rew = total_reward[coal_1].view(-1, 1)
      
      with torch.no_grad():
        tot_adv = tot_rew.view(-1) - (mediator_critic(coalition) * coalition).sum(dim=1)[coal_1]
 
        if IC_CONSTRAINT:
          mask = 1 < torch.sum(coalition, dim=1) # masking on coalitions with more than one participant

          for idx in range(num_of_agents):
            if mask.any():
              counterfactual_coal = torch.clone(coalition[mask])
              counterfactual_coal[:, idx] = 0
              ic_adv = reward[mask, idx] - mediator_critic(counterfactual_coal)[:, idx]

              new_ids = torch.logical_and(torch.sum(coalition[coal_1], dim=1) == num_of_agents, coal_2 == idx)
              tot_adv[new_ids] += log_lambda_IC[idx].exp() * ic_adv

              # Weighting
              ic_adv[ic_adv < 0] *= IC_COEF   
              
              # Lambda update
              log_lambda_IC[idx] -= lr_lambda * ic_adv.mean()
  
          log_lambda_IC = torch.clip(log_lambda_IC, min=-4, max=4)

      # Actor loss
      pg_total = - tot_adv.flatten() * mediator_C.log_prob(act_final[coal_1, coal_2])
      pg_total = pg_total.mean() - entropy_coef * mediator_C.entropy().mean()

      # Update mediator actor
      opt_mediator_actor.zero_grad()
      pg_total.backward()
      torch.nn.utils.clip_grad_value_(mediator_actor.parameters(), 0.1)
      opt_mediator_actor.step()

    # Update mediator critic  
    opt_mediator_critic.zero_grad()
    adv_total.backward()
    torch.nn.utils.clip_grad_value_(mediator_critic.parameters(), 0.1)
    opt_mediator_critic.step()

    # schedule entropy
    entropy_coef = max(entropy_coef * entropy_coef_step, entropy_coef_finish)

  log_pd(prob_dist, n, mediator_actor, agents_actors)

## Experiments

In [6]:
batch_size = 128
n_episodes = 10_000

In [12]:
experiment_configs =  [('no_mediator', 2, False), ('naive', 3, False), ('ic_mediator', 3, True)]

In [None]:
for conf in experiment_configs:
  name, NUM_OF_ACTIONS, IC_CONSTRAINT = conf
  prob_dist = make_empty_dict()
  for i in range(NUM_OF_SEEDS):
    train(i, prob_dist)
  torch.save(prob_dist, f"./probs_{name}.pt")