Skip to content

Commit

Permalink
Merge pull request #882 from rezunli96:policy_aggregator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 461606727
Change-Id: I4cdcd2f50047e559128523c113a76f13b289d85c
  • Loading branch information
lanctot committed Jul 18, 2022
2 parents 2651e77 + 92a2a40 commit b5e0bf6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
3 changes: 2 additions & 1 deletion open_spiel/python/algorithms/policy_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
policy by sweeping over the state space.
"""

import copy
import numpy as np
from open_spiel.python import policy
import pyspiel
Expand Down Expand Up @@ -236,7 +237,7 @@ def assert_type(cond, msg):
used_moves = np.unique(used_moves)

for uid in used_moves:
new_reaches = np.copy(my_reaches)
new_reaches = copy.deepcopy(my_reaches)
if pid == turn_player:
for i in range(len(legal_policies)):
# compute the new reach for each policy for this action
Expand Down
6 changes: 3 additions & 3 deletions open_spiel/python/algorithms/policy_aggregator_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
policy.
"""

import numpy as np
import copy
from open_spiel.python import policy
import pyspiel

Expand Down Expand Up @@ -159,7 +159,7 @@ def _sub_aggregate(self, pid, weights):
self._policy = {}

state = self._game.new_initial_state()
self._rec_aggregate(pid, state, weights.copy())
self._rec_aggregate(pid, state, copy.deepcopy(weights))

# Now normalize
for key in self._policy:
Expand Down Expand Up @@ -215,7 +215,7 @@ def _rec_aggregate(self, pid, state, my_reaches):
self._policy[state_key] = {}

for action in state.legal_actions():
new_reaches = np.copy(my_reaches)
new_reaches = copy.deepcopy(my_reaches)
if pid == current_player:
for idx, state_action_probs in enumerate(action_probabilities_list):
# compute the new reach for each policy for this action
Expand Down
26 changes: 26 additions & 0 deletions open_spiel/python/algorithms/policy_aggregator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import policy_aggregator
import pyspiel


class PolicyAggregatorTest(parameterized.TestCase):
Expand Down Expand Up @@ -84,6 +85,31 @@ def test_policy_aggregation_tabular_randinit(self, game_name):
for key in value_normal.keys():
self.assertAlmostEqual(value[key], value_normal[key], 8)

@parameterized.named_parameters({
"testcase_name": "tic_tac_toe",
"game_name": "tic_tac_toe",
})
def test_policy_aggregation_variadic(self, game_name):
game = pyspiel.load_game(game_name)

uniform_policy = policy.UniformRandomPolicy(game)
first_action_policy = policy.FirstActionPolicy(game)

pol_ag = policy_aggregator.PolicyAggregator(game)

weights0 = [1.0, 0.0]
player0 = pol_ag.aggregate(
list(range(game.num_players())),
[[uniform_policy, first_action_policy]] + [[uniform_policy]] *
(game.num_players() - 1),
[weights0] + [[1.0]] * (game.num_players() - 1))
state = game.new_initial_state()
action_prob = player0.action_probabilities(state)
for action in action_prob:
if action_prob[action] > 0:
self.assertAlmostEqual(action_prob[action],
1. / len(state.legal_actions()))


if __name__ == "__main__":
unittest.main()

0 comments on commit b5e0bf6

Please sign in to comment.