Skip to content

Commit

Permalink
Merge pull request #983 from deepmind:lanctot-patch-46
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 499011506
Change-Id: I2e390f4b794edaa306d7a192e4a04a97edc55b69
  • Loading branch information
lanctot committed Jan 2, 2023
2 parents 55011e6 + e4e18dc commit 6d1dde5
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
source ./open_spiel/scripts/python_extra_deps.sh
${CI_PYBIN} -m pip install --upgrade $OPEN_SPIEL_PYTHON_JAX_DEPS $OPEN_SPIEL_PYTHON_PYTORCH_DEPS $OPEN_SPIEL_PYTHON_TENSORFLOW_DEPS $OPEN_SPIEL_PYTHON_MISC_DEPS
${CI_PYBIN} -m pip install twine
${CI_PYBIN} -m pip install cibuildwheel==2.5.0
${CI_PYBIN} -m pip install cibuildwheel==2.11.1
- name: Build sdist
run: |
pipx run build --sdist
Expand Down
5 changes: 1 addition & 4 deletions open_spiel/integration_tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,7 @@ def _assert_is_perfect_recall_recursive(state, current_history,
for s, a in current_history
if s.current_player() == current_player]

if not all([
np.array_equal(x, y)
for x, y in zip(expected_infosets_history, infosets_history)
]):
if infosets_history != expected_infosets_history:
raise ValueError("The history as tensor in the same infoset "
"are different:\n"
"History: {!r}\n".format(state.history()))
Expand Down
1 change: 1 addition & 0 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
algorithms/mcts_agent_test.py
algorithms/mcts_test.py
algorithms/minimax_test.py
algorithms/nfg_utils_test.py
algorithms/noisy_policy_test.py
algorithms/outcome_sampling_mccfr_test.py
algorithms/policy_aggregator_joint_test.py
Expand Down
4 changes: 2 additions & 2 deletions open_spiel/python/algorithms/alpha_zero/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def test_model_learns_optimal(self, model_type):
train_inputs = list(solved.values())
print("states:", len(train_inputs))
losses = []
policy_loss_goal = 0.1
value_loss_goal = 0.1
policy_loss_goal = 0.12
value_loss_goal = 0.12
for i in range(500):
loss = model.update(train_inputs)
print(i, loss)
Expand Down
82 changes: 82 additions & 0 deletions open_spiel/python/algorithms/nfg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Some helpers for normal-form games."""

import collections
import numpy as np


class StrategyAverager(object):
"""A helper class for averaging strategies for players."""

def __init__(self, num_players, action_space_shapes, window_size=None):
"""Initialize the average strategy helper object.
Args:
num_players (int): the number of players in the game,
action_space_shapes: an vector of n integers, where each element
represents the size of player i's actions space,
window_size (int or None): if None, computes the players' average
strategies over the entire sequence, otherwise computes the average
strategy over a finite-sized window of the k last entries.
"""
self._num_players = num_players
self._action_space_shapes = action_space_shapes
self._window_size = window_size
self._num = 0
if self._window_size is None:
self._sum_meta_strategies = [
np.zeros(action_space_shapes[p]) for p in range(num_players)
]
else:
self._window = collections.deque(maxlen=self._window_size)

def append(self, meta_strategies):
"""Append the meta-strategies to the averaged sequence.
Args:
meta_strategies: a list of strategies, one per player.
"""
if self._window_size is None:
for p in range(self._num_players):
self._sum_meta_strategies[p] += meta_strategies[p]
else:
self._window.append(meta_strategies)
self._num += 1

def average_strategies(self):
"""Return each player's average strategy.
Returns:
The averaged strategies, as a list containing one strategy per player.
"""

if self._window_size is None:
avg_meta_strategies = [
np.copy(x) for x in self._sum_meta_strategies
]
num_strategies = self._num
else:
avg_meta_strategies = [
np.zeros(self._action_space_shapes[p])
for p in range(self._num_players)
]
for i in range(len(self._window)):
for p in range(self._num_players):
avg_meta_strategies[p] += self._window[i][p]
num_strategies = len(self._window)
for p in range(self._num_players):
avg_meta_strategies[p] /= num_strategies
return avg_meta_strategies
58 changes: 58 additions & 0 deletions open_spiel/python/algorithms/nfg_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import numpy as np
from open_spiel.python.algorithms import nfg_utils


class NfgUtilsTest(absltest.TestCase):

def test_strategy_averager_len_smaller_than_window(self):
averager = nfg_utils.StrategyAverager(2, [2, 2], window_size=50)
averager.append([np.array([1.0, 0.0]), np.array([0.0, 1.0])])
averager.append([np.array([0.0, 1.0]), np.array([1.0, 0.0])])
avg_strategies = averager.average_strategies()
self.assertLen(avg_strategies, 2)
self.assertAlmostEqual(avg_strategies[0][0], 0.5)
self.assertAlmostEqual(avg_strategies[0][1], 0.5)
self.assertAlmostEqual(avg_strategies[1][0], 0.5)
self.assertAlmostEqual(avg_strategies[1][1], 0.5)

def test_strategy_averager(self):
first_action_strat = np.array([1.0, 0.0])
second_action_strat = np.array([0.0, 1.0])
averager_full = nfg_utils.StrategyAverager(2, [2, 2])
averager_window5 = nfg_utils.StrategyAverager(2, [2, 2], window_size=5)
averager_window6 = nfg_utils.StrategyAverager(2, [2, 2], window_size=6)
for _ in range(5):
averager_full.append([first_action_strat, first_action_strat])
averager_window5.append([first_action_strat, first_action_strat])
averager_window6.append([first_action_strat, first_action_strat])
for _ in range(5):
averager_full.append([second_action_strat, second_action_strat])
averager_window5.append([second_action_strat, second_action_strat])
averager_window6.append([second_action_strat, second_action_strat])
avg_full = averager_full.average_strategies()
avg_window5 = averager_window5.average_strategies()
avg_window6 = averager_window6.average_strategies()
self.assertAlmostEqual(avg_full[0][1], 0.5)
self.assertAlmostEqual(avg_window5[0][1], 5.0 / 5.0)
self.assertAlmostEqual(avg_window6[0][1], 5.0 / 6.0)


if __name__ == '__main__':
absltest.main()

15 changes: 8 additions & 7 deletions open_spiel/python/algorithms/projected_replicator_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import numpy as np

from open_spiel.python.algorithms import nfg_utils


def _partial_multi_dot(player_payoff_tensor, strategies, index_avoided):
"""Computes a generalized dot product avoiding one dimension.
Expand Down Expand Up @@ -189,13 +191,12 @@ def projected_replicator_dynamics(payoff_tensors,
for k in range(number_players)
]

average_over_last_n_strategies = average_over_last_n_strategies or prd_iterations
averager = nfg_utils.StrategyAverager(number_players, action_space_shapes,
average_over_last_n_strategies)
averager.append(new_strategies)

meta_strategy_window = []
for i in range(prd_iterations):
for _ in range(prd_iterations):
new_strategies = _projected_replicator_dynamics_step(
payoff_tensors, new_strategies, prd_dt, prd_gamma, use_approx)
if i >= prd_iterations - average_over_last_n_strategies:
meta_strategy_window.append(new_strategies)
average_new_strategies = np.mean(meta_strategy_window, axis=0)
return average_new_strategies
averager.append(new_strategies)
return averager.average_strategies()
15 changes: 8 additions & 7 deletions open_spiel/python/algorithms/regret_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"""

import numpy as np
from open_spiel.python.algorithms import nfg_utils


# Start with initial regrets of 1 / denom
INITIAL_REGRET_DENOM = 1e6
Expand Down Expand Up @@ -131,13 +133,12 @@ def regret_matching(payoff_tensors,
for k in range(number_players)
]

average_over_last_n_strategies = average_over_last_n_strategies or iterations
averager = nfg_utils.StrategyAverager(number_players, action_space_shapes,
average_over_last_n_strategies)
averager.append(new_strategies)

meta_strategy_window = []
for i in range(iterations):
for _ in range(iterations):
new_strategies = _regret_matching_step(payoff_tensors, new_strategies,
regrets, gamma)
if i >= iterations - average_over_last_n_strategies:
meta_strategy_window.append(new_strategies)
average_new_strategies = np.mean(meta_strategy_window, axis=0)
return average_new_strategies
averager.append(new_strategies)
return averager.average_strategies()
2 changes: 1 addition & 1 deletion open_spiel/scripts/python_extra_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
# scripts/global_variables.sh
export OPEN_SPIEL_PYTHON_JAX_DEPS="jax==0.3.24 jaxlib==0.3.24 dm-haiku==0.0.8 optax==0.1.3 chex==0.1.5 rlax==0.1.4"
export OPEN_SPIEL_PYTHON_PYTORCH_DEPS="torch==1.13.0"
export OPEN_SPIEL_PYTHON_TENSORFLOW_DEPS="numpy==1.21.6 tensorflow==2.9.0 tensorflow-probability==0.16.0 tensorflow_datasets==4.5.2 keras==2.9.0"
export OPEN_SPIEL_PYTHON_TENSORFLOW_DEPS="numpy==1.21.6 tensorflow==2.11.0 tensorflow-probability==0.16.0 tensorflow_datasets==4.5.2 keras==2.11.0"
export OPEN_SPIEL_PYTHON_MISC_DEPS="IPython==5.8.0 cvxopt==1.3.0 networkx==2.4 matplotlib==3.5.2 mock==4.0.2 nashpy==0.0.19 scipy==1.7.3 testresources==2.0.1 cvxpy==1.2.0 ecos==2.0.10 osqp==0.6.2.post5 clu==0.0.6 flax==0.5.3"

0 comments on commit 6d1dde5

Please sign in to comment.