In [1]:
from typing import Tuple
import matplotlib.pyplot as plt
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
import os
import math
import numpy as np
import cv2
from waymo_open_dataset import dataset_pb2 as open_dataset
from waymo_open_dataset.wdl_limited.camera.ops import py_camera_model_ops

from waymo_open_dataset.protos import end_to_end_driving_data_pb2 as wod_e2ed_pb2
from waymo_open_dataset.protos import end_to_end_driving_submission_pb2 as wod_e2ed_submission_pb2
# Replace this path with your own tfrecords.
# This tutorial is based on using data in the E2E Driving proto format directly,
# so choose the correct dataset version.
DATASET_FOLDER = '/home/hansung/end2end_ad/datasets/waymo_open_dataset_end_to_end_camera_v_1_0_0' #Raw data tfrecords directory. Modify
OUTPUT_DIR = "/home/hansung/OpenEMMA/waymo_dataset" #Modify

TRAIN_FILES = os.path.join(DATASET_FOLDER, "training_*.tfrecord-*")
VALIDATION_FILES = os.path.join(DATASET_FOLDER,"val_*.tfrecord-*")
TEST_FILES = os.path.join(DATASET_FOLDER, "test_*.tfrecord-*")
dataset_mode = 'val' #['val','testing']
if dataset_mode == 'val':
    filenames = tf.io.matching_files(VALIDATION_FILES)
else:
    filenames = tf.io.matching_files(TEST_FILES)
dataset = tf.data.TFRecordDataset(filenames, compression_type='')
dataset_iter = dataset.as_numpy_iterator()



2025-12-23 18:47:03.292851: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-23 18:47:03.318009: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-23 18:47:04.565450: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-12-23 18

In [2]:
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from einops import rearrange
import cv2
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

def yaw_to_rotmat_z(yaw: torch.Tensor) -> torch.Tensor:
    """yaw: (T,) -> R: (T,3,3)"""
    yaw = yaw.reshape(-1)
    c = torch.cos(yaw)
    s = torch.sin(yaw)
    R = torch.zeros((yaw.numel(), 3, 3), dtype=torch.float32, device=yaw.device)
    R[:, 0, 0] = c
    R[:, 0, 1] = -s
    R[:, 1, 0] = s
    R[:, 1, 1] = c
    R[:, 2, 2] = 1.0
    return R

def _states_to_world_xyz_and_yaw(states_obj: Any) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns:
      p_world: (T,3) float32
      yaw:     (T,) float32   (estimated if missing)
    """
    xs = getattr(states_obj, "pos_x", None)
    ys = getattr(states_obj, "pos_y", None)
    zs = getattr(states_obj, "pos_z", None)

    if xs is None or ys is None:
        return torch.zeros((0, 3), dtype=torch.float32), torch.zeros((0,), dtype=torch.float32)

    x = torch.as_tensor(list(xs), dtype=torch.float32)
    y = torch.as_tensor(list(ys), dtype=torch.float32)
    T = int(x.numel())
    if T == 0:
        return torch.zeros((0, 3), dtype=torch.float32), torch.zeros((0,), dtype=torch.float32)

    if zs is None or len(zs) == 0:
        z = torch.zeros((T,), dtype=torch.float32)
    else:
        z = torch.as_tensor(list(zs), dtype=torch.float32)
        if z.numel() != T:
            z = torch.zeros((T,), dtype=torch.float32)

    p_world = torch.stack([x, y, z], dim=-1)  # (T,3)

    # Prefer provided yaw/heading if it exists
    yaw = None
    if hasattr(states_obj, "yaw"):
        yaw = torch.as_tensor(list(states_obj.yaw), dtype=torch.float32)
    elif hasattr(states_obj, "heading"):
        yaw = torch.as_tensor(list(states_obj.heading), dtype=torch.float32)

    # Otherwise estimate yaw from dx,dy
    if yaw is None or yaw.numel() != T:
        dx = torch.diff(x, prepend=x[:1])
        dy = torch.diff(y, prepend=y[:1])
        speed2 = dx * dx + dy * dy
        yaw = torch.atan2(dy, dx)
        yaw = torch.where(speed2 < 1e-6, torch.zeros_like(yaw), yaw)
        yaw[0] = yaw[1] if T > 1 else 0.0

    return p_world, yaw

def _world_to_local(
    p_world: torch.Tensor,
    yaw_world: torch.Tensor,
    p0: torch.Tensor,
    yaw0: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Convert world positions/orientations to local frame at (p0,yaw0).

    Returns:
      xyz_local: (T,3)
      rot_local: (T,3,3) where rot_local[t] = R0^T * R_world[t]
    """
    R_world = yaw_to_rotmat_z(yaw_world)  # (T,3,3)
    R0 = yaw_to_rotmat_z(yaw0[None])[0]   # (3,3)
    R0_inv = R0.t()

    # xyz_local = R0^T (p - p0)
    xyz_local = torch.einsum("ij,tj->ti", R0_inv, (p_world - p0))

    # rot_local = R0^T R_world
    rot_local = torch.einsum("ij,tjk->tik", R0_inv, R_world)
    return xyz_local, rot_local

INTENT = {
    0: "UNKNOWN",
    1: "GO_STRAIGHT",
    2: "GO_LEFT",
    3: "GO_RIGHT",
}

def load_waymo_e2e_data(
    e2e_data: Any,
    time_step: float = 0.25,
    camera_features: Optional[List[Union[str, int]]] = None,
    resize_mode: str = "min_hw",                 # "min_hw" or "fixed"
    fixed_hw: Tuple[int, int] = (576, 1024),     # used if resize_mode=="fixed"
) -> Dict[str, Any]:

    # --------------------
    # States
    # --------------------
    past = e2e_data.past_states
    future = e2e_data.future_states
    e2e_frame = e2e_data.frame
    # Build world tensors
    p_past_w, yaw_past_w = _states_to_world_xyz_and_yaw(past)
    p_fut_w,  yaw_fut_w  = _states_to_world_xyz_and_yaw(future)

    Th = p_past_w.shape[0]
    Tf = p_fut_w.shape[0]

    # Choose t0 as last history state
    if Th == 0:
        raise RuntimeError("past_states is empty; cannot define t0.")

    p0 = p_past_w[-1]          # (3,)
    yaw0 = yaw_past_w[-1]      # scalar tensor

    # Localize both history and future w.r.t. SAME (p0,yaw0)
    p_past_l, R_past_l = _world_to_local(p_past_w, yaw_past_w, p0, yaw0)
    p_fut_l,  R_fut_l  = _world_to_local(p_fut_w,  yaw_fut_w,  p0, yaw0)

    ego_history_xyz = np.stack([e2e_data.past_states.pos_x, e2e_data.past_states.pos_y, [0]*len(e2e_data.past_states.pos_x)], axis=1)
    ego_history_rot = R_past_l.unsqueeze(0).unsqueeze(0)   # (1,1,Th,3,3)
    ego_future_xyz  = np.stack([e2e_data.future_states.pos_x, e2e_data.future_states.pos_y, e2e_data.future_states.pos_z], axis=1)
    ego_future_rot  = R_fut_l.unsqueeze(0).unsqueeze(0)    # (1,1,Tf,3,3)

    # --------------------
    # Timestamps (synthetic, relative to t0)
    # (Waymo E2E frame doesn't always contain full per-step timestamps for states.)
    # --------------------
    t0_us = int(getattr(e2e_frame, "timestamp_micros", 0) or 0)
    if t0_us == 0:
        # fallback: use camera trigger time of first camera if present
        try:
            t0_us = int(e2e_frame.images[0].camera_trigger_time)
        except Exception:
            t0_us = 0

    history_timestamps = (t0_us + np.arange(-(Th - 1), 1, 1) * int(time_step * 1e6)).astype(np.int64)
    future_timestamps  = (t0_us + np.arange(1, Tf + 1, 1) * int(time_step * 1e6)).astype(np.int64)

    # --------------------
    # Images
    # --------------------
    # Decide which cameras to keep
    # camera_features can be:
    #   None -> keep all cameras in frame
    #   ['ALL'] -> keep all
    #   list of camera name strings or ints matching camera.name
    imgs = list(e2e_frame.images)

    if camera_features is not None and camera_features != ["ALL"]:
        keep = set(camera_features)
        imgs = [im for im in imgs if (im.name in keep) or (str(im.name) in keep)]

    # Sort cameras deterministically by camera enum id
    imgs = sorted(imgs, key=lambda im: int(im.name))

    # Decode once; optionally compute min_hw
    decoded = []
    hws = []
    for im in imgs:
        arr = tf.io.decode_image(im.image, channels=3).numpy()  # (H,W,3) uint8
        decoded.append((im, arr))
        hws.append(arr.shape[:2])

    if len(decoded) == 0:
        raise RuntimeError("No decodable images in e2e_frame.images")

    if resize_mode == "min_hw":
        target_h = min(h for h, w in hws)
        target_w = min(w for h, w in hws)
    elif resize_mode == "fixed":
        target_h, target_w = fixed_hw
    else:
        raise ValueError(f"Unknown resize_mode: {resize_mode}")

    image_frames_list = []
    camera_indices_list = []
    timestamps_list = []

    for im, arr in decoded:
        h, w = arr.shape[:2]
        if (h != target_h) or (w != target_w):
            arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_AREA)

        # (num_frames=1, 3, H, W)
        frames_tensor = torch.from_numpy(arr).unsqueeze(0)  # (1,H,W,3)
        frames_tensor = rearrange(frames_tensor, "t h w c -> t c h w").contiguous()

        cam_idx = int(im.name)
        cam_ts_us = int(getattr(im, "camera_trigger_time", t0_us) or t0_us)

        image_frames_list.append(frames_tensor)
        camera_indices_list.append(cam_idx)
        timestamps_list.append(torch.tensor([cam_ts_us], dtype=torch.int64))

    image_frames = torch.stack(image_frames_list, dim=0)               # (Ncam, 1, 3, H, W)
    camera_indices = torch.tensor(camera_indices_list, dtype=torch.int64)
    absolute_timestamps = torch.stack(timestamps_list, dim=0)          # (Ncam, 1)
    camera_tmin = absolute_timestamps.min()
    relative_timestamps = (absolute_timestamps - camera_tmin).float() * 1e-6

    # Camera name map (Waymo)
    camera_ind_name = {
        1: "FRONT",
        2: "FRONT_LEFT",
        3: "FRONT_RIGHT",
        4: "SIDE_LEFT",
        5: "SIDE_RIGHT",
        6: "REAR",
        7: "REAR_LEFT",
        8: "REAR_RIGHT",
    }

    camera_calibrations = e2e_data.frame.context.camera_calibrations
    camera_intrinsics = []
    camera_extrinsics = []
    for cam_calibration in camera_calibrations:
        intr = np.array(cam_calibration.intrinsic,dtype=np.float64)

        fu, fv, cu, cv = intr[:4]
        K = np.array([
            [fu, 0., cu],
            [0., fv, cv],
            [0., 0., 1.]
        ],dtype=np.float64)

        dist = np.array([intr[4], intr[5], intr[6], intr[7], intr[8]], dtype=np.float64)
        camera_intrinsics.append((K,dist))
        T = np.array(cam_calibration.extrinsic.transform).reshape(4,4)
        camera_extrinsics.append(T)

    clip_id = getattr(getattr(e2e_frame, "context", None), "name", "")
    ego_intent = INTENT[e2e_data.intent]
    vehicle_pose = np.array(e2e_frame.images[0].pose.transform).reshape(4, 4)
    

    speed = np.sqrt((e2e_frame.images[0].velocity.v_x)**2 + (e2e_frame.images[0].velocity.v_y)**2 + (e2e_frame.images[0].velocity.v_z)**2)

    return {
        "image_frames": image_frames,                   # (Ncam,1,3,H,W) uint8
        "ego_history_xyz": ego_history_xyz,             # (1,1,Th,3) local
        "ego_future_xyz": ego_future_xyz,               # (1,1,Tf,3) local
        "ego_intent": ego_intent,
        "camera_intrinsic": camera_intrinsics,
        "camera_extrinsic": camera_extrinsics,
        "vehicle_pose": vehicle_pose,
        "speed": speed
    }


In [None]:
#save data_dict
import pickle
import gzip
import tqdm
def save_dict_to_pickle(d: dict, path: str) -> None:
    # Use highest protocol for speed/size; write in binary mode
    with open(path, "wb") as f:
        pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
save_path = OUTPUT_DIR + '/' + dataset_mode + '/'
ctr = 0
threshold = 0 #Since the dataset is large, you can choose which index you want to threshold when preprocessing data
for bytes in tqdm.tqdm(dataset_iter):
    
    data = wod_e2ed_pb2.E2EDFrame()
    data.ParseFromString(bytes)

    i = int(data.frame.context.name.split('-')[1])
    if i >= threshold:
        data_dict = load_waymo_e2e_data(data,0.25,camera_features=['ALL'])
        filename = data.frame.context.name + '.pkl'
        save_dict_to_pickle(data_dict, save_path+filename)


106360it [44:40, 39.68it/s]
