Skip to content

Commit

Permalink
Merge pull request #1034 from rezunli96:policy_aggregator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519133651
Change-Id: I6e0fa9a16faafe3c3591600df1d8c00260b60a4d
  • Loading branch information
lanctot committed Mar 27, 2023
2 parents c8eaa57 + b0dc735 commit 181aca5
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 52 deletions.
20 changes: 15 additions & 5 deletions open_spiel/python/algorithms/exploitability.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import numpy as np

from open_spiel.python import policy as policy_lib
from open_spiel.python.algorithms import best_response as pyspiel_best_response
import pyspiel

Expand All @@ -47,11 +48,20 @@ def _state_values(state, num_players, policy):
if state.is_terminal():
return np.array(state.returns())
else:
p_action = (
state.chance_outcomes() if state.is_chance_node() else
policy.action_probabilities(state).items())
return sum(prob * _state_values(state.child(action), num_players, policy)
for action, prob in p_action)
if state.is_simultaneous_node():
p_action = tuple(policy_lib.joint_action_probabilities(state, policy))

else:
p_action = (
state.chance_outcomes()
if state.is_chance_node()
else policy.action_probabilities(state).items()
)
return sum(
prob
* _state_values(policy_lib.child(state, action), num_players, policy)
for action, prob in p_action
)


def best_response(game, policy, player_id):
Expand Down
73 changes: 42 additions & 31 deletions open_spiel/python/algorithms/policy_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
"""

import copy
import numpy as np
import itertools
from open_spiel.python import policy
import pyspiel


class PolicyFunction(policy.Policy):
Expand Down Expand Up @@ -74,12 +73,9 @@ def action_probabilities(self, state, player_id=None):
"""
state_key = self._state_key(state, player_id=player_id)
if state.is_simultaneous_node():
# Policy aggregator doesn't yet support simultaneous moves nodes.
# The below lines are one step towards that direction.
result = []
for player_pol in self._policies:
result.append(player_pol[state_key])
return result
# for simultaneous node, assume player id must be provided
assert player_id >= 0
return self._policies[player_id][state_key]
if player_id is None:
player_id = state.current_player()
return self._policies[player_id][state_key]
Expand Down Expand Up @@ -188,29 +184,47 @@ def _rec_aggregate(self, pid, state, my_reaches):
if state.is_terminal():
return
elif state.is_simultaneous_node():
# TODO(author10): this is assuming that if there is a sim.-move state, it is
# the only state, i.e., the game is a normal-form game
def assert_type(cond, msg):
assert cond, msg
assert_type(self._game_type.dynamics ==
pyspiel.GameType.Dynamics.SIMULTANEOUS,
"Game must be simultaneous-move")
assert_type(self._game_type.chance_mode ==
pyspiel.GameType.ChanceMode.DETERMINISTIC,
"Chance nodes not supported")
assert_type(self._game_type.information ==
pyspiel.GameType.Information.ONE_SHOT,
"Only one-shot NFGs supported")

policies = self._policy_pool(state, pid)
state_key = self._state_key(state, pid)
self._policy[state_key] = {}
for player_policy, weight in zip(policies, my_reaches[pid]):
for action in player_policy.keys():
if action in self._policy[state_key]:
self._policy[state_key][action] += weight * player_policy[action]
used_moves = state.legal_actions(pid)

for uid in used_moves:
new_reaches = copy.deepcopy(my_reaches)
for i in range(len(policies)):
# compute the new reach for each policy for this action
new_reaches[pid][i] *= policies[i].get(uid, 0)
# add reach * prob(a) for this policy to the computed policy
if uid in self._policy[state_key].keys():
self._policy[state_key][uid] += new_reaches[pid][i]
else:
self._policy[state_key][action] = weight * player_policy[action]
self._policy[state_key][uid] = new_reaches[pid][i]

num_players = self._game.num_players()
all_other_used_moves = []
for player in range(num_players):
if player != pid:
all_other_used_moves.append(state.legal_actions(player))

other_joint_actions = itertools.product(*all_other_used_moves)

# enumerate every possible other-agent actions for next-state
for other_joint_action in other_joint_actions:
for uid in used_moves:
new_reaches = copy.deepcopy(my_reaches)
for i in range(len(policies)):
# compute the new reach for each policy for this action
new_reaches[pid][i] *= policies[i].get(uid, 0)

joint_action = list(
other_joint_action[:pid] + (uid,) + other_joint_action[pid:]
)
new_state = state.clone()
new_state.apply_actions(joint_action)
self._rec_aggregate(pid, new_state, new_reaches)
return

elif state.is_chance_node():
# do not factor in opponent reaches
outcomes, _ = zip(*state.chance_outcomes())
Expand All @@ -228,13 +242,10 @@ def assert_type(cond, msg):
if pid == turn_player:
# update the current node
# will need the observation to query the policies
if state not in self._policy:
if state_key not in self._policy:
self._policy[state_key] = {}

used_moves = []
for k in range(len(legal_policies)):
used_moves += [a[0] for a in legal_policies[k].items()]
used_moves = np.unique(used_moves)
used_moves = state.legal_actions(turn_player)

for uid in used_moves:
new_reaches = copy.deepcopy(my_reaches)
Expand Down
50 changes: 35 additions & 15 deletions open_spiel/python/algorithms/policy_aggregator_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"""

import copy
import itertools
from open_spiel.python import policy
import pyspiel


def _aggregate_at_state(joint_policies, state, player):
Expand Down Expand Up @@ -176,25 +176,45 @@ def _rec_aggregate(self, pid, state, my_reaches):
return

if state.is_simultaneous_node():
assert (self._game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS
), "Game must be simultaneous-move"
assert (self._game_type.chance_mode == pyspiel.GameType.ChanceMode
.DETERMINISTIC), "Chance nodes not supported"
assert (self._game_type.information == pyspiel.GameType.Information
.ONE_SHOT), "Only one-shot NFGs supported"
policies = _aggregate_at_state(self._joint_policies, state, pid)
state_key = self._state_key(state, pid)

self._policy[state_key] = {}
used_moves = state.legal_actions(pid)

for player_policies, weight in zip(policies, my_reaches):
player_policy = player_policies[pid]
for action in player_policy.keys():
if action in self._policy[state_key]:
self._policy[state_key][action] += weight * player_policy[action]
for uid in used_moves:
new_reaches = copy.deepcopy(my_reaches)
for i in range(len(policies)):
# compute the new reach for each policy for this action
new_reaches[i] *= policies[i].get(uid, 0)
# add reach * prob(a) for this policy to the computed policy
if uid in self._policy[state_key].keys():
self._policy[state_key][uid] += new_reaches[i]
else:
self._policy[state_key][action] = weight * player_policy[action]
# No recursion because we only support one shot simultaneous games.
self._policy[state_key][uid] = new_reaches[i]

num_players = self._game.num_players()
all_other_used_moves = []
for player in range(num_players):
if player != pid:
all_other_used_moves.append(state.legal_actions(player))

other_joint_actions = itertools.product(*all_other_used_moves)

# enumerate every possible other-agent actions for next-state
for other_joint_action in other_joint_actions:
for uid in used_moves:
new_reaches = copy.deepcopy(my_reaches)
for i in range(len(policies)):
# compute the new reach for each policy for this action
new_reaches[i] *= policies[i].get(uid, 0)

joint_action = list(
other_joint_action[:pid] + (uid,) + other_joint_action[pid:]
)
new_state = state.clone()
new_state.apply_actions(joint_action)
self._rec_aggregate(pid, new_state, new_reaches)
return

if state.is_chance_node():
Expand All @@ -211,7 +231,7 @@ def _rec_aggregate(self, pid, state, my_reaches):
if pid == current_player:
# update the current node
# will need the observation to query the policies
if state not in self._policy:
if state_key not in self._policy:
self._policy[state_key] = {}

for action in state.legal_actions():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_policy_aggregation_random(self, game_name):
probs = list(state_action_probs.values())
expected_prob = 1. / len(probs)
for prob in probs:
self.assertEqual(expected_prob, prob)
self.assertAlmostEqual(expected_prob, prob, places=10)


if __name__ == "__main__":
Expand Down

0 comments on commit 181aca5

Please sign in to comment.