Skip to content

Commit

Permalink
Merge pull request #170 from inverted-ai/traffic_light_state_machines
Browse files Browse the repository at this point in the history
traffic light state machines
  • Loading branch information
rf-ivtdai committed Jan 16, 2024
2 parents b98927f + f29cd4e commit 127f95f
Show file tree
Hide file tree
Showing 19 changed files with 348 additions and 53 deletions.
32 changes: 29 additions & 3 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
InfractionIndicators,
AgentAttributes,
TrafficLightStatesDict,
LightRecurrentStates,
LightRecurrentState,
)


Expand All @@ -41,7 +43,8 @@ class DriveResponse(BaseModel):
is_inside_supported_area: List[
bool
] #: For each agent, indicates whether the predicted state is inside supported area.
model_version: str # Model version used for this API call
traffic_lights_states: Optional[TrafficLightStatesDict] # Traffic light states for the next time step
light_recurrent_states: Optional[LightRecurrentStates] # Recurrent states for all traffic lights at the next time step


@validate_arguments
Expand All @@ -51,6 +54,7 @@ def drive(
agent_attributes: List[AgentAttributes],
recurrent_states: List[RecurrentState],
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[LightRecurrentStates] = None,
get_birdview: bool = False,
rendering_center: Optional[Tuple[float, float]] = None,
rendering_fov: Optional[float] = None,
Expand Down Expand Up @@ -99,6 +103,9 @@ def drive(
their current state should be provided here. Any traffic light for which no
state is provided will be ignored by the agents.
light_recurrent_states:
Specifies the state and time remaining for each light group in the scene.
random_seed:
Controls the stochastic aspects of agent behavior for reproducibility.
Expand Down Expand Up @@ -141,6 +148,7 @@ def _tolist(input_data: List):
agent_attributes=[state.tolist() for state in agent_attributes],
recurrent_states=[r.packed for r in recurrent_states],
traffic_lights_states=traffic_lights_states,
light_recurrent_states=light_recurrent_states,
get_birdview=get_birdview,
get_infractions=get_infractions,
random_seed=random_seed,
Expand Down Expand Up @@ -172,7 +180,16 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)

return response
Expand Down Expand Up @@ -243,7 +260,16 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)

return response
26 changes: 24 additions & 2 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
TrafficLightStatesDict,
Image,
InfractionIndicators,
LightRecurrentStates,
LightRecurrentState,
)


Expand All @@ -41,6 +43,8 @@ class InitializeResponse(BaseModel):
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states at the first time step
light_recurrent_states: Optional[LightRecurrentStates] # To pass to :func:`iai.drive` at the first time step
model_version: str # Model version used for this API call


Expand Down Expand Up @@ -181,7 +185,16 @@ def initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)
return response
except TryAgain as e:
Expand Down Expand Up @@ -251,6 +264,15 @@ async def async_initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)
return response
6 changes: 6 additions & 0 deletions invertedai/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class TrafficLightState(str, Enum):
red = "red"


class LightRecurrentState(BaseModel):
state: int
ticks_remaining: int


class AgentType(str, Enum):
car = "car"
pedestrian = "pedestrian"
Expand Down Expand Up @@ -262,3 +267,4 @@ def fromdict(cls, d):


TrafficLightStatesDict = Dict[TrafficLightId, TrafficLightState]
LightRecurrentStates = List[LightRecurrentState]
1 change: 1 addition & 0 deletions invertedai_cpp/examples/drive_body.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"get_birdview": true,
"get_infractions": true,
"traffic_lights_states": null,
"light_recurrent_states": null,
"random_seed": null,
"rendering_fov": null,
"rendering_center": null
Expand Down
12 changes: 9 additions & 3 deletions invertedai_cpp/examples/initialize_body.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
{
"location": "carla:Town03",
"num_agents_to_spawn": 10,
"location": "iai:terminal_and_quebec",
"num_agents_to_spawn": 20,
"states_history": null,
"agent_attributes": null,
"agent_attributes": [
["car"],
["pedestrian"],
["pedestrian"],
["pedestrian"],
["pedestrian"]
],
"traffic_light_state_history": null,
"get_birdview": true,
"get_infractions": true,
Expand Down
4 changes: 2 additions & 2 deletions invertedai_cpp/examples/initialize_sampling_with_types.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"location": "carla:Town04",
"num_agents_to_spawn": 10,
"location": "iai:drake_street_and_pacific_blvd",
"num_agents_to_spawn": 20,
"states_history": null,
"agent_attributes": [
["car"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"location": "canada:ubc_roundabout",
"num_agents_to_spawn": 10,
"num_agents_to_spawn": 20,
"states_history": [
[
[
Expand Down
8 changes: 8 additions & 0 deletions invertedai_cpp/invertedai/data_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ struct TrafficLightState {
std::string value;
};

/**
* Light recurrent state that contains the current state and ticks remaining in this state.
*/
struct LightRecurrentState {
float state;
float ticks_remaining;
};

/**
* Infractions committed by a given agent, as returned from invertedai::drive().
*/
Expand Down
72 changes: 57 additions & 15 deletions invertedai_cpp/invertedai/drive_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,33 @@ DriveRequest::DriveRequest(const std::string &body_str) {
}
this->recurrent_states_.push_back(recurrent_state);
}
this->traffic_lights_states_.clear();
for (const auto &element : this->body_json_["traffic_lights_states"]) {
TrafficLightState traffic_light_state = {
element[0],
element[1]
};
this->traffic_lights_states_.push_back(traffic_light_state);
if (this->body_json_["traffic_lights_states"].is_null()) {
this->traffic_lights_states_ = std::nullopt;
} else {
if (this->traffic_lights_states_.has_value()) {
this->traffic_lights_states_.value().clear();
} else {
this->traffic_lights_states_ = std::map<std::string, std::string>();
}
for (const auto &element : this->body_json_["traffic_lights_states"].items()) {
this->traffic_lights_states_.value()[element.key()] = element.value();
}
}
if (this->body_json_["light_recurrent_states"].is_null()) {
this->light_recurrent_states_ = std::nullopt;
} else {
if (this->light_recurrent_states_.has_value()) {
this->light_recurrent_states_.value().clear();
} else {
this->light_recurrent_states_ = std::vector<LightRecurrentState>();
}
for (const auto &element : this->body_json_["light_recurrent_states"]) {
LightRecurrentState light_recurrent_state = {
element[0],
element[1]
};
this->light_recurrent_states_.value().push_back(light_recurrent_state);
}
}
this->get_birdview_ = this->body_json_["get_birdview"].is_boolean()
? this->body_json_["get_birdview"].get<bool>()
Expand Down Expand Up @@ -89,12 +109,24 @@ void DriveRequest::refresh_body_json_() {
this->body_json_["recurrent_states"].push_back(elements);
}
this->body_json_["traffic_lights_states"].clear();
for (const TrafficLightState &traffic_light_state : this->traffic_lights_states_) {
json element = {
traffic_light_state.id,
traffic_light_state.value
};
this->body_json_["traffic_lights_states"].push_back(element);
if (this->traffic_lights_states_.has_value()) {
for (const auto &pair : this->traffic_lights_states_.value()) {
this->body_json_["traffic_lights_states"][pair.first] = pair.second;
}
} else {
this->body_json_["traffic_lights_states"] = nullptr;
}
this->body_json_["light_recurrent_states"].clear();
if (this->light_recurrent_states_.has_value()) {
for (const LightRecurrentState &light_recurrent_state : this->light_recurrent_states_.value()) {
json element = {
light_recurrent_state.state,
light_recurrent_state.ticks_remaining
};
this->body_json_["light_recurrent_states"].push_back(element);
}
} else {
this->body_json_["light_recurrent_states"] = nullptr;
}
this->body_json_["get_birdview"] = this->get_birdview_;
this->body_json_["get_infractions"] = this->get_infractions_;
Expand Down Expand Up @@ -125,11 +157,13 @@ void DriveRequest::update(const InitializeResponse &init_res) {
this->agent_states_ = init_res.agent_states();
this->agent_attributes_ = init_res.agent_attributes();
this->recurrent_states_ = init_res.recurrent_states();
this->light_recurrent_states_ = init_res.light_recurrent_states();
}

void DriveRequest::update(const DriveResponse &drive_res) {
this->agent_states_ = drive_res.agent_states();
this->recurrent_states_ = drive_res.recurrent_states();
this->light_recurrent_states_ = drive_res.light_recurrent_states();
}

std::string DriveRequest::body_str() {
Expand All @@ -149,14 +183,18 @@ std::vector<AgentAttributes> DriveRequest::agent_attributes() const {
return this->agent_attributes_;
};

std::vector<TrafficLightState> DriveRequest::traffic_lights_states() const {
std::optional<std::map<std::string, std::string>> DriveRequest::traffic_lights_states() const {
return this->traffic_lights_states_;
};

std::vector<std::vector<double>> DriveRequest::recurrent_states() const {
return this->recurrent_states_;
};

std::optional<std::vector<LightRecurrentState>> DriveRequest::light_recurrent_states() const {
return this->light_recurrent_states_;
};

bool DriveRequest::get_birdview() const {
return this->get_birdview_;
}
Expand Down Expand Up @@ -193,10 +231,14 @@ void DriveRequest::set_agent_attributes(const std::vector<AgentAttributes> &agen
this->agent_attributes_ = agent_attributes;
}

void DriveRequest::set_traffic_lights_states(const std::vector<TrafficLightState> &traffic_lights_states) {
void DriveRequest::set_traffic_lights_states(const std::map<std::string, std::string> &traffic_lights_states) {
this->traffic_lights_states_ = traffic_lights_states;
}

void DriveRequest::set_light_recurrent_states(const std::vector<LightRecurrentState> &light_recurrent_states) {
this->light_recurrent_states_ = light_recurrent_states;
}

void DriveRequest::set_recurrent_states(const std::vector<std::vector<double>> &recurrent_states) {
this->recurrent_states_ = recurrent_states;
}
Expand Down
17 changes: 14 additions & 3 deletions invertedai_cpp/invertedai/drive_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <vector>
#include <map>

#include "externals/json.hpp"

Expand All @@ -20,7 +21,8 @@ class DriveRequest {
std::string location_;
std::vector<AgentState> agent_states_;
std::vector<AgentAttributes> agent_attributes_;
std::vector<TrafficLightState> traffic_lights_states_;
std::optional<std::map<std::string, std::string>> traffic_lights_states_;
std::optional<std::vector<LightRecurrentState>> light_recurrent_states_;
std::vector<std::vector<double>> recurrent_states_;
bool get_birdview_;
bool get_infractions_;
Expand Down Expand Up @@ -73,11 +75,15 @@ class DriveRequest {
/**
* Get the states of traffic lights.
*/
std::vector<TrafficLightState> traffic_lights_states() const;
std::optional<std::map<std::string, std::string>> traffic_lights_states() const;
/**
* Get the recurrent states for all agents.
*/
std::vector<std::vector<double>> recurrent_states() const;
/**
* Get the recurrent states for all light groups in location.
*/
std::optional<std::vector<LightRecurrentState>> light_recurrent_states() const;
/**
* Check whether to return an image visualizing the simulation state.
*/
Expand Down Expand Up @@ -126,13 +132,18 @@ class DriveRequest {
* traffic light for which no state is provided will be ignored by the agents.
*/
void set_traffic_lights_states(
const std::vector<TrafficLightState> &traffic_lights_states);
const std::map<std::string, std::string> &traffic_lights_states);
/**
* Set the recurrent states for all agents, obtained from the
* previous call to drive() or initialize().
*/
void set_recurrent_states(
const std::vector<std::vector<double>> &recurrent_states);
/**
* Set light recurrent states for all light groups in location.
*/
void set_light_recurrent_states(
const std::vector<LightRecurrentState> &light_recurrent_states);
/**
* Set whether to return an image visualizing the simulation state.
* This is very slow and should only be used for debugging.
Expand Down

0 comments on commit 127f95f

Please sign in to comment.