Skip to content

Commit

Permalink
some changes and cleanup docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
AlirezaMorsali committed Oct 18, 2022
1 parent 4c74d26 commit c2a512d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 48 deletions.
6 changes: 3 additions & 3 deletions examples/Demo_Drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
get_infractions=True,
)
pbar.set_description(
f"Collision rate: {100*np.array(response.infractions.collisions)[-1, :].mean():.2f}% | "
+ f"Off-road rate: {100*np.array(response.infractions.offroad)[-1, :].mean():.2f}% | "
+ f"Wrong-way rate: {100*np.array(response.infractions.wrong_way)[-1, :].mean():.2f}%"
f"Collision rate: {100*np.array(response.infractions.collisions).mean():.2f}% | "
+ f"Off-road rate: {100*np.array(response.infractions.offroad).mean():.2f}% | "
+ f"Wrong-way rate: {100*np.array(response.infractions.wrong_way).mean():.2f}%"
)

birdview = np.array(response.bird_view, dtype=np.uint8)
Expand Down
Binary file modified examples/iai-drive.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 65 additions & 29 deletions invertedai/api_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,22 @@
DriveResponse,
AgentState,
AgentAttributes,
TrafficLightId,
TrafficLightState,
InfractionIndicators,
StaticMapActors,
StaticMapActor,
RecurrentStates,
TrafficLightStates,
TrafficLightId,
TrafficLightState,
)

TIMEOUT = 10


def location_info(
location: str = "iai:ubc_roundabout", include_map_source: bool = True
location: str = "iai:ubc_roundabout", include_map_source: bool = False
) -> LocationResponse:
"""
Providing map information, i.e., rendered Bird's-eye view image, map in OSM format,
list of static agents (traffic lights and traffic signs).
Providing map information, i.e., rendered bird's-eye view image, map in OSM format,
list of static agents (traffic lights).
Parameters
----------
Expand Down Expand Up @@ -82,7 +81,7 @@ def location_info(
response = iai.session.request(model="location_info", params=params)
if response["static_actors"] is not None:
response["static_actors"] = [
StaticMapActors(**actor) for actor in response["static_actors"]
StaticMapActor(**actor) for actor in response["static_actors"]
]
return LocationResponse(**response)
except TryAgain as e:
Expand All @@ -93,28 +92,40 @@ def location_info(

def initialize(
location: str = "iai:ubc_roundabout",
agent_count: Optional[int] = None,
agent_attributes: Optional[List[AgentAttributes]] = None,
states_history: Optional[List[List[AgentState]]] = None,
traffic_light_state_history: Optional[TrafficLightStates] = None,
traffic_light_state_history: Optional[
List[Dict[TrafficLightId, TrafficLightState]]
] = None,
agent_count: Optional[int] = None,
random_seed: Optional[int] = None,
) -> InitializeResponse:
"""
Parameters
----------
location : str
Name of the location.
agent_count : int
Number of cars to spawn on the map
agent_attributes : List[AgentAttributes]
List of agent attributes
agent_attributes : Optional[List[AgentAttributes]]
List of agent attributes. Each agent requires, length: [float]
width: [float] and rear_axis_offset: [float] all in meters.
states_history: [List[List[AgentState]]]
History of agent states
states_history: Optional[[List[List[AgentState]]]]
History of list of agent states. Each agent state must include x: [float],
y: [float] corrdinate in meters orientation: [float] in radians with 0
pointing along x and pi/2 pointing along y and speed: [float] in m/s.
traffic_light_state_history: Optional[List[Dict[TrafficLightId, TrafficLightState]]]
History of traffic light states
History of traffic light states.
agent_count : Optional[int]
Number of cars to spawn on the map.
random_seed: Optional[int]
This parameter controls the stochastic behavior of INITIALIZE. With the
same seed and the same inputs, the outputs will be approximately the same
with high accuracy.
Returns
-------
Expand Down Expand Up @@ -146,7 +157,19 @@ def initialize(
"num_agents_to_spawn": agent_count,
"include_recurrent_states": include_recurrent_states,
}
initial_states = iai.session.request(model="initialize", params=params)
model_inputs = dict(
states_history=states_history
if states_history is None
else [state.tolist() for state in states_history],
agent_attributes=agent_attributes
if agent_attributes is None
else [state.tolist() for state in agent_attributes],
traffic_light_state_history=traffic_light_state_history,
random_seed=random_seed,
)
initial_states = iai.session.request(
model="initialize", params=params, data=model_inputs
)
agents_spawned = len(initial_states["agent_states"])
if agents_spawned != agent_count:
iai.logger.warning(
Expand Down Expand Up @@ -174,9 +197,12 @@ def drive(
agent_states: List[AgentState] = [],
agent_attributes: List[AgentAttributes] = [],
recurrent_states: RecurrentStates = [],
traffic_lights_states: Optional[
Dict[TrafficLightId, List[TrafficLightState]]
] = None,
get_birdviews: bool = False,
get_infractions: bool = False,
traffic_lights_states: Optional[TrafficLightStates] = None,
random_seed: Optional[int] = None,
) -> DriveResponse:
"""
Parameters
Expand All @@ -185,23 +211,32 @@ def drive(
Name of the location.
agent_states : List[AgentState]
List of agent states.
List of agent states. The state must include x: [float], y: [float] corrdinate in meters
orientation: [float] in radians with 0 pointing along x and pi/2 pointing along y and
speed: [float] in m/s.
agent_attributes : List[AgentAttributes]
List of agent attributes
List of agent attributes. Each agent requires, length: [float]
width: [float] and rear_axis_offset: [float] all in meters.
recurrent_states : List[RecurrentStates]
Internal simulation state
Internal simulation state obtained from previous calls to DRIVE or INITIZLIZE.
get_birdviews: bool = False
If True, a rendered bird's-eye view of the map with agents is returned
If True, a rendered bird's-eye view of the map with agents is returned.
get_infractions: bool = False
If True, 'collision', 'offroad', 'wrong_way' infractions of each agent
is returned.
are returned.
traffic_light_state_history: Optional[Dict[TrafficLightId, List[TrafficLightState]]]
Traffic light states.
random_seed: Optional[int]
This parameter controls the stochastic behavior of DRIVE. With the
same seed and the same inputs, the outputs will be approximately the same
with high accuracy.
traffic_light_state_history: Optional[List[TrafficLightStates]]
Traffic light states
Returns
-------
Expand Down Expand Up @@ -246,6 +281,7 @@ def _tolist(input_data: List):
traffic_lights_states=traffic_lights_states,
get_birdviews=get_birdviews,
get_infractions=get_infractions,
random_seed=random_seed,
)

start = time.time()
Expand All @@ -255,7 +291,7 @@ def _tolist(input_data: List):
try:
response = iai.session.request(model="drive", data=model_inputs)

out = DriveResponse(
response = DriveResponse(
agent_states=[AgentState(*state) for state in response["agent_states"]],
recurrent_states=response["recurrent_states"],
bird_view=response["bird_view"],
Expand All @@ -267,7 +303,7 @@ def _tolist(input_data: List):
present_mask=response["present_mask"],
)

return out
return response
except Exception as e:
iai.logger.warning("Retrying")
if timeout is not None and time.time() > start + timeout:
Expand Down
31 changes: 15 additions & 16 deletions invertedai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from enum import Enum

RecurrentStates = List[float] # Recurrent Dim
TrafficLightId = str
TrafficLightId = int
Point = Tuple[float, float]
Origin = Tuple[
float, float
] # lat/lon of the origin point use to project the OSM map to UTM
Map = Tuple[str, Origin] # serialized OSM file and the associated origin point
Image = List[int] # Map birdview encoded in JPEG format
# (for decoding use a JPEG decoder
# such as cv2.imdecode(birdview_image: Image, cv2.IMREAD_COLOR) ).
Map = Tuple[str, Origin] # Serialized OSM file and the associated origin point
Image = List[int] # Images encoded in JPEG format
# for decoding use a JPEG decoder
# such as cv2.imdecode(birdview_image: Image, cv2.IMREAD_COLOR).


class TrafficLightState(Enum):
Expand Down Expand Up @@ -62,31 +62,31 @@ class TrafficLightStates:

@dataclass
class InfractionIndicators:
collisions: List[bool]
offroad: List[bool]
wrong_way: List[bool]
collisions: bool
offroad: bool
wrong_way: bool


@dataclass
class DriveResponse:
agent_states: List[AgentState]
present_mask: List[bool] # A
recurrent_states: List[RecurrentStates] # Ax2x64
bird_view: Optional[Image]
infractions: Optional[InfractionIndicators]
infractions: Optional[List[InfractionIndicators]]
present_mask: List[bool] # A


@dataclass
class InitializeResponse:
recurrent_states: List[RecurrentStates]
agent_states: List[AgentState]
agent_attributes: List[AgentAttributes]
recurrent_states: List[RecurrentStates]


@dataclass
class StaticMapActors:
class StaticMapActor:
track_id: int
agent_type: Literal["traffic-light", "stop-sign"]
agent_type: Literal["traffic-light"] # Kept for possible changes in the future
x: float
y: float
psi_rad: float
Expand All @@ -98,10 +98,9 @@ class StaticMapActors:
class LocationResponse:
version: str
max_agent_number: int
birdview_image: Image
bounding_polygon: Optional[
List[Point]
] # “inner” polygon – the map may extend beyond this
# birdview_image:
birdview_image: Image
osm_map: Optional[Map]
static_actors: Optional[List[StaticMapActors]]
static_actors: List[StaticMapActor]

0 comments on commit c2a512d

Please sign in to comment.