Skip to content

Commit

Permalink
bug fix and field name update
Browse files Browse the repository at this point in the history
  • Loading branch information
rf-ivtdai committed Jan 17, 2024
1 parent 127f95f commit 6d4c3be
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
9 changes: 5 additions & 4 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _tolist(input_data: List):
agent_attributes=[state.tolist() for state in agent_attributes],
recurrent_states=[r.packed for r in recurrent_states],
traffic_lights_states=traffic_lights_states,
light_recurrent_states=light_recurrent_states,
light_recurrent_states=[light_recurrent_state.tolist() for light_recurrent_state in light_recurrent_states]
if light_recurrent_states is not None else None,
get_birdview=get_birdview,
get_infractions=get_infractions,
random_seed=random_seed,
Expand Down Expand Up @@ -181,11 +182,11 @@ def _tolist(input_data: List):
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
Expand Down Expand Up @@ -265,7 +266,7 @@ def _tolist(input_data: List):
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
Expand Down
6 changes: 3 additions & 3 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def initialize(
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
Expand Down Expand Up @@ -265,11 +265,11 @@ async def async_initialize(
if response["infraction_indicators"]
else [],
model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
light_recurrent_states=[
LightRecurrentState(state=state_arr[0], ticks_remaining=state_arr[1])
LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1])
for state_arr in response["light_recurrent_states"]
]
if response["light_recurrent_states"] is not None
Expand Down
10 changes: 8 additions & 2 deletions invertedai/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,14 @@ class TrafficLightState(str, Enum):


class LightRecurrentState(BaseModel):
state: int
ticks_remaining: int
state: float
time_remaining: float

def tolist(self):
"""
Convert LightRecurrentState to a list in this order: [state, time_remaining]
"""
return [self.state, self.time_remaining]


class AgentType(str, Enum):
Expand Down
15 changes: 6 additions & 9 deletions tests/test_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,13 @@ def recurrent_states_helper(states_to_extend):
def run_initialize_drive_flow(location, states_history, agent_attributes, get_infractions, agent_count,
simulation_length: int = 20):
location_info_response = location_info(location=location, rendering_fov=200)
if any(actor.agent_type == "traffic-light" for actor in location_info_response.static_actors):
scene_has_lights = True
light_response = light(location=location)
else:
light_response = None
scene_has_lights = False
scene_has_lights = any(actor.agent_type == "traffic-light" for actor in location_info_response.static_actors)

initialize_response = initialize(
location,
agent_attributes=agent_attributes,
states_history=states_history,
traffic_light_state_history=[light_response.traffic_lights_states] if scene_has_lights else None,
traffic_light_state_history=None,
get_birdview=False,
get_infractions=get_infractions,
agent_count=agent_count,
Expand All @@ -120,14 +116,15 @@ def run_initialize_drive_flow(location, states_history, agent_attributes, get_in
agent_attributes=agent_attributes,
agent_states=updated_state.agent_states,
recurrent_states=updated_state.recurrent_states,
traffic_lights_states=light_response.traffic_lights_states if light_response is not None else None,
light_recurrent_states=updated_state.light_recurrent_states if scene_has_lights else None,
get_birdview=False,
location=location,
get_infractions=get_infractions,
)
assert isinstance(updated_state,
DriveResponse) and updated_state.agent_states is not None and updated_state.recurrent_states is not None
if updated_state.traffic_lights_states is not None:
if scene_has_lights:
assert updated_state.traffic_lights_states is not None
assert updated_state.light_recurrent_states is not None

def run_direct_drive(location, agent_states, agent_attributes, recurrent_states, get_infractions):
Expand Down

0 comments on commit 6d4c3be

Please sign in to comment.