In [4]:
from train_utils import DrivingData
from bokeh.plotting import figure, show, ColumnDataSource
from bokeh.io import output_notebook
import ipywidgets
import numpy as np
import bokeh.plotting as bkp
from bokeh.io import output_notebook, push_notebook
# 读取数据
path = ["/cailiu2/Diffusion-Planner/data/processed/us-nv-las-vegas-strip_00015fc2840d5313.npz"]
output_notebook()

train_set = DrivingData(path, 10, 10)
ego, agents, map_lanes, map_crosswalks, route_lanes, ego_future_gt, agents_future_gt, _, _ = train_set[0]

# 确保数据维度正确
print("ego shape:", ego.shape)  # 应为 (num_frames, 2)
print("agents shape:", agents.shape)  # 应为 (num_agents, num_frames, 2)

num_frames = len(ego)
num_agents = len(agents)

# 提取ego和agents的信息
ego_x, ego_y = ego[:, 0], ego[:, 1]
agents_x, agents_y = agents[:, :, 0], agents[:, :, 1]

# 提取map lanes信息
num_map_lanes = len(map_lanes)
num_points_of_lane = len(map_lanes[0])

lane_xs=[map_lanes[i, :, 0].tolist() for i in range(num_map_lanes)]
lane_ys=[map_lanes[i, :, 1].tolist() for i in range(num_map_lanes)]

#提取 map_crosswalks信息
num_crosswalks = len(map_crosswalks)
num_points_of_crosswalks = len(map_crosswalks[0])
crosswalks_xs = [map_crosswalks[i, :, 0].tolist() for i in range(num_crosswalks)]
crosswalks_ys = [map_crosswalks[i, :, 1].tolist() for i in range(num_crosswalks)]

#提取route lanes信息
num_route_lane = len(route_lanes)
num_points_of_route_lane = len(route_lanes[0])
route_lane_xs=[route_lanes[i, :, 0].tolist() for i in range(num_route_lane)]
route_lane_ys=[route_lanes[i, :, 1].tolist() for i in range(num_route_lane)]

#提取未来8s的自车真实轨迹
ego_future_gt_xs = ego_future_gt[:, 0].tolist()
ego_future_gt_ys = ego_future_gt[:, 1].tolist()

#提取未来8s障碍车的真实轨迹
num_agents_future_gt = agents_future_gt.shape[0]
agents_future_gt_xs = [agents_future_gt[i, :, 0] for i in range(num_agents_future_gt)]
agents_future_gt_ys = [agents_future_gt[i, :, 1] for i in range(num_agents_future_gt)]


# 初始视图范围
initial_x, initial_y = ego_x[0], ego_y[0]
p = bkp.figure(
    title="Ego Vehicle & agents",
    x_range=(initial_x - 50, initial_x + 50),
    y_range=(initial_y - 50, initial_y + 50),
    width=600, height=600
)

# 数据源
source = ColumnDataSource(data={
    'ego_x': [],
    'ego_y': [],
})

agents_source = ColumnDataSource(data={
    'agents_x': agents_x[:, 0].tolist(),
    'agents_y': agents_y[:, 0].tolist()
})

map_lane_source = ColumnDataSource(data={'lane_xs': lane_xs, 'lane_ys': lane_ys})

map_crosswalks_source = ColumnDataSource(data={'crosswalks_xs': crosswalks_xs, 'crosswalks_ys': crosswalks_ys})

route_lane_source = ColumnDataSource(data={'lane_xs': route_lane_xs, 'lane_ys': route_lane_ys})

ego_future_gt_source = ColumnDataSource(data={'xs': ego_future_gt_xs, 'ys': ego_future_gt_ys})

agents_future_gt_source = ColumnDataSource(data={
    'agents_x': agents_future_gt_xs,
    'agents_y': agents_future_gt_ys
})
# 绘图
p.circle('ego_x', 'ego_y', source=source, size=12, color="red", legend_label="Ego Vehicle")
p.scatter('agents_x', 'agents_y', source=agents_source, size=6, color="blue", alpha=0.6, legend_label="agents")

p.multi_line('lane_xs', 'lane_ys', source=map_lane_source, line_width=2, color="blue", alpha=0.6, legend_label="map lines")
p.multi_line('crosswalks_xs', 'crosswalks_ys', source = map_crosswalks_source, line_width=1, color="red", alpha = 0.6, legend_label="crosswalks")
p.multi_line('lane_xs', 'lane_ys', source = route_lane_source, line_width=1, color="green", alpha = 0.6, legend_label="route lane")
p.line('xs', 'ys', source = ego_future_gt_source, line_width = 1.5, color = "yellow", alpha = 0.9, legend_label="ego future gt" )

p.multi_line('agents_x', 'agents_y', source = agents_future_gt_source, line_width=1, color="black", alpha = 0.6, legend_label="agents gt traj")
# 创建滑块类，初始化并绑定回调
slider_class = ipywidgets.IntSlider(
    value=0, min=0, max=num_frames-1, step=1, description="Frame"
)

# 滑块回调
def slider_callback(frame):
    print("frame: ", frame)
    new_ego_x = ego_x[frame]
    new_ego_y = ego_y[frame]
    print("new ego_x: ", new_ego_x)
    print("new ego_y: ", new_ego_y)

    # 更新数据
    source.data.update({'ego_x': [new_ego_x], 'ego_y': [new_ego_y]})
    agents_source.data.update( {'agents_x': agents_x[:, frame].tolist(), 'agents_y': agents_y[:, frame].tolist()})

    bkp.show(p, notebook_handle=True)
    push_notebook()


ipywidgets.interactive(slider_callback, frame=slider_class)




ego shape: (21, 7)
agents shape: (20, 21, 11)


interactive(children=(IntSlider(value=0, description='Frame', max=20), Output()), _dom_classes=('widget-intera…

In [None]:
print(map_lanes.shape)
num_lanes = len(map_lanes)
num_points_of_lane = len(map_lanes[0])
print(num_lanes)
print(num_points_of_lane)

print(map_crosswalks.shape)
print(map_crosswalks[4])
print(route_lanes.shape)
print(ego_future_gt.shape)

print(ego_future_gt_xs)
print(agents_future_gt.shape)
print(num_agents)

(40, 50, 7)
40
50
(5, 30, 3)
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
(10, 50, 3)
(80, 3)
[1.2110577821731567, 2.4194631576538086, 3.628206253051758, 4.832894802093506, 6.032045364379883, 7.243168830871582, 8.452635765075684, 9.652140617370605, 10.8611421585083, 12.070699691772461, 13.261822700500488, 14.468297004699707, 15.672863006591797, 16.866498947143555, 18.07832908630371, 19.279552459716797, 20.467708587646484, 21.659832000732422, 22.846923828125, 24.035789489746094, 25.228199005126953, 26.407556533813477, 27.59702491760254, 28.768293380737305, 29.960952758789062, 31.129106521606445, 32.301048278808594, 33.47299575805664, 34.63808822631836, 35.80805587768555, 36.98