Skip to content

Commit

Permalink
Merge pull request #924 from rezunli96:correlated-q
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 472021181
Change-Id: I92980e520aa4c5924e0518ae99ab3b1bded84db1
  • Loading branch information
lanctot committed Sep 4, 2022
2 parents 97fbfe4 + 40812a4 commit 6b616ee
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions open_spiel/python/algorithms/tabular_multiagent_qlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from open_spiel.python import rl_agent
from open_spiel.python import rl_tools
from open_spiel.python.algorithms.jpsro import _mgcce
from open_spiel.python.algorithms.jpsro import _mgce
from open_spiel.python.algorithms.stackelberg_lp import solve_stackelberg
import pyspiel

Expand Down Expand Up @@ -107,16 +106,25 @@ def __init__(self, is_cce=False):
self._is_cce = is_cce

def __call__(self, payoffs_array):
size = len(payoffs_array)
assert size > 0
mixture, _ = (_mgcce(payoffs_array, [1] * size, ignore_repeats=True) if
self._is_cce else _mgce(payoffs_array, [1] * size,
ignore_repeats=True))
num_players = len(payoffs_array)
assert num_players > 0
num_strategies_per_player = payoffs_array.shape[1:]
mixture, _ = (
_mgcce( # pylint: disable=g-long-ternary
payoffs_array,
[np.ones([ns], dtype=np.int32) for ns in num_strategies_per_player],
ignore_repeats=True)
if self._is_cce else _mgcce(
payoffs_array,
[np.ones([ns], dtype=np.int32) for ns in num_strategies_per_player],
ignore_repeats=True))
mixtures, values = [], []
for n in range(size):
for n in range(num_players):
values.append(np.sum(payoffs_array[n] * mixture))
mixtures.append(
np.sum(mixture, axis=tuple([n_ for n_ in range(size) if n_ != n])))
np.sum(
mixture,
axis=tuple([n_ for n_ in range(num_players) if n_ != n])))
return mixtures, values


Expand Down

0 comments on commit 6b616ee

Please sign in to comment.