Skip to content

Commit

Permalink
Merge branch 'documentation' into mock-api
Browse files Browse the repository at this point in the history
  • Loading branch information
AlirezaMorsali committed Oct 19, 2022
2 parents 0156d93 + 5b098e9 commit cb16c70
Show file tree
Hide file tree
Showing 9 changed files with 4,978 additions and 253 deletions.
32 changes: 11 additions & 21 deletions examples/Demo_Drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@
parser.add_argument("--location", type=str, default="iai:ubc_roundabout")
args = parser.parse_args()

api_key = args.api_key or os.environ.get('IAI_API_KEY', None)
if api_key:
iai.add_apikey(api_key)
else:
print("Running mock API - specify IAI_API_KEY to obtain real results")
iai.use_mock_api()
iai.add_apikey(args.api_key)

# response = iai.available_locations("carla", "roundabout")
response = iai.location_info(location=args.location)

file_name = args.location.replace(":", "_")
Expand All @@ -40,30 +36,24 @@
rendered_map = np.array(response.birdview_image, dtype=np.uint8)
image = cv2.imdecode(rendered_map, cv2.IMREAD_COLOR)
cv2.imwrite(file_path, image)
response = iai.initialize(
simulation = iai.Simulation(
location=args.location,
agent_count=10,
agent_attributes=None,
monitor_infractions=True,
render_birdview=True,
)
agent_attributes = response.agent_attributes
frames = []
pbar = tqdm(range(50))
for i in pbar:
response = iai.drive(
agent_attributes=agent_attributes,
agent_states=response.agent_states,
recurrent_states=response.recurrent_states,
get_birdviews=True,
location=args.location,
get_infractions=True,
)
simulation.step(current_ego_agent_states=[])
collision, offroad, wrong_way = simulation.infractions
pbar.set_description(
f"Collision rate: {100*np.array([inf.collisions for inf in response.infractions]).mean():.2f}% | "
+ f"Off-road rate: {100*np.array([inf.offroad for inf in response.infractions]).mean():.2f}% | "
+ f"Wrong-way rate: {100*np.array([inf.wrong_way for inf in response.infractions]).mean():.2f}%"
f"Collision rate: {100*np.array(collision).mean():.2f}% | "
+ f"Off-road rate: {100*np.array(offroad).mean():.2f}% | "
+ f"Wrong-way rate: {100*np.array(wrong_way).mean():.2f}%"
)

birdview = np.array(response.bird_view, dtype=np.uint8)
birdview = np.array(simulation.birdview, dtype=np.uint8)
image = cv2.imdecode(birdview, cv2.IMREAD_COLOR)
frames.append(image)
im = PImage.fromarray(image)
Expand Down
64 changes: 64 additions & 0 deletions examples/Demo_Drive_REST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env ipython
import os
import sys
from PIL import Image as PImage
import imageio
import numpy as np
import cv2
from tqdm import tqdm
import argparse
from dotenv import load_dotenv

load_dotenv()
if os.environ.get("DEV", False):
sys.path.append("../")
import invertedai as iai

# logger.setLevel(10)

parser = argparse.ArgumentParser(description="Simulation Parameters.")
parser.add_argument("--api_key", type=str, default="")
parser.add_argument("--location", type=str, default="iai:ubc_roundabout")
args = parser.parse_args()

iai.add_apikey("")

response = iai.location_info(location=args.location)

file_name = args.location.replace(":", "_")
if response.osm_map is not None:
file_path = f"{file_name}.osm"
with open(file_path, "w") as f:
f.write(response.osm_map[0])
if response.birdview_image is not None:
file_path = f"{file_name}.jpg"
rendered_map = np.array(response.birdview_image, dtype=np.uint8)
image = cv2.imdecode(rendered_map, cv2.IMREAD_COLOR)
cv2.imwrite(file_path, image)
response = iai.initialize(
location=args.location,
agent_count=10,
)
agent_attributes = response.agent_attributes
frames = []
pbar = tqdm(range(50))
for i in pbar:
response = iai.drive(
agent_attributes=agent_attributes,
agent_states=response.agent_states,
recurrent_states=response.recurrent_states,
get_birdviews=True,
location=args.location,
get_infractions=True,
)
pbar.set_description(
f"Collision rate: {100*np.array([inf.collisions for inf in response.infractions]).mean():.2f}% | "
+ f"Off-road rate: {100*np.array([inf.offroad for inf in response.infractions]).mean():.2f}% | "
+ f"Wrong-way rate: {100*np.array([inf.wrong_way for inf in response.infractions]).mean():.2f}%"
)

birdview = np.array(response.bird_view, dtype=np.uint8)
image = cv2.imdecode(birdview, cv2.IMREAD_COLOR)
frames.append(image)
im = PImage.fromarray(image)
imageio.mimsave("iai-drive.gif", np.array(frames), format="GIF-PIL")
15 changes: 4 additions & 11 deletions examples/Drive-Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,15 @@
"rendered_map = cv2.imdecode(np.array(rendered_map, dtype=np.uint8), cv2.IMREAD_COLOR)\n",
"renderer.add_frame(rendered_map)\n",
"\n",
"response = iai.initialize(\n",
"simulation = iai.Simulation(\n",
" location=location,\n",
" agent_count=10,\n",
" agent_attributes=None,\n",
" render_birdview=True,\n",
")\n",
"agent_attributes = response.agent_attributes\n",
"frames = []\n",
"for t in range(simulation_length):\n",
" response = iai.drive(\n",
" agent_attributes=agent_attributes,\n",
" agent_states=response.agent_states,\n",
" recurrent_states=response.recurrent_states,\n",
" get_birdviews=True,\n",
" location=location,\n",
" )\n",
" birdview = cv2.imdecode(np.array(response.bird_view, dtype=np.uint8), cv2.IMREAD_COLOR)\n",
" simulation.step(current_ego_agent_states=[])\n",
" birdview = cv2.imdecode(np.array(simulation.birdview, dtype=np.uint8), cv2.IMREAD_COLOR)\n",
" renderer.add_frame(birdview)"
]
},
Expand Down
Binary file modified examples/iai-drive.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions invertedai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
initialize,
location_info,
)
from invertedai.simulation import Simulation
from invertedai.utils import Jupyter_Render, IAILogger, Session

load_dotenv()
Expand All @@ -28,12 +29,12 @@
use_mock_api()

model_resources = {
"initialize": ("get", "/initialize"),
"initialize": ("post", "/initialize"),
"drive": ("post", "/drive"),
"location_info": ("get", "/location_info"),
"available_locations": ("get", "/available_locations"),
}
__all__ = [
"Simulation",
"drive",
"initialize",
"location_info",
Expand Down
73 changes: 41 additions & 32 deletions invertedai/api_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
from typing import List, Optional, Dict
import time
import invertedai as iai
from invertedai.mock import get_mock_birdview, get_mock_agent_attributes, get_mock_agent_state, \
get_mock_recurrent_state, mock_update_agent_state, get_mock_infractions
from invertedai.mock import (
get_mock_birdview,
get_mock_agent_attributes,
get_mock_agent_state,
get_mock_recurrent_state,
mock_update_agent_state,
get_mock_infractions,
)
from invertedai.models import (
LocationResponse,
InitializeResponse,
Expand Down Expand Up @@ -80,7 +86,14 @@ def location_info(
"""

if mock_api:
response = LocationResponse(rendered_map=get_mock_birdview(), lanelet_map_source=None, static_actors=None)
response = LocationResponse(
version="v0.0.0",
birdview_image=get_mock_birdview(),
osm_map=None,
static_actors=[],
bounding_polygon=[],
max_agent_number=10,
)
return response

start = time.time()
Expand All @@ -94,6 +107,9 @@ def location_info(
response["static_actors"] = [
StaticMapActor(**actor) for actor in response["static_actors"]
]
if response["osm_map"] is not None:
response["osm_map"] = (response["osm_map"], response["map_origin"])
del response["map_origin"]
return LocationResponse(**response)
except TryAgain as e:
if timeout is not None and time.time() > start + timeout:
Expand Down Expand Up @@ -164,36 +180,29 @@ def initialize(
agent_states = states_history[-1]
recurrent_states = [get_mock_recurrent_state() for _ in range(agent_count)]
response = InitializeResponse(
agent_states=agent_states, agent_attributes=agent_attributes, recurrent_states=recurrent_states
agent_states=agent_states,
agent_attributes=agent_attributes,
recurrent_states=recurrent_states,
)
return response

model_inputs = dict(
location=location,
num_agents_to_spawn=agent_count,
states_history=states_history
if states_history is None
else [state.tolist() for state in states_history],
agent_attributes=agent_attributes
if agent_attributes is None
else [state.tolist() for state in agent_attributes],
traffic_light_state_history=traffic_light_state_history,
random_seed=random_seed,
)
start = time.time()
timeout = TIMEOUT

while True:
try:
include_recurrent_states = (
False if location.split(":")[0] == "huawei" else True
)
params = {
"location": location,
"num_agents_to_spawn": agent_count,
"include_recurrent_states": include_recurrent_states,
}
model_inputs = dict(
states_history=states_history
if states_history is None
else [state.tolist() for state in states_history],
agent_attributes=agent_attributes
if agent_attributes is None
else [state.tolist() for state in agent_attributes],
traffic_light_state_history=traffic_light_state_history,
random_seed=random_seed,
)
initial_states = iai.session.request(
model="initialize", params=params, data=model_inputs
)
initial_states = iai.session.request(model="initialize", data=model_inputs)
agents_spawned = len(initial_states["agent_states"])
if agents_spawned != agent_count:
iai.logger.warning(
Expand Down Expand Up @@ -221,9 +230,7 @@ def drive(
agent_states: List[AgentState] = [],
agent_attributes: List[AgentAttributes] = [],
recurrent_states: List[RecurrentState] = [],
traffic_lights_states: Optional[
Dict[TrafficLightId, TrafficLightState]
] = None,
traffic_lights_states: Optional[Dict[TrafficLightId, TrafficLightState]] = None,
get_birdviews: bool = False,
get_infractions: bool = False,
random_seed: Optional[int] = None,
Expand Down Expand Up @@ -294,8 +301,11 @@ def drive(
bird_view = get_mock_birdview()
infractions = get_mock_infractions(len(agent_states))
response = DriveResponse(
agent_states=agent_states, present_mask=present_mask, recurrent_states=recurrent_states,
bird_view=bird_view, infractions=infractions
agent_states=agent_states,
is_inside_supported_area=present_mask,
recurrent_states=recurrent_states,
bird_view=bird_view,
infractions=infractions,
)
return response

Expand All @@ -318,7 +328,6 @@ def _tolist(input_data: List):
get_infractions=get_infractions,
random_seed=random_seed,
)

start = time.time()
timeout = TIMEOUT

Expand Down

0 comments on commit cb16c70

Please sign in to comment.