<img src="https://raw.githubusercontent.com/inverted-ai/invertedai/master/docs/images/banner-small.png" alt="InvertedAI" width="200"/>


In [None]:
import sys
sys.path.insert(0, "..")

import IPython
from IPython.display import display, Image, clear_output
from ipywidgets import interact
from IPython.utils import io

import matplotlib.pyplot as plt
import imageio
import numpy as np
import cv2
import invertedai as iai

from shapely.geometry import Polygon
from shapely.errors import GEOSException

from dataclasses import dataclass
from typing import Tuple

In [None]:
# API key:
iai.add_apikey("")

In [None]:
# pick a location (4 way, signalized intgersection)
location = "iai:drake_street_and_pacific_blvd"

In [None]:
location_info_response = iai.location_info(location=location)
rendered_static_map = location_info_response.birdview_image.decode()
scene_plotter = iai.utils.ScenePlotter(rendered_static_map,
                                       location_info_response.map_fov,
                                       (location_info_response.map_center.x, location_info_response.map_center.y),
                                       location_info_response.static_actors)

In [None]:
collisions = {}

@dataclass
class LogCollision:
    collision_agents: Tuple[int, int]
    start_time: int 
    end_time: int

def compute_pairwise_collisions(agent_states,agent_attributes):
    """
    Use polygon intersections to check each agent combination whether there is a collision.
    
    Args:
        agent_states: List of all current agent states including x and y coordinates and angle.
        agent_attributes: List of static attributes of all agents including agent length and width.

    Returns:
        List of colliding pairs of agents. => List[List[int,int]]
    """
    
    num_agents = len(agent_states)
    if not num_agents == len(agent_attributes):
        raise Exception("Incorrect number of agents or agent attributes.")
    
    polygons = [None]*num_agents
    for i, (state, attributes) in enumerate(zip(agent_states,agent_attributes)):
        dx = attributes.length/2
        dy = attributes.width/2
        
        origin = np.array([[state.center.x], [state.center.y]])
        
        angle = state.orientation
        c, s = np.cos(angle), np.sin(angle)

        R = np.array([[c, -s],
                      [c, s]])
        V = np.array([[dx,dx,-dx,-dx],
                      [dy,-dy,-dy,dy]])
        RV = np.matmul(R,V)
        
        TV = origin + RV
        np.around(TV,3)
        polygons[i] = TV
    
    collisions = []
    polygons = [Polygon([p[:,0], p[:,1], p[:,2], p[:,3]]) for p in polygons]
    for j in range(num_agents):
        a = polygons[j]
        for k in range(j+1,num_agents):
            b = polygons[k]
            try:
                intersection = a.intersection(b).area
                if intersection:
                    collisions.append([j,k])
            except GEOSException as e:
                print(f"Collision candidates {j} and {k} failed with error {e}.")
                pass
    
    return collisions

# Simulate with `initialize`, `drive` and `light` until there are collisions.
for _ in range(20): #Attempt 20 simulations looking for a collision
    light_response = iai.light(location=location)

    response = iai.initialize(
        location=location,
        agent_count=15,
        get_birdview=True,
        traffic_light_state_history=[light_response.traffic_lights_states]
    )
    agent_attributes = response.agent_attributes
    scene_plotter.initialize_recording(response.agent_states,
                                       agent_attributes=agent_attributes,
                                       traffic_light_states=light_response.traffic_lights_states)

    agent_state_history = []
    traffic_light_state_history = []

    # 10-second scene
    SIMULATION_LENGTH = 100
    for t in range(SIMULATION_LENGTH):
        light_response = iai.light(location=location, recurrent_states=light_response.recurrent_states)
        response = iai.drive(
            location=location,
            agent_attributes=agent_attributes,
            agent_states=response.agent_states,
            recurrent_states=response.recurrent_states,
            get_birdview=False,
            traffic_lights_states=light_response.traffic_lights_states,
            get_infractions=True,
            random_seed=1
        )
        scene_plotter.record_step(response.agent_states, traffic_light_states=light_response.traffic_lights_states)
        agent_state_history.append(response.agent_states)
        traffic_light_state_history.append(light_response.traffic_lights_states)

        #Check for collisions and acquire: agent IDs, collision start time, collision end time
        #Collect all collisions in the simulation
        detected_collisions = compute_pairwise_collisions(response.agent_states,agent_attributes)
        for collision in detected_collisions:
            agent_tuple = tuple(collision)
            #If a colliding pair is detected and are not in the dictionary, add them here to track them
            if not agent_tuple in collisions:
                collisions[agent_tuple] = LogCollision(
                    collision_agents=agent_tuple,
                    start_time=t,
                    end_time=None
                )
        for agent_tuple, collision in collisions.items():
            if collision.end_time == None:
                if not list(collision.collision_agents) in detected_collisions or t >= SIMULATION_LENGTH-1:
                    #If a collision with no end time is no longer detected within the dictionary, assume the agents
                    #are no longer colliding and the end time is found.
                    collisions[agent_tuple].end_time = t
        
        
        print(f"Attempted collision simulation number {_} iteration number {t}.")
        clear_output(wait=True)
          
    if collisions: 
        #If a collision is detected, cease attempting more simulations
        break

print(collisions)

In [None]:
blame_responses = []
for colliding_agents, collision_data in collisions.items():
    blame_response = iai.blame(
        location=location,
        colliding_agents=colliding_agents,
        agent_state_history=agent_state_history[:collision_data.start_time],
        traffic_light_state_history=traffic_light_state_history[:collision_data.start_time],
        agent_attributes=agent_attributes,
        get_reasons=True,
        get_confidence_score=True,
        get_birdviews=False
    )
    print(blame_response.agents_at_fault)
    blame_responses.append(blame_response)

In [None]:
for response in blame_responses:
    print(response.reasons)

In [None]:
for response in blame_responses:
    print(response.confidence_score)

In [None]:
%%capture
fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50))
gif_name = 'blame-example.gif'
scene_plotter.animate_scene(
    output_name=gif_name,
    ax=ax,
    numbers=True,
    direction_vec=False,
    velocity_vec=False,
    plot_frame_number=True
)

In [None]:
Image(gif_name, width=1000, height=800)