Skip to content

Commit

Permalink
Allow creation of per-player random policies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 520870021
Change-Id: I6bd0ab06164cd36f386088e7712b27c8681252a3
  • Loading branch information
lukemarris authored and lanctot committed Mar 31, 2023
1 parent a4efe30 commit 77aca74
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 28 deletions.
11 changes: 10 additions & 1 deletion open_spiel/algorithms/expected_returns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "open_spiel/simultaneous_move_game.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"

namespace open_spiel {
namespace algorithms {
Expand Down Expand Up @@ -101,12 +102,17 @@ std::vector<double> ExpectedReturnsImpl(
SpielFatalError("Error in ExpectedReturnsImpl; infostate not found.");
}
values = state.Rewards();
float total_prob = 0.0;
for (const Action action : state.LegalActions()) {
std::unique_ptr<State> child = state.Child(action);
// GetProb can return -1 for legal actions not in the policy. We treat
// these as having zero probability, but check that at least some actions
// have positive probability.
double action_prob = GetProb(state_policy, action);
SPIEL_CHECK_GE(action_prob, 0.0);
SPIEL_CHECK_LE(action_prob, 1.0);
if (action_prob > prob_cut_threshold) {
SPIEL_CHECK_GE(action_prob, 0.0);
total_prob += action_prob;
std::vector<double> child_values =
ExpectedReturnsImpl(
*child, policy_func, depth_limit - 1, prob_cut_threshold);
Expand All @@ -115,6 +121,9 @@ std::vector<double> ExpectedReturnsImpl(
}
}
}
// Check that there is a least some positive mass on at least one action.
// Consider using: SPIEL_CHECK_FLOAT_EQ(total_prob, 1.0);
SPIEL_CHECK_GT(total_prob, 0.0);
}
SPIEL_CHECK_EQ(values.size(), state.NumPlayers());
return values;
Expand Down
3 changes: 3 additions & 0 deletions open_spiel/algorithms/expected_returns.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace algorithms {
// prob_cut_threshold > 0 will cut the tree search if the reach probability
// goes below this value resulting in an approximate return.
//
// Policies need not be complete; any missing legal actions will be assumed to
// have zero probability.
//
// The second overloaded function acts the same way, except assumes that all of
// the players' policies are encapsulated in one joint policy.
//
Expand Down
100 changes: 81 additions & 19 deletions open_spiel/policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ ActionsAndProbs PartialTabularPolicy::GetStatePolicy(
}

TabularPolicy GetEmptyTabularPolicy(const Game& game,
bool initialize_to_uniform) {
bool initialize_to_uniform,
Player player) {
std::unordered_map<std::string, ActionsAndProbs> policy;
if (game.GetType().dynamics != GameType::Dynamics::kSequential) {
SpielFatalError("Game is not sequential.");
Expand All @@ -272,20 +273,24 @@ TabularPolicy GetEmptyTabularPolicy(const Game& game,
std::vector<Action> legal_actions = state->LegalActions();
const int num_legal_actions = legal_actions.size();
SPIEL_CHECK_GT(num_legal_actions, 0.);
double action_probability = 1.;
if (initialize_to_uniform) {
action_probability = 1. / num_legal_actions;
}

infostate_policy.reserve(num_legal_actions);
for (Action action : legal_actions) {
to_visit.push_back(state->Child(action));
infostate_policy.push_back({action, action_probability});
}
if (infostate_policy.empty()) {
SpielFatalError("State has zero legal actions.");
if (player < 0 || state->IsPlayerActing(player)) {
double action_probability = 1.;
if (initialize_to_uniform) {
action_probability = 1. / num_legal_actions;
}
ActionsAndProbs infostate_policy;
infostate_policy.reserve(num_legal_actions);
for (Action action : legal_actions) {
infostate_policy.push_back({action, action_probability});
}
if (infostate_policy.empty()) {
SpielFatalError("State has zero legal actions.");
}
policy.insert({state->InformationStateString(), infostate_policy});
}
policy.insert({state->InformationStateString(), infostate_policy});
}
}
return TabularPolicy(policy);
Expand All @@ -297,9 +302,9 @@ TabularPolicy GetUniformPolicy(const Game& game) {

template <typename RandomNumberDistribution>
TabularPolicy SamplePolicy(
const Game& game, int seed, RandomNumberDistribution& dist) {
const Game& game, int seed, RandomNumberDistribution& dist, Player player) {
std::mt19937 gen(seed);
TabularPolicy policy = GetEmptyTabularPolicy(game);
TabularPolicy policy = GetEmptyTabularPolicy(game, false, player);
std::unordered_map<std::string, ActionsAndProbs>& policy_table =
policy.PolicyTable();
for (auto& kv : policy_table) {
Expand All @@ -311,8 +316,8 @@ TabularPolicy SamplePolicy(
double sum = 0;
double prob;
for (const auto& action_and_prob : kv.second) {
// We multiply the original probability by a random number between 0
// and 1. We then normalize. This has the effect of randomly permuting the
// We multiply the original probability by a random number greater than
// 0. We then normalize. This has the effect of randomly permuting the
// policy but all illegal actions still have zero probability.
prob = dist(gen) * action_and_prob.second;
sum += prob;
Expand All @@ -333,14 +338,71 @@ TabularPolicy SamplePolicy(
return policy;
}

TabularPolicy GetRandomPolicy(const Game& game, int seed) {
TabularPolicy GetRandomPolicy(const Game& game, int seed, Player player) {
std::uniform_real_distribution<double> dist(0, 1);
return SamplePolicy(game, seed, dist);
return SamplePolicy(game, seed, dist, player);
}

TabularPolicy GetFlatDirichletPolicy(const Game& game, int seed) {
TabularPolicy GetFlatDirichletPolicy(
const Game& game, int seed, Player player) {
std::gamma_distribution<double> dist(1.0, 1.0);
return SamplePolicy(game, seed, dist);
return SamplePolicy(game, seed, dist, player);
}

TabularPolicy GetRandomDeterministicPolicy(
const Game& game, int seed, Player player) {
std::mt19937 gen(seed);
std::unordered_map<int, std::uniform_int_distribution<int>> dists;
std::unordered_map<std::string, ActionsAndProbs> policy;
if (game.GetType().dynamics != GameType::Dynamics::kSequential) {
SpielFatalError("Game is not sequential.");
return TabularPolicy(policy);
}
const GameType::Information information = game.GetType().information;
std::list<std::unique_ptr<State>> to_visit;
to_visit.push_back(game.NewInitialState());
while (!to_visit.empty()) {
std::unique_ptr<State> state = std::move(to_visit.back());
to_visit.pop_back();
if (state->IsTerminal()) {
continue;
} else if (state->IsChanceNode()) {
for (const auto& outcome_and_prob : state->ChanceOutcomes()) {
to_visit.emplace_back(state->Child(outcome_and_prob.first));
}
} else if (player < 0 || state->IsPlayerActing(player)) {
std::vector<Action> legal_actions = state->LegalActions();
const int num_legal_actions = legal_actions.size();
SPIEL_CHECK_GT(num_legal_actions, 0.);
if (dists.count(num_legal_actions) == 0) {
std::uniform_int_distribution<int> dist(0, num_legal_actions - 1);
dists.insert({num_legal_actions, std::move(dist)});
}
const int legal_action_index = dists[num_legal_actions](gen);
SPIEL_CHECK_GE(legal_action_index, 0);
SPIEL_CHECK_LT(legal_action_index, num_legal_actions);
const int action = legal_actions[legal_action_index];
ActionsAndProbs infostate_policy;
infostate_policy.reserve(1);
infostate_policy.push_back({action, 1.0});
policy.insert({state->InformationStateString(), infostate_policy});
if (information == GameType::Information::kPerfectInformation) {
to_visit.push_back(state->Child(action));
} else {
for (Action action : legal_actions) {
to_visit.push_back(state->Child(action));
}
}
} else {
std::vector<Action> legal_actions = state->LegalActions();
const int num_legal_actions = legal_actions.size();
SPIEL_CHECK_GT(num_legal_actions, 0.);
for (Action action : legal_actions) {
to_visit.push_back(state->Child(action));
}
}
}
return TabularPolicy(policy);
}

TabularPolicy GetFirstActionPolicy(const Game& game) {
Expand Down
13 changes: 10 additions & 3 deletions open_spiel/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,18 @@ class PreferredActionPolicy : public Policy {
TabularPolicy ToTabularPolicy(const Game& game, const Policy* policy);

// Helper functions that generate policies for testing.
// The player parameter can be used to only generate policies for a single
// player. By default -1 will generate policies for all players.
TabularPolicy GetEmptyTabularPolicy(const Game& game,
bool initialize_to_uniform = false);
bool initialize_to_uniform = false,
Player player = -1);
TabularPolicy GetUniformPolicy(const Game& game);
TabularPolicy GetRandomPolicy(const Game& game, int seed = 0);
TabularPolicy GetFlatDirichletPolicy(const Game& game, int seed = 0);
TabularPolicy GetRandomPolicy(
const Game& game, int seed = 0, Player player = -1);
TabularPolicy GetFlatDirichletPolicy(
const Game& game, int seed = 0, Player player = -1);
TabularPolicy GetRandomDeterministicPolicy(
const Game& game, int seed = 0, Player player = -1);
TabularPolicy GetFirstActionPolicy(const Game& game);

// Returns a preferred action policy as a tabular policy.
Expand Down
9 changes: 7 additions & 2 deletions open_spiel/python/pybind11/policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ void init_pyspiel_policy(py::module& m) {
.def("policy_table",
py::overload_cast<>(&open_spiel::PartialTabularPolicy::PolicyTable));

m.def("GetRandomPolicy", &open_spiel::GetRandomPolicy);
m.def("GetFlatDirichletPolicy", &open_spiel::GetFlatDirichletPolicy);
m.def("GetRandomPolicy", &open_spiel::GetRandomPolicy,
py::arg("game"), py::arg("seed"), py::arg("player") = -1);
m.def("GetFlatDirichletPolicy", &open_spiel::GetFlatDirichletPolicy,
py::arg("game"), py::arg("seed"), py::arg("player") = -1);
m.def("GetRandomDeterministicPolicy",
&open_spiel::GetRandomDeterministicPolicy,
py::arg("game"), py::arg("seed"), py::arg("player") = -1);
m.def("UniformRandomPolicy", &open_spiel::GetUniformPolicy);
py::class_<open_spiel::UniformPolicy,
std::shared_ptr<open_spiel::UniformPolicy>, open_spiel::Policy>(
Expand Down
21 changes: 19 additions & 2 deletions open_spiel/python/tests/policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
]


def test_policy_on_game(self, game, policy_object):
def test_policy_on_game(self, game, policy_object, player=-1):
"""Checks the policy conforms to the conventions.
Checks the Policy.action_probabilities contains only legal actions (but not
Expand All @@ -62,6 +62,7 @@ def test_policy_on_game(self, game, policy_object):
function to test policies.
game: A `pyspiel.Game`, same as the one used in the policy.
policy_object: A `policy.Policy` object on `game`. to test.
player: Restrict testing policy to a player.
"""

all_states = get_all_states.get_all_states(
Expand Down Expand Up @@ -92,7 +93,10 @@ def test_policy_on_game(self, game, policy_object):
for prob in action_probabilities.values():
sum_ += prob
self.assertGreaterEqual(prob, 0)
self.assertAlmostEqual(1, sum_)
if player < 0 or state.current_player() == player:
self.assertAlmostEqual(1, sum_)
else:
self.assertAlmostEqual(0, sum_)


_LEDUC_POKER = pyspiel.load_game("leduc_poker")
Expand All @@ -115,10 +119,23 @@ def test_policy_on_leduc(self, policy_object):
pyspiel.GetRandomPolicy(_LEDUC_POKER, 1)),
("pyspiel.GetFlatDirichletPolicy",
pyspiel.GetFlatDirichletPolicy(_LEDUC_POKER, 1)),
("pyspiel.GetRandomDeterministicPolicy",
pyspiel.GetRandomDeterministicPolicy(_LEDUC_POKER, 1)),
])
def test_cpp_policies_on_leduc(self, policy_object):
test_policy_on_game(self, _LEDUC_POKER, policy_object)

@parameterized.named_parameters([
("pyspiel.GetRandomPolicy0",
pyspiel.GetRandomPolicy(_LEDUC_POKER, 1, 0), 0),
("pyspiel.GetFlatDirichletPolicy1",
pyspiel.GetFlatDirichletPolicy(_LEDUC_POKER, 1, 1), 1),
("pyspiel.GetRandomDeterministicPolicym1",
pyspiel.GetRandomDeterministicPolicy(_LEDUC_POKER, 1, -1), -1),
])
def test_cpp_player_policies_on_leduc(self, policy_object, player):
test_policy_on_game(self, _LEDUC_POKER, policy_object, player)


class TabularTicTacToePolicyTest(parameterized.TestCase):

Expand Down
3 changes: 2 additions & 1 deletion open_spiel/tests/spiel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ void PolicyTest() {
};
std::vector<PolicyGenerator> policy_generators = {
GetUniformPolicy, random_policy_default_seed, GetFirstActionPolicy,
flat_dirichlet_policy_default_seed};
flat_dirichlet_policy_default_seed,
};

// For some reason, this can't seem to be brace-initialized, so instead we use
// push_back.
Expand Down

0 comments on commit 77aca74

Please sign in to comment.