Skip to content

Commit

Permalink
initialize from get to post
Browse files Browse the repository at this point in the history
  • Loading branch information
AlirezaMorsali committed Oct 18, 2022
1 parent 6f886c3 commit c3e5961
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 29 deletions.
Binary file modified examples/iai-drive.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions invertedai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
session = Session(api_key)
add_apikey = session.add_apikey
model_resources = {
"initialize": ("get", "/initialize"),
"initialize": ("post", "/initialize"),
"drive": ("post", "/drive"),
"location_info": ("get", "/location_info"),
"available_locations": ("get", "/available_locations"),
}
__all__ = [
"drive",
Expand Down
40 changes: 14 additions & 26 deletions invertedai/api_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,23 @@ def initialize(
>>> response = iai.initialize(location="iai:ubc_roundabout", agent_count=10)
"""

model_inputs = dict(
location=location,
num_agents_to_spawn=agent_count,
states_history=states_history
if states_history is None
else [state.tolist() for state in states_history],
agent_attributes=agent_attributes
if agent_attributes is None
else [state.tolist() for state in agent_attributes],
traffic_light_state_history=traffic_light_state_history,
random_seed=random_seed,
)
start = time.time()
timeout = TIMEOUT

while True:
try:
include_recurrent_states = (
False if location.split(":")[0] == "huawei" else True
)
params = {
"location": location,
"num_agents_to_spawn": agent_count,
"include_recurrent_states": include_recurrent_states,
}
model_inputs = dict(
states_history=states_history
if states_history is None
else [state.tolist() for state in states_history],
agent_attributes=agent_attributes
if agent_attributes is None
else [state.tolist() for state in agent_attributes],
traffic_light_state_history=traffic_light_state_history,
random_seed=random_seed,
)
initial_states = iai.session.request(
model="initialize", params=params, data=model_inputs
)
initial_states = iai.session.request(model="initialize", data=model_inputs)
agents_spawned = len(initial_states["agent_states"])
if agents_spawned != agent_count:
iai.logger.warning(
Expand Down Expand Up @@ -197,9 +188,7 @@ def drive(
agent_states: List[AgentState] = [],
agent_attributes: List[AgentAttributes] = [],
recurrent_states: List[RecurrentState] = [],
traffic_lights_states: Optional[
Dict[TrafficLightId, TrafficLightState]
] = None,
traffic_lights_states: Optional[Dict[TrafficLightId, TrafficLightState]] = None,
get_birdviews: bool = False,
get_infractions: bool = False,
random_seed: Optional[int] = None,
Expand Down Expand Up @@ -283,7 +272,6 @@ def _tolist(input_data: List):
get_infractions=get_infractions,
random_seed=random_seed,
)

start = time.time()
timeout = TIMEOUT

Expand Down
4 changes: 3 additions & 1 deletion invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def _request(
result.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response.status_code == 403:
raise error.APIConnectionError("Connection forbidden. Please check the provided API key.")
raise error.APIConnectionError(
"Connection forbidden. Please check the provided API key."
)
elif e.response.status_code in [400, 422]:
raise e
else:
Expand Down

0 comments on commit c3e5961

Please sign in to comment.