# Socially Optimal Restrictions for Continuous-Action Games
AAAI 2023 Main Track, Paper ID: 6770

## Setup and Definitions

### Imports

In [None]:
import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from collections import deque
import math
import collections
from IPython.display import HTML, display
from matplotlib.ticker import PercentFormatter

### Utilities

In [None]:
def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

In [None]:
class NoEquilibriumFoundException(Exception):
  def __init__(self, *args) -> None:
      super().__init__(*args)

class NoBestResponseFoundException(Exception):
  def __init__(self, *args) -> None:
      super().__init__(*args)

class NoOptimumFoundException(Exception):
  def __init__(self, *args) -> None:
      super().__init__(*args)

RestrictionSolverResult = collections.namedtuple('RestrictionSolverResult', 'game optimal_restriction optimal_nash_equilibrium optimal_social_utility initial_restriction initial_nash_equilibrium initial_social_utility info')
RestrictionSolverException = collections.namedtuple('RestrictionSolverException', 'game exception args')

In [None]:
class IntervalUnion:
  def __init__(self, intervals=[(-np.Inf, np.Inf)]):
    assert isinstance(intervals, list)
    assert all(isinstance(interval, tuple) for interval in intervals)
    
    self.intervals = intervals

  def __str__(self):
    intervals = ' '.join(f'[{a}, {b})' for a, b in self.intervals) if self.intervals else '()'
    return f'<IntervalUnion {intervals}>'

  def __repr__(self):
      return self.__str__()

  def __bool__(self):
    return bool(self.intervals)

  def __contains__(self, x):
    for (a, b) in self.intervals:
      if x < a:
        return False
      elif x <= b:
        return True

    return False

  def __len__(self):
    return len(self.intervals)

  def __eq__(self, other):
    for (a, b), (x, y) in zip(self.intervals, other.intervals):
      if a != x or b != y:
        return False
      
    return True

  def __hash__(self) -> int:
      return hash(tuple(self.intervals))

  def clone(self):
    return IntervalUnion(copy.deepcopy(self.intervals))

  def last_interval_before_or_within(self, x):
    for i, (a, b) in enumerate(self.intervals):
      if x < a:
        return i, (a, b), False
      elif x <= b:
        return i, (a, b), True

    return None, (None, None), False

  def first_interval_after_or_within(self, x):
    for i, (a, b) in reversed(list(enumerate(self.intervals))):
      if x >= b:
        return i, (a, b), False
      elif x >= a:
        return i, (a, b), True

    return None, (None, None), False

  def insert(self, x, y):
    if x >= y:
      return

    i, (a, b), v = self.last_interval_before_or_within(x)
    j, (c, d), w = self.first_interval_after_or_within(y)

    if i is None:
      self.intervals.append((x, y))
    elif j is None:
      self.intervals.insert(0, (x, y))
    else:
      self.intervals[i:j+1] = [(x if a is None else min(a, x), y if d is None else max(d, y))]

  def remove(self, x, y):
    if not self.intervals:
      return

    if x is None:
      x = self.intervals[0][0]

    if y is None:
      y = self.intervals[-1][1]

    if x >= y:
      return

    i, (a, b), v = self.last_interval_before_or_within(x)
    j, (c, d), w = self.first_interval_after_or_within(y)

    if i is not None and j is not None:
      if v and (a < x):
        if w and (d > y):
          self.intervals[i:j+1] = [(a, x), (y, d)]
        else:
          self.intervals[i:j+1] = [(a, x)]
      else:
        if w:
          self.intervals[i:j+1] = [(y, d)]
        else:
          self.intervals[i:j+1] = []

  def clone_and_remove(self, x, y):
    new_interval_union = IntervalUnion(copy.deepcopy(self.intervals))
    new_interval_union.remove(x, y)
    return new_interval_union

  def ndarray(self, step=1.0):
    return np.concatenate([np.arange(a, b, step) for a, b in self.intervals])

  @property
  def complement(self):
    if not self.intervals:
      return IntervalUnion()
    else:
      intervals = [(-np.Inf, self.intervals[0][0])] if self.intervals[0][0] != -np.Inf else []
      
      for i in range(1, len(self.intervals)):
        intervals.append((self.intervals[i-1][1], self.intervals[i][0]))
      
      if self.intervals[-1][1] != np.Inf:
        intervals.append((self.intervals[-1][1], np.Inf))

      return intervals
  
  @property
  def inner_complement(self):
    if not self.intervals:
      return IntervalUnion()
    else:
      return [(self.intervals[i-1][1], self.intervals[i][0]) for i in range(1, len(self.intervals))]

  @property
  def size(self):
    if not self.intervals:
      return 0.0
    elif not self.has_lower_bound() or not self.has_upper_bound():
      return np.inf
    else:
      return sum(b - a for a, b in self.intervals)

  def has_lower_bound(self):
    return (not self.intervals) or (not math.isinf(self.intervals[0][0]))

  def has_upper_bound(self):
    return (not self.intervals) or (not math.isinf(self.intervals[-1][1]))

  def upper_bound(self):
    return None if not self.intervals else self.intervals[-1][1]

  def lower_bound(self):
    return None if not self.intervals else self.intervals[0][0]

  def outer_bounds(self):
    return [] if not self.intervals else [self.intervals[0][0], self.intervals[-1][1]]

  def nearest_elements(self, x):
    if not self.intervals:
      return []

    for i, (a, b) in enumerate(self.intervals):
      if x < a:
        if i > 0:
          return ([self.intervals[i-1][1]] if x - self.intervals[i-1][1] <= a - x else []) + ([a] if x - self.intervals[i-1][1] >= a - x else [])
        else:
          return [a]
      elif x < b:
        return [x]

    return [self.intervals[-1][1]]

  def sample(self):
    assert self.has_lower_bound() and self.has_upper_bound()

    if not self.intervals:
      return None
    else:
      x = random.uniform(0.0, self.size)
      for i, (a, b) in enumerate(self.intervals):
        if x > b - a:
          x -= b - a
        else:
          return a + x

    return self.intervals[-1][1]

In [None]:
class UtilityFunction:
  def __init__(self, player):
    self.player = player

class QuadraticTwoPlayerUtility(UtilityFunction):
  def __init__(self, player, coeffs):
    super().__init__(player)

    self.coeffs = np.array(coeffs)

  def __call__(self, x):
    return self.coeffs[0] * x[0] ** 2 + self.coeffs[1] * x[1] ** 2 + self.coeffs[2] * x[0] * x[1] + self.coeffs[3] * x[0] + self.coeffs[4] * x[1] + self.coeffs[5]

  def best_response(self, x, action_space):
    best_responses = self.best_responses(x, action_space)

    if not best_responses:
      raise NoBestResponseFoundException()
    else:
      return best_responses[0]

  def best_responses(self, x, action_space):
    assert self.player is not None
    assert x in action_space, f'{x} is not allowed in {action_space}'

    if self.player == 0:  
      a, b, c, d, e, f = self.coeffs
    elif self.player == 1: 
      b, a, c, e, d, f = self.coeffs

    if a == 0:
      if c * x + d == 0:
        # Function is constant in the player: Any response is a best response
        raise NoBestResponseFoundException('Constant function!')
      elif c * x + d > 0:
        # Function is linear with positive slope: Maximum is upper bound if it exists
        if not action_space.has_upper_bound():
          raise NoBestResponseFoundException()
        else:
          return [action_space.upper_bound()]
      else:
        # Function is linear with negative slope: Maximum is lower bound if it exists
        if not action_space.has_lower_bound():
          raise NoBestResponseFoundException()
        else:
          return [action_space.lower_bound()]
    elif a > 0:
      # Function is convex in the player: Maximum is one of the outer bounds
      if not action_space.has_lower_bound() or not action_space.has_upper_bound():
        raise NoBestResponseFoundException()
      else:
        candidates = action_space.outer_bounds()
    else:
      # Function is concave in the player: Maximum is in the middle or close to it
      candidates = action_space.nearest_elements((c * x + d) / (-2 * a))

    candidate_values = [self((c, x)) for c in candidates]
    maximum_value = max(candidate_values)

    return [x for x, y in zip(candidates, candidate_values) if y == maximum_value]

  def social_optimum(self, action_space):
    assert self.player is None, 'The social optimum is only defined for a social utility function'

    a, b, c, d, e, f = self.coeffs
    if 4*a*b == c**2:
      raise NoOptimumFoundException()

    x, y = (c*e - 2*b*d) / (4*a*b - c**2), (c*d - 2*a*e) / (4*a*b - c**2)

    if x in action_space and y in action_space:
      return self((x, y))
    else:
      raise NotImplementedError()

  def __add__(self, other):
    return QuadraticTwoPlayerUtility(None if self.player != other.player else self.player, self.coeffs + other.coeffs)

  def __str__(self):
    return f'<QuadraticTwoPlayerUtility {self.coeffs[0]}x^2 + {self.coeffs[1]}y^2 + {self.coeffs[2]}xy + {self.coeffs[3]}x + {self.coeffs[4]}y + {self.coeffs[5]}>'

  def __repr__(self) -> str:
      return self.__str__()

In [None]:
class NormalFormGame:
  def __init__(self, action_space, utilities):
    self.action_space = action_space
    self.utilities = utilities

  @property
  def number_of_players(self):
    return len(self.utilities)

  def __str__(self):
    return f'<NormalFormGame n={self.number_of_players}, A={self.action_space}, u={self.utilities}>'

  def __repr__(self) -> str:
      return self.__str__()

class GovernedNormalFormGame(NormalFormGame):
  def __init__(self, action_space, utilities, social_utility):
    super().__init__(action_space, utilities)
    self.social_utility = social_utility

  def social_optimum(self, action_space=None):
    return self.social_utility.social_optimum(action_space or self.action_space)

  def __str__(self):
    return f'<GovernedNormalFormGame n={self.number_of_players}, A={self.action_space}, u={self.utilities}, social_utility={self.social_utility}>'

class GovernedNormalFormGameWithOracle(GovernedNormalFormGame):
  def __init__(self, action_space, utilities, social_utility, oracle):
    super().__init__(action_space, utilities, social_utility)
    self.oracle = oracle

  def __str__(self):
    return f'<GovernedNormalFormGameWithOracle n={self.number_of_players}, A={self.action_space}, u={self.utilities}, social_utility={self.social_utility}>'

In [None]:
def is_equilibrium(utilities, x, y, action_space):
  return np.any(np.isclose(x, np.array(utilities[0].best_responses(y, action_space)))) and np.any(np.isclose(y, np.array(utilities[1].best_responses(x, action_space))))

In [None]:
def hill_climbing_nash_equilibria(game: NormalFormGame, action_space: IntervalUnion, number_of_samples=10, number_of_steps=10, decimals=None):
  try:
    xs = { action_space.sample() for _ in range(number_of_samples) }
    ps = { (x, y) for x in xs for y in game.utilities[1].best_responses(x, action_space) }

    for i in range(number_of_steps):
      ps = { (brx, y) for (x, y) in ps for brx in game.utilities[0].best_responses(y, action_space) }
      ps = { (x, bry) for (x, y) in ps for bry in game.utilities[1].best_responses(x, action_space) }

    return { (round(x, decimals), round(y, decimals)) if decimals is not None else (x, y) for (x, y) in ps if is_equilibrium(game.utilities, x, y, action_space) }
  except NoBestResponseFoundException as e:
    raise NoEquilibriumFoundException(e)

def worst_hill_climbing_nash_equilibrium(game: GovernedNormalFormGame, action_space: IntervalUnion, decimals=None):
  nash_equilibria = hill_climbing_nash_equilibria(game, action_space, decimals=decimals)
  if not nash_equilibria:
    raise NoEquilibriumFoundException()
  else:
    return min(nash_equilibria, key=game.social_utility)

In [None]:
def quadratic_utility_nash_equilibria(game: NormalFormGame, action_space: IntervalUnion, decimals=None):
  assert all(isinstance(u, QuadraticTwoPlayerUtility) for u in game.utilities)

  u_1, u_2 = game.utilities
  a_1, b_1, c_1, d_1, e_1, f_1 = u_1.coeffs
  a_2, b_2, c_2, d_2, e_2, f_2 = u_2.coeffs

  # Analytical solution is only valid when both utility functions are concave
  # in the respective agents, and if the solution is allowed by action_space.
  # Otherwise, use hill climbing to determine the Nash Equilibria.
  if a_1 < 0 and b_2 < 0:
    x = (c_1*e_2 - 2*d_1*b_2) / (4*a_1*b_2 - c_1*c_2)
    y = (c_2*d_1 - 2*e_2*a_1) / (4*a_1*b_2 - c_1*c_2)

    if x in action_space and y in action_space:
      return { (round(x, decimals), round(y, decimals)) if decimals is not None else (x, y) }
    else:
      return hill_climbing_nash_equilibria(game, action_space, decimals=decimals)
  else:
    return hill_climbing_nash_equilibria(game, action_space, decimals=decimals)

def worst_quadratic_utility_nash_equilibrium(game: GovernedNormalFormGame, action_space: IntervalUnion, decimals=None):
  nash_equilibria = quadratic_utility_nash_equilibria(game, action_space, decimals=decimals)
  if not nash_equilibria:
    raise NoEquilibriumFoundException()
  else:
    return min(nash_equilibria, key=game.social_utility)

In [None]:
def absolute_improvement(result):
  return result.optimal_social_utility - result.initial_social_utility

def relative_improvement(result):
  return absolute_improvement(result) / abs(result.initial_social_utility) if result.initial_social_utility != 0 else np.Inf

def degree_of_restriction(result):
  return 1.0 - (result.optimal_restriction.size / result.initial_restriction.size)

### Algorithm

In [None]:
class IntervalUnionRestrictionSolver:
  def __init__(self, *, epsilon=0.1, decimals=None, timeout_steps=None, timeout_explored_restrictions=None):
    assert epsilon > 0

    self.epsilon = epsilon
    self.decimals = decimals or math.ceil(-math.log(self.epsilon, 10))
    self.timeout_steps = timeout_steps
    self.timeout_explored_restrictions = timeout_explored_restrictions

  def solve(self, game: GovernedNormalFormGame, *, nash_equilibrium_oracle=None) -> RestrictionSolverResult:
    # Decide which oracle function to use
    if nash_equilibrium_oracle is None:
      if isinstance(game, GovernedNormalFormGameWithOracle):
        nash_equilibrium_oracle = game.oracle
      else:
        nash_equilibrium_oracle = self._nash_equilibrium_oracle

    # Keep track of explored restrictions to avoid double work
    explored_restrictions, current_step = set(), 0
    
    # Initialize optimum with current restriction (i.e., full action_space)
    try:
      initial_restriction, initial_equilibrium = game.action_space, nash_equilibrium_oracle(game, game.action_space, decimals=self.decimals)
      explored_restrictions.add(initial_restriction)

      optimal_restriction, optimal_social_utility = initial_restriction, np.round(game.social_utility(initial_equilibrium), decimals=self.decimals)
    
      # Maintain a queue with all open (unexplored) restrictions
      restriction_queue = deque([(initial_restriction, initial_equilibrium)])
      while restriction_queue:
        current_restriction, current_equilibrium = restriction_queue.pop()
        
        for relevant_action in self._relevant_actions(current_equilibrium):
          current_step += 1

          new_restriction = current_restriction.clone_and_remove(round(relevant_action - self.epsilon, self.decimals), round(relevant_action + self.epsilon, self.decimals))
          
          if new_restriction and not (new_restriction in explored_restrictions):
            explored_restrictions.add(new_restriction)

            try:
              new_equilibrium = nash_equilibrium_oracle(game, new_restriction, decimals=self.decimals)
              restriction_queue.append((new_restriction, new_equilibrium))

              # Update optimum if new_restriction is better
              new_social_utility = np.round(game.social_utility(new_equilibrium), decimals=self.decimals)
              if (new_social_utility > optimal_social_utility) or (new_social_utility == optimal_social_utility and new_restriction.size > optimal_restriction.size):
                optimal_restriction, optimal_social_utility = new_restriction, new_social_utility

            except NoEquilibriumFoundException as e:
              # New restriction does not have an equilibrium, so we cannot use it for further restrictions
              continue
        
        # Check if one of the timeout conditions is met
        if (self.timeout_steps is not None and current_step >= self.timeout_steps) or (self.timeout_explored_restrictions is not None and len(explored_restrictions) >= self.timeout_explored_restrictions):
          break
    except NoEquilibriumFoundException as e:
      raise e
    else:
      optimal_equilibrium = nash_equilibrium_oracle(game, optimal_restriction, decimals=self.decimals)
      initial_social_utility = np.round(game.social_utility(initial_equilibrium), decimals=self.decimals)

      return RestrictionSolverResult(game, 
                                     optimal_restriction, optimal_equilibrium, optimal_social_utility, 
                                     initial_restriction, initial_equilibrium,initial_social_utility,
                                     { 'number_of_oracle_calls': len(explored_restrictions) }
      )

  # Generic solver for restricted Nash Equilibrium (only used if no specialized solver is available)
  def _nash_equilibrium_oracle(game: NormalFormGame, restriction: IntervalUnion) -> tuple:
    raise NotImplementedError()

  def _relevant_actions(self, joint_action):
    # joint_action can either be one joint action or a list of joint actions
    return set(sum(joint_action, ())) if isinstance(joint_action, list) else set(joint_action)

## Experiments

### Parameterized Cournot Game (CG)

In [None]:
results = []
epsilon, decimals = 0.1, 3
solver = IntervalUnionRestrictionSolver(epsilon=epsilon)
progress_bar = display(progress(0, 100), display_id=True)
lambda_min, lambda_max = 10.0, 200.0
lambdas = list(np.round(np.arange(lambda_min, lambda_max, 1.0), decimals=decimals))

print(f'Solving {len(lambdas)} Cournot games...')
for i, lambda_ in enumerate(lambdas):
  progress_bar.update(progress(i, len(lambdas)))

  u_1 = QuadraticTwoPlayerUtility(0, [-1.0, 0.0, -1.0, lambda_, 0.0, 0.0])
  u_2 = QuadraticTwoPlayerUtility(1, [0.0, -1.0, -1.0, 0.0, lambda_, 0.0])

  a = IntervalUnion([(0.0, lambda_)])
  g = GovernedNormalFormGame(a, [u_1, u_2], u_1 + u_2)

  results.append(solver.solve(g, nash_equilibrium_oracle=worst_hill_climbing_nash_equilibrium))

progress_bar.update(progress(len(lambdas), len(lambdas)))  
  
print('Done!')

In [None]:
X = lambdas
fig, ax1 = plt.subplots(figsize=(8, 4))
plt.xlabel('$\\lambda$')

ax1.set_ylabel('MESU')

Y = [result.initial_social_utility for result in results]
ax1.plot(X, Y, label='Unrestricted MESU')

Y = [result.optimal_social_utility for result in results]
ax1.plot(X, Y, label='Restricted MESU')

ax2 = ax1.twinx()
ax2.set_ylabel('$\\Delta(R^*)$')
ax2.set_ylim([0.0, 30.0])
ax2.yaxis.set_major_formatter(PercentFormatter())
Y = [100.0 * relative_improvement(result) for result in results]
ax2.plot(X, Y, color='g', label='Relative improvement')

fig.legend()

In [None]:
X = lambdas
fig, ax1 = plt.subplots(figsize=(8, 4))

plt.xlabel('$\\lambda$')
ax1.set_ylabel('$\\mathfrak{r}(R^*)$')
ax1.set_ylim([20.0, 30.0])
ax1.yaxis.set_major_formatter(PercentFormatter(decimals=0))
Y = [100 * degree_of_restriction(result) for result in results]
ax1.plot(X, Y, label='Degree of restriction')

ax2 = ax1.twinx()
ax2.set_ylabel('# oracle calls')
Y = [result.info['number_of_oracle_calls'] for result in results]
ax2.plot(X, Y, color='g', label='Number of oracle calls')

fig.legend()

### Parameterized Continuous Braess Paradox (BP)

In [None]:
results = []
epsilon, decimals = 0.0001, 5
solver = IntervalUnionRestrictionSolver(epsilon=epsilon)
progress_bar = display(progress(0, 100), display_id=True)
b_min, b_max, b_step = 4.0, 18.0, 0.1
bs = list(np.round(np.arange(b_min, b_max, b_step), decimals=decimals))
params = [(0.0, b, 4.0, 0.0) for b in bs]

print(f'Solving {len(params)} Braess games...')
for i, [a, b, c, d] in enumerate(params):
  progress_bar.update(progress(i, len(params)))
  
  u_1 = QuadraticTwoPlayerUtility(0, [-a - c, 0.0, 0.0, 2*a + b - c - 1, -c, 4*c + d + 1])
  u_2 = QuadraticTwoPlayerUtility(1, [0.0, -a - c, 0.0, -c, 2*a + b - c - 1, 4*c + d + 1])

  a = IntervalUnion([(0.0, 1.0)])
  g = GovernedNormalFormGame(a, [u_1, u_2], u_1 + u_2)

  results.append(solver.solve(g, nash_equilibrium_oracle=worst_hill_climbing_nash_equilibrium))

progress_bar.update(progress(len(params), len(params)))  
  
print('Done!')

In [None]:
X = np.array([b for a, b, c, d in params])
fig, ax1 = plt.subplots(figsize=(8, 4))
plt.xlabel('$b$')

ax1.set_ylabel('MESU')

Y = [result.initial_social_utility for result in results]
ax1.plot(X, Y, label='Unrestricted MESU')

Y = [result.optimal_social_utility for result in results]
ax1.plot(X, Y, label='Restricted MESU')

ax2 = ax1.twinx()
ax2.set_ylabel('$\\Delta(R^*)$')
ax2.yaxis.set_major_formatter(PercentFormatter())
Y = [100.0 * relative_improvement(result) for result in results]
ax2.plot(X, Y, color='g', label='Relative improvement')

fig.legend()

In [None]:
X = np.array([b for a, b, c, d in params])
fig, ax1 = plt.subplots(figsize=(8, 4))
plt.xlabel('$b$')

ax1.set_ylabel('$\\mathfrak{r}(R^*)$')
ax1.yaxis.set_major_formatter(PercentFormatter(decimals=0))
Y = [100 * degree_of_restriction(result) for result in results]
ax1.plot(X, Y, label='Degree of restriction')

ax2 = ax1.twinx()
ax2.set_ylabel('# oracle calls')
Y = [result.info['number_of_oracle_calls'] for result in results]
ax2.plot(X, Y, color='g', label='Number of oracle calls')
ax2.set_ylim(bottom=0)

fig.legend()