In [2]:
#common libs
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import math
from easydict import EasyDict as edict


In [3]:
#mtr modules
from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_yaml_file
from mtr.utils import common_utils

In [4]:
import os
print(os.cpu_count())

64


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MotionLSTM(nn.Module):
    '''
    Input: 
    - obj_trajs (num_center_objects(batch_size), num_objects, num_timestamps, num_attrs)
    - obj_trajs_mask (num_center_objects(batch_size), num_objects, num_timestamps)
    - map_polylines (num_center_objects(batch_size),num_polylines, num_points_each_polyline, 7)
    - map_polylines_mask (num_center_objects(batch_size),num_polylines(4000), num_points_each_polyline(20))
    - track index (num_center_objects(batch_size), )
    '''
    def __init__(self, 
                input_dim=29,  # Based on MTR dataset obj_trajs feature dimension
                # Map polylines encoder parameters
                map_polyline_encoder_output_dim=256,  # Hidden dimension for the map polyline encoder
                map_polyline_encoder_hidden_dim=512,  # Hidden dimension for the map polyline encoder
                # Encoder parameters for object trajectories
                encoder_hidden_dim=256,
                encoder_output_dim=256,  # Output dimension of the encoder
                # LSTM parameters
                lstm_hidden_dim=256,
                lstm_num_layers=2,
                # Mode predictor parameters
                mode_predictor_hidden_dim=256,
                # Trajectory decoder parameters
                trajectory_decoder_hidden_dim=256,
                num_modes=6,  # Number of prediction modes
                future_steps=80,  # Number of future timesteps to predict
                dropout=0.1):
        super(MotionLSTM, self).__init__()

        self.input_dim = input_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.encoder_output_dim = encoder_output_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.lstm_num_layers = lstm_num_layers
        self.num_modes = num_modes
        self.future_steps = future_steps
        self.dropout = dropout
        self.map_polyline_encoder_output_dim = map_polyline_encoder_output_dim
        
        # Map polylines encoder - will be initialized dynamically
        self.map_polyline_encoder = None
        self.map_polyline_encoder_hidden_dim = map_polyline_encoder_hidden_dim

        # Feature encoder for input trajectories
        self.feature_encoder = nn.Sequential(
            nn.Linear(input_dim, encoder_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(encoder_hidden_dim, encoder_output_dim)
        )

        # Fusion layer for feature encoder with map polylines
        self.fusion_layer = nn.Sequential(
            nn.Linear(encoder_output_dim + map_polyline_encoder_output_dim, encoder_output_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=encoder_output_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            dropout=dropout if lstm_num_layers > 1 else 0,
            bidirectional=False
        )
        
        # Multi-modal prediction heads
        self.mode_predictor = nn.Sequential(
            nn.Linear(lstm_hidden_dim, mode_predictor_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mode_predictor_hidden_dim, num_modes)
        )
        
        # Trajectory decoder for each mode
        self.traj_decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(lstm_hidden_dim, trajectory_decoder_hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(trajectory_decoder_hidden_dim, trajectory_decoder_hidden_dim),
                nn.ReLU(),
                nn.Linear(trajectory_decoder_hidden_dim, future_steps * 4)  # x, y, vx, vy for each timestep
            ) for _ in range(num_modes)
        ])
        
        # Attention mechanism for object interactions
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden_dim,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )
        
        self._init_weights()
    
    def _init_map_encoder(self, input_size):
        """Initialize map encoder with correct input size"""
        if self.map_polyline_encoder is None:
            self.map_polyline_encoder = nn.Sequential(
                nn.Linear(input_size, self.map_polyline_encoder_hidden_dim),  
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.map_polyline_encoder_hidden_dim, self.map_polyline_encoder_output_dim)
            )
            # Move to same device as other parameters
            device = next(self.parameters()).device
            self.map_polyline_encoder = self.map_polyline_encoder.to(device)
            # Initialize weights for the new layers
            for module in self.map_polyline_encoder:
                if isinstance(module, nn.Linear):
                    nn.init.xavier_uniform_(module.weight)
                    nn.init.constant_(module.bias, 0)
    
    def _init_weights(self):
        """Initialize model weights"""
        for name, param in self.named_parameters():
            if 'weight' in name and len(param.shape) >= 2:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
        
    def forward(self, batch_dict):
        """
        Forward pass of the model
        
        Args:
            batch_dict: Dictionary containing:
                - obj_trajs: (batch_size, num_objects, num_timestamps, input_dim)
                - obj_trajs_mask: (batch_size, num_objects, num_timestamps)
                - track_index_to_predict: (batch_size,) indices of center objects
                - static_map_polylines: (batch_size, num_polylines, num_points_each_polyline, 7)
                - static_map_polylines_mask: (batch_size, num_polylines, num_points_each_polyline)
        
        Returns:
            pred_scores: (batch_size, num_modes) - confidence scores for each mode
            pred_trajs: (batch_size, num_modes, future_steps, 4) - predicted trajectories
        """
        input_dict = batch_dict["input_dict"]
        
        # Get device from model parameters
        device = next(self.parameters()).device
        
        obj_trajs = input_dict['obj_trajs'].to(device)  # (batch_size, num_objects, num_timestamps, input_dim)
        obj_trajs_mask = input_dict['obj_trajs_mask'].to(device)  # (batch_size, num_objects, num_timestamps)
        track_indices = input_dict['track_index_to_predict'].to(device)  # (batch_size,)
        
        static_map_polylines = input_dict["static_map_polylines"].to(device)  # (batch_size, num_polylines, num_points_each_polyline, 7)
        static_map_polylines_mask = input_dict["static_map_polylines_mask"].to(device)  # (batch_size, num_polylines, num_points_each_polyline)
        
        batch_size, num_objects, num_timestamps, input_dim = obj_trajs.shape
        
        # Encode map polylines
        map_polyline_features = static_map_polylines * static_map_polylines_mask.unsqueeze(-1).float()  # Apply mask to polylines
        map_flat_size = map_polyline_features.shape[1] * map_polyline_features.shape[2] * map_polyline_features.shape[3]
        map_polyline_features = map_polyline_features.view(batch_size, map_flat_size)  # Flatten
        
        # Initialize map encoder with correct input size
        self._init_map_encoder(map_flat_size)
        map_polyline_features = self.map_polyline_encoder(map_polyline_features)  # (batch_size, map_polyline_encoder_output_dim)

        # Encode input features
        obj_features = self.feature_encoder(obj_trajs.view(-1, input_dim))
        obj_features = obj_features.view(batch_size, num_objects, num_timestamps, self.encoder_output_dim)
        
        # Apply mask to features (only once)
        mask_expanded = obj_trajs_mask.unsqueeze(-1).expand(-1, -1, -1, self.encoder_output_dim)
        obj_features = obj_features * mask_expanded.float()  # (batch_size, num_objects, num_timestamps, encoder_output_dim)
        
        # Concatenate object features with map polyline features
        map_polyline_expanded = map_polyline_features[:, None, None, :]  # (batch_size, 1, 1, map_encoder_output_dim)
        map_polyline_expanded = map_polyline_expanded.expand(-1, num_objects, num_timestamps, -1)
        obj_map_features = torch.cat((obj_features, map_polyline_expanded), dim=-1)  
        obj_features = self.fusion_layer(obj_map_features)  # (batch_size, num_objects, num_timestamps, encoder_output_dim)
        
        # Process trajectories through LSTM more efficiently
        # Reshape to process all objects together
        obj_features_reshaped = obj_features.view(batch_size * num_objects, num_timestamps, self.encoder_output_dim)
        obj_mask_reshaped = obj_trajs_mask.view(batch_size * num_objects, num_timestamps)
        
        # Create packed sequence for efficient LSTM processing
        seq_lengths = obj_mask_reshaped.sum(dim=1).cpu()  # (batch_size * num_objects,)
        
        # Only process sequences with length > 0
        valid_indices = seq_lengths > 0
        if valid_indices.sum() > 0:
            valid_features = obj_features_reshaped[valid_indices]
            valid_lengths = seq_lengths[valid_indices]
            
            # Pack sequences
            packed_input = nn.utils.rnn.pack_padded_sequence(
                valid_features, valid_lengths, batch_first=True, enforce_sorted=False
            )
            
            # Process through LSTM
            packed_output, _ = self.lstm(packed_input)
            lstm_output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
            
            # Get last valid output for each sequence
            last_outputs = []
            for i, length in enumerate(output_lengths):
                if length > 0:
                    last_outputs.append(lstm_output[i, length-1, :])
                else:
                    last_outputs.append(torch.zeros(self.lstm_hidden_dim, device=device))
            
            valid_last_outputs = torch.stack(last_outputs, dim=0)
        
        # Reconstruct full output tensor
        all_lstm_outputs = torch.zeros(batch_size * num_objects, self.lstm_hidden_dim, device=device)
        if valid_indices.sum() > 0:
            all_lstm_outputs[valid_indices] = valid_last_outputs
        
        all_lstm_outputs = all_lstm_outputs.view(batch_size, num_objects, self.lstm_hidden_dim)
        
        # Apply attention mechanism for object interactions
        # Create attention mask: True for positions to ignore
        attn_mask = ~(obj_trajs_mask.sum(dim=2) > 0)  # (batch_size, num_objects)
        
        attn_output, _ = self.attention(
            all_lstm_outputs, all_lstm_outputs, all_lstm_outputs,
            key_padding_mask=attn_mask
        )
        
        # Extract center object features
        center_features = []
        for b in range(batch_size):
            center_idx = track_indices[b]
            center_features.append(attn_output[b, center_idx, :])
        center_features = torch.stack(center_features, dim=0)  # (batch_size, lstm_hidden_dim)
        
        # Predict mode probabilities
        mode_logits = self.mode_predictor(center_features)  # (batch_size, num_modes)
        pred_scores = F.softmax(mode_logits, dim=-1)
        
        # Predict trajectories for each mode
        pred_trajs_list = []
        for mode_idx in range(self.num_modes):
            traj_flat = self.traj_decoders[mode_idx](center_features)  # (batch_size, future_steps * 4)
            traj = traj_flat.view(batch_size, self.future_steps, 4)  # (batch_size, future_steps, 4)
            pred_trajs_list.append(traj)
        
        pred_trajs = torch.stack(pred_trajs_list, dim=1)  # (batch_size, num_modes, future_steps, 4)
        
        return pred_scores, pred_trajs

In [None]:
#define loss function
class MotionLoss(nn.Module):
    """
    Loss function for trajectory prediction
    """
    
    def __init__(self, 
                 regression_loss_weight=1.0,
                 classification_loss_weight=1.0,
                 future_loss_weight=1.0):
        super(MotionLoss, self).__init__()
        self.reg_weight = regression_loss_weight
        self.cls_weight = classification_loss_weight
        self.future_weight = future_loss_weight
    
    def forward(self, pred_scores, pred_trajs, batch_dict):
        """
        Compute loss
        
        Args:
            pred_scores: (batch_size, num_modes)
            pred_trajs: (batch_size, num_modes, future_steps, 4)
            batch_dict: Contains ground truth data
        
        Returns:
            loss_dict: Dictionary containing different loss components
        """
        center_gt_trajs = batch_dict['input_dict']['center_gt_trajs'].to('cuda')  # (batch_size, future_steps, 4)
        center_gt_trajs_mask = batch_dict['input_dict']['center_gt_trajs_mask'].to('cuda')  # (batch_size, future_steps)
        
        batch_size, num_modes, future_steps, _ = pred_trajs.shape
        
        # Compute trajectory regression loss for each mode
        gt_trajs_expanded = center_gt_trajs.unsqueeze(1).expand(-1, num_modes, -1, -1)
        gt_mask_expanded = center_gt_trajs_mask.unsqueeze(1).expand(-1, num_modes, -1)
        
        # L2 loss for position (x, y)
        pos_loss = F.mse_loss(
            pred_trajs[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, :2] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # L2 loss for velocity (vx, vy)
        vel_loss = F.mse_loss(
            pred_trajs[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            gt_trajs_expanded[:, :, :, 2:4] * gt_mask_expanded.unsqueeze(-1),
            reduction='none'
        ).sum(dim=-1)  # (batch_size, num_modes, future_steps)
        
        # Weighted loss over time (give more weight to near future)
        time_weights = torch.exp(-0.025 * torch.arange(future_steps, device=pred_trajs.device))
        time_weights = time_weights.view(1, 1, -1)
        
        pos_loss = (pos_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        vel_loss = (vel_loss * time_weights * gt_mask_expanded).sum(dim=-1)  # (batch_size, num_modes)
        
        # Find best mode for each sample
        total_traj_loss = pos_loss + vel_loss  # (batch_size, num_modes)
        best_mode_indices = torch.argmin(total_traj_loss, dim=1)  # (batch_size,)
        
        # Regression loss (best mode)
        best_pos_loss = pos_loss[torch.arange(batch_size), best_mode_indices].mean()
        best_vel_loss = vel_loss[torch.arange(batch_size), best_mode_indices].mean()
        regression_loss = best_pos_loss + best_vel_loss
        
        # Classification loss (encourage higher confidence for best mode)
        target_scores = torch.zeros_like(pred_scores)
        target_scores[torch.arange(batch_size), best_mode_indices] = 1.0
        classification_loss = F.cross_entropy(pred_scores, target_scores)
        
        # Total loss
        total_loss = (self.reg_weight * regression_loss + 
                     self.cls_weight * classification_loss)
        
        loss_dict = {
            'total_loss': total_loss,
            'regression_loss': regression_loss,
            'classification_loss': classification_loss,
            'pos_loss': best_pos_loss,
            'vel_loss': best_vel_loss
        }
        
        return loss_dict


In [6]:
#train loop
def train_model(model, train_dataloader, val_dataloader, num_epochs=100, lr=1e-3):
    """
    Training loop for the LSTM model
    """
    assert torch.cuda.is_available(), "CUDA is not available. Please check your PyTorch installation."
    device = torch.device('cuda')
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = MotionLoss()
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = []
        
        for batch_idx, batch_dict in enumerate(train_dataloader):
            # Move data to device
            for key, value in batch_dict.items():
                if isinstance(value, torch.Tensor):
                    batch_dict[key] = value.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            pred_scores, pred_trajs = model(batch_dict)
            
            # Compute loss
            loss_dict = criterion(pred_scores, pred_trajs, batch_dict)
            loss = loss_dict['total_loss']
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_losses.append(loss.item())
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch_dict in val_dataloader:
                # Move data to device
                for key, value in batch_dict.items():
                    if isinstance(value, torch.Tensor):
                        batch_dict[key] = value.to(device)
                
                pred_scores, pred_trajs = model(batch_dict)
                loss_dict = criterion(pred_scores, pred_trajs, batch_dict)
                val_losses.append(loss_dict['total_loss'].item())
        
        scheduler.step()
        
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/code/jjiang23/csc587/KimchiVision/best_motion_lstm.pth')
    
    return model

In [10]:
cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
logger = common_utils.create_logger("/files/waymo/log.txt", rank=0)
args = edict({
    "batch_size": 32,
    "workers": 6,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 5,
    "add_worker_init_fn": False,
    
})

In [11]:
#prepare data
train_set, train_loader, train_sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    batch_size=args.batch_size,
    dist=False, workers=args.workers,
    logger=logger,
    training=True,
    merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
    total_epochs=args.epochs,
    add_worker_init_fn=args.add_worker_init_fn,
)
test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=False, workers=args.workers, logger=logger, training=False
)

2025-06-09 14:18:54,264   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_training_infos.pkl
2025-06-09 14:18:58,558   INFO  Total scenes before filters: 243401
2025-06-09 14:19:04,202   INFO  Total scenes after filter_info_by_object_type: 243401
2025-06-09 14:19:04,229   INFO  Total scenes after filters: 243401
2025-06-09 14:19:04,231   INFO  Start to load infos from /files/waymo/code/MTR/data/waymo/processed_scenarios_val_infos.pkl
2025-06-09 14:19:06,035   INFO  Total scenes before filters: 22089
2025-06-09 14:19:06,552   INFO  Total scenes after filter_info_by_object_type: 22089
2025-06-09 14:19:06,554   INFO  Total scenes after filters: 22089


In [None]:
# Initialize model
model = MotionLSTM()

# Train the model
trained_model = train_model(model, train_loader, test_loader)

Epoch 0, Batch 0, Loss: 1889.6151


In [None]:
# evaluate the model

In [None]:
# loaded_model = MotionLSTM()
# loaded_model.load_state_dict(torch.load('/code/jjiang23/csc587/KimchiVision/best_motion_lstm.pth'))
# loaded_model.eval()
# loaded_model.to('cuda')

MotionLSTM(
  (map_polyline_encoder): Sequential(
    (0): Linear(in_features=560000, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
  )
  (feature_encoder): Sequential(
    (0): Linear(in_features=29, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (fusion_layer): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.1)
  (mode_predictor): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=6, bias=True)
  )
  (traj_decoders): ModuleList(
    (0-5): 6 x Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): ReLU()
      (2): Dro

In [None]:
# for batch_idx, batch_dict in enumerate(test_loader):
#     # Move data to device
#     for key, value in batch_dict.items():
#         if isinstance(value, torch.Tensor):
#             batch_dict[key] = value.to("cuda")
    
#     with torch.no_grad():
#         pred_scores, pred_trajs = loaded_model(batch_dict)
    
#     # Process predictions as needed
#     print(f'Batch {batch_idx}: Predicted scores shape: {pred_scores.shape}, Predicted trajectories shape: {pred_trajs.shape}')
    
#     if batch_idx >= 10:  # Limit to first 10 batches for demonstration
#         break

# pred_scores

KeyboardInterrupt: 

In [16]:
# Motion Transformer (MTR): https://arxiv.org/abs/2209.13508
# Published at NeurIPS 2022
# Written by Shaoshuai Shi 
# All Rights Reserved
import argparse
import datetime
import glob
import os
from pathlib import Path
import math

import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_sched
from tensorboardX import SummaryWriter

from mtr.datasets import build_dataloader
from mtr.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from mtr.utils import common_utils


from train_utils.train_utils import train_model


def parse_config():
    # parser = argparse.ArgumentParser(description='arg parser')
    # parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')

    # parser.add_argument('--batch_size', type=int, default=None, required=False, help='batch size for training')
    # parser.add_argument('--epochs', type=int, default=None, required=False, help='number of epochs to train for')
    # parser.add_argument('--workers', type=int, default=8, help='number of workers for dataloader')
    # parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
    # parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
    # parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
    # parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
    # parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
    # parser.add_argument('--without_sync_bn', action='store_true', default=False, help='whether to use sync bn')
    # parser.add_argument('--fix_random_seed', action='store_true', default=False, help='')
    # parser.add_argument('--ckpt_save_interval', type=int, default=2, help='number of training epochs')
    # parser.add_argument('--local_rank', type=int, default=None, help='local rank for distributed training')
    # parser.add_argument('--max_ckpt_save_num', type=int, default=5, help='max number of saved checkpoint')
    # parser.add_argument('--merge_all_iters_to_one_epoch', action='store_true', default=False, help='')
    # parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
    #                     help='set extra config keys if needed')

    # parser.add_argument('--max_waiting_mins', type=int, default=0, help='max waiting minutes')
    # parser.add_argument('--start_epoch', type=int, default=0, help='')
    # parser.add_argument('--save_to_file', action='store_true', default=False, help='')
    # parser.add_argument('--not_eval_with_train', action='store_true', default=False, help='')
    # parser.add_argument('--logger_iter_interval', type=int, default=50, help='')
    # parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds')

    # parser.add_argument('--add_worker_init_fn', action='store_true', default=False, help='')
    # args = parser.parse_args()
    
    cfg_from_yaml_file("/code/jjiang23/csc587/KimchiVision/cfg/kimchiConfig.yaml", cfg)
    # take all default args
    args = edict({
    "batch_size": 32,
    "workers": 4,
    "merge_all_iters_to_one_epoch": False,
    "epochs": 5,
    "add_worker_init_fn": False,
    "extra_tag": 'default',
    "launcher": 'none',
    "tcp_port": 18888,
    "without_sync_bn": False,
    "fix_random_seed": False,
    "ckpt_save_interval": 2,
    "local_rank": None,
    "max_ckpt_save_num": 5,
    "set_cfgs": None,
    "max_waiting_mins": 0,
    "start_epoch": 0,
    "save_to_file": False,
    "not_eval_with_train": False,
    "logger_iter_interval": 50,
    "ckpt_save_time_interval": 300,
    "add_worker_init_fn": False,
    "pretrained_model": None,
    "ckpt": None,
    "cfg_file": None,
    "fix_random_seed": False,
    "extra_tag": 'default',


    
    
})
    return args, cfg


def build_optimizer(model, opt_cfg):
    if opt_cfg.OPTIMIZER == 'Adam':
        optimizer = torch.optim.Adam(
            [each[1] for each in model.named_parameters()],
            lr=opt_cfg.LR, weight_decay=opt_cfg.get('WEIGHT_DECAY', 0)
        )
    elif opt_cfg.OPTIMIZER == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=opt_cfg.LR, weight_decay=opt_cfg.get('WEIGHT_DECAY', 0))
    else:
        assert False

    return optimizer


def build_scheduler(optimizer, dataloader, opt_cfg, total_epochs, total_iters_each_epoch, last_epoch):
    decay_steps = [x * total_iters_each_epoch for x in opt_cfg.get('DECAY_STEP_LIST', [5, 10, 15, 20])]
    def lr_lbmd(cur_epoch):
        cur_decay = 1
        for decay_step in decay_steps:
            if cur_epoch >= decay_step:
                cur_decay = cur_decay * opt_cfg.LR_DECAY
        return max(cur_decay, opt_cfg.LR_CLIP / opt_cfg.LR)

    if opt_cfg.get('SCHEDULER', None) == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=2 * len(dataloader),
            T_mult=1,
            eta_min=max(1e-2 * opt_cfg.LR, 1e-6),
            last_epoch=-1,
        )
    elif opt_cfg.get('SCHEDULER', None) == 'lambdaLR':
        scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch)
    elif opt_cfg.get('SCHEDULER', None) == 'linearLR':
        total_iters = total_iters_each_epoch * total_epochs
        scheduler = lr_sched.LinearLR(optimizer, start_factor=1.0, end_factor=opt_cfg.LR_CLIP / opt_cfg.LR, total_iters=total_iters, last_epoch=last_epoch)
    else:
        scheduler = None

    return scheduler


def main():
    args, cfg = parse_config()
    if args.launcher == 'none':
        dist_train = False
        total_gpus = 1
        args.without_sync_bn = True
    else:
        if args.local_rank is None:
            args.local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
            args.tcp_port, args.local_rank, backend='nccl'
        )
        dist_train = True

    if args.batch_size is None:
        args.batch_size = cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU
    else:
        assert args.batch_size % total_gpus == 0, 'Batch size should match the number of gpus'
        args.batch_size = args.batch_size // total_gpus

    args.epochs = cfg.OPTIMIZATION.NUM_EPOCHS if args.epochs is None else args.epochs

    if args.fix_random_seed:
        common_utils.set_random_seed(666)

    output_dir = Path("/code/jjiang23/csc587/KimchiVision/output/")
    ckpt_dir = output_dir / 'ckpt'
    output_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    log_file = output_dir / ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    # log to file
    logger.info('**********************Start logging**********************')
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)
    if cfg.LOCAL_RANK == 0:
        os.system('cp %s %s' % (args.cfg_file, output_dir))
    tb_log = SummaryWriter(log_dir=str(output_dir / 'tensorboard')) if cfg.LOCAL_RANK == 0 else None

    train_set, train_loader, train_sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers,
        logger=logger,
        training=True,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
        total_epochs=args.epochs,
        add_worker_init_fn=args.add_worker_init_fn,
    )

    model = MotionLSTM()
    if not args.without_sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda()

    optimizer = build_optimizer(model, cfg.OPTIMIZATION)

    # load checkpoint if it is possible
    start_epoch = it = 0
    last_epoch = -1

    if args.pretrained_model is not None:
        model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist_train, logger=logger)

    if args.ckpt is not None:
        it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer,
                                                           logger=logger)
        last_epoch = start_epoch + 1
    else:
        ckpt_list = glob.glob(str(ckpt_dir / '*.pth'))
        if len(ckpt_list) > 0:
            ckpt_list.sort(key=os.path.getmtime)
            while len(ckpt_list) > 0:
                basename = os.path.basename(ckpt_list[-1])
                if basename == 'best_model.pth':
                    ckpt_list = ckpt_list[:-1]
                    continue

                try:
                    it, start_epoch = model.load_params_with_optimizer(
                        ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger
                    )
                    last_epoch = start_epoch + 1
                    break
                except:
                    ckpt_list = ckpt_list[:-1]

    scheduler = build_scheduler(
        optimizer, train_loader, cfg.OPTIMIZATION, total_epochs=args.epochs,
        total_iters_each_epoch=len(train_loader), last_epoch=last_epoch
    )

    model.train()  # before wrap to DistributedDataParallel to support to fix some parameters

    if dist_train:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()], find_unused_parameters=True)
    logger.info(model)
    num_total_params = sum([x.numel() for x in model.parameters()])
    logger.info(f'Total number of parameters: {num_total_params}')

    test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers, logger=logger, training=False
    )

    eval_output_dir = Path("/code/jjiang23/csc587/KimchiVision/output/eval")
    eval_output_dir.mkdir(parents=True, exist_ok=True)

    # -----------------------start training---------------------------
    logger.info('**********************Start training %s/%s(%s)**********************'
                )
    train_model(
        model,
        optimizer,
        train_loader,
        optim_cfg=cfg.OPTIMIZATION,
        start_epoch=start_epoch,
        total_epochs=args.epochs,
        start_iter=it,
        rank=cfg.LOCAL_RANK,
        ckpt_save_dir=ckpt_dir,
        train_sampler=train_sampler,
        ckpt_save_interval=args.ckpt_save_interval,
        max_ckpt_save_num=args.max_ckpt_save_num,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
        tb_log=tb_log,
        scheduler=scheduler,
        logger=logger,
        eval_output_dir=eval_output_dir,
        test_loader=test_loader if not args.not_eval_with_train else None,
        cfg=cfg, dist_train=dist_train, logger_iter_interval=args.logger_iter_interval,
        ckpt_save_time_interval=args.ckpt_save_time_interval
    )

    logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
                )


    logger.info('**********************Start evaluation %s/%s(%s)**********************' 
                )

    eval_output_dir = output_dir / 'eval' / 'eval_with_train'
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    args.start_epoch = max(args.epochs - 0, 0)  # Only evaluate the last 10 epochs
    cfg.DATA_CONFIG.SAMPLE_INTERVAL.val = 1

    test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers, logger=logger, training=False
    )

    from test import repeat_eval_ckpt, eval_single_ckpt
    repeat_eval_ckpt(
        model.module if dist_train else model,
        test_loader, args, eval_output_dir, logger, ckpt_dir,
        dist_test=dist_train
    )

    logger.info('**********************End evaluation %s/%s(%s)**********************' 
                )


if __name__ == '__main__':
    main()



2025-06-09 16:28:31,359   INFO  **********************Start logging**********************
2025-06-09 16:28:31,359   INFO  **********************Start logging**********************
2025-06-09 16:28:31,364   INFO  CUDA_VISIBLE_DEVICES=1
2025-06-09 16:28:31,364   INFO  CUDA_VISIBLE_DEVICES=1
2025-06-09 16:28:31,365   INFO  batch_size       32
2025-06-09 16:28:31,365   INFO  batch_size       32
2025-06-09 16:28:31,366   INFO  workers          4
2025-06-09 16:28:31,366   INFO  workers          4
2025-06-09 16:28:31,367   INFO  merge_all_iters_to_one_epoch False
2025-06-09 16:28:31,367   INFO  merge_all_iters_to_one_epoch False
2025-06-09 16:28:31,368   INFO  epochs           5
2025-06-09 16:28:31,368   INFO  epochs           5
2025-06-09 16:28:31,369   INFO  add_worker_init_fn False
2025-06-09 16:28:31,369   INFO  add_worker_init_fn False
2025-06-09 16:28:31,370   INFO  extra_tag        default
2025-06-09 16:28:31,370   INFO  extra_tag        default
2025-06-09 16:28:31,371   INFO  launcher

ValueError: not enough values to unpack (expected 3, got 2)