Skip to content

Commit

Permalink
Merge pull request #171 from inverted-ai/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
rf-ivtdai committed Jan 17, 2024
2 parents 3b50af9 + 246b76e commit 206e606
Show file tree
Hide file tree
Showing 21 changed files with 382 additions and 72 deletions.
39 changes: 35 additions & 4 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
InfractionIndicators,
AgentAttributes,
TrafficLightStatesDict,
LightRecurrentStates,
LightRecurrentState,
)


Expand All @@ -41,7 +43,8 @@ class DriveResponse(BaseModel):
is_inside_supported_area: List[
bool
] #: For each agent, indicates whether the predicted state is inside supported area.
model_version: str # Model version used for this API call
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group.


@validate_arguments
Expand All @@ -51,6 +54,7 @@ def drive(
agent_attributes: List[AgentAttributes],
recurrent_states: List[RecurrentState],
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[LightRecurrentStates] = None,
get_birdview: bool = False,
rendering_center: Optional[Tuple[float, float]] = None,
rendering_fov: Optional[float] = None,
Expand Down Expand Up @@ -97,7 +101,14 @@ def drive(
traffic_lights_states:
If the location contains traffic lights within the supported area,
their current state should be provided here. Any traffic light for which no
state is provided will be ignored by the agents.
state is provided will have a state generated by iai.
light_recurrent_states:
Light recurrent states for all agents, obtained from the previous call to
:func:`drive` or :func:`initialize`.
Specifies the state and time remaining for each light group in the map.
If manual control of individual traffic lights is desired, modify the relevant state(s)
in traffic_lights_states, then pass in light_recurrent_states as usual.
random_seed:
Controls the stochastic aspects of agent behavior for reproducibility.
Expand Down Expand Up @@ -141,6 +152,8 @@ def _tolist(input_data: List):
agent_attributes=[state.tolist() for state in agent_attributes],
recurrent_states=[r.packed for r in recurrent_states],
traffic_lights_states=traffic_lights_states,
light_recurrent_states=[light_recurrent_state.tolist() for light_recurrent_state in light_recurrent_states]
if light_recurrent_states is not None else None,
get_birdview=get_birdview,
get_infractions=get_infractions,
random_seed=random_seed,
Expand Down Expand Up @@ -172,7 +185,16 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)

return response
Expand Down Expand Up @@ -243,7 +265,16 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)

return response
30 changes: 26 additions & 4 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
TrafficLightStatesDict,
Image,
InfractionIndicators,
LightRecurrentStates,
LightRecurrentState,
)


Expand All @@ -41,6 +43,8 @@ class InitializeResponse(BaseModel):
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group.
model_version: str # Model version used for this API call


Expand Down Expand Up @@ -92,8 +96,8 @@ def initialize(
traffic_light_state_history:
History of traffic light states - the list is over time, in chronological order, i.e.
the last element is the current state. Not specifying traffic light state is equivalent
to disabling traffic lights.
the last element is the current state. If there are traffic lights in the map,
not specifying traffic light state is equivalent to using iai generated light states.
location_of_interest:
Optional coordinates for spawning agents with the given location as center instead of the default map center
Expand Down Expand Up @@ -181,7 +185,16 @@ def initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)
return response
except TryAgain as e:
Expand Down Expand Up @@ -251,6 +264,15 @@ async def async_initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
else None
)
return response
15 changes: 15 additions & 0 deletions invertedai/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,20 @@ class TrafficLightState(str, Enum):
red = "red"


class LightRecurrentState(BaseModel):
"""
Recurrent state of all the traffic lights in one light group (one intersection).
"""
state: float
time_remaining: float

def tolist(self):
"""
Convert LightRecurrentState to a list in this order: [state, time_remaining]
"""
return [self.state, self.time_remaining]


class AgentType(str, Enum):
car = "car"
pedestrian = "pedestrian"
Expand Down Expand Up @@ -262,3 +276,4 @@ def fromdict(cls, d):


TrafficLightStatesDict = Dict[TrafficLightId, TrafficLightState]
LightRecurrentStates = List[LightRecurrentState]
18 changes: 11 additions & 7 deletions invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ async def reinit(reinitialize_agent_state, reinitialize_agent_attrs, area_center
get_birdview=get_birdview,
)
except BaseException:
return [], [], []
return [], [], [], ""
SLACK = 0
valid_agents = list(filter(lambda x: inside_fov(
center=area_center, initialize_fov=initialize_fov - SLACK, point=x[0].center),
Expand All @@ -361,7 +361,7 @@ async def reinit(reinitialize_agent_state, reinitialize_agent_attrs, area_center
file_path = f"{birdview_path}-{(area_center.x, area_center.y)}.jpg"
response.birdview.decode_and_save(file_path)

return valid_agent_state, valid_agent_attrs, valid_agent_rs
return valid_agent_state, valid_agent_attrs, valid_agent_rs, response.model_version

stride = initialize_fov / 2

Expand Down Expand Up @@ -393,15 +393,19 @@ async def reinit(reinitialize_agent_state, reinitialize_agent_attrs, area_center

results = await asyncio.gather(*[reinit(agnts["state"], agnts["attr"], agnts["center"]) for agnts in initialize_payload])

model_version = ""
for result in results:
new_agent_state += result[0]
new_attributes += result[1]
new_recurrent_states += result[2]
model_version = result[3]

return invertedai.api.InitializeResponse(
recurrent_states=new_recurrent_states,
agent_states=new_agent_state,
agent_attributes=new_attributes)
agent_attributes=new_attributes,
model_version=model_version
)


def area_re_initialization(location, agent_attributes, states_history, traffic_lights_states=None, random_seed=None,
Expand Down Expand Up @@ -477,10 +481,10 @@ def inside_fov(center: Point, initialize_fov: float, point: Point) -> bool:
file_path = f"{birdview_path}-{(area_center.x, area_center.y)}.jpg"
response.birdview.decode_and_save(file_path)

return invertedai.api.InitializeResponse(
recurrent_states=new_recurrent_states,
agent_states=new_agent_state,
agent_attributes=new_attributes)
response.recurrent_states = new_recurrent_states
response.agent_states = new_agent_state
response.agent_attributes = new_attributes
return response


def area_initialization(location, agent_density, traffic_lights_states=None, random_seed=None, map_center=(0, 0),
Expand Down
1 change: 1 addition & 0 deletions invertedai_cpp/examples/drive_body.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"get_birdview": true,
"get_infractions": true,
"traffic_lights_states": null,
"light_recurrent_states": null,
"random_seed": null,
"rendering_fov": null,
"rendering_center": null
Expand Down
12 changes: 9 additions & 3 deletions invertedai_cpp/examples/initialize_body.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
{
"location": "carla:Town03",
"num_agents_to_spawn": 10,
"location": "iai:terminal_and_quebec",
"num_agents_to_spawn": 20,
"states_history": null,
"agent_attributes": null,
"agent_attributes": [
["car"],
["pedestrian"],
["pedestrian"],
["pedestrian"],
["pedestrian"]
],
"traffic_light_state_history": null,
"get_birdview": true,
"get_infractions": true,
Expand Down
4 changes: 2 additions & 2 deletions invertedai_cpp/examples/initialize_sampling_with_types.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"location": "carla:Town04",
"num_agents_to_spawn": 10,
"location": "iai:drake_street_and_pacific_blvd",
"num_agents_to_spawn": 20,
"states_history": null,
"agent_attributes": [
["car"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"location": "canada:ubc_roundabout",
"num_agents_to_spawn": 10,
"num_agents_to_spawn": 20,
"states_history": [
[
[
Expand Down
8 changes: 8 additions & 0 deletions invertedai_cpp/invertedai/data_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ struct TrafficLightState {
std::string value;
};

/**
* Recurrent state of all the traffic lights in one light group (one intersection).
*/
struct LightRecurrentState {
float state;
float time_remaining;
};

/**
* Infractions committed by a given agent, as returned from invertedai::drive().
*/
Expand Down

0 comments on commit 206e606

Please sign in to comment.