In [1]:
from mtr.datasets import build_dataloader

In [2]:
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils
import torch

In [3]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)

{'ROOT_DIR': PosixPath('/code/jjiang23/csc587/KimchiVision'),
 'LOCAL_RANK': 0,
 'DATA_CONFIG': {'DATASET': 'WaymoDataset',
  'OBJECT_TYPE': ['TYPE_VEHICLE', 'TYPE_PEDESTRIAN', 'TYPE_CYCLIST'],
  'DATA_ROOT': '/files/waymo/code/MTR/data/waymo',
  'SPLIT_DIR': {'train': 'processed_scenarios_training',
   'test': 'processed_scenarios_validation'},
  'INFO_FILE': {'train': 'processed_scenarios_training_infos.pkl',
   'test': 'processed_scenarios_val_infos.pkl'},
  'SAMPLE_INTERVAL': {'train': 1, 'test': 1},
  'INFO_FILTER_DICT': {'filter_info_by_object_type': ['TYPE_VEHICLE',
    'TYPE_PEDESTRIAN',
    'TYPE_CYCLIST']},
  'POINT_SAMPLED_INTERVAL': 1,
  'NUM_POINTS_EACH_POLYLINE': 20,
  'VECTOR_BREAK_DIST_THRESH': 1.0,
  'NUM_OF_SRC_POLYLINES': 768,
  'CENTER_OFFSET_OF_MAP': [30.0, 0]},
 'MODEL': {'CONTEXT_ENCODER': {'NAME': 'MTREncoder',
   'NUM_OF_ATTN_NEIGHBORS': 16,
   'NUM_INPUT_ATTR_AGENT': 29,
   'NUM_INPUT_ATTR_MAP': 9,
   'NUM_CHANNEL_IN_MLP_AGENT': 256,
   'NUM_CHANNEL_IN_MLP_MAP

In [4]:
logger = common_utils.create_logger("/files/waymo/log.txt", rank=0)


In [5]:
from easydict import EasyDict as edict
args = edict({
    "batch_size": 3,
    "workers": 64,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 10,
    "add_worker_init_fn": False,
})

In [6]:
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

2025-06-07 03:03:06,383   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-07 03:03:11,432   INFO  Total scenes before filters: 243401
2025-06-07 03:03:17,651   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-07 03:03:17,678   INFO  Total scenes after filters: 243401


In [7]:
sample = next(iter(train_loader))

In [8]:
sample.keys()

dict_keys(['batch_size', 'input_dict', 'batch_sample_count'])

In [9]:
input_dict = sample["input_dict"]

In [10]:
for key in input_dict.keys():
    if isinstance(input_dict[key], torch.Tensor):
        print(f"{key}: {input_dict[key].shape}")
    else:
        print(f"{key}: {type(input_dict[key])}")

scenario_id: <class 'numpy.ndarray'>
obj_trajs: torch.Size([11, 177, 11, 29])
obj_trajs_mask: torch.Size([11, 177, 11])
track_index_to_predict: torch.Size([11])
obj_trajs_pos: torch.Size([11, 177, 11, 3])
obj_trajs_last_pos: torch.Size([11, 177, 3])
obj_types: <class 'numpy.ndarray'>
obj_ids: <class 'numpy.ndarray'>
center_objects_world: torch.Size([11, 10])
center_objects_id: <class 'numpy.ndarray'>
center_objects_type: <class 'numpy.ndarray'>
obj_trajs_future_state: torch.Size([11, 177, 80, 4])
obj_trajs_future_mask: torch.Size([11, 177, 80])
center_gt_trajs: torch.Size([11, 80, 4])
center_gt_trajs_mask: torch.Size([11, 80])
center_gt_final_valid_idx: torch.Size([11])
center_gt_trajs_src: torch.Size([11, 91, 10])
map_polylines: torch.Size([11, 768, 20, 9])
map_polylines_mask: torch.Size([11, 768, 20])
map_polylines_center: torch.Size([11, 768, 3])
static_map_polylines: torch.Size([12000, 20, 7])
static_map_polylines_mask: torch.Size([12000, 20])


In [None]:
input_dict["static_map_polylines"].shape, input_dict["static_map_polylines_mask"].shape

In [None]:
input_dict["static_map_polylines"].shape, input_dict["static_map_polylines_mask"].shape

In [None]:
input_dict["map_polylines"].shape

In [None]:
input_dict["center_objects_world"].shape

In [None]:
input_dict["map_polylines_mask"].shape

In [None]:
input_dict["obj_trajs_mask"][0][0]

In [None]:
input_dict["obj_trajs"].shape

In [None]:
# (num_center_objects, num_topk_polylines, num_points_each_polyline, 9): [x, y, z, dir_x, dir_y, dir_z, global_type, pre_x, pre_y]
sample["input_dict"]['map_polylines'].shape

In [None]:
input_dict["map_polylines_center"].shape

In [None]:
input_dict['track_index_to_predict']

In [None]:
sample['batch_size']

In [None]:
sample["input_dict"]["scenario_id"]

In [None]:
sample["input_dict"]["center_objects_world"].shape

In [None]:
#(num_center_objects, num_objects, num_timestamps, num_attrs)
sample["input_dict"]["obj_trajs"].shape

In [None]:
obj_trajs = sample["input_dict"]["obj_trajs"]

In [None]:
#(num_center_objects, num_future_timestamps, 4): [x, y, vx, vy] ground truth trajectories
sample["input_dict"]['center_gt_trajs'].shape

In [None]:
sample["input_dict"]['center_gt_trajs_src'].shape

In [None]:
obj_trajs[0][0]

In [None]:
sample["input_dict"]["timestamps"]