In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#common libs
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import math
from easydict import EasyDict as edict


In [3]:
#mtr modules
from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils

In [4]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/damon_log.txt", rank=0)
args = edict({
    "batch_size": 1,
    "workers": 32,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 5,
    "add_worker_init_fn": False,
    
})

In [5]:
#prepare data
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=False, workers=args.workers, logger=logger, training=False
)

2025-06-07 10:37:53,623   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl


2025-06-07 10:37:57,769   INFO  Total scenes before filters: 243401
2025-06-07 10:38:03,662   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-07 10:38:03,677   INFO  Total scenes after filters: 243401
2025-06-07 10:38:03,680   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_val_infos.pkl
2025-06-07 10:38:06,189   INFO  Total scenes before filters: 22089
2025-06-07 10:38:06,725   INFO  Total scenes after filter_info_by_object_type: 22089
2025-06-07 10:38:06,770   INFO  Total scenes after filters: 22089


In [6]:
from lstm.lstm import MotionLSTM

model = MotionLSTM()

THIS IS THE PATH OF THE LSTM, change when needed


In [78]:
batch = next(iter(train_loader))

In [79]:
batch["input_dict"].keys()

dict_keys(['scenario_id', 'obj_trajs', 'obj_trajs_mask', 'track_index_to_predict', 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_types', 'obj_ids', 'center_objects_world', 'center_objects_id', 'center_objects_type', 'obj_trajs_future_state', 'obj_trajs_future_mask', 'center_gt_trajs', 'center_gt_trajs_mask', 'center_gt_final_valid_idx', 'center_gt_trajs_src', 'map_polylines', 'map_polylines_mask', 'map_polylines_center', 'static_map_polylines', 'static_map_polylines_mask'])

In [80]:
batch["input_dict"]["track_index_to_predict"]

tensor([17, 30, 26,  6, 16])

In [81]:
batch["input_dict"]["center_objects_id"]

array([ 50, 326,  95,  33,  49])

In [82]:
batch["input_dict"]["obj_ids"]

array([ 25,  26,  27,  28,  31,  32,  33,  34,  35,  37,  39,  40,  41,
        42,  44,  47,  49,  50,  51,  53,  54,  56,  58, 325,  90,  94,
        95, 102, 103, 104, 326])

In [89]:
input = batch["input_dict"]
obj_trajs = input["obj_trajs"]
obj_pos = input["obj_trajs_pos"]
obj_last_pos = input["obj_trajs_last_pos"]
obj_type = input["obj_types"] # car, bicycycle, pedestrian
obj_trajs_mask = input['obj_trajs_mask']
obj_of_interest = input['track_index_to_predict']

In [90]:
num_center_objects, num_objects, num_timestamps, num_attrs = obj_trajs.shape

In [None]:

obj_type = model.convert_type(obj_type)

In [92]:
from einops import rearrange, repeat

# expanding everything
obj_last_pos = repeat(obj_last_pos, "c o p -> c o timestamps p", timestamps=num_timestamps)
obj_type = repeat(obj_type, "type -> centers objects timestamps type", centers=num_center_objects, objects=num_objects, timestamps=num_timestamps)

In [94]:
torch.cat([obj_trajs, obj_pos, obj_last_pos, obj_type], dim=-1).shape

torch.Size([5, 31, 11, 66])

In [95]:
model._print_batch(batch)

Key: scenario_id, Val: (5,)
Key: obj_trajs, Val: torch.Size([5, 31, 11, 29])
Key: obj_trajs_mask, Val: torch.Size([5, 31, 11])
Key: track_index_to_predict, Val: torch.Size([5])
Key: obj_trajs_pos, Val: torch.Size([5, 31, 11, 3])
Key: obj_trajs_last_pos, Val: torch.Size([5, 31, 3])
Key: obj_types, Val: (31,)
Key: obj_ids, Val: (31,)
Key: center_objects_world, Val: torch.Size([5, 10])
Key: center_objects_id, Val: (5,)
Key: center_objects_type, Val: (5,)
Key: obj_trajs_future_state, Val: torch.Size([5, 31, 80, 4])
Key: obj_trajs_future_mask, Val: torch.Size([5, 31, 80])
Key: center_gt_trajs, Val: torch.Size([5, 80, 4])
Key: center_gt_trajs_mask, Val: torch.Size([5, 80])
Key: center_gt_final_valid_idx, Val: torch.Size([5])
Key: center_gt_trajs_src, Val: torch.Size([5, 91, 10])
Key: map_polylines, Val: torch.Size([5, 768, 20, 9])
Key: map_polylines_mask, Val: torch.Size([5, 768, 20])
Key: map_polylines_center, Val: torch.Size([5, 768, 3])
Key: static_map_polylines, Val: torch.Size([5, 4000,

In [9]:
model(batch)

{'batch_size': 1, 'input_dict': {'scenario_id': array(['8d96c7371cc64881'], dtype='<U16'), 'obj_trajs': tensor([[[[ 2.1549e+01, -4.8992e+01, -5.3284e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1549e+01, -4.8992e+01, -6.0562e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1549e+01, -4.8992e+01, -6.7688e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 2.1549e+01, -4.8992e+01, -6.0287e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1549e+01, -4.8992e+01, -6.5933e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1549e+01, -4.8992e+01, -6.5033e-02,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 2.1594e+01, -6.4921e+01,  1.3637e-01,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1594e+01, -6.4921e+01,  1.2909e-01,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 2.1594e+01,