In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # set minimum logging to "info"

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import utils

In [2]:
# needed to deserialize dataset
features_description = utils.get_features_description()

In [3]:
# result is a single scene with a dictionary where the keys are as defined in features_description (e.g. "state/current/x", "state/current/y"...etc
def parse_example_basic(value):
    decoded_example = tf.io.parse_single_example(value, features_description)
    return decoded_example

FILES = [
    "./data/training_tfexample.tfrecord-00000-of-01000",
    "./data/training_tfexample.tfrecord-00001-of-01000",
]

dataset = tf.data.TFRecordDataset(FILES) # make dataset
dataset = dataset.map(parse_example_basic)


# take the first scene and inspect the shape of the data:
for example in dataset:
    for key, value in example.items():
        print(f"Key: {key}")
        print(f"Shape: {value.shape}")
        print(f"Type: {type(value)}")
        print("")
    break

Key: roadgraph_samples/dir
Shape: (20000, 3)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: roadgraph_samples/id
Shape: (20000, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: roadgraph_samples/type
Shape: (20000, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: roadgraph_samples/valid
Shape: (20000, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: roadgraph_samples/xyz
Shape: (20000, 3)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: scenario/id
Shape: (1,)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: state/current/bbox_yaw
Shape: (128, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: state/current/height
Shape: (128, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: state/current/length
Shape: (128, 1)
Type: <class 'tensorflow.python.framework.ops.EagerTensor'>

Key: state/current/speed
Shape: (128, 1)
Type: <class 'tensorflow.p

Note that these values are tensorflow tensors, which might be complicated to work with depending on the task... You might want to get the numpy arrays directly with:

In [4]:
for example in dataset.as_numpy_iterator():
    for key, value in example.items():
        print(f"Key: {key}")
        print(f"Shape: {value.shape}")
        print(f"Type: {type(value)}")
        print("")
    break

Key: roadgraph_samples/dir
Shape: (20000, 3)
Type: <class 'numpy.ndarray'>

Key: roadgraph_samples/id
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>

Key: roadgraph_samples/type
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>

Key: roadgraph_samples/valid
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>

Key: roadgraph_samples/xyz
Shape: (20000, 3)
Type: <class 'numpy.ndarray'>

Key: scenario/id
Shape: (1,)
Type: <class 'numpy.ndarray'>

Key: state/current/bbox_yaw
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/height
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/length
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/speed
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/timestamp_micros
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/valid
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/vel_yaw
Shape: (128, 1)
Type: <class 'numpy.ndarray'>

Key: state/current/velocity_x
Shape:

### Modifying data while loading
You might want to load the data in a specific format, you can do this in the parse function. For simplicity, let's just load up the x and y positions of RUs and other relevant scene info (e.g. roadgraph)

In [5]:
state_features = ["x", "y"]

# put past, current and future states of each feature together
def _parse_states(decoded_example):
    past_states = [decoded_example['state/past/{}'.format(feature)] for feature in state_features if
                   feature != "valid"]
    past_states = tf.stack(past_states, -1)

    current_states = [decoded_example['state/current/{}'.format(feature)] for feature in state_features if
                      feature != "valid"]
    current_states = tf.stack(current_states, -1)

    future_states = [decoded_example['state/future/{}'.format(feature)] for feature in state_features if
                     feature != "valid"]
    future_states = tf.stack(future_states, -1)


    # note this value will change later when we load data in batches instead of single scenes
    axis = 1

    # entire trajectory ground truth
    states = tf.concat([past_states, current_states, future_states], axis)

    past_is_valid = decoded_example['state/past/valid'] > 0
    current_is_valid = decoded_example['state/current/valid'] > 0
    future_is_valid = decoded_example['state/future/valid'] > 0
    states_is_valid = tf.concat([past_is_valid, current_is_valid, future_is_valid], axis)

    # If a sample was not seen at all in the past, we declare the sample as invalid.
    sample_is_valid = tf.reduce_any(tf.concat([past_is_valid, current_is_valid], axis), axis)

    return states, states_is_valid, sample_is_valid

def parse_examples(examples):
    decoded_examples = tf.io.parse_single_example(examples, features_description)
    states, states_is_valid, sample_is_valid = _parse_states(decoded_examples)

    # this is now the dictionary returned while loading the data
    result = {
        'states': states,
        'states_is_valid': states_is_valid,
        'scenario_id': decoded_examples['scenario/id'],
        'sample_is_valid': sample_is_valid,
        'rg_xyz': decoded_examples["roadgraph_samples/xyz"],
        'rg_dir': decoded_examples["roadgraph_samples/dir"],
        'rg_type': decoded_examples["roadgraph_samples/type"],
        'rg_valid': decoded_examples["roadgraph_samples/valid"],
        'rg_ids': decoded_examples["roadgraph_samples/id"],
    }

    return result


dataset = tf.data.TFRecordDataset(FILES)
dataset = dataset.map(parse_examples)


for example in dataset.as_numpy_iterator():
    for key, value in example.items():
        print(f"Key: {key}")
        print(f"Shape: {value.shape}")
        print(f"Type: {type(value)}")
        print("")
    break

Key: states
Shape: (128, 91, 2)
Type: <class 'numpy.ndarray'>

Key: states_is_valid
Shape: (128, 91)
Type: <class 'numpy.ndarray'>

Key: scenario_id
Shape: (1,)
Type: <class 'numpy.ndarray'>

Key: sample_is_valid
Shape: (128,)
Type: <class 'numpy.ndarray'>

Key: rg_xyz
Shape: (20000, 3)
Type: <class 'numpy.ndarray'>

Key: rg_dir
Shape: (20000, 3)
Type: <class 'numpy.ndarray'>

Key: rg_type
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>

Key: rg_valid
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>

Key: rg_ids
Shape: (20000, 1)
Type: <class 'numpy.ndarray'>



Note that the shape of `states` is now (128, 91, 2) = (#agents, #timesteps, #features)

### More efficient loading: batching, prefetching...etc
Loading scenes one by one will be extremely slow, it is much more efficient to load and apply operations in batches. There are also some additional parameters to optimize the whole process (more details to come)

In [6]:
state_features = ["x", "y"]

# put past, current and future states of each feature together. This function is essentially,
# the same as the one previously defined, but with a different axis for putting the timesteps together
def _parse_states(decoded_examples):
    past_states = [decoded_examples['state/past/{}'.format(feature)] for feature in state_features if
                   feature != "valid"]
    past_states = tf.stack(past_states, -1)

    current_states = [decoded_examples['state/current/{}'.format(feature)] for feature in state_features if
                      feature != "valid"]
    current_states = tf.stack(current_states, -1)

    future_states = [decoded_examples['state/future/{}'.format(feature)] for feature in state_features if
                     feature != "valid"]
    future_states = tf.stack(future_states, -1)


    # Our data now has one extra dimension for the batches, so axis=2
    axis = 2

    # entire trajectory ground truth
    states = tf.concat([past_states, current_states, future_states], axis)

    past_is_valid = decoded_examples['state/past/valid'] > 0
    current_is_valid = decoded_examples['state/current/valid'] > 0
    future_is_valid = decoded_examples['state/future/valid'] > 0
    states_is_valid = tf.concat([past_is_valid, current_is_valid, future_is_valid], axis)

    # If a sample was not seen at all in the past, we declare the sample as invalid.
    sample_is_valid = tf.reduce_any(tf.concat([past_is_valid, current_is_valid], axis), axis)

    return states, states_is_valid, sample_is_valid

def parse_examples(examples):
    # !!! note a different function here: parse_example vs parse_single_example
    decoded_examples = tf.io.parse_example(examples, features_description)
    states, states_is_valid, sample_is_valid = _parse_states(decoded_examples)

    # this is now the dictionary returned while loading the data
    result = {
        'states': states,
        'states_is_valid': states_is_valid,
        'scenario_id': decoded_examples['scenario/id'],
        'sample_is_valid': sample_is_valid,
        'rg_xyz': decoded_examples["roadgraph_samples/xyz"],
        'rg_dir': decoded_examples["roadgraph_samples/dir"],
        'rg_type': decoded_examples["roadgraph_samples/type"],
        'rg_valid': decoded_examples["roadgraph_samples/valid"],
        'rg_ids': decoded_examples["roadgraph_samples/id"],
    }

    return result

num_parallel_calls, prefetch_size = tf.data.AUTOTUNE, tf.data.AUTOTUNE # some params for optimizing loading time
batch_size = 16 # Adjust depending on how much RAM you have. I usually use 128, but it depends on how many features you load and the type of operations you do...
dataset = tf.data.TFRecordDataset(FILES)
dataset = dataset.batch(batch_size)
dataset = dataset.map(parse_examples, num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_size)


for batch in dataset.as_numpy_iterator():
    for key, value in batch.items():
        print(f"Key: {key}")
        print(f"Shape: {value.shape}")
        print(f"Type: {type(value)}")
        print("")
    break

Key: states
Shape: (16, 128, 91, 2)
Type: <class 'numpy.ndarray'>

Key: states_is_valid
Shape: (16, 128, 91)
Type: <class 'numpy.ndarray'>

Key: scenario_id
Shape: (16, 1)
Type: <class 'numpy.ndarray'>

Key: sample_is_valid
Shape: (16, 128)
Type: <class 'numpy.ndarray'>

Key: rg_xyz
Shape: (16, 20000, 3)
Type: <class 'numpy.ndarray'>

Key: rg_dir
Shape: (16, 20000, 3)
Type: <class 'numpy.ndarray'>

Key: rg_type
Shape: (16, 20000, 1)
Type: <class 'numpy.ndarray'>

Key: rg_valid
Shape: (16, 20000, 1)
Type: <class 'numpy.ndarray'>

Key: rg_ids
Shape: (16, 20000, 1)
Type: <class 'numpy.ndarray'>



Note the extra dimension since now we load multiple scenes at once