## Model Evaluation Notebook

In [None]:
import torch
from your_module import TrajectoryLSTM  # wherever you defined it

# 1) Re-instantiate the exact same model architecture & hyperparameters
model = TrajectoryLSTM(
    input_dim=29,
    hidden_dim=256,
    num_layers=2,
    num_modes=6,
    future_steps=80,
    dropout=0.1
)

# 2) Choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 3) Load the saved weights
checkpoint_path = '/code/jjiang23/csc587/KimchiVision/best_trajectory_lstm.pth'
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)

# 4) Switch to eval mode (disables dropout, etc.)
model.eval()

In [None]:
import pickle
import mtr.datasets.waymo.waymo_eval as waymo_eval

model.eval()
pred_dicts = []

with torch.no_grad():
    for batch_dict in val_dataloader:
        # Move to GPU
        for k,v in batch_dict.items():
            if isinstance(v, torch.Tensor):
                batch_dict[k] = v.cuda()

        # 1) forward
        scores, trajs = model(batch_dict)  
        # scores: (B,6), trajs: (B,6,80,4)  [x,y,vx,vy]

        B = scores.shape[0]
        # assume your batch_dict contains these GT fields:
        #   batch_dict["input_dict"]["scenario_id"]       (B,)
        #   batch_dict["input_dict"]["track_index_to_predict"]   (B,)
        #   batch_dict["input_dict"]["object_type"]       (B,) or (B,num_objects)
        #   batch_dict["input_dict"]["gt_trajs_full"]     (B,T_total,7)
        #   batch_dict["input_dict"]["gt_trajs_mask_full"] (B,T_total)

        for i in range(B):
            sid    = batch_dict["input_dict"]["scenario_id"][i].item()
            oid    = batch_dict["input_dict"]["track_index_to_predict"][i].item()
            otype  = batch_dict["input_dict"]["object_type"][i,oid].item()  # if per-object
            gt7    = batch_dict["input_dict"]["gt_trajs_full"][i].cpu().numpy()
            mask   = batch_dict["input_dict"]["gt_trajs_mask_full"][i].cpu().numpy()

            pred_dicts.append({
              "scenario_id":   sid,
              "object_id":     oid,
              "object_type":   otype,
              "pred_scores":   scores[i].cpu().numpy(),      # (6,)
              # only feed (x,y) to Waymo script; it will ignore vx,vy
              "pred_trajs":    trajs[i,:,:,:2].cpu().numpy(),# (6,80,2)
              "gt_trajs":      gt7,                          # (T_total,7)
              "gt_valid_mask": mask,                         # (T_total,)
            })

# dump to disk
with open("preds.pkl","wb") as f:
    pickle.dump(pred_dicts, f)

# then in shell:
# python mtr/datasets/waymo/waymo_eval.py --pred_infos preds.pkl --eval_second 8 --num_modes_for_eval 6
