In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#common libs
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import math
from easydict import EasyDict as edict


In [3]:
#mtr modules
from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils

In [4]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/damon_log.txt", rank=0)
args = edict({
    "batch_size": 1,
    "workers": 32,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 5,
    "add_worker_init_fn": False,
    
})

In [5]:
#prepare data
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=False, workers=args.workers, logger=logger, training=False
)

2025-06-07 10:37:53,623   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl


2025-06-07 10:37:57,769   INFO  Total scenes before filters: 243401
2025-06-07 10:38:03,662   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-07 10:38:03,677   INFO  Total scenes after filters: 243401
2025-06-07 10:38:03,680   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_val_infos.pkl
2025-06-07 10:38:06,189   INFO  Total scenes before filters: 22089
2025-06-07 10:38:06,725   INFO  Total scenes after filter_info_by_object_type: 22089
2025-06-07 10:38:06,770   INFO  Total scenes after filters: 22089


In [None]:
from lstm.simple_lstm import MotionLSTM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionLSTM().to(device)

In [131]:
batch = next(iter(train_loader))

In [79]:
batch["input_dict"].keys()

dict_keys(['scenario_id', 'obj_trajs', 'obj_trajs_mask', 'track_index_to_predict', 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_types', 'obj_ids', 'center_objects_world', 'center_objects_id', 'center_objects_type', 'obj_trajs_future_state', 'obj_trajs_future_mask', 'center_gt_trajs', 'center_gt_trajs_mask', 'center_gt_final_valid_idx', 'center_gt_trajs_src', 'map_polylines', 'map_polylines_mask', 'map_polylines_center', 'static_map_polylines', 'static_map_polylines_mask'])

In [80]:
batch["input_dict"]["track_index_to_predict"]

tensor([17, 30, 26,  6, 16])

In [81]:
batch["input_dict"]["center_objects_id"]

array([ 50, 326,  95,  33,  49])

In [82]:
batch["input_dict"]["obj_ids"]

array([ 25,  26,  27,  28,  31,  32,  33,  34,  35,  37,  39,  40,  41,
        42,  44,  47,  49,  50,  51,  53,  54,  56,  58, 325,  90,  94,
        95, 102, 103, 104, 326])

In [123]:
input = batch["input_dict"]
obj_trajs = input["obj_trajs"]
obj_pos = input["obj_trajs_pos"]
obj_last_pos = input["obj_trajs_last_pos"]
obj_type = input["obj_types"] # car, bicycycle, pedestrian
obj_trajs_mask = input['obj_trajs_mask']
obj_of_interest = input['track_index_to_predict']

In [164]:
num_center_objects, num_objects, num_timestamps, num_attrs = obj_trajs.shape

In [153]:
model._print_batch(batch)

THIS IS THE PATH OF THE LSTM, change when needed
Key: scenario_id, Val: (5,)
Key: obj_trajs, Val: torch.Size([5, 55, 11, 29])
Key: obj_trajs_mask, Val: torch.Size([5, 55, 11])
Key: track_index_to_predict, Val: torch.Size([5])
Key: obj_trajs_pos, Val: torch.Size([5, 55, 11, 3])
Key: obj_trajs_last_pos, Val: torch.Size([5, 55, 3])
Key: obj_types, Val: (55,)
Key: obj_ids, Val: (55,)
Key: center_objects_world, Val: torch.Size([5, 10])
Key: center_objects_id, Val: (5,)
Key: center_objects_type, Val: (5,)
Key: obj_trajs_future_state, Val: torch.Size([5, 55, 80, 4])
Key: obj_trajs_future_mask, Val: torch.Size([5, 55, 80])
Key: center_gt_trajs, Val: torch.Size([5, 80, 4])
Key: center_gt_trajs_mask, Val: torch.Size([5, 80])
Key: center_gt_final_valid_idx, Val: torch.Size([5])
Key: center_gt_trajs_src, Val: torch.Size([5, 91, 10])
Key: map_polylines, Val: torch.Size([5, 768, 20, 9])
Key: map_polylines_mask, Val: torch.Size([5, 768, 20])
Key: map_polylines_center, Val: torch.Size([5, 768, 3])
Key

In [165]:
scores = model(batch)


In [166]:
scores

tensor([[[0.1650, 0.1676, 0.1594, 0.1689, 0.1698, 0.1693],
         [0.1648, 0.1694, 0.1579, 0.1702, 0.1691, 0.1687],
         [0.1640, 0.1673, 0.1589, 0.1692, 0.1700, 0.1706],
         [0.1638, 0.1678, 0.1586, 0.1688, 0.1694, 0.1717],
         [0.1635, 0.1684, 0.1570, 0.1694, 0.1706, 0.1711],
         [0.1645, 0.1671, 0.1587, 0.1687, 0.1713, 0.1697],
         [0.1650, 0.1676, 0.1585, 0.1679, 0.1703, 0.1706],
         [0.1635, 0.1691, 0.1580, 0.1687, 0.1708, 0.1700],
         [0.1640, 0.1682, 0.1573, 0.1704, 0.1692, 0.1709],
         [0.1648, 0.1666, 0.1586, 0.1697, 0.1706, 0.1696],
         [0.1634, 0.1690, 0.1571, 0.1697, 0.1704, 0.1703]],

        [[0.1631, 0.1693, 0.1592, 0.1698, 0.1697, 0.1688],
         [0.1641, 0.1681, 0.1592, 0.1691, 0.1709, 0.1685],
         [0.1639, 0.1679, 0.1598, 0.1696, 0.1691, 0.1697],
         [0.1640, 0.1680, 0.1591, 0.1700, 0.1699, 0.1689],
         [0.1648, 0.1671, 0.1603, 0.1698, 0.1694, 0.1686],
         [0.1639, 0.1677, 0.1586, 0.1694, 0.1704, 0.16