Skip to content

Commit

Permalink
Merge pull request #20 from inverted-ai/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Ruishenl committed Sep 30, 2022
2 parents c85581f + 7eaf513 commit 4652fa0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 78 deletions.
1 change: 0 additions & 1 deletion examples/Demo_Drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
recurrent_states=response["recurrent_states"],
get_birdviews=True,
location=args.location,
obs_length=1,
steps=1,
)

Expand Down
62 changes: 15 additions & 47 deletions examples/Drive-Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "911445d4-8b61-49a1-9620-d1f3eb0c4199",
"metadata": {},
"outputs": [],
Expand All @@ -28,61 +28,21 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "069cf9cb-30eb-4378-b4f9-9cb7d9038620",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8aa8557b12314e4aac4f7be83c3b093d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Jupyter_Render(children=(HBox(children=(Play(value=0, description='Press play', max=0), IntSlider(value=0, des…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "72f0b0e272df42088db3c8eec85853fe",
"version_major": 2,
"version_minor": 0
},
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAH0CAYAAADL1t+KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIBElEQVR4nO3aMQrEMAwAwfOR/39Z6d2FFCbLTKlK3SLQmpn5AQCf9j+9AADwnqADQICgA0DAtQ/WWif2AAAe2F/gXOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABAg6AAQIOgAECDoABBw7YOZObEHAPCCCx0AAgQdAAIEHQACbpFBDeWvnr/eAAAAAElFTkSuQmCC",
"text/html": [
"\n",
" <div style=\"display: inline-block;\">\n",
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
" Figure\n",
" </div>\n",
" <img src='' width=500.0/>\n",
" </div>\n",
" "
],
"text/plain": [
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"iai.add_apikey(\"\")\n",
"location=\"CARLA:Town03:Roundabout\"\n",
"simulation_length = 30\n",
"simulation_length = 100\n",
"renderer = Jupyter_Render()\n",
"display(renderer)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"id": "94506da0-9a15-4dc0-9d57-5ac30bc36d72",
"metadata": {},
"outputs": [],
Expand All @@ -103,8 +63,8 @@
" recurrent_states=response[\"recurrent_states\"],\n",
" get_birdviews=True,\n",
" location=location,\n",
" obs_length=1,\n",
" steps=1,\n",
" exclude_ego_agent=1\n",
" )\n",
" birdview = cv2.imdecode(np.array(response[\"bird_view\"], dtype=np.uint8), cv2.IMREAD_COLOR)\n",
" renderer.add_frame(birdview)"
Expand All @@ -117,6 +77,14 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "43110bde",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -135,7 +103,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.0"
}
},
"nbformat": 4,
Expand Down
27 changes: 3 additions & 24 deletions invertedai/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@dataclass
class Config:
api_key: str = ""
location: str = "Town03_Roundabout"
location: str = "CARLA:Town03:Roundabout"
agent_count: int = 100
batch_size: int = 1
obs_length: int = 1
Expand All @@ -32,7 +32,6 @@ def initialize(
batch_size=1,
min_speed=1,
max_speed=5,
fix_carla_coord=False,
) -> dict:
start = time.time()
timeout = TIMEOUT
Expand All @@ -45,7 +44,6 @@ def initialize(
batch_size=batch_size,
min_speed=np.ceil(min_speed / 3.6).astype(int),
max_speed=np.ceil(max_speed / 3.6).astype(int),
fix_carla_coord=fix_carla_coord,
)
response = {
"states": initial_states["initial_condition"]["agent_states"],
Expand All @@ -63,21 +61,16 @@ def drive(
states: dict,
agent_attributes: dict,
recurrent_states: Optional[InputDataType] = None,
present_masks: Optional[InputDataType] = None,
get_birdviews: bool = False,
location="CARLA:Town03:Roundabout",
obs_length: int = 1,
steps: int = 1,
batch_size: int = 1,
fix_carla_coord: bool = False,
get_infractions: bool = False,
exclude_ego_agent: bool = True
) -> dict:
def _validate(input_dict: dict, input_name: str):
input_data = input_dict[input_name]
if isinstance(input_data, list):
input_data = torch.Tensor(input_data)
if input_data.shape[0] != batch_size:
raise Exception(f"{input_name} has the wrong batch size (dim 0)")
if input_data.shape[1] != agent_count:
raise Exception(f"{input_name} has the wrong agent counts (dim 1)")
if len(input_data.shape) > 2:
Expand All @@ -88,8 +81,6 @@ def _validate(input_dict: dict, input_name: str):
def _validate_recurrent_states(input_data: InputDataType):
if isinstance(input_data, list):
input_data = torch.Tensor(input_data)
if input_data.shape[0] != batch_size:
raise Exception("Recurrent states has the wrong batch size (dim 0)")
if input_data.shape[1] != agent_count:
raise Exception("Recurrent states has the wrong agent counts (dim 2)")
if input_data.shape[2] != 2:
Expand All @@ -108,11 +99,6 @@ def _validate_and_tolist(input_data: dict, input_name: str):
return _tolist(_validate(input_data, input_name))

agent_count = len(states[0])
present_masks = (
_validate_and_tolist(present_masks, "present_masks")
if present_masks is not None
else None
) # BxA
recurrent_states = (
_tolist(_validate_recurrent_states(recurrent_states))
if recurrent_states is not None
Expand All @@ -127,17 +113,10 @@ def _validate_and_tolist(input_data: dict, input_name: str):
),
recurrent_states=recurrent_states,
# Expand from BxA to BxAxT_total for the API interface
present_masks=[
[[a for _ in range(obs_length + steps)] for a in b]
for b in present_masks
]
if present_masks
else None,
batch_size=batch_size,
steps=steps,
get_birdviews=get_birdviews,
fix_carla_coord=fix_carla_coord,
get_infractions=get_infractions,
exclude_ego_agent=exclude_ego_agent
)

start = time.time()
Expand Down
8 changes: 3 additions & 5 deletions invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,13 @@ def initialize(
batch_size=1,
min_speed=1,
max_speed=3,
fix_carla_coord=False,
):
params = {
"location": location,
"num_agents_to_spawn": agent_count,
"num_samples": batch_size,
"spawn_min_speed": min_speed,
"spawn_max_speed": max_speed,
"fix_carla_coord": fix_carla_coord,
# "spawn_min_speed": min_speed,
# "spawn_max_speed": max_speed,
}

response = self._request(
Expand Down Expand Up @@ -106,7 +104,7 @@ def _get_base_url(self) -> str:
The method path should be appended to the base_url
"""
if not iai.dev:
base_url = "https://api.inverted.ai/drive"
base_url = "https://api.inverted.ai/v0/aws/m1"
else:
base_url = iai.dev_url
# TODO: Add endpoint option and versioning to base_url
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "invertedai"
version = "0.0.2"
version = "0.0.2post1"
description = "Client SDK for InvertedAI"
authors = ["Inverted AI <info@inverted.ai>"]
readme = "README.md"
Expand Down

0 comments on commit 4652fa0

Please sign in to comment.