Skip to content

Commit

Permalink
Merge pull request #5 from deepmind/master
Browse files Browse the repository at this point in the history
update from base
  • Loading branch information
alexminnaar committed Oct 17, 2019
2 parents ca8a9b8 + 05f860a commit d92614f
Show file tree
Hide file tree
Showing 35 changed files with 1,457 additions and 404 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ For a longer introduction to the core concepts, formalisms, and terminology,
including an overview of the algorithms and some results, please see
[OpenSpiel: A Framework for Reinforcement Learning in Games](https://arxiv.org/abs/1908.09453).

If you use OpenSpiel in a research paper, please cite the paper using the
following BibTeX:
If you use OpenSpiel in your research, please cite the paper using the following
BibTeX:

```
@article{LanctotEtAl2019OpenSpiel,
title = {{OpenSpiel}: A Framework for Reinforcement Learning in Games},
author = {Marc Lanctot and Edward Lockhart and Jean-Baptiste Lespiau and Vinicius Zambaldi and
Satyaki Upadhyay and Julien P\'{e}rolat and Sriram Srinivasan and Finbarr Timbers and
Karl Tuyls and Shayegan Omidshafiei and Daniel Hennes and Dustin Morrill and Paul Muller and
Timo Ewalds and Ryan Faulkner and J\'{a}nos Kramár and Bart De Vylder and Brennan Saeta and
Timo Ewalds and Ryan Faulkner and J\'{a}nos Kram\'{a}r and Bart De Vylder and Brennan Saeta and
James Bradbury and David Ding and Sebastian Borgeaud and Matthew Lai and Julian Schrittwieser and
Thomas Anthony and Edward Hughes and Ivo Danihelka and Jonah Ryan-Davis},
year = {2019},
Expand Down
4 changes: 0 additions & 4 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ release!). Contributions are certainly not limited to these suggestions!
Preferably, an implementation that closely matches the pseudo-code provided
in the paper.

- **Baselines for Monte Carlo CFR**. Implementations of the variance-reduction
techniques for MCCFR ([Ref1](https://arxiv.org/abs/1809.03057),
[Ref2](https://arxiv.org/abs/1907.09633)).

- **Checkers / Draughts**. This is a classic game and an important one in the
history of game AI
(["Checkers is solved"](https://science.sciencemag.org/content/317/5844/1518)).
Expand Down
6 changes: 6 additions & 0 deletions open_spiel/algorithms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ add_library (algorithms OBJECT
mcts.cc
minimax.h
minimax.cc
outcome_sampling_mccfr.h
outcome_sampling_mccfr.cc
tabular_exploitability.h
tabular_exploitability.cc
trajectories.h
Expand Down Expand Up @@ -78,6 +80,10 @@ add_executable(minimax_test minimax_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(minimax_test minimax_test)

add_executable(outcome_sampling_mccfr_test outcome_sampling_mccfr_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(outcome_sampling_mccfr_test outcome_sampling_mccfr_test)

add_executable(tabular_exploitability_test tabular_exploitability_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(tabular_exploitability_test tabular_exploitability_test)
Expand Down
14 changes: 14 additions & 0 deletions open_spiel/algorithms/cfr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,20 @@ std::vector<double> CFRSolverBase::GetPolicy(
return entry->second.current_policy;
}

std::string CFRInfoStateValues::ToString() const {
std::string str = "";
absl::StrAppend(&str, "Legal actions: ", absl::StrJoin(legal_actions, ", "),
"\n");
absl::StrAppend(&str, "Current policy: ", absl::StrJoin(current_policy, ", "),
"\n");
absl::StrAppend(&str, "Cumulative regrets: ",
absl::StrJoin(cumulative_regrets, ", "), "\n");
absl::StrAppend(&str,
"Cumulative policy: ", absl::StrJoin(cumulative_policy, ", "),
"\n");
return str;
}

void CFRInfoStateValues::ApplyRegretMatching() {
double sum_positive_regrets = 0.0;
for (int aidx = 0; aidx < num_actions(); ++aidx) {
Expand Down
3 changes: 3 additions & 0 deletions open_spiel/algorithms/cfr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct CFRInfoStateValues {
bool empty() const { return legal_actions.empty(); }
int num_actions() const { return legal_actions.size(); }

// A string representation of the information state values
std::string ToString() const;

// Samples from current policy using randomly generated z, adding epsilon
// exploration (mixing in uniform).
int SampleActionIndex(double epsilon, double z);
Expand Down
2 changes: 1 addition & 1 deletion open_spiel/algorithms/evaluate_bots.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<double> EvaluateBots(State* state, const std::vector<Bot*>& bots,
while (!state->IsTerminal()) {
if (state->IsChanceNode()) {
state->ApplyAction(
SampleChanceOutcome(state->ChanceOutcomes(), uniform(rng)));
SampleChanceOutcome(state->ChanceOutcomes(), uniform(rng)).first);
} else if (state->IsSimultaneousNode()) {
for (auto p = Player{0}; p < num_players; ++p) {
if (state->LegalActions(p).empty()) {
Expand Down
3 changes: 2 additions & 1 deletion open_spiel/algorithms/external_sampling_mccfr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ double ExternalSamplingMCCFRSolver::UpdateRegrets(const State& state,
if (state.IsTerminal()) {
return state.PlayerReturn(player);
} else if (state.IsChanceNode()) {
Action action = SampleChanceOutcome(state.ChanceOutcomes(), dist_(*rng));
Action action =
SampleChanceOutcome(state.ChanceOutcomes(), dist_(*rng)).first;
return UpdateRegrets(*state.Child(action), player, rng);
} else if (state.IsSimultaneousNode()) {
SpielFatalError(
Expand Down
55 changes: 33 additions & 22 deletions open_spiel/algorithms/mcts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ std::vector<double> RandomRolloutEvaluator::evaluate(const State& state) const {
if (working_state->IsChanceNode()) {
ActionsAndProbs outcomes = working_state->ChanceOutcomes();
Action action = SampleChanceOutcome(
outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng_));
outcomes, std::uniform_real_distribution<double>(
0.0, 1.0)(rng_))
.first;
working_state->ApplyAction(action);
} else {
std::vector<Action> actions = working_state->LegalActions();
Expand Down Expand Up @@ -151,8 +153,10 @@ std::string SearchNode::ToString(const State& state) const {
return absl::StrFormat(
"%6s: player: %d, prior: %5.3f, value: %6.3f, sims: %5d, outcome: %s, "
"%3d children",
(action >= 0 ? state.ActionToString(player, action) : "none"), player,
prior, (explore_count ? total_reward / explore_count : 0.), explore_count,
(action != kInvalidAction ? state.ActionToString(player, action)
: "none"),
player, prior, (explore_count ? total_reward / explore_count : 0.),
explore_count,
(outcome.empty() ? "none" : absl::StrFormat("%4.1f", outcome[player])),
children.size());
}
Expand All @@ -177,10 +181,16 @@ MCTSBot::MCTSBot(
rng_(seed),
evaluator_{evaluator} {
GameType game_type = game.GetType();
if (game_type.reward_model != GameType::RewardModel::kTerminal ||
game_type.dynamics != GameType::Dynamics::kSequential) {
SpielFatalError("Game must have sequential turns and terminal rewards.");
}
if (game_type.reward_model != GameType::RewardModel::kTerminal)
SpielFatalError("Game must have terminal rewards.");
if (game_type.dynamics != GameType::Dynamics::kSequential)
SpielFatalError("Game must have sequential turns.");
if (game_type.information != GameType::Information::kPerfectInformation)
SpielFatalError("Game must be perfect information.");
if (player < 0 || player >= game.NumPlayers())
SpielFatalError(absl::StrFormat(
"Game doesn't support that many players. Max: %d, player: %d",
game.NumPlayers(), player));
}

std::pair<ActionsAndProbs, Action> MCTSBot::Step(const State& state) {
Expand All @@ -198,11 +208,12 @@ std::pair<ActionsAndProbs, Action> MCTSBot::Step(const State& state) {
std::cerr << root->ToString(state) << std::endl;
std::cerr << "Children:" << std::endl;
std::cerr << root->ChildrenStr(state) << std::endl;
std::unique_ptr<State> chosen_state = state.Clone();
chosen_state->ApplyAction(best.action);
std::cerr << std::endl;
std::cerr << "Children of chosen:" << std::endl;
std::cerr << best.ChildrenStr(*chosen_state) << std::endl;
if (!best.children.empty()) {
std::unique_ptr<State> chosen_state = state.Clone();
chosen_state->ApplyAction(best.action);
std::cerr << "Children of chosen:" << std::endl;
std::cerr << best.ChildrenStr(*chosen_state) << std::endl;
}
}

return {{{best.action, 1.0}}, best.action};
Expand Down Expand Up @@ -232,14 +243,11 @@ std::unique_ptr<State> MCTSBot::ApplyTreePolicy(
if (working_state->IsChanceNode()) {
// For chance nodes, rollout according to chance node's probability
// distribution
ActionsAndProbs outcomes = working_state->ChanceOutcomes();

double rand = std::uniform_real_distribution<double>(0.0, 1.0)(rng_);
int index = 0;
for (double sum = 0; sum < rand; ++index) {
sum += outcomes[index].second;
}
Action chosen_action = outcomes[index].first;
Action chosen_action =
SampleChanceOutcome(
working_state->ChanceOutcomes(),
std::uniform_real_distribution<double>(0.0, 1.0)(rng_))
.first;

for (SearchNode& child : current_node->children) {
if (child.action == chosen_action) {
Expand Down Expand Up @@ -269,7 +277,8 @@ std::unique_ptr<State> MCTSBot::ApplyTreePolicy(

std::unique_ptr<SearchNode> MCTSBot::MCTSearch(const State& state) {
memory_used_ = 0;
auto root = std::make_unique<SearchNode>(-1, state.CurrentPlayer(), 1);
auto root = std::make_unique<SearchNode>(
kInvalidAction, state.CurrentPlayer(), 1);
std::vector<SearchNode*> visit_path;
std::vector<double> returns;
visit_path.reserve(64);
Expand All @@ -295,7 +304,9 @@ std::unique_ptr<SearchNode> MCTSBot::MCTSearch(const State& state) {
for (auto it = visit_path.rbegin(); it != visit_path.rend(); ++it) {
SearchNode* node = *it;

node->total_reward += returns[node->player];
if (node->player != kChancePlayerId) {
node->total_reward += returns[node->player];
}
node->explore_count += 1;

// Back up solved results as well.
Expand Down
140 changes: 121 additions & 19 deletions open_spiel/algorithms/mcts_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,137 @@

#include "open_spiel/algorithms/mcts.h"

#include <memory>
#include <utility>

#include "open_spiel/abseil-cpp/absl/strings/string_view.h"
#include "open_spiel/algorithms/evaluate_bots.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_bots.h"
#include "open_spiel/spiel_utils.h"

namespace open_spiel {
namespace {

void BotTest_RandomVsRandom() {
auto game = LoadGame("kuhn_poker");
std::vector<std::shared_ptr<Bot>> bots = {
MakeUniformRandomBot(*game, /*player_id=*/0, /*seed=*/1234),
MakeUniformRandomBot(*game, /*player_id=*/1, /*seed=*/4321)};
std::vector<Bot*> bot_ptrs = {bots[0].get(), bots[1].get()};
constexpr int num_players = 2;
std::vector<double> average_results(num_players);
constexpr int num_iters = 100000;
for (int iteration = 0; iteration < num_iters; ++iteration) {
auto this_results = EvaluateBots(game->NewInitialState().get(), bot_ptrs,
/*seed=*/iteration);
for (auto i = Player{0}; i < num_players; ++i)
average_results[i] += this_results[i];
constexpr double UCT_C = 2;

std::unique_ptr<open_spiel::Bot> InitBot(
const open_spiel::Game& game, open_spiel::Player player,
int max_simulations, const open_spiel::algorithms::Evaluator& evaluator) {;
return std::make_unique<open_spiel::algorithms::MCTSBot>(
game,
player,
evaluator,
UCT_C,
max_simulations,
/* max_memory_mb */ 5,
/* solve */ true,
/* seed */ 42,
/* verbose */ false);
}

void MCTSTest_CanPlayTicTacToe() {
auto game = LoadGame("tic_tac_toe");
int max_simulations = 100;
open_spiel::algorithms::RandomRolloutEvaluator evaluator(20, 42);
auto bot0 = InitBot(*game, 0, max_simulations, evaluator);
auto bot1 = InitBot(*game, 1, max_simulations, evaluator);
auto results = EvaluateBots(
game->NewInitialState().get(), {bot0.get(), bot1.get()}, 42);
SPIEL_CHECK_EQ(results[0] + results[1], 0);
}

void MCTSTest_CanPlaySinglePlayer() {
auto game = LoadGame("catch");
int max_simulations = 100;
open_spiel::algorithms::RandomRolloutEvaluator evaluator(20, 42);
auto bot = InitBot(*game, 0, max_simulations, evaluator);
auto results = EvaluateBots(
game->NewInitialState().get(), {bot.get()}, 42);
SPIEL_CHECK_GT(results[0], 0);
}

void MCTSTest_CanPlayThreePlayerStochasticGames() {
auto game = LoadGame("pig(players=3,winscore=20,horizon=30)");
int max_simulations = 1000;
open_spiel::algorithms::RandomRolloutEvaluator evaluator(20, 42);
auto bot0 = InitBot(*game, 0, max_simulations, evaluator);
auto bot1 = InitBot(*game, 1, max_simulations, evaluator);
auto bot2 = InitBot(*game, 2, max_simulations, evaluator);
auto results = EvaluateBots(
game->NewInitialState().get(), {bot0.get(), bot1.get(), bot2.get()}, 42);
SPIEL_CHECK_FLOAT_EQ(results[0] + results[1] + results[2], 0);
}

open_spiel::Action GetAction(const open_spiel::State& state,
const absl::string_view action_str) {
for (open_spiel::Action action : state.LegalActions()) {
if (action_str == state.ActionToString(state.CurrentPlayer(), action))
return action;
}
for (auto i = Player{0}; i < num_players; ++i)
average_results[i] /= num_iters;
open_spiel::SpielFatalError(absl::StrCat("Illegal action: ", action_str));
}

SPIEL_CHECK_FLOAT_NEAR(average_results[0], 0.125, 0.01);
SPIEL_CHECK_FLOAT_NEAR(average_results[1], -0.125, 0.01);
std::pair<std::unique_ptr<algorithms::SearchNode>, std::unique_ptr<State>>
SearchTicTacToeState(const absl::string_view initial_actions) {
auto game = LoadGame("tic_tac_toe");
std::unique_ptr<State> state = game->NewInitialState();
for (const auto& action_str : absl::StrSplit(initial_actions, ' ')) {
state->ApplyAction(GetAction(*state, action_str));
}
open_spiel::algorithms::RandomRolloutEvaluator evaluator(20, 42);
algorithms::MCTSBot bot(
*game,
state->CurrentPlayer(),
evaluator,
UCT_C,
/* max_simulations */ 10000,
/* max_memory_mb */ 10,
/* solve */ true,
/* seed */ 42,
/* verbose */ false);
return {bot.MCTSearch(*state), std::move(state)};
}

void MCTSTest_SolveDraw() {
auto [root, state] = SearchTicTacToeState("x(1,1) o(0,0) x(2,2)");
SPIEL_CHECK_EQ(state->ToString(), "o..\n.x.\n..x");
SPIEL_CHECK_EQ(root->outcome[root->player], 0);
for (const algorithms::SearchNode& c : root->children)
SPIEL_CHECK_LE(c.outcome[c.player], 0); // No winning moves.
const algorithms::SearchNode& best = root->BestChild();
SPIEL_CHECK_EQ(best.outcome[best.player], 0);
std::string action_str = state->ActionToString(best.player, best.action);
if (action_str != "o(0,2)" && action_str != "o(2,0)") // All others lose.
SPIEL_CHECK_EQ(action_str, "o(0,2)"); // "o(2,0)" is also valid.
}

void MCTSTest_SolveLoss() {
auto [root, state] = SearchTicTacToeState(
"x(1,1) o(0,0) x(2,2) o(1,0) x(2,0)");
SPIEL_CHECK_EQ(state->ToString(), "oox\n.x.\n..x");
SPIEL_CHECK_EQ(root->outcome[root->player], -1);
for (const algorithms::SearchNode& c : root->children)
SPIEL_CHECK_EQ(c.outcome[c.player], -1); // All losses.
}

void MCTSTest_SolveWin() {
auto [root, state] = SearchTicTacToeState("x(1,0) o(2,2)");
SPIEL_CHECK_EQ(state->ToString(), ".x.\n...\n..o");
SPIEL_CHECK_EQ(root->outcome[root->player], 1);
const algorithms::SearchNode& best = root->BestChild();
SPIEL_CHECK_EQ(best.outcome[best.player], 1);
SPIEL_CHECK_EQ(state->ActionToString(best.player, best.action), "x(2,0)");
}

} // namespace
} // namespace open_spiel

int main(int argc, char** argv) { open_spiel::BotTest_RandomVsRandom(); }
int main(int argc, char** argv) {
open_spiel::MCTSTest_CanPlayTicTacToe();
open_spiel::MCTSTest_CanPlaySinglePlayer();
open_spiel::MCTSTest_CanPlayThreePlayerStochasticGames();
open_spiel::MCTSTest_SolveDraw();
open_spiel::MCTSTest_SolveLoss();
open_spiel::MCTSTest_SolveWin();
}
Loading

0 comments on commit d92614f

Please sign in to comment.