## Create index file

In [None]:
import os
import glob
import struct

os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "upb")

DATA_ROOT = '/path/to/your/tfrecords'
INDEX_FILE = 'index.idx'
BATCH_SIZE = 32
NUM_CAMS = 8

train_file_paths = sorted(glob.glob(os.path.join(DATA_ROOT, 'training_*.tfrecord*')))

if not os.path.exists(INDEX_FILE):

    with open(INDEX_FILE, 'wb') as index_f:
        for i, file_path in enumerate(train_file_paths):
            with open(file_path, 'rb') as f:
                while True:
                    proto_size_bytes = f.read(8)
                    if not proto_size_bytes:
                        break  # EOF

                    proto_size = struct.unpack('<Q', proto_size_bytes)[0]

                    index_f.write(struct.pack('<Q', i))
                    index_f.write(struct.pack('<Q', f.tell() + 4)) # +4 to skip length Checksum
                    index_f.write(struct.pack('<Q', proto_size))

                    # Move past proto and two checksums
                    f.seek(proto_size + 8, os.SEEK_CUR)


## Load index

In [38]:
index = []
with open(INDEX_FILE, 'rb') as f:
    while True:
        file_idx_bytes = f.read(8)

        if not file_idx_bytes:
            break # EOF

        file_idx = struct.unpack('<Q', file_idx_bytes)[0]
        offset = struct.unpack('<Q', f.read(8))[0]
        size = struct.unpack('<Q', f.read(8))[0]
        index.append((file_idx, offset, size))


## Setup dataset

In [None]:
import glob
import os
import numpy as np

from torch.utils.data import IterableDataset, get_worker_info
from waymo_open_dataset.protos import end_to_end_driving_data_pb2

class WaymoDataset(IterableDataset):
    def __init__(self, data_root, index):
        self.data_root = data_root
        self.index = index
        self.file_paths = sorted(glob.glob(os.path.join(data_root, 'training_*.tfrecord*')))
        self.file = None
        self.file_path = None

    def _shard_index(self):
        wi = get_worker_info()
        if wi is None:
            return self.index
        
        return [rec for i, rec in enumerate(self.index) if i % wi.num_workers == wi.id]

    def __iter__(self):
        shard = self._shard_index()

        for file_idx, offset, size in shard:
            file_path = self.file_paths[file_idx]
            if file_path != self.file_path:
                if self.file:
                    self.file.close()
                
                self.file_path = file_path
                self.file = open(file_path, 'rb')

            self.file.seek(offset)
            proto_raw = self.file.read(size)

            record = end_to_end_driving_data_pb2.E2EDFrame()
            record.ParseFromString(proto_raw)

            images = [np.frombuffer(img.image, dtype=np.uint8) for img in record.frame.images]

            fs = record.future_states
            if fs.pos_x:
                px = np.asarray(fs.pos_x, dtype=np.float32)
                py = np.asarray(fs.pos_y, dtype=np.float32)
                pz = np.asarray(fs.pos_z, dtype=np.float32)
                future_states = np.stack([px, py, pz], axis=-1)   # [T, 3]
            else:
                future_states = np.empty((0,3), dtype=np.float32)

            yield images, future_states

## Setup pipeline

In [40]:
from nvidia.dali import pipeline_def, types, fn

def waymo_generator(dataset):
    for images_bufs, future_states in dataset:
        if len(images_bufs) != NUM_CAMS:
            raise ValueError(f"Expected {NUM_CAMS} images, but got {len(images_bufs)}")
        yield (*images_bufs, future_states)


@pipeline_def
def waymo_data_pipe(source):
    *image_bufs, future_states = fn.external_source(
        source=source,
        num_outputs=NUM_CAMS + 1,
        batch=False,
        device="cpu",
        no_copy=False
    )

    images = [fn.decoders.image(img_buf, device="mixed", output_type=types.RGB, use_fast_idct=True) for img_buf in image_bufs]

    return *images, future_states


In [41]:
import time
from tqdm import tqdm

dataset = WaymoDataset(DATA_ROOT, index)

pipe = waymo_data_pipe(
    source=waymo_generator(dataset),
    batch_size=BATCH_SIZE,
    num_threads=8,
    device_id=0,
    prefetch_queue_depth=3
)

pipe.build()

num_batches = len(index) // BATCH_SIZE

start = time.time()
for i in tqdm(range(num_batches)):
    try:
        outputs = pipe.run()
        del outputs
    except StopIteration:
        print("Stop Iteration!")
        break

end = time.time()
print(f"iterations per second: {num_batches / (end - start)}")

100%|██████████| 246/246 [00:19<00:00, 12.43it/s]

iterations per second: 12.432738168611408



