Skip to content

Commit

Permalink
Add ObservationString, ObservationTensor, and ObservationTensorShape …
Browse files Browse the repository at this point in the history
…to Goofspiel and make both tensors observing player-relative

PiperOrigin-RevId: 283363694
Change-Id: I205363c9d641a574e2018cb9312672c8b15e8d9f
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Dec 2, 2019
1 parent 2776d59 commit 84883f1
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 44 deletions.
178 changes: 163 additions & 15 deletions open_spiel/games/goofspiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>
#include <utility>

#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
#include "open_spiel/game_parameters.h"
#include "open_spiel/spiel.h"

Expand All @@ -37,8 +38,8 @@ const GameType kGameType{
/*min_num_players=*/2,
/*provides_information_state_string=*/true,
/*provides_information_state_tensor=*/true,
/*provides_observation_string=*/false,
/*provides_observation_tensor=*/false,
/*provides_observation_string=*/true,
/*provides_observation_tensor=*/true,
/*parameter_specification=*/
{{"imp_info", GameParameter(kDefaultImpInfo)},
{"num_cards", GameParameter(kDefaultNumCards)},
Expand Down Expand Up @@ -378,6 +379,64 @@ std::string GoofspielState::InformationStateString(Player player) const {
}
}

std::string GoofspielState::ObservationString(Player player) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);

// Perfect info case, show:
// - current point card
// - everyone's current points
// - everyone's current hands
// Imperfect info case, show:
// - current point card
// - everyone's current points
// - my current hand
// - current win sequence
std::string current_trick =
absl::StrCat("Current point card: ", point_card_index_ + 1);
std::string points_line = "Points: ";
std::string hands = "";
std::string win_seq = "Win Sequence: ";
for (auto p = Player{0}; p < num_players_; ++p) {
absl::StrAppend(&points_line, points_[p], " ");
}

if (impinfo_) {
// Only my hand
absl::StrAppend(&hands, "P", player, " hand: ");
for (int c = 0; c < num_cards_; ++c) {
if (player_hands_[player][c]) {
absl::StrAppend(&hands, c + 1, " ");
}
}
absl::StrAppend(&hands, "\n");

// Show the win sequence.
for (int i = 0; i < win_sequence_.size(); ++i) {
absl::StrAppend(&win_seq, win_sequence_[i], " ");
}
return absl::StrCat(current_trick, "\n", points_line, "\n", hands, win_seq,
"\n");
} else {
// Show the hands in the perfect info case.
for (auto p = Player{0}; p < num_players_; ++p) {
absl::StrAppend(&hands, "P", p, " hand: ");
for (int c = 0; c < num_cards_; ++c) {
if (player_hands_[p][c]) {
absl::StrAppend(&hands, c + 1, " ");
}
}
absl::StrAppend(&hands, "\n");
}
return absl::StrCat(current_trick, "\n", points_line, "\n", hands);
}
}

void GoofspielState::NextPlayer(int* count, Player* player) const {
*count += 1;
*player = (*player + 1) % num_players_;
}

void GoofspielState::InformationStateTensor(Player player,
std::vector<double>* values) const {
SPIEL_CHECK_GE(player, 0);
Expand All @@ -386,13 +445,9 @@ void GoofspielState::InformationStateTensor(Player player,
values->clear();
values->reserve(game_->InformationStateTensorSize());

// 1-hot vector for the observing player.
for (auto p = Player{0}; p < num_players_; ++p) {
values->push_back(p == player ? 1 : 0);
}

// Point totals: one-hot vector encoding points, per player.
for (auto p = Player{0}; p < num_players_; ++p) {
Player p = player;
for (int n = 0; n < num_players_; NextPlayer(&n, &p)) {
// Cards numbered 1 .. K
int max_points_slots = (num_cards_ * (num_cards_ + 1)) / 2 + 1;
for (int i = 0; i < max_points_slots; ++i) {
Expand Down Expand Up @@ -451,12 +506,77 @@ void GoofspielState::InformationStateTensor(Player player,
for (int i = 0; i < future_tricks * num_cards_; ++i) values->push_back(0);

// Bit vectors encoding all players' hands.
for (auto p = Player{0}; p < num_players_; ++p) {
p = player;
for (int n = 0; n < num_players_; NextPlayer(&n, &p)) {
for (int c = 0; c < num_cards_; ++c) {
values->push_back(player_hands_[p][c] ? 1 : 0);
}
}
}

SPIEL_CHECK_EQ(values->size(), game_->InformationStateTensorSize());
}

void GoofspielState::ObservationTensor(Player player,
std::vector<double>* values) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);

values->clear();
values->reserve(game_->ObservationTensorSize());

// Perfect info case, show:
// - one-hot encoding the current point card
// - everyone's current points
// - everyone's current hands
// Imperfect info case, show:
// - one-hot encoding the current point card
// - everyone's current points
// - my current hand
// - current win sequence

// Current point card.
for (int i = 0; i < num_cards_; ++i) {
values->push_back(i == point_card_index_ ? 1.0 : 0.0);
}

// Point totals: one-hot vector encoding points, per player.
Player p = player;
for (int n = 0; n < num_players_; NextPlayer(&n, &p)) {
// Cards numbered 1 .. K
int max_points_slots = (num_cards_ * (num_cards_ + 1)) / 2 + 1;
for (int i = 0; i < max_points_slots; ++i) {
values->push_back(i == points_[p] ? 1 : 0);
}
}

if (impinfo_) {
// Bit vector of observing player's hand.
for (int c = 0; c < num_cards_; ++c) {
values->push_back(player_hands_[player][c] ? 1 : 0);
}

// Sequence of who won each trick.
for (int i = 0; i < win_sequence_.size(); ++i) {
for (auto p = Player{0}; p < num_players_; ++p) {
values->push_back(win_sequence_[i] == p ? 1 : 0);
}
}

// Padding for future tricks
int future_tricks = num_cards_ - win_sequence_.size();
for (int i = 0; i < future_tricks * num_players_; ++i) values->push_back(0);
} else {
// Bit vectors encoding all players' hands.
p = player;
for (int n = 0; n < num_players_; NextPlayer(&n, &p)) {
for (int c = 0; c < num_cards_; ++c) {
values->push_back(player_hands_[p][c] ? 1 : 0);
}
}
}

SPIEL_CHECK_EQ(values->size(), game_->ObservationTensorSize());
}

std::unique_ptr<State> GoofspielState::Clone() const {
Expand Down Expand Up @@ -486,9 +606,7 @@ int GoofspielGame::MaxChanceOutcomes() const {

std::vector<int> GoofspielGame::InformationStateTensorShape() const {
if (impinfo_) {
return {// 1-hot bit vector for observing player
num_players_ +
// 1-hot bit vector for point total per player; upper bound is 1 +
return {// 1-hot bit vector for point total per player; upper bound is 1 +
// 2 + ... + K = K*(K+1) / 2, but must add one to include 0 points.
num_players_ * ((num_cards_ * (num_cards_ + 1)) / 2 + 1) +
// Bit vector for my remaining cards:
Expand All @@ -501,9 +619,7 @@ std::vector<int> GoofspielGame::InformationStateTensorShape() const {
// The observing player's own action sequence
num_cards_ * num_cards_};
} else {
return {// 1-hot bit vector for observing player
num_players_ +
// 1-hot bit vector for point total per player; upper bound is 1 +
return {// 1-hot bit vector for point total per player; upper bound is 1 +
// 2 + ... + K = K*(K+1) / 2, but must add one to include 0 points.
num_players_ * ((num_cards_ * (num_cards_ + 1)) / 2 + 1) +
// A sequence of 1-hot bit vectors encoding the point card sequence
Expand All @@ -513,5 +629,37 @@ std::vector<int> GoofspielGame::InformationStateTensorShape() const {
}
}

std::vector<int> GoofspielGame::ObservationTensorShape() const {
// Perfect info case, show:
// - current point card showing
// - everyone's current points
// - everyone's current hands
// Imperfect info case, show:
// - current point card showing
// - everyone's current points
// - my current hand
// - current win sequence
if (impinfo_) {
return {// 1-hot bit to encode the current point card
num_cards_ +
// 1-hot bit vector for point total per player; upper bound is 1 +
// 2 + ... + K = K*(K+1) / 2, but must add one to include 0 points.
num_players_ * ((num_cards_ * (num_cards_ + 1)) / 2 + 1) +
// Bit vector for my remaining cards:
num_cards_ +
// A sequence of 1-hot bit vectors encoding the player who won that
// turn, where max number of turns is num_cards
num_cards_ * num_players_};
} else {
return {// 1-hot bit to encode the current point card
num_cards_ +
// 1-hot bit vector for point total per player; upper bound is 1 +
// 2 + ... + K = K*(K+1) / 2, but must add one to include 0 points.
num_players_ * ((num_cards_ * (num_cards_ + 1)) / 2 + 1) +
// Bit vector for each card per player
num_players_ * num_cards_};
}
}

} // namespace goofspiel
} // namespace open_spiel
9 changes: 8 additions & 1 deletion open_spiel/games/goofspiel.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ class GoofspielState : public SimMoveState {
bool IsTerminal() const override;
std::vector<double> Returns() const override;
std::string InformationStateString(Player player) const override;
std::string ObservationString(Player player) const override;

void InformationStateTensor(Player player,
std::vector<double>* values) const override;
void ObservationTensor(Player player,
std::vector<double>* values) const override;
std::unique_ptr<State> Clone() const override;
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;

Expand All @@ -84,6 +87,9 @@ class GoofspielState : public SimMoveState {
void DoApplyActions(const std::vector<Action>& actions) override;

private:
// Increments the count and increments the player mod num_players_.
void NextPlayer(int* count, Player* player) const;

int num_cards_;
PointsOrder points_order_;
bool impinfo_;
Expand All @@ -96,7 +102,7 @@ class GoofspielState : public SimMoveState {
std::vector<int> point_deck_; // Current point deck.
std::vector<std::vector<bool>> player_hands_; // true if card is in hand.
std::vector<int> point_card_sequence_;
std::vector<int> win_sequence_;
std::vector<int> win_sequence_; // Which player won
std::vector<std::vector<Action>> actions_history_;
};

Expand All @@ -115,6 +121,7 @@ class GoofspielGame : public Game {
return std::shared_ptr<const Game>(new GoofspielGame(*this));
}
std::vector<int> InformationStateTensorShape() const override;
std::vector<int> ObservationTensorShape() const override;
int MaxGameLength() const override { return num_cards_; }

private:
Expand Down

0 comments on commit 84883f1

Please sign in to comment.