In [29]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from torchinfo import summary
import pickle

In [4]:
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils
import numpy as np
from mtr.datasets.dataset import DatasetTemplate
from mtr.datasets import build_dataloader

In [6]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/log.txt", rank=0)

from easydict import EasyDict as edict
args = edict({
    "batch_size": 2,
    "workers": 4,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 10,
    "add_worker_init_fn": False,
})
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

2025-06-05 19:51:05,455   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-05 19:51:05,455   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-05 19:51:05,455   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-05 19:51:09,999   INFO  Total scenes before filters: 243401
2025-06-05 19:51:09,999   INFO  Total scenes before filters: 243401
2025-06-05 19:51:09,999   INFO  Total scenes before filters: 243401
2025-06-05 19:51:16,394   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-05 19:51:16,394   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-05 19:51:16,394   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-05 19:51:16,408   INFO  Total scenes after filters: 243401
2025-06-05 19:51:16,408   INFO  Total scenes after filters: 243401
2025-06-05 19:51:

In [9]:
sample =next(iter(train_loader))

In [11]:
sample['input_dict'].keys()

dict_keys(['scenario_id', 'obj_trajs', 'obj_trajs_mask', 'track_index_to_predict', 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_types', 'obj_ids', 'center_objects_world', 'center_objects_id', 'center_objects_type', 'obj_trajs_future_state', 'obj_trajs_future_mask', 'center_gt_trajs', 'center_gt_trajs_mask', 'center_gt_final_valid_idx', 'center_gt_trajs_src', 'map_polylines', 'map_polylines_center'])

In [13]:
input_dict = sample['input_dict']
obj_trajs = input_dict['obj_trajs']
obj_trajs.shape

torch.Size([5, 40, 11, 29])

In [None]:
#inputs i want to use

In [None]:
import torch.nn as nn
class WemoLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, future_steps=80,dropout=0.5):
        super(WemoLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        self.obj_feature_encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size)
        )
        self.obj_feature_decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.ReLU(),
            nn.Dropout(dropout)
            nn.Linear(hidden_dim, future_steps * 4)  # x, y, vx, vy
        )

    def forward(self, x):
        # encode object trajectories for LSTM input
        batch_size, num_objects, num_timestamps, input_dim = x.shape

        obj_features = self.obj_feature_encoder(x.view(-1, input_dim))

        out, _ = self.lstm(obj_features)

        out = self.obj_feature_decoder(out)
        # Reshape output to match the input shape

        out = out.view(batch_size, num_objects, future_steps, input_dim)

        return out

def train_model(model, train_loader, num_epochs=10):
    #output_dim = 4  # x, y, vx, vy vs future_steps = 80 x y vx, vy
    model.train()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(num_epochs):
        for i, sample in enumerate(train_loader):
            input_dict = sample['input_dict']
            obj_trajs = input_dict['obj_trajs'].cuda()  # Move to GPU if available

            optimizer.zero_grad()
            outputs = model(obj_trajs)

            # Assuming the target is the same as the input for simplicity
            targets = obj_trajs[:, :, -80:, :4].cuda()  # Last 80 timesteps as target

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')



In [36]:
lstm_model = WemoLSTM(input_size=29, hidden_size=256, num_layers=2)
lstm_model = lstm_model.cuda()


In [41]:
summary(lstm_model, input_size=(32, 50, 10, 29))  #batch_size, num_objects, num_timestamps, input_dim

Layer (type:depth-idx)                   Output Shape              Param #
WemoLSTM                                 [32, 50, 10, 29]          --
├─Sequential: 1-1                        [16000, 256]              --
│    └─Linear: 2-1                       [16000, 256]              7,680
│    └─ReLU: 2-2                         [16000, 256]              --
│    └─Dropout: 2-3                      [16000, 256]              --
│    └─Linear: 2-4                       [16000, 256]              65,792
├─LSTM: 1-2                              [16000, 256]              1,052,672
├─Sequential: 1-3                        [16000, 29]               --
│    └─Linear: 2-5                       [16000, 29]               7,453
│    └─ReLU: 2-6                         [16000, 29]               --
│    └─Dropout: 2-7                      [16000, 29]               --
Total params: 1,133,597
Trainable params: 1,133,597
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 4.31
Input size (MB): 1.86


In [42]:
train_model(lstm_model, train_loader, num_epochs=10)
# Save the model
torch.save(lstm_model.state_dict(), '/code/jjiang23/csc587/KimchiVision/lstm_model.pth')

Epoch [1/10], Step [10/121701], Loss: 61.3961
Epoch [1/10], Step [20/121701], Loss: 91.8938
Epoch [1/10], Step [30/121701], Loss: 75.7614
Epoch [1/10], Step [40/121701], Loss: 58.2785
Epoch [1/10], Step [50/121701], Loss: 98.7127
Epoch [1/10], Step [60/121701], Loss: 58.7620
Epoch [1/10], Step [70/121701], Loss: 62.2473
Epoch [1/10], Step [80/121701], Loss: 78.8073
Epoch [1/10], Step [90/121701], Loss: 91.1985


KeyboardInterrupt: 