In [17]:
from nuplan.database.nuplan_db_orm.nuplandb import NuPlanDB
from pathlib import Path
from nuplan.database.nuplan_db_orm.nuplandb_wrapper import NuPlanDBWrapper
from nuplan.database.nuplan_db_orm.frame import Frame
import os

from nuplan.planning.utils.multithreading.worker_pool import Task
from nuplan.planning.utils.multithreading.worker_parallel import SingleMachineParallelExecutor
from nuplan.planning.scenario_builder.scenario_filter import ScenarioFilter
from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_builder import NuPlanScenarioBuilder
from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_utils import ScenarioMapping

import os
import math
import argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
from data_utils import *
from trajectory_tree_planner import *
from common_utils import get_filter_parameters, get_scenario_map
from data_process import DataProcessor
from train_utils import DrivingData
import numpy as np
from torch.utils.data import DataLoader

In [18]:
db_path = "/cailiu2/Diffusion-Planner/data/2021.10.21.14.43.30_veh-28_01244_01519.db" # single db file
db_path = "/share/data_cold/open_data/nuplan/data/cache/mini"  # multi db files
data_root = "/cailiu2/Diffusion-Planner/data/"
map_path = "/cailiu2/Diffusion-Planner/data/maps"
map_version = "nuplan-maps-v1.0"

save_processed_path = "/cailiu2/Diffusion-Planner/data/processed"

In [19]:
path = ["/cailiu2/Diffusion-Planner/data/processed/us-nv-las-vegas-strip_7fc811bcf45f5e79.npz"] # the path must be list
train_set = DrivingData(path, 10, 10) # DrivingData继承自torch.Dataset，存储和加载npz数据
print("len train set: ", len(train_set))
print("train_set: ", train_set)
ego, neighbors, map_lanes, map_crosswalks, route_lanes, ego_future_gt, neighbors_future_gt, first_stage, second_stage = train_set[0]


len train set:  1
train_set:  <train_utils.DrivingData object at 0x7f126c04d3a0>


In [20]:
print("len ego: ", len(ego))
print("len neighbors: ", len(neighbors))
print("len map_lanes: ", len(map_lanes))
print("len map_crosswalks: ", len(map_crosswalks))
print("len route_lanes: ", len(route_lanes))
print("len ego_future_gt: ", len(ego_future_gt))
print("len route_lanes: ", len(route_lanes))
print("len ego_future_gt: ", len(ego_future_gt))
print("len neighbors_future_gt: ", len(neighbors_future_gt))
print("len first_stage: ", len(first_stage))
print("len second_stage: ", len(second_stage))



len ego:  21
len neighbors:  20
len map_lanes:  40
len map_crosswalks:  5
len route_lanes:  10
len ego_future_gt:  80
len route_lanes:  10
len ego_future_gt:  80
len neighbors_future_gt:  10
len first_stage:  10
len second_stage:  10


In [21]:
neighbor = neighbors[0]
print("neighbor: ", neighbor)

neighbor:  [[-1.69942074e+01 -9.70942751e-02 -3.26274359e-03  1.32518077e+00
  -1.40478015e-02  4.15088376e-04  4.37049246e+00  1.82270801e+00
   1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-1.67705040e+01  1.45155443e-02 -2.84765521e-03  1.60856521e+00
  -1.47808194e-02 -5.75511949e-03  4.37055159e+00  1.83040953e+00
   1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-1.64909954e+01  1.42736183e-02 -4.41359542e-03  1.75110054e+00
  -8.00150633e-03 -5.79216983e-03  4.37110329e+00  1.84281862e+00
   1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-1.64630928e+01 -4.16522510e-02 -4.00591595e-03  1.93580914e+00
  -6.53553009e-03  5.61464229e-04  4.39109659e+00  1.84603357e+00
   1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-1.61835842e+01 -4.18941788e-02 -4.30131936e-03  2.18200779e+00
  -4.81873751e-03  4.60099755e-03  4.42417240e+00  1.84593010e+00
   1.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-1.59040756e+01 -4.21361029e-02 -3.08585400e-03  2.36826086e+00
  -1.398

In [16]:
data = np.load(path[0])
print("data files: ", data.files)

print("token: ", data['token'])
print(data['ego_agent_past'].shape)

data files:  ['map_name', 'token', 'ego_agent_past', 'ego_agent_future', 'first_stage_ego_trajectory', 'second_stage_ego_trajectory', 'neighbor_agents_past', 'neighbor_agents_future', 'map_lanes', 'map_crosswalks', 'route_lanes']
token:  7fc811bcf45f5e79
(21, 7)


In [24]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [26]:
# test inputs
train_loader = DataLoader(train_set, batch_size= 12, num_workers=os.cpu_count())

#prepare data
with tqdm(train_loader, desc="Training", unit = "batch") as deta_epoch:
  for batch in deta_epoch:
    inputs = {'ego_agent_past': batch[0].to(device),
              'neighbor_agents_past': batch[1].to(device),
              'map_lanes': batch[2].to(device),
              'map_crosswalks': batch[3].to(device),
              'route_lanes': batch[4].to(device)
              }
    ego_gt_future = batch[5].to(device)
    neighbors_gt_future = batch[6].to(device)
    neighbors_future_valid = torch.ne(neighbors_gt_future[..., :3], 0)
    print("inputs: ", inputs)
    


Training: 100%|██████████| 1/1 [00:07<00:00,  7.58s/batch]

inputs:  {'ego_agent_past': tensor([[[-8.0778e+00, -4.8910e-02,  1.1531e-02,  2.0099e+00, -4.2228e-02,
           1.7718e+00, -9.4090e-02],
         [-7.7983e+00, -4.9152e-02,  1.1245e-02,  2.2141e+00, -5.4906e-02,
           1.8612e+00, -1.4924e-01],
         [-7.5467e+00,  6.5320e-03,  1.1099e-02,  2.4225e+00, -5.6634e-02,
           1.7653e+00,  2.1776e-03],
         [-7.2672e+00,  6.2901e-03,  1.1107e-02,  2.6288e+00, -6.0398e-02,
           1.8552e+00, -1.1647e-01],
         [-7.2114e+00, -1.0556e-01,  1.1141e-02,  2.8362e+00, -6.9311e-02,
           1.9565e+00, -2.0264e-01],
         [-6.7082e+00,  5.8062e-03,  1.1089e-02,  3.0483e+00, -7.7475e-02,
           2.0824e+00,  8.3421e-02],
         [-6.4287e+00,  5.5643e-03,  1.0864e-02,  3.2604e+00, -7.1105e-02,
           2.0798e+00, -1.2008e-01],
         [-6.1213e+00, -5.0604e-02,  1.0791e-02,  3.4628e+00, -8.1433e-02,
           1.9574e+00,  1.5863e-02],
         [-5.8418e+00, -5.0845e-02,  1.0486e-02,  3.6699e+00, -7.1902e-02,
 

Training: 100%|██████████| 1/1 [00:07<00:00,  7.85s/batch]
