In [1]:
from mtr.datasets import build_dataloader

In [2]:
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils
import torch

In [3]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)

{'ROOT_DIR': PosixPath('/code/jjiang23/csc587/KimchiVision'),
 'LOCAL_RANK': 0,
 'DATA_CONFIG': {'DATASET': 'WaymoDataset',
  'OBJECT_TYPE': ['TYPE_VEHICLE', 'TYPE_PEDESTRIAN', 'TYPE_CYCLIST'],
  'DATA_ROOT': '/files/waymo/code/MTR/data/waymo',
  'SPLIT_DIR': {'train': 'processed_scenarios_training',
   'test': 'processed_scenarios_validation'},
  'INFO_FILE': {'train': 'processed_scenarios_training_infos.pkl',
   'test': 'processed_scenarios_val_infos.pkl'},
  'SAMPLE_INTERVAL': {'train': 1, 'test': 1},
  'INFO_FILTER_DICT': {'filter_info_by_object_type': ['TYPE_VEHICLE',
    'TYPE_PEDESTRIAN',
    'TYPE_CYCLIST']},
  'POINT_SAMPLED_INTERVAL': 1,
  'NUM_POINTS_EACH_POLYLINE': 20,
  'VECTOR_BREAK_DIST_THRESH': 1.0,
  'NUM_OF_SRC_POLYLINES': 768,
  'CENTER_OFFSET_OF_MAP': [30.0, 0]},
 'MODEL': {'CONTEXT_ENCODER': {'NAME': 'MTREncoder',
   'NUM_OF_ATTN_NEIGHBORS': 16,
   'NUM_INPUT_ATTR_AGENT': 29,
   'NUM_INPUT_ATTR_MAP': 9,
   'NUM_CHANNEL_IN_MLP_AGENT': 256,
   'NUM_CHANNEL_IN_MLP_MAP

In [4]:
logger = common_utils.create_logger("/files/waymo/log.txt", rank=0)


In [5]:
from easydict import EasyDict as edict
args = edict({
    "batch_size": 2,
    "workers": 4,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 10,
    "add_worker_init_fn": False,
})

In [6]:
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

2025-06-05 19:19:06,535   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-05 19:19:12,252   INFO  Total scenes before filters: 243401
2025-06-05 19:19:18,649   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-05 19:19:18,677   INFO  Total scenes after filters: 243401


In [7]:
sample = next(iter(train_loader))

In [8]:
sample.keys()

dict_keys(['batch_size', 'input_dict', 'batch_sample_count'])

In [9]:
input_dict = sample["input_dict"]

In [None]:
# (num_obj, timesteps, attributes) [x, y, z, dir_x, dir_y, dir_z, global_type, pre_x, pre_y]
# input_dict['obj_trajs_full'].shape

KeyError: 'obj_trajs_full'

In [11]:
sample["input_dict"].keys()

dict_keys(['scenario_id', 'obj_trajs', 'obj_trajs_mask', 'track_index_to_predict', 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_types', 'obj_ids', 'center_objects_world', 'center_objects_id', 'center_objects_type', 'obj_trajs_future_state', 'obj_trajs_future_mask', 'center_gt_trajs', 'center_gt_trajs_mask', 'center_gt_final_valid_idx', 'center_gt_trajs_src', 'map_polylines', 'map_polylines_center'])

In [12]:
input_dict["obj_trajs_mask"].shape

torch.Size([8, 92, 11])

In [14]:
input_dict["obj_trajs_mask"][0][0]

tensor([True, True, True, True, True, True, True, True, True, True, True])

In [13]:
input_dict["obj_trajs"].shape

torch.Size([8, 92, 11, 29])

In [13]:
# (num_center_objects, num_topk_polylines, num_points_each_polyline, 9): [x, y, z, dir_x, dir_y, dir_z, global_type, pre_x, pre_y]
sample["input_dict"]['map_polylines'].shape

torch.Size([8, 768, 20, 9])

In [24]:
input_dict["map_polylines_center"].shape

torch.Size([8, 768, 3])

In [23]:
input_dict['track_index_to_predict']

tensor([99,  4, 46, 40, 28,  3,  6, 69])

In [14]:
sample['batch_size']

1

In [22]:
sample["input_dict"]["scenario_id"]

array(['29a0bc9c441e35d8', '29a0bc9c441e35d8', '4ea07205c8472eb5',
       '4ea07205c8472eb5', '4ea07205c8472eb5', '4ea07205c8472eb5',
       '4ea07205c8472eb5'], dtype='<U16')

In [16]:
sample["input_dict"]["center_objects_world"].shape

torch.Size([3, 10])

In [27]:
#(num_center_objects, num_objects, num_timestamps, num_attrs)
sample["input_dict"]["obj_trajs"].shape

torch.Size([7, 68, 11, 29])

In [21]:
obj_trajs = sample["input_dict"]["obj_trajs"]

In [None]:
#(num_center_objects, num_future_timestamps, 4): [x, y, vx, vy] ground truth trajectories
sample["input_dict"]['center_gt_trajs'].shape

torch.Size([3, 80, 4])

In [27]:
sample["input_dict"]['center_gt_trajs_src'].shape

torch.Size([3, 91, 10])

In [30]:
obj_trajs[0][0]

tensor([[-2.6953e+01, -6.1536e+00, -4.0184e-02,  4.9420e+00,  2.2433e+00,
          1.8850e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  2.5558e-02,  9.9967e-01,
          6.0192e+00,  1.5370e-01, -2.8333e-01, -1.0392e+00],
        [-2.6354e+01, -6.1486e+00, -4.4090e-02,  4.9374e+00,  2.2427e+00,
          1.8797e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0003e-01,  2.1916e-02,  9.9976e-01,
          5.9909e+00,  4.9778e-02, -2.8333e-01, -1.0392e+00],
        [-2.5744e+01, -6.1284e+00, -4.0466e-02,  4.9259e+00,  2.2430e+00,
          1.8701e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  

In [26]:
sample["input_dict"]["timestamps"]

tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000,
        0.9000, 1.0000, 1.1001, 1.2001, 1.3001, 1.4001, 1.5002, 1.6002, 1.7003,
        1.8003, 1.9003, 2.0004, 2.1004, 2.2004, 2.3004, 2.4005, 2.5004, 2.6005,
        2.7005, 2.8005, 2.9005, 3.0005, 3.1005, 3.2005, 3.3005, 3.4005, 3.5005,
        3.6005, 3.7005, 3.8005, 3.9005, 4.0005, 4.1005, 4.2005, 4.3005, 4.4005,
        4.5005, 4.6006, 4.7006, 4.8006, 4.9006, 5.0006, 5.1006, 5.2006, 5.3006,
        5.4006, 5.5006, 5.6007, 5.7007, 5.8007, 5.9007, 6.0007, 6.1007, 6.2007,
        6.3007, 6.4007, 6.5007, 6.6006, 6.7006, 6.8005, 6.9004, 7.0003, 7.1001,
        7.2000, 7.2998, 7.3995, 7.4993, 7.5991, 7.6988, 7.7985, 7.8983, 7.9980,
        8.0977, 8.1974, 8.2970, 8.3966, 8.4961, 8.5957, 8.6952, 8.7947, 8.8942,
        8.9936], dtype=torch.float64)