Skip to content

Commit

Permalink
Merge pull request #101 from inverted-ai/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Ruishenl committed Dec 9, 2022
2 parents cc99d04 + 17d4f62 commit 865d580
Show file tree
Hide file tree
Showing 3 changed files with 945 additions and 462 deletions.
196 changes: 195 additions & 1 deletion invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import invertedai.api.config
from invertedai import error, api
import logging
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import animation
import numpy as np

TIMEOUT_SECS = 600
MAX_RETRIES = 10


class Session:
def __init__(self, api_token: str = ""):
self.session = requests.Session()
Expand Down Expand Up @@ -339,3 +342,194 @@ def fmt(key, val):
return f"{key}={val}"

return " ".join([fmt(key, val) for key, val in sorted(props.items())])


def rot(rot):
"""Rotate in 2d"""
return np.array([[np.cos(rot), -np.sin(rot)],
[np.sin(rot), np.cos(rot)]])


class ScenePlotter:
def __init__(self, map_image, fov, xy_offset, static_actors):
self.conditional_agents = None
self.agent_attributes = None
self.traffic_lights_history = None
self.agent_states_history = None
self.map_image = map_image
self.fov = fov
self.extent = (- self.fov / 2 + xy_offset[0], self.fov / 2 + xy_offset[0]) + \
(- self.fov / 2 + xy_offset[1], self.fov / 2 + xy_offset[1])

self.traffic_lights = {static_actor.actor_id: static_actor
for static_actor in static_actors
if static_actor.agent_type == 'traffic-light'}

self.traffic_light_colors = {
'red': (1.0, 0.0, 0.0),
'green': (0.0, 1.0, 0.0),
'yellow': (1.0, 0.8, 0.0)
}

self.agent_c = (0.2, 0.2, 0.7)
self.cond_c = (0.75, 0.35, 0.35)
self.dir_c = (0.9, 0.9, 0.9)
self.v_c = (0.2, 0.75, 0.2)

self.dir_lines = {}
self.v_lines = {}
self.actor_boxes = {}
self.traffic_light_boxes = {}
self.box_labels = {}
self.frame_label = None
self.current_ax = None

self.reset_recording()

self.numbers = False

def initialize_recording(self, agent_states, agent_attributes, traffic_light_states=None, conditional_agents=None):
self.agent_states_history = [agent_states]
self.traffic_lights_history = [traffic_light_states]
self.agent_attributes = agent_attributes
if conditional_agents is not None:
self.conditional_agents = conditional_agents
else:
self.conditional_agents = []

def reset_recording(self):
self.agent_states_history = []
self.traffic_lights_history = []
self.agent_attributes = None
self.conditional_agents = []

def record_step(self, agent_states, traffic_light_states=None):
self.agent_states_history.append(agent_states)
self.traffic_lights_history.append(traffic_light_states)

def plot_scene(self, agent_states, agent_attributes, traffic_light_states=None, conditional_agents=None,
ax=None, numbers=False, direction_vec=True, velocity_vec=False):
self.initialize_recording(agent_states, agent_attributes,
traffic_light_states=traffic_light_states,
conditional_agents=conditional_agents)

self.plot_frame(idx=0, ax=ax, numbers=numbers, direction_vec=direction_vec,
velocity_vec=velocity_vec, plot_frame_number=False)

self.reset_recording()

def plot_frame(self, idx, ax=None, numbers=False, direction_vec=False, velocity_vec=False, plot_frame_number=False):
self._initialize_plot(ax=ax, numbers=numbers, direction_vec=direction_vec,
velocity_vec=velocity_vec, plot_frame_number=plot_frame_number)
self._update_frame_to(idx)

def animate_scene(self, output_name=None, start_idx=0, end_idx=-1, ax=None,
numbers=False, direction_vec=True, velocity_vec=False,
plot_frame_number=False):
self._initialize_plot(ax=ax, numbers=numbers, direction_vec=direction_vec,
velocity_vec=velocity_vec, plot_frame_number=plot_frame_number)
end_idx = len(self.agent_states_history) if end_idx == -1 else end_idx
fig = self.current_ax.figure

def animate(i):
self._update_frame_to(i)

ani = animation.FuncAnimation(fig, animate, np.arange(start_idx, end_idx), interval=100)
if output_name is not None:
ani.save(f'{output_name}', writer='pillow')
return ani

def _initialize_plot(self, ax=None, numbers=False, direction_vec=True, velocity_vec=False, plot_frame_number=False):
if ax is None:
plt.clf()
ax = plt.gca()
self.current_ax = ax
ax.imshow(self.map_image, extent=self.extent)

self.dir_lines = {}
self.v_lines = {}
self.actor_boxes = {}
self.traffic_light_boxes = {}
self.box_labels = {}
self.frame_label = None

self.numbers = numbers
self.direction_vec = direction_vec
self.velocity_vec = velocity_vec
self.plot_frame_number = plot_frame_number

self._update_frame_to(0)

def _update_frame_to(self, frame_idx):
for i, (agent, agent_attribute) in enumerate(zip(self.agent_states_history[frame_idx], self.agent_attributes)):
self._update_agent(i, agent, agent_attribute)

if self.traffic_lights_history[frame_idx] is not None:
for light_id, light_state in self.traffic_lights_history[frame_idx].items():
self._plot_traffic_light(light_id, light_state)

if self.plot_frame_number:
if self.frame_label is None:
self.frame_label = self.current_ax.text(self.extent[0], self.extent[2], str(frame_idx), c='r', fontsize=18)
else:
self.frame_label.set_text(str(frame_idx))

self.current_ax.set_xlim(*self.extent[0:2])
self.current_ax.set_ylim(*self.extent[2:4])

def _update_agent(self, agent_idx, agent, agent_attribute):
l, w = agent_attribute.length, agent_attribute.width
x, y = agent.center.x, agent.center.y
v = agent.speed
psi = agent.orientation
box = np.array([
[0, 0], [l * 0.5, 0], # direction vector
[0, 0], [v * 0.5, 0], # speed vector at (0.5 m / s ) / m
])
box = np.matmul(rot(psi), box.T).T + np.array([[x, y]])
if self.direction_vec:
if agent_idx not in self.dir_lines:
self.dir_lines[agent_idx] = self.current_ax.plot(box[0:2,0], box[0:2,1], lw=2.0, c=self.dir_c)[0] # plot the direction vector
else:
self.dir_lines[agent_idx].set_xdata(box[0:2,0])
self.dir_lines[agent_idx].set_ydata(box[0:2,1])

if self.velocity_vec:
if agent_idx not in self.v_lines:
self.v_lines[agent_idx] = self.current_ax.plot(box[2:4,0], box[2:4,1], lw=1.5 , c=self.v_c)[0] # plot the speed
else:
self.v_lines[agent_idx].set_xdata(box[2:4,0])
self.v_lines[agent_idx].set_ydata(box[2:4,1])
if self.numbers:
if agent_idx not in self.box_labels:
self.box_labels[agent_idx] = self.current_ax.text(x, y, str(agent_idx), c='r', fontsize=18)
self.box_labels[agent_idx].set_clip_on(True)
else:
self.box_labels[agent_idx].set_x(x)
self.box_labels[agent_idx].set_y(y)

if agent_idx in self.conditional_agents:
c = self.cond_c
else:
c = self.agent_c

rect = Rectangle((x - l / 2,y - w / 2), l, w, angle=psi * 180 / np.pi, rotation_point='center', fc=c, lw=0)
if agent_idx in self.actor_boxes:
self.actor_boxes[agent_idx].remove()
self.actor_boxes[agent_idx] = rect
self.actor_boxes[agent_idx].set_clip_on(True)
self.current_ax.add_patch(self.actor_boxes[agent_idx])

def _plot_traffic_light(self, light_id, light_state):
light = self.traffic_lights[light_id]
x, y = light.center.x, light.center.y
psi = light.orientation
l, w = light.length, light.width

rect = Rectangle((x - l / 2,y - w / 2), l, w, angle=psi * 180 / np.pi,
rotation_point='center',
fc=self.traffic_light_colors[light_state], lw=0)
if light_id in self.traffic_light_boxes:
self.traffic_light_boxes[light_id].remove()
self.current_ax.add_patch(rect)
self.traffic_light_boxes[light_id] = rect

0 comments on commit 865d580

Please sign in to comment.