In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#common libs
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import math
from easydict import EasyDict as edict


In [4]:
#mtr modules
from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils

In [5]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/damon_log.txt", rank=0)
args = edict({
    "batch_size": 1,
    "workers": 32,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 5,
    "add_worker_init_fn": False,
    
})

In [6]:
#prepare data
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)

test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=False, workers=args.workers, logger=logger, training=False
)

2025-06-08 14:42:29,976   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-08 14:42:36,327   INFO  Total scenes before filters: 243401
2025-06-08 14:42:43,140   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-08 14:42:43,150   INFO  Total scenes after filters: 243401
2025-06-08 14:42:43,152   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_val_infos.pkl
2025-06-08 14:42:44,973   INFO  Total scenes before filters: 22089
2025-06-08 14:42:45,617   INFO  Total scenes after filter_info_by_object_type: 22089
2025-06-08 14:42:45,622   INFO  Total scenes after filters: 22089


In [121]:
from lstm.simple_lstm import MotionLSTM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionLSTM().to(device)

In [79]:
batch = next(iter(train_loader))

In [18]:
batch["input_dict"].keys()

dict_keys(['scenario_id', 'obj_trajs', 'obj_trajs_mask', 'track_index_to_predict', 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_types', 'obj_ids', 'center_objects_world', 'center_objects_id', 'center_objects_type', 'obj_trajs_future_state', 'obj_trajs_future_mask', 'center_gt_trajs', 'center_gt_trajs_mask', 'center_gt_final_valid_idx', 'center_gt_trajs_src', 'map_polylines', 'map_polylines_mask', 'map_polylines_center', 'static_map_polylines', 'static_map_polylines_mask'])

In [19]:
input = batch["input_dict"]
obj_trajs = input["obj_trajs"]
obj_pos = input["obj_trajs_pos"]
obj_last_pos = input["obj_trajs_last_pos"]
obj_type = input["obj_types"] # car, bicycycle, pedestrian
obj_trajs_mask = input['obj_trajs_mask']
obj_of_interest = input['track_index_to_predict']

In [20]:
num_center_objects, num_objects, num_timestamps, num_attrs = obj_trajs.shape

In [12]:
static_map_polylines=input["static_map_polylines"].to('cuda')  # (batch_size, num_polylines, num_points_each_polyline, 7)
static_map_polylines_mask=input["static_map_polylines_mask"].to('cuda') # (batch_size, num_polylines, num_points_each_polyline)

In [66]:
model._print_batch(batch)

Key: scenario_id, Val: (1,)
Key: obj_trajs, Val: torch.Size([1, 41, 11, 29])
Key: obj_trajs_mask, Val: torch.Size([1, 41, 11])
Key: track_index_to_predict, Val: torch.Size([1])
Key: obj_trajs_pos, Val: torch.Size([1, 41, 11, 3])
Key: obj_trajs_last_pos, Val: torch.Size([1, 41, 3])
Key: obj_types, Val: (41,)
Key: obj_ids, Val: (41,)
Key: center_objects_world, Val: torch.Size([1, 10])
Key: center_objects_id, Val: (1,)
Key: center_objects_type, Val: (1,)
Key: obj_trajs_future_state, Val: torch.Size([1, 41, 80, 4])
Key: obj_trajs_future_mask, Val: torch.Size([1, 41, 80])
Key: center_gt_trajs, Val: torch.Size([1, 80, 4])
Key: center_gt_trajs_mask, Val: torch.Size([1, 80])
Key: center_gt_final_valid_idx, Val: torch.Size([1])
Key: center_gt_trajs_src, Val: torch.Size([1, 91, 10])
Key: map_polylines, Val: torch.Size([1, 742, 20, 9])
Key: map_polylines_mask, Val: torch.Size([1, 742, 20])
Key: map_polylines_center, Val: torch.Size([1, 742, 3])
Key: static_map_polylines, Val: torch.Size([1, 4000,

In [81]:
batch["input_dict"]["center_gt_final_valid_idx"]

tensor([79., 79., 79., 79.])

In [83]:
torch.nonzero(obj_trajs_mask[0, 0, :])[-1].squeeze()

tensor(10)

In [112]:
from lstm.loss import MotionLoss

criterion = MotionLoss()

In [126]:
pred_scores, pred_trajs = model(batch)

In [127]:
pred_scores.shape

torch.Size([4, 6])

In [132]:
loss

tensor(14007.6367, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
print(loss)

In [None]:
input['center_gt_trajs_mask'].shape

In [None]:
#define loss function
class MotionLoss(nn.Module):
    """
    Loss function for trajectory prediction
    """
    
    def __init__(self, 
                 regression_loss_weight=1.0,
                 classification_loss_weight=1.0,
                 future_loss_weight=1.0):
        super(MotionLoss, self).__init__()
        self.reg_weight = regression_loss_weight
        self.cls_weight = classification_loss_weight
        self.future_weight = future_loss_weight
    
    def forward(self, pred_scores, pred_trajs, batch_dict):
        """
        Compute loss
        
        Args:
            pred_scores: (batch_size, num_modes)
            pred_trajs: (batch_size, num_modes, future_steps, 4)
            batch_dict: Contains ground truth data
        
        Returns:
            loss_dict: Dictionary containing different loss components
        """
        center_gt_trajs = batch_dict['input_dict']['center_gt_trajs'].to('cuda')  # (batch_size, future_steps, 4)
        center_gt_trajs_mask = batch_dict['input_dict']['center_gt_trajs_mask'].to('cuda')  # (batch_size, future_steps)
        
        batch_size, num_modes, future_steps, _ = pred_trajs.shape
        
        # Compute trajectory regression loss for each mode
        gt_trajs_expanded = center_gt_trajs.unsqueeze(1).expand(-1, num_modes, -1, -1)
        gt_mask_expanded = center_gt_trajs_mask.unsqueeze(1).expand(-1, num_modes, -1)
        
        # L2 loss for position (x, y)
        pos_loss = F.mse_loss(
            pred_trajs[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # L2 loss for velocity (vx, vy)
        vel_loss = F.mse_loss(
            pred_trajs[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # Weighted loss over time (give more weight to near future)
        time_weights = torch.exp(-0.1 * torch.arange(future_steps, device=pred_trajs.device))
        time_weights = time_weights.view(1, 1, -1)
        
        pos_loss = (pos_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        vel_loss = (vel_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        
        # Find best mode for each sample
        total_traj_loss = pos_loss + vel_loss  # (batch_size, num_modes)
        best_mode_indices = torch.argmin(total_traj_loss, dim=1)  # (batch_size,)
        
        # Regression loss (best mode)
        best_pos_loss = pos_loss[torch.arange(batch_size), best_mode_indices].mean()
        best_vel_loss = vel_loss[torch.arange(batch_size), best_mode_indices].mean()
        regression_loss = best_pos_loss + best_vel_loss
        
        # Classification loss (encourage higher confidence for best mode)
        target_scores = torch.zeros_like(pred_scores)
        target_scores[torch.arange(batch_size), best_mode_indices] = 1.0
        classification_loss = F.cross_entropy(pred_scores, target_scores)
        
        # Total loss
        total_loss = (self.reg_weight * regression_loss + 
                     self.cls_weight * classification_loss)
        
        loss_dict = {
            'total_loss': total_loss,
            'regression_loss': regression_loss,
            'classification_loss': classification_loss,
            'pos_loss': best_pos_loss,
            'vel_loss': best_vel_loss
        }
        
        return loss_dict


In [None]:
from lstm.loss import MotionLoss

#train loop
def train_model(model, train_dataloader, val_dataloader, num_epochs=100, lr=1e-3):
    """
    Training loop for the LSTM model
    """
    assert torch.cuda.is_available(), "CUDA is not available. Please check your PyTorch installation."
    device = torch.device('cuda')
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = MotionLoss()
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = []
        
        for batch_idx, batch_dict in enumerate(train_dataloader):
            # Move data to device
            for key, value in batch_dict.items():
                if isinstance(value, torch.Tensor):
                    batch_dict[key] = value.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            pred_scores, pred_trajs = model(batch_dict)
            
            # Compute loss
            loss = criterion(pred_scores, pred_trajs, batch_dict)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_losses.append(loss.item())
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch_dict in val_dataloader:
                # Move data to device
                for key, value in batch_dict.items():
                    if isinstance(value, torch.Tensor):
                        batch_dict[key] = value.to(device)
                
                pred_scores, pred_trajs = model(batch_dict)
                loss_dict = criterion(pred_scores, pred_trajs, batch_dict)
                val_losses.append(loss_dict['total_loss'].item())
        
        scheduler.step()
        
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/code/jjiang23/csc587/KimchiVision/best_motion_lstm.pth')
    
    return model

In [139]:
from lstm.train_util import train_model
from lstm.simple_lstm import MotionLSTM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionLSTM().to(device)

# Train the model
trained_model = train_model(model, train_loader, test_loader)

TypeError: super(type, obj): obj must be an instance or subtype of type