Skip to content

Commit

Permalink
Add game-specific Python bindings for gin rummy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 481646858
Change-Id: I1136c561bbb41eed00ef321a6bc43f2c67dd7443
  • Loading branch information
jhtschultz authored and lanctot committed Oct 18, 2022
1 parent 12f3992 commit ad1a0a1
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 6 deletions.
25 changes: 19 additions & 6 deletions open_spiel/games/gin_rummy.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ class GinRummyState : public State {
std::vector<Action> LegalActions() const override;
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;

protected:
void DoApplyAction(Action action) override;

private:
friend class GinRummyObserver;

enum class Phase {
kDeal,
kFirstUpcard,
Expand All @@ -124,6 +118,21 @@ class GinRummyState : public State {
kGameOver
};

// Used for Python bindings.
Phase CurrentPhase() const { return phase_; }
absl::optional<int> Upcard() const { return upcard_; }
int StockSize() const { return stock_size_; }
std::vector<std::vector<int>> Hands() const { return hands_; }
std::vector<int> DiscardPile() const { return discard_pile_; }
std::vector<int> Deadwood() const { return deadwood_; }
std::vector<bool> Knocked() const { return knocked_; }

protected:
void DoApplyAction(Action action) override;

private:
friend class GinRummyObserver;

inline static constexpr std::array<absl::string_view, 8> kPhaseString = {
"Deal", "FirstUpcard", "Draw", "Discard",
"Knock", "Layoff", "Wall", "GameOver"};
Expand Down Expand Up @@ -243,6 +252,10 @@ class GinRummyGame : public Game {
std::shared_ptr<GinRummyObserver> default_observer_;
std::shared_ptr<GinRummyObserver> info_state_observer_;

// Used for Python bindings.
bool Oklahoma() const { return oklahoma_; }
int KnockCard() const { return knock_card_; }

private:
const bool oklahoma_;
const int knock_card_;
Expand Down
3 changes: 3 additions & 0 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ set(PYTHON_BINDINGS ${PYTHON_BINDINGS}
pybind11/games_colored_trails.h
pybind11/games_euchre.cc
pybind11/games_euchre.h
pybind11/games_gin_rummy.cc
pybind11/games_gin_rummy.h
pybind11/games_kuhn_poker.cc
pybind11/games_kuhn_poker.h
pybind11/games_leduc_poker.cc
Expand Down Expand Up @@ -233,6 +235,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
tests/game_transforms_test.py
tests/games_bridge_test.py
tests/games_euchre_test.py
tests/games_gin_rummy_test.py
tests/games_sim_test.py
tests/policy_test.py
tests/pyspiel_test.py
Expand Down
83 changes: 83 additions & 0 deletions open_spiel/python/pybind11/games_gin_rummy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2022 DeepMind Technologies Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/python/pybind11/games_gin_rummy.h"

#include <memory>

#include "open_spiel/games/gin_rummy.h"
#include "open_spiel/python/pybind11/pybind11.h"
#include "open_spiel/spiel.h"

// Several function return absl::optional or lists of absl::optional, so must
// use pybind11_abseil here.
#include "pybind11_abseil/absl_casters.h"

PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::gin_rummy::GinRummyGame);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::gin_rummy::GinRummyState);

namespace open_spiel {

namespace py = ::pybind11;
using gin_rummy::GinRummyGame;
using gin_rummy::GinRummyState;

void init_pyspiel_games_gin_rummy(py::module& m) {
py::classh<GinRummyState, State> state_class(m, "GinRummyState");
state_class.def("current_phase", &GinRummyState::CurrentPhase)
.def("current_player", &GinRummyState::CurrentPlayer)
.def("upcard", &GinRummyState::Upcard)
.def("stock_size", &GinRummyState::StockSize)
.def("hands", &GinRummyState::Hands)
.def("discard_pile", &GinRummyState::DiscardPile)
.def("deadwood", &GinRummyState::Deadwood)
.def("knocked", &GinRummyState::Knocked)
// Pickle support
.def(py::pickle(
[](const GinRummyState& state) { // __getstate__
return SerializeGameAndState(*state.GetGame(), state);
},
[](const std::string& data) { // __setstate__
std::pair<std::shared_ptr<const Game>, std::unique_ptr<State>>
game_and_state = DeserializeGameAndState(data);
return dynamic_cast<GinRummyState*>(
game_and_state.second.release());
}));

py::enum_<gin_rummy::GinRummyState::Phase>(state_class, "Phase")
.value("DEAL", gin_rummy::GinRummyState::Phase::kDeal)
.value("FIRST_UPCARD", gin_rummy::GinRummyState::Phase::kFirstUpcard)
.value("DRAW", gin_rummy::GinRummyState::Phase::kDraw)
.value("DISCARD", gin_rummy::GinRummyState::Phase::kDiscard)
.value("KNOCK", gin_rummy::GinRummyState::Phase::kKnock)
.value("LAYOFF", gin_rummy::GinRummyState::Phase::kLayoff)
.value("WALL", gin_rummy::GinRummyState::Phase::kWall)
.value("GAME_OVER", gin_rummy::GinRummyState::Phase::kGameOver)
.export_values();

py::classh<GinRummyGame, Game>(m, "GinRummyGame")
.def("oklahoma", &GinRummyGame::Oklahoma)
.def("knock_card", &GinRummyGame::KnockCard)
// Pickle support
.def(py::pickle(
[](std::shared_ptr<const GinRummyGame> game) { // __getstate__
return game->ToString();
},
[](const std::string& data) { // __setstate__
return std::dynamic_pointer_cast<GinRummyGame>(
std::const_pointer_cast<Game>(LoadGame(data)));
}));
}
} // namespace open_spiel

25 changes: 25 additions & 0 deletions open_spiel/python/pybind11/games_gin_rummy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2022 DeepMind Technologies Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef OPEN_SPIEL_PYTHON_PYBIND11_GAMES_GIN_RUMMY_H_
#define OPEN_SPIEL_PYTHON_PYBIND11_GAMES_GIN_RUMMY_H_

#include "open_spiel/python/pybind11/pybind11.h"

// Initialize the Python interface for gin_rummy.
namespace open_spiel {
void init_pyspiel_games_gin_rummy(::pybind11::module &m);
}

#endif // OPEN_SPIEL_PYTHON_PYBIND11_GAMES_GIN_RUMMY_H_
2 changes: 2 additions & 0 deletions open_spiel/python/pybind11/pyspiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "open_spiel/python/pybind11/games_chess.h"
#include "open_spiel/python/pybind11/games_colored_trails.h"
#include "open_spiel/python/pybind11/games_euchre.h"
#include "open_spiel/python/pybind11/games_gin_rummy.h"
#include "open_spiel/python/pybind11/games_kuhn_poker.h"
#include "open_spiel/python/pybind11/games_leduc_poker.h"
#include "open_spiel/python/pybind11/games_negotiation.h"
Expand Down Expand Up @@ -624,6 +625,7 @@ PYBIND11_MODULE(pyspiel, m) {
init_pyspiel_games_chess(m); // Chess game.
init_pyspiel_games_colored_trails(m); // Colored Trails game.
init_pyspiel_games_euchre(m); // Game-specific functions for euchre.
init_pyspiel_games_gin_rummy(m); // Game-specific functions for gin_rummy.
init_pyspiel_games_kuhn_poker(m); // Kuhn Poker game.
init_pyspiel_games_leduc_poker(m); // Leduc poker game.
init_pyspiel_games_negotiation(m); // Negotiation game.
Expand Down
41 changes: 41 additions & 0 deletions open_spiel/python/tests/games_gin_rummy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the game-specific functions for gin rummy."""


from absl.testing import absltest

import pyspiel


class GamesGinRummyTest(absltest.TestCase):

def test_bindings(self):
game = pyspiel.load_game('gin_rummy')
self.assertFalse(game.oklahoma())
self.assertEqual(game.knock_card(), 10)
state = game.new_initial_state()
self.assertEqual(state.current_phase(), state.Phase.DEAL)
self.assertEqual(state.current_player(), pyspiel.PlayerId.CHANCE)
self.assertIsNone(state.upcard())
self.assertEqual(state.stock_size(), 52)
self.assertEqual(state.hands(), [[], []])
self.assertEqual(state.discard_pile(), [])
self.assertEqual(state.deadwood(), [0, 0])
self.assertEqual(state.knocked(), [False, False])


if __name__ == '__main__':
absltest.main()

0 comments on commit ad1a0a1

Please sign in to comment.