-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #135 from inverted-ai/develop
Update BLAME Documentation
- Loading branch information
Showing
69 changed files
with
481,476 additions
and
1,534 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
# Dev Stuff | ||
dev/ | ||
dev.py | ||
|
||
img/ | ||
gifs/ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# BLAME | ||
|
||
|
||
```{eval-rst} | ||
.. autofunction:: invertedai.api.blame | ||
``` | ||
|
||
--- | ||
```{eval-rst} | ||
.. autoclass:: invertedai.api.BlameResponse | ||
:members: | ||
:undoc-members: | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,356 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "uYrB4Yk3aizw" | ||
}, | ||
"source": [ | ||
"<img src=\"https://raw.githubusercontent.com/inverted-ai/invertedai/master/docs/images/banner-small.png\" alt=\"InvertedAI\" width=\"200\"/>\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "b7l_n8sULmAX" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import IPython\n", | ||
"from IPython.display import display, Image, clear_output\n", | ||
"from ipywidgets import interact\n", | ||
"from IPython.utils import io\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import imageio\n", | ||
"import numpy as np\n", | ||
"import cv2\n", | ||
"import invertedai as iai\n", | ||
"\n", | ||
"from shapely.geometry import Polygon\n", | ||
"from shapely.errors import GEOSException\n", | ||
"\n", | ||
"from dataclasses import dataclass\n", | ||
"from typing import Tuple" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "SFu9oOcDIQGs" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# API key:\n", | ||
"iai.add_apikey(\"\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "PiySgisPG6mG" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# pick a location (4 way, signalized intgersection)\n", | ||
"location = \"iai:drake_street_and_pacific_blvd\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "kT3m0-zoNdvW" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"location_info_response = iai.location_info(location=location)\n", | ||
"rendered_static_map = location_info_response.birdview_image.decode()\n", | ||
"scene_plotter = iai.utils.ScenePlotter(rendered_static_map,\n", | ||
" location_info_response.map_fov,\n", | ||
" (location_info_response.map_center.x, location_info_response.map_center.y),\n", | ||
" location_info_response.static_actors)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "L0ufxMO8NdvX", | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"@dataclass\n", | ||
"class LogCollision:\n", | ||
" collision_agents: Tuple[int, int]\n", | ||
" start_time: int \n", | ||
" end_time: int\n", | ||
"\n", | ||
"def transform_all_agent_vertices_into_world_frame(agent_states,agent_attributes):\n", | ||
" \"\"\"\n", | ||
" Transform the vertices of all agents into points within the world frame of the map environment\n", | ||
" Args:\n", | ||
" agent_states: List[AgentState] #List of all current agent states including x and y coordinates and angle.\n", | ||
" agent_attributes: List[AgentAttributes] #List of static attributes of all agents including agent length and width.\n", | ||
" Returns:\n", | ||
" List[Polygon] #List of Polygon data types containing a list of vertices for each agent\n", | ||
" \"\"\"\n", | ||
" polygons = [None]*len(agent_states)\n", | ||
" for i, (state, attributes) in enumerate(zip(agent_states,agent_attributes)):\n", | ||
" dx = attributes.length/2\n", | ||
" dy = attributes.width/2\n", | ||
"\n", | ||
" vehicle_origin = np.array([state.center.x, state.center.y])\n", | ||
"\n", | ||
" vehicle_orientation = state.orientation\n", | ||
" c, s = np.cos(vehicle_orientation), np.sin(vehicle_orientation)\n", | ||
"\n", | ||
" rotation_matrix = np.array([[c, s],\n", | ||
" [-s, c]])\n", | ||
" stacked_vertices = np.array([[dx,dy],[dx,-dy],[-dx,-dy],[-dx,dy]]) #Formatted in a continuous sequence\n", | ||
" rotated_vertices = np.matmul(stacked_vertices,rotation_matrix)\n", | ||
" \n", | ||
" polygons[i] = vehicle_origin + rotated_vertices\n", | ||
"\n", | ||
" return [Polygon(p) for p in polygons]\n", | ||
"\n", | ||
"def check_agent_pairwise_intersections(polygons):\n", | ||
" \"\"\"\n", | ||
" Check all pairs of polygons in a list for intersections in their area.\n", | ||
" Args:\n", | ||
" polygons: List[Polygon] #List of polygons representing agents in an environment\n", | ||
" Returns:\n", | ||
" List[Tuple[int,int]] #List of agent ID pair tuples indicating collisions\n", | ||
" \"\"\"\n", | ||
" \n", | ||
" detected_overlap_agent_pairs = []\n", | ||
" num_agents = len(polygons)\n", | ||
" for j in range(num_agents):\n", | ||
" for k in range(j+1,num_agents):\n", | ||
" try:\n", | ||
" if polygons[j].intersection(polygons[k]).area:\n", | ||
" detected_overlap_agent_pairs.append((j,k))\n", | ||
" except GEOSException as e:\n", | ||
" print(f\"Collision candidates {j} and {k} failed with error {e}.\")\n", | ||
" pass\n", | ||
" \n", | ||
" return detected_overlap_agent_pairs\n", | ||
" \n", | ||
"def compute_pairwise_collisions(agent_states_history,agent_attributes):\n", | ||
" \"\"\"\n", | ||
" Use polygon intersections to check each agent combination whether there is a collision.\n", | ||
" Args:\n", | ||
" agent_states: List[List[AgentState]] #At all time steps, list of all current agent states including x and y coordinates and angle.\n", | ||
" agent_attributes: List[AgentAttributes] #List of static attributes of all agents including agent length and width.\n", | ||
" Returns:\n", | ||
" List[LogCollision] #List of collisions logs containing information about the colliding agent pairs IDs and the time period of the collision\n", | ||
" \"\"\"\n", | ||
" \n", | ||
" collisions_ongoing = {}\n", | ||
" collisions_all = []\n", | ||
" \n", | ||
" for t, agent_states in enumerate(agent_states_history):\n", | ||
" if len(agent_states) != len(agent_attributes):\n", | ||
" raise Exception(\"Incorrect number of agents or agent attributes.\")\n", | ||
"\n", | ||
" polygons = transform_all_agent_vertices_into_world_frame(agent_states,agent_attributes)\n", | ||
" detected_agent_pairs = check_agent_pairwise_intersections(polygons) \n", | ||
" \n", | ||
" for agent_tuple in detected_agent_pairs:\n", | ||
" if agent_tuple not in collisions_ongoing:\n", | ||
" collisions_ongoing[agent_tuple] = LogCollision(\n", | ||
" collision_agents=agent_tuple,\n", | ||
" start_time=t,\n", | ||
" end_time=None\n", | ||
" )\n", | ||
" untracked_agent_pairs = []\n", | ||
" for agent_tuple, collision in collisions_ongoing.items():\n", | ||
" if agent_tuple not in detected_agent_pairs:\n", | ||
" #The previous time step is the last in which the collision was observed\n", | ||
" collisions_ongoing[agent_tuple].end_time = t-1\n", | ||
" elif t >= SIMULATION_LENGTH-1:\n", | ||
" #The collision has not necessarily ended at this time step but it is the last time step it was observed to occur\n", | ||
" collisions_ongoing[agent_tuple].end_time = t\n", | ||
" \n", | ||
" if collisions_ongoing[agent_tuple].end_time is not None:\n", | ||
" collisions_all.append(collisions_ongoing[agent_tuple])\n", | ||
" untracked_agent_pairs.append(agent_tuple)\n", | ||
" \n", | ||
" collisions_ongoing = {k:v for k, v in collisions_ongoing.items() if not k in untracked_agent_pairs}\n", | ||
" \n", | ||
" return collisions_all\n", | ||
"\n", | ||
"# Simulate with `initialize`, `drive` and `light` until there are collisions.\n", | ||
"for _ in range(20): #Attempt 20 simulations looking for a collision\n", | ||
" light_response = iai.light(location=location)\n", | ||
"\n", | ||
" response = iai.initialize(\n", | ||
" location=location,\n", | ||
" agent_count=15,\n", | ||
" get_birdview=True,\n", | ||
" traffic_light_state_history=[light_response.traffic_lights_states]\n", | ||
" )\n", | ||
" agent_attributes = response.agent_attributes\n", | ||
" scene_plotter.initialize_recording(\n", | ||
" response.agent_states,\n", | ||
" agent_attributes=agent_attributes,\n", | ||
" traffic_light_states=light_response.traffic_lights_states\n", | ||
" )\n", | ||
"\n", | ||
" agent_state_history = []\n", | ||
" traffic_light_state_history = []\n", | ||
"\n", | ||
" # 10-second scene\n", | ||
" SIMULATION_LENGTH = 100\n", | ||
" for t in range(SIMULATION_LENGTH):\n", | ||
" light_response = iai.light(\n", | ||
" location=location, \n", | ||
" recurrent_states=light_response.recurrent_states\n", | ||
" )\n", | ||
" response = iai.drive(\n", | ||
" location=location,\n", | ||
" agent_attributes=agent_attributes,\n", | ||
" agent_states=response.agent_states,\n", | ||
" recurrent_states=response.recurrent_states,\n", | ||
" get_birdview=False,\n", | ||
" traffic_lights_states=light_response.traffic_lights_states,\n", | ||
" get_infractions=True,\n", | ||
" random_seed=1\n", | ||
" )\n", | ||
" scene_plotter.record_step(\n", | ||
" response.agent_states, \n", | ||
" traffic_light_states=light_response.traffic_lights_states\n", | ||
" )\n", | ||
" agent_state_history.append(response.agent_states)\n", | ||
" traffic_light_state_history.append(light_response.traffic_lights_states)\n", | ||
" \n", | ||
" print(f\"Attempted collision simulation number {_} iteration number {t}.\")\n", | ||
" clear_output(wait=True)\n", | ||
" \n", | ||
" collisions = compute_pairwise_collisions(agent_state_history,agent_attributes)\n", | ||
" if collisions: \n", | ||
" #If a collision is detected, cease generating more simulations\n", | ||
" break\n", | ||
"\n", | ||
"print(collisions)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "YQ4dXaKQNdvZ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"blame_responses = []\n", | ||
"all_collision_agents = []\n", | ||
"for collision_data in collisions:\n", | ||
" all_collision_agents.extend(list(collision_data.collision_agents))\n", | ||
" blame_response = iai.blame(\n", | ||
" location=location,\n", | ||
" colliding_agents=collision_data.collision_agents,\n", | ||
" agent_state_history=agent_state_history[:collision_data.start_time],\n", | ||
" traffic_light_state_history=traffic_light_state_history[:collision_data.start_time],\n", | ||
" agent_attributes=agent_attributes,\n", | ||
" get_reasons=True,\n", | ||
" get_confidence_score=True,\n", | ||
" get_birdviews=False\n", | ||
" )\n", | ||
" print(blame_response.agents_at_fault)\n", | ||
" blame_responses.append(blame_response)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "jRRaiLInNdvZ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"for response in blame_responses:\n", | ||
" print(response.reasons)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "oZ79k112Ndva" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"for response in blame_responses:\n", | ||
" print(response.confidence_score)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "QzbqXvS7Ndva" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))\n", | ||
"gif_name = 'blame-example.gif'\n", | ||
"scene_plotter.animate_scene(\n", | ||
" output_name=gif_name,\n", | ||
" ax=ax,\n", | ||
" numbers=all_collision_agents,\n", | ||
" direction_vec=False,\n", | ||
" velocity_vec=False,\n", | ||
" plot_frame_number=True\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "aUYCAHfsNdva" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"Image(gif_name, width=1000, height=800)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |
Oops, something went wrong.