Skip to content

Commit

Permalink
remove time dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
AlirezaMorsali committed Oct 17, 2022
1 parent d5dbe5a commit 35b49bd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 141 deletions.
Binary file modified examples/iai-drive.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 18 additions & 75 deletions invertedai/api_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,6 @@ def initialize(
agent_count : int
Number of cars to spawn on the map
min_speed : Optional[int]
Not available yet, (for setting the minimum speed of spawned cars)
max_speed : Optional[int]
Not available yet, (for setting the minimum speed of spawned cars)
Returns
-------
Response : InitializeResponse
Expand Down Expand Up @@ -183,16 +177,10 @@ def initialize(
iai.logger.warning(
f"Unable to spawn a scenario for {agent_count} agents, {agents_spawned} spawned instead."
)
# response = InitializeResponse(
# agent_states=initial_states["agent_states"],
# agent_attributes=initial_states["agent_attributes"],
# recurrent_states=initial_states["recurrent_states"],
# )

response = InitializeResponse(
agent_states=[
AgentState(*state[0]) for state in initial_states["agent_states"]
], # TODO: Remove [0] after time dimension is removed
AgentState(*state) for state in initial_states["agent_states"]
],
agent_attributes=[
AgentAttributes(*attr)
for attr in initial_states["agent_attributes"]
Expand All @@ -215,7 +203,6 @@ def drive(
steps: int = 1,
get_infractions: bool = False,
exclude_ego_agent: bool = True,
traffic_light_state: Optional[Dict[TrafficLightId, TrafficLightState]] = {},
present_mask: Optional[List] = None,
) -> DriveResponse:
"""
Expand All @@ -224,40 +211,22 @@ def drive(
location : str
Name of the location.
states : List[List[Tuple[(float,) * 4]]] (AxTx4)
List of positions and speeds of agents.
List of A (number of actors) lists,
each element is of T (number of time steps) list,
each element is a list of 4 floats (x,y,speed, orientation)
states : List[AgentState]
List of agent states.
agent_attributes : List[Tuple[(float,) * 3]] (Ax3)
agent_attributes : List[AgentAttributes]
List of agent attributes
List of A (number of actors) lists,
each element is a list of x floats (width, length, lr)
recurrent_states : List[Tuple[(Tuple[(float,) * 64],) * 2]] (Ax2x64)
Internal state of simulation, which must be fedback to continue simulation
This should have been obtained either from iai.drive or iai.initialize.
recurrent_states : List[RecurrentStates]
Internal simulation state
get_birdviews: bool = False
If True, a rendered bird's-eye view of the map with agents is returned
steps: int
Number of time-steps to run the simulation
get_infractions: bool = False
If True, 'collision', 'offroad', 'wrong_way' infractions of each agent
is returned.
traffic_states_id: str = "000:0"
An id to set the state of the traffic-lights.
If Traffic-lights are controlled by Inverted-AI
This parameter should be set using "traffic_states_id" returned by
"invertedia.initialize" or "invertedai.drive"
exclude_ego_agent: bool = True,
This parameter will be deprecated soon.
present_mask: Optional[List] = None
A list of booleans of size A (number of agents), which is false when
an agent has crossed the boundary of the map.
Expand Down Expand Up @@ -298,13 +267,10 @@ def _tolist(input_data: List):
recurrent_states = (
_tolist(recurrent_states) if recurrent_states is not None else None
) # AxTx2x64
agent_states = [[state.tolist()] for state in agent_states]
# TODO: Rmove [] around state.tolist() after time is removed
agent_attributes = [state.tolist() for state in agent_attributes]
model_inputs = dict(
location=location,
agent_states=agent_states,
agent_attributes=agent_attributes,
agent_states=[state.tolist() for state in agent_states],
agent_attributes=[state.tolist() for state in agent_attributes],
recurrent_states=recurrent_states,
# Expand from A to AxT_total for the API interface
steps=steps,
Expand All @@ -321,41 +287,18 @@ def _tolist(input_data: List):
try:
response = iai.session.request(model="drive", data=model_inputs)

agent_states = [AgentState(*state[0]) for state in response["agent_states"]]
recurrent_states = response["recurrent_states"]
bird_view = response["bird_view"]

infractions = InfractionIndicators(
collisions=response["collision"],
offroad=response["offroad"],
wrong_way=response["wrong_way"],
)
present_mask = response["present_mask"]
out = DriveResponse(
agent_states=agent_states,
recurrent_states=recurrent_states,
bird_view=bird_view,
infractions=infractions,
present_mask=present_mask,
agent_states=[AgentState(*state) for state in response["agent_states"]],
recurrent_states=response["recurrent_states"],
bird_view=response["bird_view"],
infractions=InfractionIndicators(
collisions=response["collision"],
offroad=response["offroad"],
wrong_way=response["wrong_way"],
),
present_mask=response["present_mask"],
)

# out = DriveResponse(
# agent_states=[
# AgentState(*state[0]) for state in response["agent_states"]
# ], # TODO: Remove [0] after time dimension is removed
# recurrent_states=response["recurrent_states"],
# bird_view=response["bird_view"],
# infractions=InfractionIndicators(
# {
# "collisions": response["collision"],
# "offroad": response["offroad"],
# "wrong_way": response["wrong_way"],
# }
# ),
# present_mask=response["present_mask"],
# )

# out = DriveResponse(**response)
return out
except Exception as e:
iai.logger.warning("Retrying")
Expand Down
66 changes: 0 additions & 66 deletions invertedai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,72 +62,6 @@ class DriveResponse:
infractions: Optional[InfractionIndicators]


# @dataclass
# class DrivePayload:
# """
# agent_states : List[List[Tuple[(float,) * 4]]] (AxTx4)
# List of positions and speeds of agents.
# List of A (number of actors) lists,
# each element is of T (number of time steps) list,
# each element is a list of 4 floats (x,y,speed, orientation)

# present_mask : List[int]
# A list of booleans of size A (number of agents), which is false when
# an agent has crossed the boundary of the map.

# recurrent_states : List[Tuple[(Tuple[(float,) * 64],) * 2]] (Ax2x64)
# Internal state of simulation, which must be fedback to continue simulation

# attributes : List[Tuple[(float,) * 3]] (Ax3)
# List of agent attributes
# List of A (number of actors) lists,
# each element is a list of x floats (width, length, lr)

# traffic_light_state: Dict[str, str]
# Dictionary of traffic light states.
# Keys are the traffic-light ids and
# values are light state: 'red', 'green', 'yellow' and 'red'

# traffic_state_id: str
# The id of the current stat of the traffic light,
# which must be fedback to get the next state of the traffic light

# bird_view : List[int]
# Rendered image of the amp with agents encoded in JPEG format,
# (for decoding use JPEG decoder
# e.g., cv2.imdecode(response["rendered_map"], cv2.IMREAD_COLOR) ).

# collision : List[Tuple[(float,) * T_obs+T]] (AxT_obs+T)
# List of collision infraction for each of the agents.
# List of A (number of actors) lists,
# each element is a list of size T_obs+T (number of time steps)
# floats (intersection over union).

# offroad : List[Tuple[(float,) * T_obs+T]] (AxT_obs+T)
# List of offroad infraction for each of the agents.
# List of A (number of actors) lists,
# each element is a list of size T_obs+T (number of time steps) floats.

# wrong_way : List[Tuple[(float,) * T_obs+T]] (AxT_obs+T)
# List of wrong_way infraction for each of the agents.
# List of A (number of actors) lists,
# each element is a list of size T_obs+T (number of time steps) floats.
# """

# location: str
# agent_states: AgentStates
# agent_attributes: AgentSizes
# steps: int
# recurrent_states: Optional[RecurrentStates]
# get_birdviews: bool = False
# get_infractions: bool = False
# include_traffic_controls: bool = False
# traffic_lights_states: Optional[List[TrafficLightStates]] = None
# exclude_ego_agent: bool = False
# traffic_states_id: Optional[str] = None
# present_mask: Optional[List[bool]] = None # xA


@dataclass
class InitializeResponse:
"""
Expand Down

0 comments on commit 35b49bd

Please sign in to comment.