# Waymo Open Dataset Motion Tutorial

- Website: https://waymo.com/open
- GitHub: https://github.com/waymo-research/waymo-open-dataset

This tutorial demonstrates:
- How to decode and interpret the data.
- How to train a simple model with Tensorflow.

Visit the [Waymo Open Dataset Website](https://waymo.com/open) to download the full dataset.

To use, open this notebook in [Colab](https://colab.research.google.com).

Uncheck the box "Reset all runtimes before running" if you run this colab directly from the remote kernel. Alternatively, you can make a copy before trying to run it by following "File > Save copy in Drive ...".

# Package installation

Please follow the instructions in [tutorial.ipynb](https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial.ipynb).

# Imports and global definitions

In [14]:
# Data location. Please edit.

# A tfrecord containing tf.Example protos as downloaded from the Waymo dataset
# webpage.

# Replace this path with your own tfrecords.
FILENAME = 'data_processing/data/training_tfexample.tfrecord-00009-of-01000'

In [15]:
import math
import os
import uuid
import time

from matplotlib import cm
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pandas as pd

import numpy as np
from IPython.display import HTML
import itertools
import tensorflow as tf

from google.protobuf import text_format
from waymo_open_dataset.metrics.ops import py_metrics_ops
from waymo_open_dataset.metrics.python import config_util_py as config_util
from waymo_open_dataset.protos import motion_metrics_pb2

scenario_features = {
    'scenario/id':
        tf.io.FixedLenFeature([1], tf.string, default_value = None)
}

# Example field definition
roadgraph_features = {
    'roadgraph_samples/dir':
        tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
    'roadgraph_samples/id':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/type':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/valid':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/xyz':
        tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
}

# Features of other agents.
state_features = {
    'state/id':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/type':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/is_sdc':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/tracks_to_predict':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/current/bbox_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/height':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/length':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/timestamp_micros':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/valid':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/vel_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/width':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/z':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/future/bbox_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/height':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/length':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/timestamp_micros':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/valid':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/vel_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/width':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/z':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/past/bbox_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/height':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/length':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/timestamp_micros':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/valid':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/vel_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/width':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/z':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/objects_of_interest':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None)
}

traffic_light_features = {
    'traffic_light_state/current/state':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/valid':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/x':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/y':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/z':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/past/state':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/valid':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/x':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/y':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/z':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/current/id':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None)
}

features_description = {}
features_description.update(scenario_features)
features_description.update(roadgraph_features)
features_description.update(state_features)
features_description.update(traffic_light_features)

# Visualize TF Example sample

## Create Dataset.

In [16]:
dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')
iter = dataset.as_numpy_iterator()

In [17]:
# num_examples = 0
# for data in dataset.as_numpy_iterator():
#     num_examples += 1
# print('num_examples', num_examples)

data = next(iter)
parsed = tf.io.parse_single_example(data, features_description)

## Generate visualization images.

In [26]:
def create_figure_and_axes(size_pixels):
  """Initializes a unique figure and axes for plotting."""
  fig, ax = plt.subplots(1, 1, num=uuid.uuid4())

  # Sets output image to pixel resolution.
  dpi = 100
  size_inches = size_pixels / dpi
  fig.set_size_inches([size_inches, size_inches])
  fig.set_dpi(dpi)
  fig.set_facecolor('white')
  ax.set_facecolor('white')
  ax.xaxis.label.set_color('black')
  ax.tick_params(axis='x', colors='black')
  ax.yaxis.label.set_color('black')
  ax.tick_params(axis='y', colors='black')
  fig.set_tight_layout(True)
  ax.grid(False)
  return fig, ax


def fig_canvas_image(fig):
  """Returns a [H, W, 3] uint8 np.array image from fig.canvas.tostring_rgb()."""
  # Just enough margin in the figure to display xticks and yticks.
  fig.subplots_adjust(
      left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0)
  fig.canvas.draw()
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
  return data.reshape(fig.canvas.get_width_height()[::-1] + (3,))


def get_colormap(num_agents):
  """Compute a color map array of shape [num_agents, 4]."""
  colors = cm.get_cmap('jet', num_agents)
  colors = colors(range(num_agents))
  np.random.shuffle(colors)
  return colors


def get_viewport(all_states, all_states_mask):
  """Gets the region containing the data.

  Args:
    all_states: states of agents as an array of shape [num_agents, num_steps,
      2].
    all_states_mask: binary mask of shape [num_agents, num_steps] for
      `all_states`.

  Returns:
    center_y: float. y coordinate for center of data.
    center_x: float. x coordinate for center of data.
    width: float. Width of data.
  """
  valid_states = all_states[all_states_mask]
  all_y = valid_states[..., 1]
  all_x = valid_states[..., 0]

  center_y = (np.max(all_y) + np.min(all_y)) / 2
  center_x = (np.max(all_x) + np.min(all_x)) / 2

  range_y = np.ptp(all_y)
  range_x = np.ptp(all_x)

  width = max(range_y, range_x)

  return center_y, center_x, width


def visualize_one_step(states,
                       mask,
                       roadgraph,
                       title,
                       center_y,
                       center_x,
                       width,
                       color_map,
                       size_pixels=1000):
  """Generate visualization for a single step."""

  # Create figure and axes.
  fig, ax = create_figure_and_axes(size_pixels=size_pixels)

  # Plot roadgraph.
  rg_pts = roadgraph[:, :2].T
  ax.plot(rg_pts[0, :], rg_pts[1, :], 'k.', alpha=1, ms=2)

  masked_x = states[:, 0][mask]
  masked_y = states[:, 1][mask]
  colors = color_map[mask]

  # Plot agent current position.
  ax.scatter(
      masked_x,
      masked_y,
      marker='o',
      linewidths=3,
      color=colors,
  )

  # Title.
  ax.set_title(title)

  # Set axes.  Should be at least 10m on a side and cover 160% of agents.
  size = max(10, width * 1.0)
  ax.axis([
      -size / 2 + center_x, size / 2 + center_x, -size / 2 + center_y,
      size / 2 + center_y
  ])
  ax.set_aspect('equal')

  image = fig_canvas_image(fig)
  plt.close(fig)
  return image


def visualize_all_agents_smooth(
    decoded_example,
    size_pixels=1000,
):
  """Visualizes all agent predicted trajectories in a serie of images.

  Args:
    decoded_example: Dictionary containing agent info about all modeled agents.
    size_pixels: The size in pixels of the output image.

  Returns:
    T of [H, W, 3] uint8 np.arrays of the drawn matplotlib's figure canvas.
  """
  # [num_agents, num_past_steps, 2] float32.
  past_states = tf.stack(
      [decoded_example['state/past/x'], decoded_example['state/past/y']],
      -1).numpy()
  past_states_mask = decoded_example['state/past/valid'].numpy() > 0.0

  # [num_agents, 1, 2] float32.
  current_states = tf.stack(
      [decoded_example['state/current/x'], decoded_example['state/current/y']],
      -1).numpy()
  current_states_mask = decoded_example['state/current/valid'].numpy() > 0.0

  # [num_agents, num_future_steps, 2] float32.
  future_states = tf.stack(
      [decoded_example['state/future/x'], decoded_example['state/future/y']],
      -1).numpy()
  future_states_mask = decoded_example['state/future/valid'].numpy() > 0.0

  # [num_points, 3] float32.
  roadgraph_xyz = decoded_example['roadgraph_samples/xyz'].numpy()

  num_agents, num_past_steps, _ = past_states.shape
  num_future_steps = future_states.shape[1]

  color_map = get_colormap(num_agents)

  # [num_agens, num_past_steps + 1 + num_future_steps, depth] float32.
  all_states = np.concatenate([past_states, current_states, future_states], 1)

  # [num_agens, num_past_steps + 1 + num_future_steps] float32.
  all_states_mask = np.concatenate(
      [past_states_mask, current_states_mask, future_states_mask], 1)

  center_y, center_x, width = get_viewport(all_states, all_states_mask)

  images = []

  # Generate images from past time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(past_states, num_past_steps, 1),
          np.split(past_states_mask, num_past_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'past: %d' % (num_past_steps - i), center_y,
                            center_x, width, color_map, size_pixels)
    images.append(im)

  # Generate one image for the current time step.
  s = current_states
  m = current_states_mask

  im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, 'current', center_y,
                          center_x, width, color_map, size_pixels)
  images.append(im)

  # Generate images from future time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(future_states, num_future_steps, 1),
          np.split(future_states_mask, num_future_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'future: %d' % (i + 1), center_y, center_x, width,
                            color_map, size_pixels)
    images.append(im)

  return images


images = visualize_all_agents_smooth(parsed)

## Display animation.

In [27]:
def create_animation(images):
  """ Creates a Matplotlib animation of the given images.

  Args:
    images: A list of numpy arrays representing the images.

  Returns:
    A matplotlib.animation.Animation.

  Usage:
    anim = create_animation(images)
    anim.save('/tmp/animation.avi')
    HTML(anim.to_html5_video())
  """

  plt.ioff()
  fig, ax = plt.subplots()
  dpi = 100
  size_inches = 1000 / dpi
  fig.set_size_inches([size_inches, size_inches])
  plt.ion()

  def animate_func(i):
    ax.imshow(images[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid('off')

  anim = animation.FuncAnimation(
      fig, animate_func, frames=len(images) // 2, interval=100)
  plt.close(fig)
  return anim


anim = create_animation(images[::5])
HTML(anim.to_html5_video())

## Generate interaction visualization images.

In [28]:
def create_figure_and_axes(size_pixels):
  """Initializes a unique figure and axes for plotting."""
  fig, ax = plt.subplots(1, 1, num=uuid.uuid4())

  # Sets output image to pixel resolution.
  dpi = 100
  size_inches = size_pixels / dpi
  fig.set_size_inches([size_inches, size_inches])
  fig.set_dpi(dpi)
  fig.set_facecolor('white')
  ax.set_facecolor('white')
  ax.xaxis.label.set_color('black')
  ax.tick_params(axis='x', colors='black')
  ax.yaxis.label.set_color('black')
  ax.tick_params(axis='y', colors='black')
  fig.set_tight_layout(True)
  ax.grid(False)
  return fig, ax


def fig_canvas_image(fig):
  """Returns a [H, W, 3] uint8 np.array image from fig.canvas.tostring_rgb()."""
  # Just enough margin in the figure to display xticks and yticks.
  fig.subplots_adjust(
      left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0)
  fig.canvas.draw()
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
  return data.reshape(fig.canvas.get_width_height()[::-1] + (3,))


def get_colormap(num_agents):
  """Compute a color map array of shape [num_agents, 4]."""
  colors = cm.get_cmap('jet', num_agents)
  colors = colors(range(num_agents))
  np.random.shuffle(colors)
  return colors


def get_viewport(all_states, all_states_mask):
  """Gets the region containing the data.

  Args:
    all_states: states of agents as an array of shape [num_agents, num_steps,
      2].
    all_states_mask: binary mask of shape [num_agents, num_steps] for
      `all_states`.

  Returns:
    center_y: float. y coordinate for center of data.
    center_x: float. x coordinate for center of data.
    width: float. Width of data.
  """
  valid_states = all_states[all_states_mask]
  all_y = valid_states[..., 1]
  all_x = valid_states[..., 0]

  center_y = (np.max(all_y) + np.min(all_y)) / 2
  center_x = (np.max(all_x) + np.min(all_x)) / 2

  range_y = np.ptp(all_y)
  range_x = np.ptp(all_x)

  width = max(range_y, range_x)

  return center_y, center_x, width


def visualize_one_step(states,
                       mask,
                       roadgraph_xyz,
                       roadgraph_type,
                       roadgraph_id,
                       title,
                       center_y,
                       center_x,
                       width,
                       color_map,
                       size_pixels=1000):
  """Generate visualization for a single step."""

  # Create figure and axes.
  fig, ax = create_figure_and_axes(size_pixels=size_pixels)

  # Plot roadgraph.
  rg_pts = roadgraph_xyz[:, :2].T
#   ax.plot(rg_pts[0, :], rg_pts[1, :], 'k.', alpha=1, ms=2)
    
  roadgraph_data = pd.DataFrame({'x': rg_pts[0, :], 'y': rg_pts[1, :], 'category': roadgraph_type[:, 0]})
#   roadgraph_data = pd.DataFrame({'x': rg_pts[0, :], 'y': rg_pts[1, :], 'category': roadgraph_id[:, 0]})

  groups = roadgraph_data.groupby('category')
  for cat, group in groups:
    ax.plot(group['x'], group['y'], '.', alpha=1, ms=3, label=cat)
  ax.legend(fontsize='xx-large')

  masked_x = states[:, 0][mask]
  masked_y = states[:, 1][mask]
  colors = color_map[mask]

  # Plot agent current position.
  ax.scatter(
      masked_x,
      masked_y,
      marker='o',
      linewidths=3,
      color=colors,
  )

  # Title.
  ax.set_title(title)

  # Set axes.  Should be at least 10m on a side and cover 160% of agents.
  size = max(10, width * 1.0)
  ax.axis([
      -size / 2 + center_x, size / 2 + center_x, -size / 2 + center_y,
      size / 2 + center_y
  ])
  ax.set_aspect('equal')

  image = fig_canvas_image(fig)
  plt.close(fig)
  return image


def visualize_all_agents_smooth(
    decoded_example,
    size_pixels=1000,
):
  """Visualizes all agent predicted trajectories in a serie of images.

  Args:
    decoded_example: Dictionary containing agent info about all modeled agents.
    size_pixels: The size in pixels of the output image.

  Returns:
    T of [H, W, 3] uint8 np.arrays of the drawn matplotlib's figure canvas.
  """
  
  ia_idx = tf.where((decoded_example['state/objects_of_interest']==1)).numpy()

  # [2, num_past_steps, 2] float32.
  past_states = tf.stack(
      [tf.gather_nd(decoded_example['state/past/x'], indices=ia_idx), tf.gather_nd(decoded_example['state/past/y'], indices=ia_idx)],
      -1).numpy()
  past_states_mask = tf.gather_nd(decoded_example['state/past/valid'], indices=ia_idx).numpy() > 0.0

  # [2, 1, 2] float32.
  current_states = tf.stack(
      [tf.gather_nd(decoded_example['state/current/x'], indices=ia_idx), tf.gather_nd(decoded_example['state/current/y'], indices=ia_idx)],
      -1).numpy()
  current_states_mask = tf.gather_nd(decoded_example['state/current/valid'], indices=ia_idx).numpy() > 0.0

  # [2, num_future_steps, 2] float32.
  future_states = tf.stack(
      [tf.gather_nd(decoded_example['state/future/x'], indices=ia_idx), tf.gather_nd(decoded_example['state/future/y'], indices=ia_idx)],
      -1).numpy()
  future_states_mask = tf.gather_nd(decoded_example['state/future/valid'], indices=ia_idx).numpy() > 0.0

  # [num_points, 3] float32.
  roadgraph_xyz = decoded_example['roadgraph_samples/xyz'].numpy()
    
  # [num_points, 1]
  roadgraph_type = decoded_example['roadgraph_samples/type'].numpy()
    
  # [num_points, 1]
  roadgraph_id = decoded_example['roadgraph_samples/id'].numpy()

  num_agents, num_past_steps, _ = past_states.shape
  num_future_steps = future_states.shape[1]

  color_map = get_colormap(num_agents)

  # [2, num_past_steps + 1 + num_future_steps, depth] float32.
  all_states = np.concatenate([past_states, current_states, future_states], 1)

  # [2, num_past_steps + 1 + num_future_steps] float32.
  all_states_mask = np.concatenate(
      [past_states_mask, current_states_mask, future_states_mask], 1)

  center_y, center_x, width = get_viewport(all_states, all_states_mask)

  images = []

  # Generate images from past time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(past_states, num_past_steps, 1),
          np.split(past_states_mask, num_past_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, roadgraph_type, roadgraph_id,
                            'past: %d' % (num_past_steps - i), center_y,
                            center_x, width, color_map, size_pixels)
    images.append(im)

  # Generate one image for the current time step.
  s = current_states
  m = current_states_mask

  im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, roadgraph_type, roadgraph_id, 'current', center_y,
                          center_x, width, color_map, size_pixels)
  images.append(im)

  # Generate images from future time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(future_states, num_future_steps, 1),
          np.split(future_states_mask, num_future_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, roadgraph_type, roadgraph_id,
                            'future: %d' % (i + 1), center_y, center_x, width,
                            color_map, size_pixels)
    images.append(im)

  return images


images = visualize_all_agents_smooth(parsed)

## Display interaction animation.

In [29]:
def create_animation(images):
  """ Creates a Matplotlib animation of the given images.

  Args:
    images: A list of numpy arrays representing the images.

  Returns:
    A matplotlib.animation.Animation.

  Usage:
    anim = create_animation(images)
    anim.save('/tmp/animation.avi')
    HTML(anim.to_html5_video())
  """

  plt.ioff()
  fig, ax = plt.subplots()
  dpi = 100
  size_inches = 1000 / dpi
  fig.set_size_inches([size_inches, size_inches])
  plt.ion()

  def animate_func(i):
    ax.imshow(images[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid('off')

  anim = animation.FuncAnimation(
      fig, animate_func, frames=len(images)//2, interval=100)
  plt.close(fig)
  return anim

anim = create_animation(images[::5])
HTML(anim.to_html5_video())

In [30]:
def create_animation(images):
  """ Creates a Matplotlib animation of the given images.

  Args:
    images: A list of numpy arrays representing the images.

  Returns:
    A matplotlib.animation.Animation.

  Usage:
    anim = create_animation(images)
    anim.save('/tmp/animation.avi')
    HTML(anim.to_html5_video())
  """

  plt.ioff()
  fig, ax = plt.subplots()
  dpi = 100
  size_inches = 1000 / dpi
  fig.set_size_inches([size_inches, size_inches])
  plt.ion()

  def animate_func(i):
    ax.imshow(images[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid('off')

  anim = animation.FuncAnimation(
      fig, animate_func, frames=len(images), interval=100)
  plt.close(fig)
  return anim

anim = create_animation(images[::1])
HTML(anim.to_html5_video())

# Details of tf.Example format

In [18]:
decoded_example = tf.io.parse_single_example(data, features_description)

In [19]:
# scenario id
print(decoded_example['scenario/id'])

tf.Tensor([b'7bf4f2ce60321112'], shape=(1,), dtype=string)


In [20]:
# id
print('id for each object:', decoded_example['state/id'])

id for each object: tf.Tensor(
[ 6.290e+02  6.730e+02  7.190e+02  7.260e+02  7.330e+02  7.360e+02
  7.410e+02  7.450e+02  1.001e+03  6.340e+02  6.500e+02  6.250e+02
  6.210e+02  6.320e+02  6.360e+02  6.240e+02  6.330e+02  7.280e+02
  6.350e+02  6.200e+02  6.470e+02  6.490e+02  6.230e+02  6.260e+02
  7.100e+02  7.050e+02  7.300e+02  6.590e+02  7.350e+02  6.600e+02
  6.400e+02  6.450e+02  6.710e+02  7.270e+02  6.520e+02  6.310e+02
  6.380e+02  6.370e+02  6.460e+02  7.000e+02  6.530e+02  7.130e+02
  6.660e+02  7.250e+02  6.510e+02  7.320e+02  7.120e+02  6.560e+02
  7.080e+02  7.200e+02  6.980e+02  6.570e+02  6.720e+02  6.970e+02
  7.210e+02  7.430e+02  7.240e+02  6.270e+02  7.460e+02  6.550e+02
  7.440e+02  6.860e+02  7.390e+02  7.140e+02  7.380e+02  6.870e+02
  7.400e+02  7.290e+02  7.470e+02  7.510e+02  7.540e+02  7.560e+02
  7.580e+02  7.530e+02  7.570e+02  7.620e+02  7.590e+02  7.600e+02
  7.500e+02  9.790e+02  7.640e+02  6.680e+02  7.550e+02  7.650e+02
  7.680e+02  7.630e+02  9.870e+

In [21]:
# indicate if the object is the autonomous vehicle
print('indicate the uav:', decoded_example['state/is_sdc'])

indicate the uav: tf.Tensor(
[ 0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1], shape=(128,), dtype=int64)


In [22]:
# display the state of the uav
uav_idx = tf.where((decoded_example['state/is_sdc']==1))[0][0].numpy()
print('uav index', uav_idx)
# uav's type
print('the type of the uav (1 represents a vehicle):', decoded_example['state/type'][uav_idx])
# uav's curret x
print('current x position of the uav:', decoded_example['state/current/x'][uav_idx])
# uav's past x (10 time steps)
print('past x positions (10 time steps) of the uav:', decoded_example['state/past/x'][uav_idx])
# uav's future x (80 time steps)
print('future x positions (80 time steps) of the uav:', decoded_example['state/future/x'][uav_idx])
# uav's curret x velocity
print('current x velocity of the uav:', decoded_example['state/current/velocity_x'][uav_idx])
# uav's (current) height
print('current height of the uav:', decoded_example['state/current/height'][uav_idx])

uav index 8
the type of the uav (1 represents a vehicle): tf.Tensor(1.0, shape=(), dtype=float32)
current x position of the uav: tf.Tensor([9984.675], shape=(1,), dtype=float32)
past x positions (10 time steps) of the uav: tf.Tensor(
[9983.588 9983.727 9983.859 9983.983 9984.101 9984.212 9984.317 9984.416
 9984.508 9984.595], shape=(10,), dtype=float32)
future x positions (80 time steps) of the uav: tf.Tensor(
[9984.749  9984.817  9984.881  9984.9375 9984.991  9985.039  9985.083
 9985.124  9985.161  9985.195  9985.225  9985.251  9985.275  9985.296
 9985.313  9985.331  9985.346  9985.359  9985.37   9985.381  9985.39
 9985.397  9985.401  9985.404  9985.405  9985.406  9985.403  9985.4
 9985.4    9985.399  9985.396  9985.3955 9985.395  9985.393  9985.392
 9985.39   9985.391  9985.391  9985.39   9985.391  9985.391  9985.391
 9985.391  9985.391  9985.391  9985.391  9985.39   9985.39   9985.39
 9985.39   9985.39   9985.39   9985.391  9985.391  9985.39   9985.39
 9985.391  9985.39   9985.39   

In [23]:
# interactive agents
print('indicate the two interactive agents:', decoded_example['state/objects_of_interest'])
# extract interactive agents:
ia_idx = tf.where((decoded_example['state/objects_of_interest']==1)).numpy()
print('the index of interactive agents', ia_idx)
# interactive agent types:
print('the type of interactive agent 1 (1: vehicle; 2: pedestrain)', decoded_example['state/type'][ia_idx[0][0]])
print('the type of interactive agent 2 (1: vehicle; 2: pedestrain)', decoded_example['state/type'][ia_idx[1][0]])
# interactive agent x position
print('the x position of the interative agents', tf.gather_nd(decoded_example['state/past/x'], indices=ia_idx))

indicate the two interactive agents: tf.Tensor(
[ 0  1  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1], shape=(128,), dtype=int64)
the index of interactive agents [[1]
 [3]]
the type of interactive agent 1 (1: vehicle; 2: pedestrain) tf.Tensor(1.0, shape=(), dtype=float32)
the type of interactive agent 2 (1: vehicle; 2: pedestrain) tf.Tensor(1.0, shape=(), dtype=float32)
the x position of the interative agents tf.Tensor(
[[9974.768 9975.397 9976.012 9976.586 9977.134 9977.678 9978.175 9978.656
  9979.119 9979.519]
 [9955.44  9956.448 9957.477 9958.469 9959.442 9960.389 9961.317 9962.207
  9963.069 9963.903]], shape=(2, 10), dtype=float32)


In [24]:
# roadgraph point position
print('the coodinate positions of the map data point 0: ', decoded_example['roadgraph_samples/xyz'][0])
# roadgraph point direction
print('the direction of the map data point 0: ', decoded_example['roadgraph_samples/dir'][0])
# roadgraph point type
print('the type of the map data point 0: ', decoded_example['roadgraph_samples/type'][0])
s = set()
for i in decoded_example['roadgraph_samples/type'].numpy():
    s.add(i[0])
print('roadgraph type', s)
# roadgraph_samples/id
print('the id of the map data point: ', decoded_example['roadgraph_samples/id'])

the coodinate positions of the map data point 0:  tf.Tensor([9898.255   8018.4634   -72.05614], shape=(3,), dtype=float32)
the direction of the map data point 0:  tf.Tensor([ 0.01621103 -0.9998632   0.00327913], shape=(3,), dtype=float32)
the type of the map data point 0:  tf.Tensor([2], shape=(1,), dtype=int64)
roadgraph type {2, 3}
the id of the map data point:  tf.Tensor(
[[  1]
 [  1]
 [  1]
 ...
 [572]
 [572]
 [572]], shape=(20000, 1), dtype=int64)


In [25]:
a = decoded_example['roadgraph_samples/xyz'].numpy()
print(a)
print(a.max(0))
print(a.max(0)-a.min(0))

[[ 9898.255     8018.4634     -72.05614 ]
 [ 9898.263     8017.966      -72.054504]
 [ 9898.2705    8017.468      -72.05287 ]
 ...
 [10068.251     7971.714      -71.97878 ]
 [10068.591     7971.37       -71.97661 ]
 [10068.936     7971.0303     -71.974434]]
[10083.896    8019.093     -71.70226]
[229.40723   199.9414      1.2299957]


## A Note for the roadgraph point type
LaneCenter-Freeway = 1, LaneCenter-SurfaceStreet = 2, LaneCenter-BikeLane = 3, RoadLine-BrokenSingleWhite = 6, RoadLine-SolidSingleWhite = 7, RoadLine-SolidDoubleWhite = 8, RoadLine-BrokenSingleYellow = 9, RoadLine-BrokenDoubleYellow = 10, Roadline-SolidSingleYellow = 11, Roadline-SolidDoubleYellow=12, RoadLine-PassingDoubleYellow = 13, RoadEdgeBoundary = 15, RoadEdgeMedian = 16, StopSign = 17, Crosswalk = 18, SpeedBump = 19, other values are unknown types and should not be present.

In [77]:
# traffic light state 
print('the state of the traffic light 1: ', decoded_example['traffic_light_state/current/state'][0][1])
# traffic light current x position
print('the current x position of the traffic light 1: ', decoded_example['traffic_light_state/current/x'][0][1])
# lane controlled id
print('the lane id controlled by the traffic light 1: ', decoded_example['traffic_light_state/current/id'][0][1])

the state of the traffic light 1:  tf.Tensor(0, shape=(), dtype=int64)
the current x position of the traffic light 1:  tf.Tensor(1686.2325, shape=(), dtype=float32)
the lane id controlled by the traffic light 1:  tf.Tensor(57, shape=(), dtype=int64)


## A Note for the Traffic Light State
Unknown = 0, Arrow_Stop = 1, Arrow_Caution = 2, Arrow_Go = 3, Stop = 4, Caution = 5, Go = 6, Flashing_Stop = 7, Flashing_Caution = 8

## Standard and Interactive Datasets
The standard and interactive training sets are the same.

The standard validation and test sets provide up to 8 objects to predict in each scene. Selection is biased to require objects that do not follow a constant velocity model or straight paths. The interactive versions of the validation and test sets focus on the interactive portion of the segment and require only the 2 mined interactive objects to be predicted