In [11]:
# 可视化数据集中的点云和末端位姿
#!/usr/bin/env python3
import os
import sys
import numpy as np
import pandas as pd
import h5py

In [12]:
# 读取HDF5文件
hdf5_file_path = '/home/lgx/Project/AFP/src/il_capture/data/layup_1767859976_20260108_161256_corrected.hdf5'
try:
    with h5py.File(hdf5_file_path, 'r') as hdf5_file:
        print("HDF5 file keys:", list(hdf5_file.keys()))

        demo_name = "demo_2"

        if 'data' in hdf5_file and demo_name in hdf5_file['data']:
            pcds = hdf5_file['data'][demo_name]['obs']['pointcloud'][:]
            ee_pos = hdf5_file['data'][demo_name]['obs']['robot0_eef_pos'][:]
            ee_quat = hdf5_file['data'][demo_name]['obs']['robot0_eef_quat'][:]
            wrench = hdf5_file['data'][demo_name]['obs']['robot0_eef_wrench'][:]
            
            print(f"\n--- {demo_name} 数据读取成功 ---")
            print(f"点云数据类型: {type(pcds)}")
            print(f"点云数据形状 (Shape): {pcds.shape}")
except Exception as e:
    print(f"读取HDF5文件时出错: {e}")

HDF5 file keys: ['data']

--- demo_2 数据读取成功 ---
点云数据类型: <class 'numpy.ndarray'>
点云数据形状 (Shape): (199, 2048, 6)


In [13]:
import numpy as np
import plotly.graph_objects as go

if 'pcds' in locals() and pcds is not None:
    # --- 1. 计算全局坐标轴范围 ---
    # 合并轨迹点和所有点云的极值，确保所有内容都能显示
    # 注意：如果点云数据量巨大，可以只采样部分点计算极值以提高效率
    all_ee_pos = np.array(ee_pos)
    
    # 获取点云的全局最值（遍历所有帧）
    pcd_min = np.min([np.min(p, axis=0) for p in pcds], axis=0)
    pcd_max = np.max([np.max(p, axis=0) for p in pcds], axis=0)
    
    # 结合轨迹点计算最终边界
    x_range = [min(pcd_min[0], all_ee_pos[:, 0].min()), max(pcd_max[0], all_ee_pos[:, 0].max())]
    y_range = [min(pcd_min[1], all_ee_pos[:, 1].min()), max(pcd_max[1], all_ee_pos[:, 1].max())]
    z_range = [min(pcd_min[2], all_ee_pos[:, 2].min()), max(pcd_max[2], all_ee_pos[:, 2].max())]

    # 给范围增加 10% 的缓冲空间，防止点贴在坐标轴边缘
    def add_padding(r, pad=0.1):
        diff = r[1] - r[0]
        return [r[0] - diff * pad, r[1] + diff * pad]

    x_range = add_padding(x_range)
    y_range = add_padding(y_range)
    z_range = add_padding(z_range)

    # --- 2. 构造图表（保持之前的三 Trace 结构） ---
    n_frames = len(pcds)
    trace_pcd = go.Scatter3d(
        x=pcds[0][:, 0], y=pcds[0][:, 1], z=pcds[0][:, 2],
        mode='markers', marker=dict(size=1, color=pcds[0][:, 2], colorscale='Viridis')
    )
    trace_window = go.Scatter3d(
        x=ee_pos[0:11, 0], y=ee_pos[0:11, 1], z=ee_pos[0:11, 2],
        mode='lines+markers', marker=dict(color='blue', size=1)
    )
    trace_current = go.Scatter3d(
        x=[ee_pos[0, 0]], y=[ee_pos[0, 1]], z=[ee_pos[0, 2]],
        mode='markers', marker=dict(size=2, color='red')
    )

    fig = go.Figure(data=[trace_pcd, trace_window, trace_current])

    # --- 3. 构造动画帧（同前） ---
    frames = []
    for i in range(n_frames):
        start_idx = i + 1
        end_idx = min(n_frames, i + 11)
        frames.append(go.Frame(
            data=[
                go.Scatter3d(x=pcds[i][:, 0], y=pcds[i][:, 1], z=pcds[i][:, 2]),
                go.Scatter3d(x=ee_pos[start_idx:end_idx, 0], y=ee_pos[start_idx:end_idx, 1], z=ee_pos[start_idx:end_idx, 2]),
                go.Scatter3d(x=[ee_pos[i, 0]], y=[ee_pos[i, 1]], z=[ee_pos[i, 2]])
            ],
            name=str(i)
        ))
    fig.frames = frames

    # 3. 配置滑块逻辑
    sliders = [dict(
        steps=[dict(
            method='animate',
            args=[[str(i)], dict(mode='immediate', frame=dict(duration=50, redraw=True), transition=dict(duration=0))],
            label=str(i)
        ) for i in range(n_frames)],
        transition=dict(duration=0),
        x=0.1, y=0,
        currentvalue=dict(font=dict(size=12), prefix='Frame: ', visible=True, xanchor='right'),
        len=0.9
    )]

    # --- 4. 关键点：固定坐标轴范围 ---
    fig.update_layout(
        scene=dict(
            # 显式设置坐标轴显示范围
            xaxis=dict(title='X (m)', range=x_range, autorange=False),
            yaxis=dict(title='Y (m)', range=y_range, autorange=False),
            zaxis=dict(title='Z (m)', range=z_range, autorange=False),
            
            # 建议开启 aspectmode='data' 以保持真实的 1:1:1 比例
            # 这对机器人仿真非常重要，防止轨迹看起来被拉伸
            aspectmode='data' 
        ),
        title='Fixed Axis Trajectory Visualization',
        sliders=sliders,
        updatemenus=[dict(
            type='buttons',
            showactive=False,
            y=0, x=0,
            buttons=[
                dict(label='Play', method='animate', 
                     args=[None, dict(frame=dict(duration=50, redraw=True), fromcurrent=True, transition=dict(duration=0))]),
                dict(label='Pause', method='animate', args=[[None], dict(mode='immediate')])
            ]
        )]
    )

    # 补充：滑块配置中需要确保 method='animate'
    # （略，参考之前的代码）

    fig.show()

# 可视化并选择性删除示教数据

In [1]:
import h5py
import numpy as np
import open3d as o3d
import time
import os
from collections import OrderedDict

class DataInspector:
    def __init__(self, h5_path, voxel_size=0.05, play_dt=0.05):
        self.h5_path = h5_path
        self.voxel_size = voxel_size
        self.play_dt = play_dt

        self.good_demos = []
        self.bad_demos = []
        self.undecided_demos = []

        self.f = h5py.File(h5_path, 'r')
        self.demos = sorted(self.f['data'].keys())
        print(f"检测到 {len(self.demos)} 组演示数据。")

    def run(self):
        for demo in self.demos:
            print(f"\n检查 demo: {demo}")
            decision = self.inspect_demo(demo)

            if decision == "EXIT":
                print("收到 ESC，退出程序。")
                return
            
            if decision is True:
                self.good_demos.append(demo)
                print(f"✅ 保留 {demo}")
            elif decision is False:
                self.bad_demos.append(demo)
                print(f"❌ 删除 {demo}")
            else:
                self.undecided_demos.append(demo)
                print(f"⚠️ 未决 {demo}")


        self.f.close()
        self.save_cleaned_dataset()

    # ================= 核心可视化 =================
    def inspect_demo(self, demo_name):
        group = self.f['data'][demo_name]['obs']

        if 'robot0_eef_pos' not in group or 'pointcloud' not in group:
            print("缺失关键字段，自动跳过")
            return None

        eef_pos = group['robot0_eef_pos'][:]  # 很小，可以整体读
        pcd_ds = group['pointcloud']          # HDF5 dataset，不读数据

        T = eef_pos.shape[0]

        # ---- Open3D 初始化 ----
        vis = o3d.visualization.VisualizerWithKeyCallback()
        vis.create_window(f"Inspecting {demo_name}", 1280, 800)

        # 轨迹线
        lines = [[i, i + 1] for i in range(T - 1)]
        line_set = o3d.geometry.LineSet(
            points=o3d.utility.Vector3dVector(eef_pos),
            lines=o3d.utility.Vector2iVector(lines)
        )
        line_set.colors = o3d.utility.Vector3dVector([[1, 0, 0]] * len(lines))

        # 点云（只创建一次）
        pcd = o3d.geometry.PointCloud()

        # 末端球
        sphere = o3d.geometry.TriangleMesh.create_sphere(0.001)
        sphere.paint_uniform_color([0, 1, 0])

        vis.add_geometry(line_set)
        vis.add_geometry(pcd)
        vis.add_geometry(sphere)

        state = {
            "t": 0,
            "playing": False,
            "decision": None,
            "exit_all": False
        }

        # ---- 点云读取函数（lazy）----
        def load_pcd_frame(t):
            # 原始形状: (10000, 6)
            pts = pcd_ds[t]

            # -------- 1. 只取 xyz --------
            if pts.ndim != 2 or pts.shape[1] < 3:
                raise ValueError(f"点云格式异常: {pts.shape}")

            pts = pts[:, :3]

            # -------- 2. 过滤非法点 --------
            mask = np.isfinite(pts).all(axis=1)
            pts = pts[mask]

            if pts.shape[0] == 0:
                raise ValueError("该帧点云为空")

            # -------- 3. Open3D 要求 --------
            pts = np.ascontiguousarray(pts, dtype=np.float64)

            temp = o3d.geometry.PointCloud()
            temp.points = o3d.utility.Vector3dVector(pts)

            if self.voxel_size > 0:
                temp = temp.voxel_down_sample(self.voxel_size)

            return temp

        def update():
            t = state["t"]
            temp = load_pcd_frame(t)
            pcd.points = temp.points
            sphere.translate(eef_pos[t] - sphere.get_center(), relative=True)
            vis.update_geometry(pcd)
            vis.update_geometry(sphere)

        # ---- 键盘回调 ----
        def next_frame(vis):
            if state["t"] < T - 1:
                state["t"] += 1
                update()

        def prev_frame(vis):
            if state["t"] > 0:
                state["t"] -= 1
                update()

        def toggle(vis):
            state["playing"] = not state["playing"]

        def mark_good(vis):
            state["decision"] = True
            vis.close()

        def exit_program(vis):
            state["exit_all"] = True
            state["decision"] = None
            vis.close()
        
        def mark_bad(vis):
            state["decision"] = False
            vis.close()

        vis.register_key_callback(262, next_frame)
        vis.register_key_callback(263, prev_frame)
        vis.register_key_callback(32, toggle)
        vis.register_key_callback(ord('Y'), mark_good)
        vis.register_key_callback(ord('N'), mark_bad)
        vis.register_key_callback(256, exit_program)  # ESC

        print("→ / ← 切帧 | Space 播放 | Y 保留 | N 删除 | ESC 退出")

        update()

        # ---- 主循环 ----
        while state["decision"] is None and not state["exit_all"]:
            vis.poll_events()
            vis.update_renderer()

            if state["playing"]:
                if state["t"] < T - 1:
                    state["t"] += 1
                    update()
                    time.sleep(self.play_dt)
                else:
                    state["playing"] = False

        vis.destroy_window()

        if state["exit_all"]:
            return "EXIT"
        return state["decision"]

    # ================= 数据保存 =================
    def save_cleaned_dataset(self):
        if not self.bad_demos:
            print("未发现坏数据，未生成新文件。")
            return

        new_path = self.h5_path.replace(".hdf5", "_cleaned.hdf5")
        print(f"生成新文件: {new_path}")

        with h5py.File(self.h5_path, 'r') as f_src, \
             h5py.File(new_path, 'w') as f_dst:

            dst_data = f_dst.create_group('data')
            for demo in self.good_demos:
                f_src.copy(f"data/{demo}", dst_data)

        print("清洗完成。")

# ================= 使用 =================
if __name__ == "__main__":
    path = "/home/lgx/Project/AFP/src/il_capture/data/layup_1767859976_20260108_161256.hdf5"
    inspector = DataInspector(path, voxel_size=0.01)
    inspector.run()


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
检测到 82 组演示数据。

检查 demo: demo_0
→ / ← 切帧 | Space 播放 | Y 保留 | N 删除 | ESC 退出
收到 ESC，退出程序。
