Skip to content

Commit

Permalink
Refactors state_distribution.cc to remove special-casing for Universa…
Browse files Browse the repository at this point in the history
…l Poker.

PiperOrigin-RevId: 287336222
Change-Id: I13d43fdf30520052050894cb0247f1156fc20a0d
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Dec 27, 2019
1 parent b33a612 commit f76d74e
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
14 changes: 6 additions & 8 deletions open_spiel/algorithms/state_distribution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

#include "open_spiel/abseil-cpp/absl/algorithm/container.h"
#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
#include "open_spiel/games/universal_poker.h"
#include "open_spiel/simultaneous_move_game.h"
#include "open_spiel/spiel.h"

Expand Down Expand Up @@ -153,13 +152,12 @@ HistoryDistribution UpdateIncrementalStateDistribution(
const State& state, const Policy* opponent_policy, int player_id,
const HistoryDistribution& previous) {
if (previous.first.empty()) {
// We have a special case here as universal_poker can be very big.
if (state.GetGame()->GetType().short_name == "universal_poker" &&
state.GetGame()->NumPlayers() == 2) {
auto up_state =
dynamic_cast<const universal_poker::UniversalPokerState*>(&state);
return up_state->GetHistoriesConsistentWithInfostate();
}
HistoryDistribution dist = state.GetHistoriesConsistentWithInfostate();

// If the game didn't implement GetHistoriesConsistentWithInfostate, then
// this is empty, otherwise, we're good.
if (!dist.first.empty()) return dist;

// If the previous pair is empty, then we have to do a BFS to find all
// relevant nodes:
return GetStateDistribution(state, opponent_policy);
Expand Down
3 changes: 0 additions & 3 deletions open_spiel/algorithms/state_distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
namespace open_spiel {
namespace algorithms {

using HistoryDistribution =
std::pair<std::vector<std::unique_ptr<State>>, std::vector<double>>;

// Returns a distribution over states at the information state containing the
// specified state given the opponents' policies. That is, it returns
// Pr(h | s, \pi_{-i}) by normalizing the opponents' reach probabilities over
Expand Down
3 changes: 2 additions & 1 deletion open_spiel/games/universal_poker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ double UniversalPokerState::GetTotalReward(Player player) const {
HistoryDistribution UniversalPokerState::GetHistoriesConsistentWithInfostate()
const {
// This is only implemented for 2 players.
SPIEL_CHECK_EQ(acpc_game_->GetNbPlayers(), 2);
if (acpc_game_->GetNbPlayers() != 2) return {};

logic::CardSet is_cards;
const logic::CardSet &our_cards = hole_cards_[cur_player_];
for (uint8_t card : our_cards.ToCardArray()) is_cards.AddCard(card);
Expand Down
7 changes: 1 addition & 6 deletions open_spiel/games/universal_poker.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@
namespace open_spiel {
namespace universal_poker {

// We alias this here as we can't import state_distribution.h or we'd have a
// circular dependency.
using HistoryDistribution =
std::pair<std::vector<std::unique_ptr<State>>, std::vector<double>>;

class UniversalPokerGame;

constexpr uint8_t kMaxUniversalPokerPlayers = 10;
Expand Down Expand Up @@ -72,7 +67,7 @@ class UniversalPokerState : public State {
std::vector<Action> LegalActions() const override;

// Used to make UpdateIncrementalStateDistribution much faster.
HistoryDistribution GetHistoriesConsistentWithInfostate() const;
HistoryDistribution GetHistoriesConsistentWithInfostate() const override;

protected:
void DoApplyAction(Action action_id) override;
Expand Down
13 changes: 13 additions & 0 deletions open_spiel/spiel.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,14 @@ class State {
SpielFatalError("ResampleFromInfostate() not implemented.");
}

// Returns a vector of states & probabilities that are consistent with the
// infostate from the view of the current player. By default, this is not
// implemented and returns an empty list.
virtual std::pair<std::vector<std::unique_ptr<State>>, std::vector<double>>
GetHistoriesConsistentWithInfostate() const {
return {};
}

protected:
// See ApplyAction.
virtual void DoApplyAction(Action action_id) {
Expand Down Expand Up @@ -782,6 +790,11 @@ std::string SerializeGameAndState(const Game& game, const State& state);
std::pair<std::shared_ptr<const Game>, std::unique_ptr<State>>
DeserializeGameAndState(const std::string& serialized_state);

// We alias this here as we can't import state_distribution.h or we'd have a
// circular dependency.
using HistoryDistribution =
std::pair<std::vector<std::unique_ptr<State>>, std::vector<double>>;

} // namespace open_spiel

#endif // THIRD_PARTY_OPEN_SPIEL_SPIEL_H_

0 comments on commit f76d74e

Please sign in to comment.