# Setup

We use [OpenSpiel](https://github.com/deepmind/open_spiel) library for this setting. OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning and search/planning in games.

## Imports

Import the OpenSpiel and other auxiliary libraries.

In [None]:
"""Useful imports"""

!pip install --upgrade open_spiel

In [None]:

import dataclasses
import math
import re
from typing import Dict, List, Optional, Tuple


import datetime
from matplotlib import animation
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import time

from IPython.display import HTML

from open_spiel.python import policy
from open_spiel.python import policy as policy_std
from open_spiel.python.mfg import distribution as distribution_std
from open_spiel.python.mfg import value as value_std
from open_spiel.python.mfg.algorithms import best_response_value
from open_spiel.python.mfg.algorithms import boltzmann_policy_iteration
from open_spiel.python.mfg.algorithms import distribution
from open_spiel.python.mfg.algorithms import fictitious_play
from open_spiel.python.mfg.algorithms import fixed_point
from open_spiel.python.mfg.algorithms import greedy_policy
from open_spiel.python.mfg.algorithms import mirror_descent
from open_spiel.python.mfg.algorithms import munchausen_mirror_descent
from open_spiel.python.mfg.algorithms import nash_conv
from open_spiel.python.mfg.algorithms import policy_value
from open_spiel.python.mfg.games import factory
import pyspiel

## Forbidden states

In [None]:
forbidden_states_grid = [
    '#############',
    '#     #     #',
    '#     #     #',
    '#           #',
    '#     #     #',
    '#     #     #',
    '### ##### ###',
    '#     #     #',
    '#     #     #',
    '#           #',
    '#     #     #',
    '#     #     #',
    '#############',
]

def grid_to_forbidden_states(grid):
  """Converts a grid into string representation of forbidden states.

  Args:
    grid: Rows of the grid. '#' character denotes a forbidden state. All rows
      should have the same number of columns, i.e. cells.

  Returns:
    String representation of forbidden states in the form of x (column) and y
    (row) pairs, e.g. [1|1;0|2].
  """
  forbidden_states = []
  num_cols = len(grid[0])
  for y, row in enumerate(grid):
    assert len(row) == num_cols, f'Number of columns should be {num_cols}.'
    for x, cell in enumerate(row):
      if cell == '#':
        forbidden_states.append(f'{x}|{y}')
  return '[' + ';'.join(forbidden_states) + ']'

FOUR_ROOMS_FORBIDDEN_STATES = grid_to_forbidden_states(forbidden_states_grid)
forbidden_states_indicator = np.array([[math.nan if c=='#' else 0 for c in [*row]] for row in forbidden_states_grid])

four_rooms_default_setting = {
    'forbidden_states': FOUR_ROOMS_FORBIDDEN_STATES,
    'horizon': 41,
    'initial_distribution': '[1|1]',
    'initial_distribution_value': '[1.0]',
    'size': 13,
    'only_distribution_reward': True,
}

## Helper methods for visualization

The state representation and distribution of each game would be different. OpenSpiel does not provide any built in visualization capabilities. We define some basic methods for displaying the two-dimensional grid and the distribution for our game.

In [None]:
"""Helper methods for visualization. These are game specific."""


def decode_distribution(game: pyspiel.Game,
                        dist: Dict[str, float],
                        nans: bool = True) -> np.ndarray:
  """Decodes the distribution of a 2D crowd modelling game from a dictionary."""
  # Extract the size of the distribution from the game parameters. Time, i.e.
  # horizon is the leading dimension so that we can easily present the temporal
  # aspect.
  params = game.get_parameters()
  dist_size = (params['horizon'], params['size'], params['size'])
  decoded = np.zeros(dist_size)

  for key, value in dist.items():
    m = re.fullmatch(r'\((?P<x>\d+),\s*(?P<y>\d+),\s*(?P<t>\d+)\)', key)
    if m:
      g = m.group
      decoded[(int(g('t')), int(g('y')), int(g('x')))] = value

  return decoded


def get_policy_distribution(game: pyspiel.Game,
                            policy: policy_std.Policy) -> np.ndarray:
  """Returns the distribution of the policy."""
  dist_policy = distribution.DistributionPolicy(game, policy)
  return decode_distribution(game, dist_policy.distribution)


def animate_distributions(dists: np.ndarray,
                          fixed_cbar: bool = False) -> animation.FuncAnimation:
  """Animates the given distributions.

  Args:
    dists: An np.ndarray of batched distributions.
    fixed_cbar: If true, then the color bar will have a fixed scale over all
      distributions.

  Returns:
    A function animation.
  """
  if fixed_cbar:
    vmin = np.min(dists)
    vmax = np.max(dists)
  else:
    vmin, vmax = None, None

  def frame(i):
    ax.cla()
    sns.heatmap(
        dists[i, ...],
        square=True,
        cmap=plt.cm.viridis,
        linecolor='white',
        linewidths=0.1,
        ax=ax,
        cbar=True,
        cbar_ax=cbar_ax,
        vmin=vmin,
        vmax=vmax)

  grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
  fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw=grid_kws, figsize=(7, 5))
  anim = animation.FuncAnimation(
      fig=fig, func=frame, frames=dists.shape[0], interval=50, blit=False)
  # This prevents plot output at each frame.
  plt.close()
  return anim


@dataclasses.dataclass
class RunResult:
  """Holds the result of running an algorithm.

  Attributes:
    policy: The resulting policy.
    dists: An np.ndarray that contains the distributions at horizon for each
      iteration.
    nash_convs: Nash Conv metrics at each iteration.
    last_dist: The distribution for the last iteration of the algorithm.
  """
  policy: policy_std.Policy
  dists: np.ndarray
  nash_convs: np.ndarray
  last_dist: np.ndarray



def run_algorithm(game: pyspiel.Game, algo, num_iterations: int,
                  learning_rate=None, init_policy=None):
  """Runs the algorithm for specified number of iterations.

  Args:
    game: An MFG.
    algo: Algorithm to use.
    num_iterations: Number of iterations.

  Returns:
    The final policy and the Nash Conv values at each iteration.
  """
  nash_convs = []
  dists = []
  current_policy = init_policy
  dist = None
  # Added to save the initialization
  startt = time.time()
  if not current_policy:
    current_policy = algo.get_policy()
  nash_convs.append(nash_conv.NashConv(game, current_policy).nash_conv())
  dist = get_policy_distribution(game, current_policy)
  # dists.append(dist[-1, :]) # if single population
  dists.append(dist)
  print("Done iteration = 0, \ttime = ", time.time() - startt, "\tnash_conv = ", nash_convs[-1])
  for i in range(num_iterations):
    startt = time.time()
    if learning_rate:
      algo.iteration(learning_rate=learning_rate)
    else:
      algo.iteration()
    current_policy = algo.get_policy()
    nash_convs.append(nash_conv.NashConv(game, current_policy).nash_conv())
    dist = get_policy_distribution(game, current_policy)
    dists.append(dist)
    if (i+1)%2==0:
      print("Done iteration = ", i+1, "\ttime = ", time.time() - startt, "\tnash_conv = ", nash_convs[-1])
    # print("run_algorithm: distribution: ", dists[-1])

  return RunResult(
      policy=current_policy,
      dists=np.stack(dists),
      nash_convs=np.array(nash_convs),
      last_dist=dist)


def display_result(result: RunResult):
  """Displays the run results."""
  sns.set(rc={'figure.figsize':(10, 6)})
  fig, ax = plt.subplots()
  ax.plot(result.nash_convs)
  ax.set_xlabel('iteration')
  ax.set_ylabel('Nash Conv')
  return HTML(animate_distributions(result.dists).to_jshtml())

In [None]:
# Exploitability
# Comparison of exploitability.
ft_size = 20
def display_exploitability(results: Dict[str, RunResult]):
  fig_exploitabilities = plt.gcf()
  nash_conv_df = pd.DataFrame.from_dict({name: result.nash_convs for name, result in results.items()})

  sns.set(rc={'figure.figsize':(15,8)})
  sns.set_theme(style="whitegrid")
  ax = sns.lineplot(data=nash_conv_df, palette="tab10", linewidth=2.5)
  ax.set_yscale('log')
  ax.set_xlabel('iterations', fontsize=ft_size)
  ax.set_ylabel('exploitability', fontsize=ft_size)
  plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize=ft_size)
  ax.set_xticklabels(ax.get_xticks(), size = ft_size)
  ax.set_yticklabels(ax.get_yticks(), size = ft_size)
  fig_exploitabilities.tight_layout()
  return fig_exploitabilities
# Usage:
# display_exploitability(results)

In [None]:
# Usage:
# n_steps = game.get_parameters()['horizon']
# steps = range(0,n_steps,2)
# fig_distributions = display_distribution_at_steps(results, steps, size=2)
ft_size = 20
def display_distribution_at_steps(results, steps, size=4, forbidden_states_indicator=None):
  num_steps = len(steps)
  num_results = len(results)
  fig, axs = plt.subplots(
      num_results,
      num_steps,
      sharex='col',
      sharey='row',
      figsize=(num_steps * size, num_results * size))
  for row, (name, result) in enumerate(results.items()):
    for i, step in enumerate(steps):
      d = result.last_dist[step]
      minval = round(np.amin(d), 3)
      maxval=round(np.amax(d), 3)
      if forbidden_states_indicator is not None:
        d = d + forbidden_states_indicator
      masked_array = np.ma.array (d, mask=np.isnan(d))
      cmap = plt.cm.viridis
      cmap.set_bad('grey',1.)
      ax = axs[row][i]
      ax.axis('off')
      ax.set_title(str(name) + "\n" + str(i) if not i else str(step), size = ft_size)
      im = ax.imshow(
          d,
          interpolation='nearest',
          cmap=plt.cm.viridis, vmin=minval, vmax=maxval)
      ticks=[round(minval + i*(maxval-minval)/4.0, 3) for i in range(5)]
      cbar = plt.colorbar(im, ax=ax, fraction=0.046, ticks=ticks)
      cbar.ax.tick_params(labelsize=ft_size)
      ax.set_xticklabels(ax.get_xticks(), size = ft_size)
      ax.set_yticklabels(ax.get_yticks(), size = ft_size)

  fig.tight_layout()
  return fig


# Run algos

In [None]:
settings = {
    # "with_large_noise": {"noise_intensity": 1.0},
    # "with_medium_noise": {"noise_intensity": 0.5},
    "with_small_noise": {"noise_intensity": 0.1},
    # "with_no_noise": {"noise_intensity": 0.0},
}

In [None]:
num_iterations = 300

setting_results = {}

for (sk,sv) in settings.items():
  print("\n\n\n Setting {}: noise_intensity={}\n\n\n".format(sk, sv.get("noise_intensity")))

  four_rooms_default_setting.update([("noise_intensity", sv.get("noise_intensity"))])
  game_name = 'mfg_crowd_modelling_2d'
  game_name_setting = 'mfg_crowd_modelling_2d_four_rooms_exploration'
  game = pyspiel.load_game(game_name, four_rooms_default_setting)
  init_policy = None
  #####
  print("start_time = ", datetime.datetime.now())
  start_time = time.time()
  print("start_time = ", start_time)
  ######
  start_time = time.time()
  fp = fictitious_play.FictitiousPlay(game)
  fp_result = run_algorithm(game, fp, num_iterations, init_policy=init_policy)
  print("FP DONE, time = ", time.time() - start_time)
  start_time = time.time()
  md = mirror_descent.MirrorDescent(game, lr=0.05)
  md_result = run_algorithm(game, md, num_iterations, init_policy=init_policy)
  print("OMD LR 0.1 DONE, time = ", time.time() - start_time)
  # start_time = time.time()
  # munchausen_md = munchausen_mirror_descent.MunchausenMirrorDescent(game, lr=0.1)
  # munchausen_md_result = run_algorithm(game, munchausen_md, num_iterations, init_policy=init_policy)
  # print("MOMD DONE, time = ", time.time() - start_time)
  start_time = time.time()
  fixedp = fixed_point.FixedPoint(game)
  fixedp_result = run_algorithm(game, fixedp, num_iterations, init_policy=init_policy)
  print("FixedP DONE, time = ", time.time() - start_time)
  start_time = time.time()
  fpd = fictitious_play.FictitiousPlay(game, lr=0.01)
  fpd_result = run_algorithm(game, fpd, num_iterations, init_policy=init_policy)
  print("Damped FP DONE, time = ", time.time() - start_time)
  start_time = time.time()
  fixedp_softmax = fixed_point.FixedPoint(game, temperature=0.1)
  fixedp_softmax_result = run_algorithm(game, fixedp_softmax, num_iterations, init_policy=init_policy)
  print("FixedP softmax DONE, time = ", time.time() - start_time)
  start_time = time.time()
  fpsoft = fictitious_play.FictitiousPlay(game, temperature=0.1)
  fpsoft_result = run_algorithm(game, fpsoft, num_iterations, init_policy=init_policy)
  print("FP softmax DONE, time = ", time.time() - start_time)
  start_time =  time.time()
  bpi = boltzmann_policy_iteration.BoltzmannPolicyIteration(game, lr=0.1)
  bpi_result = run_algorithm(game, bpi, num_iterations, init_policy=init_policy)
  print("BPI DONE, time = ", time.time() - start_time)
  ###
  results = {
    'Fictitious Play': fp_result,
    'Online Mirror Descent': md_result,
    # 'Munchausen OMD': munchausen_md_result,
    'Fixed Point': fixedp_result,
    'Damped Fixed Point': fpd_result,
    'Softmax Fixed Point': fixedp_softmax_result,
    'Softmax Fictitious Play': fpsoft_result,
    'Boltzmann Policy Iteration': bpi_result,
  }
  setting_results.update([(sk, results)])





# Plots

## Save data

In [None]:
from colabtools import fileedit


# # Downloading the results
# np.savez('/tmp/{}-setting_results.npz'.format(game_name_setting), setting_results=setting_results)
# # %download_file /tmp/setting_results.npz
# fileedit.download_file('/tmp/{}-setting_results.npz'.format(game_name_setting), ephemeral=True)

## Exploitability

It seems that we need to run this piece of code twice in order to have the correct figure size. The first time, the figure is smaller than expected. I suspect that the size is not well defined / fixed in the function display_exploitability.

In [None]:



# Plotting the results
for (sk, results) in setting_results.items():
  print("\n\n\n Setting {}\n\n\n".format(sk))
  s_sk = settings[sk]
  fig_exploitabilities = display_exploitability(results)
  fig_exploitabilities.savefig('/tmp/{}-noise{}_exploitabilities.pdf'.format(game_name_setting, s_sk.get("noise_intensity")))
  fileedit.download_file('/tmp/{}-noise{}_exploitabilities.pdf'.format(game_name_setting, s_sk.get("noise_intensity")), ephemeral=True)
  plt.show()

## Distributions

In this version, the plotting function has been modified to take extra parameters for the colorbar. If no parameters are given, then we are going to use the smallest and largest values of the distribution (beware that if there is a forbidden state, the smallest value is always 0 because there is no mass on forbidden states).

In [None]:
# Plotting the results
for (sk, results) in setting_results.items():
  print("\n\n\n Setting {}\n\n\n".format(sk))
  s_sk = settings[sk]
  fig_distributions = display_distribution_at_steps(results, range(0, 41, 5), 5, forbidden_states_indicator)
  fig_distributions.savefig('/tmp/{}-noise{}_distributions.pdf'.format(game_name_setting, s_sk.get("noise_intensity")))
  fileedit.download_file('/tmp/{}-noise{}_distributions.pdf'.format(game_name_setting, s_sk.get("noise_intensity")), ephemeral=True)
  plt.show()