Skip to content

Commit

Permalink
Use the same state to string function in the C++ and Python mean fiel…
Browse files Browse the repository at this point in the history
…d routing game.

PiperOrigin-RevId: 480925836
Change-Id: Ie479e7d23db40d6365f9d9b504840879271d880f
  • Loading branch information
TheoCabannes authored and lanctot committed Oct 18, 2022
1 parent 444e9e7 commit 12f3992
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 131 deletions.
20 changes: 8 additions & 12 deletions open_spiel/games/mfg/dynamic_routing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ MeanFieldRoutingGame::MeanFieldRoutingGame(const GameParameters& params)
SPIEL_CHECK_NE(max_num_time_step, 0);
time_step_length_ =
ParameterValue<double>("time_step_length", kDefaultTimeStepLength);
network_name_ =
ParameterValue<std::string>("network", kDefaultNetworkName);
network_name_ = ParameterValue<std::string>("network", kDefaultNetworkName);
SPIEL_CHECK_NE(network_name_, "");
perform_sanity_checks_ = ParameterValue<bool>("perform_sanity_checks", true);
std::unique_ptr<DynamicRoutingData> data =
Expand Down Expand Up @@ -217,26 +216,23 @@ std::string MeanFieldRoutingGameState::StateToString(
if (is_chance_init_) {
return "initial chance node";
}
if (player_id == PlayerId::kDefaultPlayerId) {
time = absl::StrFormat("%d_default", time_step);
if (player_id == PlayerId::kDefaultPlayerId ||
player_id == PlayerId::kTerminalPlayerId) {
time = absl::StrCat(time_step);
} else if (player_id == PlayerId::kMeanFieldPlayerId) {
time = absl::StrFormat("%d_mean_field", time_step);
} else if (player_id == PlayerId::kChancePlayerId) {
time = absl::StrFormat("%d_chance", time_step);
} else if (player_id == PlayerId::kTerminalPlayerId) {
time = absl::StrFormat("%d_terminal", time_step);
} else {
SpielFatalError(
"Player id should be DEFAULT_PLAYER_ID, MEAN_FIELD or CHANCE");
}
if (vehicle_final_travel_time_ != 0.0) {
return absl::StrFormat(
"Arrived at %s, with travel time %f, t=%s, return=%.2f", location,
vehicle_final_travel_time_, time, ret);
return absl::StrFormat("Arrived at %s, with arrival time %.2f, t=%s",
location, vehicle_final_travel_time_, time);
}
return absl::StrFormat(
"Location=%s, waiting time=%d, t=%s, destination=%s, return=%.2f",
location, waiting_time, time, destination, ret);
return absl::StrFormat("Location=%s, waiting time=%d, t=%s, destination=%s",
location, waiting_time, time, destination);
}

std::vector<Action> MeanFieldRoutingGameState::LegalActions() const {
Expand Down
150 changes: 51 additions & 99 deletions open_spiel/games/mfg/dynamic_routing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,71 +68,56 @@ void TestWholeGameWithLineNetwork() {
SPIEL_CHECK_EQ(state->CurrentPlayer(), kDefaultPlayerId);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=bef_O->O, waiting time=0, t=0_default, destination=D->aft_D"
", return=0.00");
"Location=bef_O->O, waiting time=0, t=0, destination=D->aft_D");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{3});
state->ApplyAction(3);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=-1, t=1_mean_field, destination=D->aft_D"
", return=0.00");
"Location=O->A, waiting time=-1, t=1_mean_field, destination=D->aft_D");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=1, t=1_default, destination=D->aft_D"
", return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Location=O->A, waiting time=1, t=1, destination=D->aft_D");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=0, t=2_mean_field, destination=D->aft_D"
", return=0.00");
"Location=O->A, waiting time=0, t=2_mean_field, destination=D->aft_D");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=0, t=2_default, destination=D->aft_D"
", return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Location=O->A, waiting time=0, t=2, destination=D->aft_D");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{1});
state->ApplyAction(1);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->D, waiting time=-1, t=3_mean_field, destination=D->aft_D"
", return=0.00");
"Location=A->D, waiting time=-1, t=3_mean_field, destination=D->aft_D");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->D, waiting time=1, t=3_default, destination=D->aft_D"
", return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->D, waiting time=1, t=3, destination=D->aft_D");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->D, waiting time=0, t=4_mean_field, destination=D->aft_D"
", return=0.00");
"Location=A->D, waiting time=0, t=4_mean_field, destination=D->aft_D");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->D, waiting time=0, t=4_default, destination=D->aft_D"
", return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->D, waiting time=0, t=4, destination=D->aft_D");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{2});
state->ApplyAction(2);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->aft_D, with travel time 4.000000, t=5_terminal"
", return=-2.00");
"Arrived at D->aft_D, with arrival time 4.00, t=5");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->aft_D, with travel time 4.000000, t=5_terminal"
", return=-2.00");
"Arrived at D->aft_D, with arrival time 4.00, t=5");
}

void TestWholeGameWithBraessNetwork() {
Expand All @@ -150,149 +135,124 @@ void TestWholeGameWithBraessNetwork() {
state->ApplyAction(0);
SPIEL_CHECK_EQ(state->CurrentPlayer(), kDefaultPlayerId);
SPIEL_CHECK_EQ(state->ToString(),
"Location=O->A, waiting time=0, t=0_default, destination=D->E"
", return=0.00");
"Location=O->A, waiting time=0, t=0, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{1, 2}));
state->ApplyAction(1);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->B, waiting time=-1, t=1_mean_field, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=-1, t=1_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->B, waiting time=3, t=1_default, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=3, t=1, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{0}));
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->B, waiting time=2, t=2_mean_field, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=2, t=2_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->B, waiting time=2, t=2_default, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=2, t=2, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{0}));
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->B, waiting time=1, t=3_mean_field, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=1, t=3_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->B, waiting time=1, t=3_default, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=1, t=3, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{0}));
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->B, waiting time=0, t=4_mean_field, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=0, t=4_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->B, waiting time=0, t=4_default, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=0, t=4, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{3, 4}));
state->ApplyAction(3);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=B->C, waiting time=-1, t=5_mean_field, destination=D->E"
", return=0.00");
"Location=B->C, waiting time=-1, t=5_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=B->C, waiting time=0, t=5_default, destination=D->E"
", return=0.00");
"Location=B->C, waiting time=0, t=5, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{5});
state->ApplyAction(5);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=C->D, waiting time=-1, t=6_mean_field, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=-1, t=6_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=C->D, waiting time=3, t=6_default, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=3, t=6, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=C->D, waiting time=2, t=7_mean_field, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=2, t=7_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=C->D, waiting time=2, t=7_default, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=2, t=7, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=C->D, waiting time=1, t=8_mean_field, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=1, t=8_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=C->D, waiting time=1, t=8_default, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=1, t=8, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=C->D, waiting time=0, t=9_mean_field, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=0, t=9_mean_field, destination=D->E");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Location=C->D, waiting time=0, t=9_default, destination=D->E"
", return=0.00");
"Location=C->D, waiting time=0, t=9, destination=D->E");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{6});
state->ApplyAction(6);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=10_mean_field"
", return=0.00");
"Arrived at D->E, with arrival time 9.00, t=10_mean_field");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=10_default, return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with arrival time 9.00, t=10");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=11_mean_field, "
"return=0.00");
"Arrived at D->E, with arrival time 9.00, t=11_mean_field");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=11_default, return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with arrival time 9.00, t=11");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{0});
state->ApplyAction(0);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=12_terminal, "
"return=-4.50");
"Arrived at D->E, with arrival time 9.00, t=12");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at D->E, with travel time 9.000000, t=12_terminal, "
"return=-4.50");
"Arrived at D->E, with arrival time 9.00, t=12");

SPIEL_CHECK_EQ(state->LegalActions(), std::vector<Action>{});
}
Expand All @@ -313,25 +273,20 @@ void TestPreEndedGameWithLineNetwork() {
SPIEL_CHECK_EQ(state->CurrentPlayer(), kDefaultPlayerId);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=bef_O->O, waiting time=0, t=0_default, destination=D->aft_D"
", return=0.00");
"Location=bef_O->O, waiting time=0, t=0, destination=D->aft_D");

state->ApplyAction(state->LegalActions()[0]);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=-1, t=1_mean_field, destination=D->aft_D"
", return=0.00");
"Location=O->A, waiting time=-1, t=1_mean_field, destination=D->aft_D");

state->UpdateDistribution(distribution);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=O->A, waiting time=1, t=1_default, destination=D->aft_D"
", return=0.00");
SPIEL_CHECK_EQ(state->ToString(),
"Location=O->A, waiting time=1, t=1, destination=D->aft_D");

state->ApplyAction(state->LegalActions()[0]);
SPIEL_CHECK_EQ(
state->ToString(),
"Arrived at O->A, with travel time 3.000000, t=2_terminal, return=-1.50");
SPIEL_CHECK_EQ(state->ToString(),
"Arrived at O->A, with arrival time 3.00, t=2");
}

void TestRandomPlayWithLineNetwork() {
Expand All @@ -358,14 +313,12 @@ void TestCorrectTravelTimeUpdate() {
SPIEL_CHECK_EQ(state->ToString(), "Before initial chance node.");
state->ApplyAction(0);
SPIEL_CHECK_EQ(state->ToString(),
"Location=O->A, waiting time=0, t=0_default, destination=D->E"
", return=0.00");
"Location=O->A, waiting time=0, t=0, destination=D->E");
SPIEL_CHECK_EQ(state->LegalActions(), (std::vector<Action>{1, 2}));
state->ApplyAction(1);
SPIEL_CHECK_EQ(
state->ToString(),
"Location=A->B, waiting time=-1, t=1_mean_field, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=-1, t=1_mean_field, destination=D->E");

std::vector<double> distribution{1};
state->UpdateDistribution({.5});
Expand All @@ -374,8 +327,7 @@ void TestCorrectTravelTimeUpdate() {
// Waiting time (in time step) = 1.5 / 0.05 (time step lenght)
// - 1 (one time step for the current time running) = 29
SPIEL_CHECK_EQ(state->ToString(),
"Location=A->B, waiting time=29, t=1_default, destination=D->E"
", return=0.00");
"Location=A->B, waiting time=29, t=1, destination=D->E");
}
} // namespace
} // namespace open_spiel::dynamic_routing
Expand Down

0 comments on commit 12f3992

Please sign in to comment.