<img src="https://raw.githubusercontent.com/inverted-ai/invertedai/master/docs/images/banner-small.png" alt="InvertedAI" width="200"/>


In [1]:
import IPython
from IPython.display import display, Image, clear_output
from ipywidgets import interact
from IPython.utils import io

import matplotlib.pyplot as plt
import imageio
import numpy as np
import cv2
import invertedai as iai

from shapely.geometry import Polygon
from shapely.errors import GEOSException

from dataclasses import dataclass
from typing import Tuple

In [2]:
# API key:
iai.add_apikey("")

In [3]:
# pick a location (4 way, signalized intgersection)
location = "iai:drake_street_and_pacific_blvd"

In [4]:
location_info_response = iai.location_info(location=location)
rendered_static_map = location_info_response.birdview_image.decode()
scene_plotter = iai.utils.ScenePlotter(rendered_static_map,
                                       location_info_response.map_fov,
                                       (location_info_response.map_center.x, location_info_response.map_center.y),
                                       location_info_response.static_actors)

In [6]:
@dataclass
class LogCollision:
    collision_agents: Tuple[int, int]
    start_time: int 
    end_time: int

def compute_pairwise_collisions(agent_states_history,agent_attributes):
    """
    Use polygon intersections to check each agent combination whether there is a collision.
    
    Args:
        agent_states: List of all current agent states including x and y coordinates and angle.
        agent_attributes: List of static attributes of all agents including agent length and width.

    Returns:
        List of colliding pairs of agents. => List[List[int,int]]
    """
    
    collisions_ongoing = {}
    collisions_all = []
    
    #Iterate through the entire agent state history to find collisions at each time step
    for t, agent_states in enumerate(agent_states_history):
        num_agents = len(agent_states)
        if not num_agents == len(agent_attributes):
            raise Exception("Incorrect number of agents or agent attributes.")

        #Transform the points of the agents polygon into the common world coordinate frame
        polygons = [None]*num_agents
        for i, (state, attributes) in enumerate(zip(agent_states,agent_attributes)):
            dx = attributes.length/2
            dy = attributes.width/2

            origin = np.array([[state.center.x], [state.center.y]])

            angle = state.orientation
            c, s = np.cos(angle), np.sin(angle)
            
            #Rotation matrix
            R = np.array([[c, -s],
                          [c, s]])
            #Stacked vectors of agent polygon points in the agents coordinate frame
            #Points are in order around the polygon
            V = np.array([[dx,dx,-dx,-dx],
                          [dy,-dy,-dy,dy]])
            #Apply the rotation into the world frame
            RV = np.matmul(R,V)
            
            #Add the displacement between the world frame and the agent frame
            TV = origin + RV
            polygons[i] = TV

        #Check every combination of of agent polygons
        detected_collisions = []
        polygons = [Polygon([p[:,0], p[:,1], p[:,2], p[:,3]]) for p in polygons]
        for j in range(num_agents):
            a = polygons[j]
            for k in range(j+1,num_agents):
                b = polygons[k]
                try:
                    if a.intersection(b).area:
                        detected_collisions.append([j,k])
                except GEOSException as e:
                    print(f"Collision candidates {j} and {k} failed with error {e}.")
                    pass
        
        #Check for collisions and acquire: agent IDs, collision start time, collision end time
        #Collect all collisions in the simulation
        for collision in detected_collisions:
            agent_tuple = tuple(collision)
            #If a colliding pair is detected and are not in the dictionary, add them here to track them
            if not agent_tuple in collisions_ongoing:
                collisions_ongoing[agent_tuple] = LogCollision(
                    collision_agents=agent_tuple,
                    start_time=t,
                    end_time=None
                )
        keys_to_delete = []
        for agent_tuple, collision in collisions_ongoing.items():
            if collision.end_time == None:
                if not list(collision.collision_agents) in detected_collisions or t >= SIMULATION_LENGTH-1:
                    #If a collision with no end time is no longer detected within the dictionary, assume the agents
                    #are no longer colliding and the end time is found.
                    collisions_ongoing[agent_tuple].end_time = t
                    collisions_all.append(collisions_ongoing[agent_tuple])
                    keys_to_delete.append(agent_tuple)
        for key in keys_to_delete:
            #Cannot delete items from dictionary while iterating over it
            #Capture collisions that are finished and delete them separately in this loop
            del collisions_ongoing[key]
    
    return collisions_all

# Simulate with `initialize`, `drive` and `light` until there are collisions.
for _ in range(20): #Attempt 20 simulations looking for a collision
    light_response = iai.light(location=location)

    response = iai.initialize(
        location=location,
        agent_count=15,
        get_birdview=True,
        traffic_light_state_history=[light_response.traffic_lights_states]
    )
    agent_attributes = response.agent_attributes
    scene_plotter.initialize_recording(response.agent_states,
                                       agent_attributes=agent_attributes,
                                       traffic_light_states=light_response.traffic_lights_states)

    agent_state_history = []
    traffic_light_state_history = []

    # 10-second scene
    SIMULATION_LENGTH = 100
    for t in range(SIMULATION_LENGTH):
        light_response = iai.light(location=location, recurrent_states=light_response.recurrent_states)
        response = iai.drive(
            location=location,
            agent_attributes=agent_attributes,
            agent_states=response.agent_states,
            recurrent_states=response.recurrent_states,
            get_birdview=False,
            traffic_lights_states=light_response.traffic_lights_states,
            get_infractions=True,
            random_seed=1
        )
        scene_plotter.record_step(response.agent_states, traffic_light_states=light_response.traffic_lights_states)
        agent_state_history.append(response.agent_states)
        traffic_light_state_history.append(light_response.traffic_lights_states)
        
        print(f"Attempted collision simulation number {_} iteration number {t}.")
        clear_output(wait=True)
          
    collisions = compute_pairwise_collisions(agent_state_history,agent_attributes)
    if collisions: 
        #If a collision is detected, cease generating more simulations
        break

print(collisions)

[LogCollision(collision_agents=(2, 4), start_time=85, end_time=88)]


In [7]:
blame_responses = []
for collision_data in collisions:
    blame_response = iai.blame(
        location=location,
        colliding_agents=collision_data.collision_agents,
        agent_state_history=agent_state_history[:collision_data.start_time],
        traffic_light_state_history=traffic_light_state_history[:collision_data.start_time],
        agent_attributes=agent_attributes,
        get_reasons=True,
        get_confidence_score=True,
        get_birdviews=False
    )
    print(blame_response.agents_at_fault)
    blame_responses.append(blame_response)

ValidationError: 2 validation errors for Blame
candidate_agents
  field required (type=value_error.missing)
kwargs
  unexpected keyword argument: 'colliding_agents' (type=type_error)

In [None]:
for response in blame_responses:
    print(response.reasons)

In [None]:
for response in blame_responses:
    print(response.confidence_score)

In [None]:
%%capture
fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))
gif_name = 'blame-example.gif'
scene_plotter.animate_scene(
    output_name=gif_name,
    ax=ax,
    numbers=True,
    direction_vec=False,
    velocity_vec=False,
    plot_frame_number=True
)

In [None]:
Image(gif_name, width=1000, height=800)