Skip to content

Commit

Permalink
Use smart_holder branch machinery to be able to pass a Bot instance b…
Browse files Browse the repository at this point in the history
…y both unique_ptr and shared_ptr interchangeably.

PiperOrigin-RevId: 517914864
Change-Id: Ia4a10b798d81d622f681ef769082ca01d833238d
  • Loading branch information
DeepMind Technologies Ltd authored and lanctot committed Mar 20, 2023
1 parent c017cdb commit f7a4340
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
11 changes: 4 additions & 7 deletions open_spiel/python/pybind11/bots.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
#include <stdint.h>

#include <memory>
#include <new>
#include <string>
#include <utility>

#include "open_spiel/algorithms/evaluate_bots.h"
#include "open_spiel/algorithms/is_mcts.h"
Expand Down Expand Up @@ -99,8 +97,7 @@ class PyBot : public Bot {
"inform_action", // Name of function in Python
InformAction, // Name of function in C++
state, // Arguments
player_id,
action);
player_id, action);
}
void InformActions(const State& state,
const std::vector<Action>& actions) override {
Expand Down Expand Up @@ -153,7 +150,7 @@ class PyBot : public Bot {
} // namespace

void init_pyspiel_bots(py::module& m) {
py::class_<Bot, PyBot> bot(m, "Bot");
py::classh<Bot, PyBot> bot(m, "Bot");
bot.def(py::init<>())
.def("step", &Bot::Step)
.def("restart", &Bot::Restart)
Expand Down Expand Up @@ -227,7 +224,7 @@ void init_pyspiel_bots(py::module& m) {
.def("to_string", &SearchNode::ToString)
.def("children_str", &SearchNode::ChildrenStr);

py::class_<algorithms::MCTSBot, Bot>(m, "MCTSBot")
py::classh<algorithms::MCTSBot, Bot>(m, "MCTSBot")
.def(
py::init([](std::shared_ptr<const Game> game,
std::shared_ptr<Evaluator> evaluator, double uct_c,
Expand All @@ -253,7 +250,7 @@ void init_pyspiel_bots(py::module& m) {
algorithms::ISMCTSFinalPolicyType::kMaxVisitCount)
.value("MAX_VALUE", algorithms::ISMCTSFinalPolicyType::kMaxValue);

py::class_<algorithms::ISMCTSBot, Bot>(m, "ISMCTSBot")
py::classh<algorithms::ISMCTSBot, Bot>(m, "ISMCTSBot")
.def(py::init<int, std::shared_ptr<Evaluator>, double, int, int,
algorithms::ISMCTSFinalPolicyType, bool, bool>(),
py::arg("seed"), py::arg("evaluator"), py::arg("uct_c"),
Expand Down
11 changes: 11 additions & 0 deletions open_spiel/python/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
// in one place to help with consistency.

namespace open_spiel {

class NormalFormGame;
class Bot;

namespace matrix_game {
class MatrixGame;
Expand All @@ -43,13 +45,22 @@ class MatrixGame;
namespace tensor_game {
class TensorGame;
}

namespace algorithms {
class MCTSBot;
class ISMCTSBot;
} // namespace algorithms

} // namespace open_spiel

PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::State);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::Game);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::NormalFormGame);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::matrix_game::MatrixGame);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::tensor_game::TensorGame);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::Bot);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::algorithms::MCTSBot);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::algorithms::ISMCTSBot);

// Custom caster for GameParameter (essentially a variant).
namespace pybind11 {
Expand Down

0 comments on commit f7a4340

Please sign in to comment.