Skip to content

Commit

Permalink
Add helper function to get bargaining action by quantities.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473450346
Change-Id: I1a963aae23ab73d22f81cf4b68bd791d120f83c3
  • Loading branch information
lanctot committed Sep 11, 2022
1 parent 7fa9941 commit fec7620
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions open_spiel/games/bargaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ int BargainingGame::NumDistinctActions() const {
return all_offers_.size() + 1;
}

std::pair<Offer, Action> BargainingGame::GetOfferByQuantities(
const std::vector<int>& quantities) const {
for (int i = 0; i < all_offers_.size(); ++i) {
if (quantities == all_offers_[i].quantities) {
return {all_offers_[i], i};
}
}
return {Offer(), kInvalidAction};
}


std::vector<int> BargainingGame::ObservationTensorShape() const {
return {
1 + // Agreement reached?
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/games/bargaining.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class BargainingGame : public Game {
const std::vector<Offer>& AllOffers() const { return all_offers_; }
const Instance& GetInstance(int num) const { return all_instances_[num]; }
const Offer& GetOffer(int num) const { return all_offers_[num]; }
std::pair<Offer, Action> GetOfferByQuantities(
const std::vector<int>& quantities) const;

private:
void ParseInstancesFile(const std::string& filename);
Expand Down
8 changes: 8 additions & 0 deletions open_spiel/python/pybind11/games_bargaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using open_spiel::State;
using open_spiel::bargaining::BargainingGame;
using open_spiel::bargaining::BargainingState;
using open_spiel::bargaining::Instance;
using open_spiel::bargaining::Offer;

PYBIND11_SMART_HOLDER_TYPE_CASTERS(BargainingGame);
PYBIND11_SMART_HOLDER_TYPE_CASTERS(BargainingState);
Expand All @@ -34,6 +35,10 @@ void open_spiel::init_pyspiel_games_bargaining(py::module& m) {
.def_readwrite("pool", &Instance::pool)
.def_readwrite("values", &Instance::values);

py::class_<Offer>(m, "Offer")
.def(py::init<>())
.def_readwrite("quantities", &Offer::quantities);

py::classh<BargainingState, State>(m, "BargainingState")
.def("instance", &BargainingState::instance)
.def("agree_action", &BargainingState::AgreeAction)
Expand All @@ -54,6 +59,9 @@ void open_spiel::init_pyspiel_games_bargaining(py::module& m) {

py::classh<BargainingGame, Game>(m, "BargainingGame")
.def("all_instances", &BargainingGame::AllInstances)
// get_offer_by_quantities(quantities: List[int]). Returns a tuple
// of (offer, OpenSpiel action)
.def("get_offer_by_quantities", &BargainingGame::GetOfferByQuantities)
// Pickle support
.def(py::pickle(
[](std::shared_ptr<const BargainingGame> game) { // __getstate__
Expand Down

0 comments on commit fec7620

Please sign in to comment.