In [1]:
cd ..

/home/xiantuo/source/grasp/SceneLeapPlus


In [2]:
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import numpy as np
# import open3d as o3d
from torch.utils.data import DataLoader
# from utils.config_utils import EasyConfig
from utils.color_utils import get_random_color
from models.diffuser_lightning import DDPMLightning
from models.cvae import GraspCVAELightning
from utils.hand_model import HandModel, HandModelType
from datasets.sceneleapplus_dataset import SceneLeapPlusDataset
from datasets import build_datasets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
import trimesh
from omegaconf import OmegaConf
from utils.hand_helper import norm_hand_pose_robust, denorm_hand_pose_robust
import matplotlib

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# 设置常量
CKPT_PATH = 'experiments/hpsearch/A5_adamw_cos_t500_min1e-5_lr3e-4_wd1e-3/checkpoints/epoch=489-val_loss=11.73.ckpt'
CONFIG_PATH = 'experiments/hpsearch/A5_adamw_cos_t500_min1e-5_lr3e-4_wd1e-3/config/whole_config.yaml'
DEVICE = 'cuda:0'
BATCH_SIZE = 5
VAL_NUM_GRASPS = 64

In [4]:
cfg = OmegaConf.load(CONFIG_PATH)
cfg = OmegaConf.create(cfg)

cfg.model.criterion.cost_weights.hand_mesh = 0.0
cfg.model.criterion.cost_weights.qpos = 0.0
cfg.model.criterion.cost_weights.translation = 100.0
cfg.model.criterion.cost_weights.rotation = 10.0
# cfg.model.fix_num_grasps = False

# 创建模型
model = DDPMLightning(cfg.model)

# 强制初始化text_encoder
if hasattr(model.eps_model, '_ensure_text_encoder'):
    print("正在初始化text_encoder...")
    model.eps_model._ensure_text_encoder()
    print("✅ text_encoder初始化完成")

# 加载checkpoint
checkpoint = torch.load(CKPT_PATH, map_location='cpu')

# 现在可以使用strict=True加载
model.load_state_dict(checkpoint['state_dict'], strict=True)
print(f"✅ 模型加载成功 (strict=True)")

# 移动到指定设备并设置为评估模式
model.to(DEVICE).eval()
print(f"✅ 模型已移动到 {DEVICE} 并设置为评估模式")
hand_model = HandModel(HandModelType.LEAP, cfg.model.criterion.hand_model.n_surface_points, cfg.model.criterion.rot_type, DEVICE)

正在初始化text_encoder...
✅ text_encoder初始化完成
✅ 模型加载成功 (strict=True)
✅ 模型已移动到 cuda:0 并设置为评估模式


In [5]:
cfg.use_object_mask

True

In [23]:
train_dataset = SceneLeapPlusDataset(
    root_dir=cfg.data.train.root_dir,
    succ_grasp_dir=cfg.data.train.succ_grasp_dir,
    obj_root_dir=cfg.data.train.obj_root_dir,
    max_grasps_per_object=cfg.data.train.max_grasps_per_object,
    mode=cfg.data.train.mode,
    num_grasps=cfg.data.train.num_grasps,
    # num_grasps=5,
    mesh_scale=cfg.data.train.mesh_scale,
    num_neg_prompts=cfg.data.train.num_neg_prompts,
    enable_cropping=cfg.data.train.enable_cropping,
    max_points=cfg.data.train.max_points,
    grasp_sampling_strategy=cfg.data.train.grasp_sampling_strategy,
    use_exhaustive_sampling=cfg.data.train.use_exhaustive_sampling,
)

test_dataset = SceneLeapPlusDataset(
    root_dir=cfg.data.val.root_dir,
    succ_grasp_dir=cfg.data.val.succ_grasp_dir,
    obj_root_dir=cfg.data.val.obj_root_dir,
    max_grasps_per_object=1280,
    num_grasps=VAL_NUM_GRASPS,
    mesh_scale=cfg.data.val.mesh_scale,
    num_neg_prompts=cfg.data.val.num_neg_prompts,
    enable_cropping=cfg.data.val.enable_cropping,
    max_points=cfg.data.val.max_points,
    grasp_sampling_strategy=cfg.data.val.grasp_sampling_strategy,
    use_exhaustive_sampling=cfg.data.val.use_exhaustive_sampling,
    # max_grasps_per_object=None,
    mode=cfg.data.val.mode
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=SceneLeapPlusDataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=SceneLeapPlusDataset.collate_fn)

In [24]:
test_ret_dict = next(iter(test_loader))

In [25]:
train_ret_dict = next(iter(train_loader))

In [26]:
new_dict = {}
for key in ['scene_pc', 'object_mask', 'se3', 'hand_model_pose', 'positive_prompt', 'negative_prompts']:
    if key in test_ret_dict:
        if isinstance(test_ret_dict[key], torch.Tensor):
            new_dict[key] = test_ret_dict[key].to(DEVICE)
        elif isinstance(test_ret_dict[key], list):
            if len(test_ret_dict[key]) > 0 and isinstance(test_ret_dict[key][0], torch.Tensor):
                new_dict[key] = [tensor.to(DEVICE) for tensor in test_ret_dict[key]]
            else:
                new_dict[key] = test_ret_dict[key]
        else:
            new_dict[key] = test_ret_dict[key]

In [27]:
train_new_dict = {}
for key in ['scene_pc', 'object_mask', 'se3', 'hand_model_pose', 'positive_prompt', 'negative_prompts']:
    if key in train_ret_dict:
        if isinstance(train_ret_dict[key], torch.Tensor):
            train_new_dict[key] = train_ret_dict[key].to(DEVICE)
        elif isinstance(train_ret_dict[key], list):
            if len(train_ret_dict[key]) > 0 and isinstance(train_ret_dict[key][0], torch.Tensor):
                train_new_dict[key] = [tensor.to(DEVICE) for tensor in train_ret_dict[key]]
            else:
                train_new_dict[key] = train_ret_dict[key]
        else:
            train_new_dict[key] = train_ret_dict[key]

In [28]:
# # 对train_new_dict和test_new_dict里面的每项数据进行分析

# def analyze_dict(data_dict, dict_name=""):
#     print(f"\n==== {dict_name} 分析 ====")
#     for key, value in data_dict.items():
#         print(f"\n键: {key}")
#         if isinstance(value, torch.Tensor):
#             print(f"  类型: torch.Tensor")
#             print(f"  形状: {value.shape}")
#             print(f"  数据类型: {value.dtype}")
#             print(f"  设备: {value.device}")
#             print(f"  前5个元素: {value.flatten()[:5]}")
#         elif isinstance(value, list):
#             print(f"  类型: list, 长度: {len(value)}")
#             if len(value) > 0 and isinstance(value[0], torch.Tensor):
#                 print(f"  子元素类型: torch.Tensor")
#                 print(f"  第一个子元素形状: {value[0].shape}")
#                 print(f"  第一个子元素设备: {value[0].device}")
#                 print(f"  第一个子元素前5个元素: {value[0].flatten()[:5]}")
#             else:
#                 print(f"  子元素类型: {type(value[0]) if len(value)>0 else '空'}")
#                 print(f"  前5个元素: {value[:5]}")
#         else:
#             print(f"  类型: {type(value)}")
#             print(f"  值: {value}")

# analyze_dict(train_new_dict, "train_new_dict")
# analyze_dict(new_dict, "test_new_dict")


In [29]:
# 在模型forward之前添加设备检查
def check_tensor_devices(data_dict, name=""):
    print(f"\n=== 设备检查 {name} ===")
    for key, value in data_dict.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.device} - shape: {value.shape}")
        elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
            devices = [t.device for t in value]
            print(f"{key}: {devices} - lengths: {[len(t) for t in value]}")

# 检查输入数据的设备
check_tensor_devices(new_dict, "输入数据")

# 检查模型参数设备
print(f"\n=== 模型参数设备 ===")
for name, param in model.named_parameters():
    if 'text' in name.lower() or 'embed' in name.lower():
        print(f"{name}: {param.device}")
        break  # 只打印几个关键参数


=== 设备检查 输入数据 ===
scene_pc: cuda:0 - shape: torch.Size([5, 10000, 6])
object_mask: cuda:0 - shape: torch.Size([5, 10000])
se3: cuda:0 - shape: torch.Size([5, 64, 4, 4])
hand_model_pose: cuda:0 - shape: torch.Size([5, 64, 23])

=== 模型参数设备 ===
eps_model.time_embed.0.weight: cuda:0


In [30]:
# 检查文本编码器的所有组件
print(f"\n=== 文本编码器详细设备检查 ===")
if hasattr(model.eps_model, 'text_processor'):
    text_processor = model.eps_model.text_processor
    print(f"text_processor类型: {type(text_processor)}")
    
    # 检查tokenizer相关组件
    if hasattr(text_processor, 'tokenizer'):
        print(f"tokenizer类型: {type(text_processor.tokenizer)}")
        
    # 检查所有参数和缓冲区
    for name, param in text_processor.named_parameters():
        print(f"  参数 {name}: {param.device}")
    
    for name, buffer in text_processor.named_buffers():
        print(f"  缓冲区 {name}: {buffer.device}")

# 强制将文本编码器移动到正确设备
if hasattr(model.eps_model, 'text_processor'):
    model.eps_model.text_processor = model.eps_model.text_processor.to(DEVICE)
    print(f"✅ 文本编码器已移动到 {DEVICE}")


=== 文本编码器详细设备检查 ===
text_processor类型: <class 'models.utils.text_encoder.TextConditionProcessor'>
✅ 文本编码器已移动到 cuda:0


In [31]:
matched_preds, matched_targets, outputs, targets = model.forward_get_pose_matched(new_dict, k=1)


Converting mask without torch.bool dtype to bool; this will negatively affect performance. Prefer to use a boolean mask directly. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343962757/work/aten/src/ATen/native/transformers/attention.cpp:150.)



=== Matcher Results ===
Cost Matrix Shape: (5, 64, 64)
Batch 0 Matches:
  Matched Queries: 64/64
  Query 0 -> Target 9
  Query 1 -> Target 52
  Query 2 -> Target 1
  Query 3 -> Target 8
  Query 4 -> Target 25
  Query 5 -> Target 26
  Query 6 -> Target 7
  Query 7 -> Target 53
  Query 8 -> Target 20
  Query 9 -> Target 36
  Query 10 -> Target 61
  Query 11 -> Target 22
  Query 12 -> Target 19
  Query 13 -> Target 63
  Query 14 -> Target 28
  Query 15 -> Target 6
  Query 16 -> Target 24
  Query 17 -> Target 11
  Query 18 -> Target 5
  Query 19 -> Target 3
  Query 20 -> Target 30
  Query 21 -> Target 51
  Query 22 -> Target 0
  Query 23 -> Target 47
  Query 24 -> Target 35
  Query 25 -> Target 2
  Query 26 -> Target 55
  Query 27 -> Target 18
  Query 28 -> Target 46
  Query 29 -> Target 4
  Query 30 -> Target 29
  Query 31 -> Target 48
  Query 32 -> Target 49
  Query 33 -> Target 32
  Query 34 -> Target 10
  Query 35 -> Target 57
  Query 36 -> Target 23
  Query 37 -> Target 42
  Query 38 

In [32]:
print("=== matched_preds 数据项和形状 ===")
for key, value in matched_preds.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: 形状 {value.shape}, 设备 {value.device}")
    elif isinstance(value, list):
        print(f"{key}: 列表，长度 {len(value)}，元素类型 {type(value[0]) if len(value)>0 else '未知'}")
    else:
        print(f"{key}: 类型 {type(value)}")

print("\n=== matched_targets 数据项和形状 ===")
for key, value in matched_targets.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: 形状 {value.shape}, 设备 {value.device}")
    elif isinstance(value, list):
        print(f"{key}: 列表，长度 {len(value)}，元素类型 {type(value[0]) if len(value)>0 else '未知'}")
    else:
        print(f"{key}: 类型 {type(value)}")

=== matched_preds 数据项和形状 ===
pred_pose_norm: 形状 torch.Size([5, 64, 25]), 设备 cuda:0
hand_model_pose: 形状 torch.Size([5, 64, 25]), 设备 cuda:0

=== matched_targets 数据项和形状 ===
norm_pose: 形状 torch.Size([5, 64, 25]), 设备 cuda:0
hand_model_pose: 形状 torch.Size([5, 64, 25]), 设备 cuda:0
scene_pc: 形状 torch.Size([5, 10000, 6]), 设备 cuda:0


In [33]:
print("outputs['hand'] 是字典，包含以下数据项及其形状：")
for key, value in outputs['hand'].items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: 形状 {value.shape}, 设备 {value.device}")
    else:
        print(f"  {key}: 类型 {type(value)}")

print("targets['hand'] 是字典，包含以下数据项及其形状：")
for key, value in targets['hand'].items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: 形状 {value.shape}, 设备 {value.device}")
    else:
        print(f"  {key}: 类型 {type(value)}")

outputs['hand'] 是字典，包含以下数据项及其形状：
  surface_points: 形状 torch.Size([320, 1024, 3]), 设备 cuda:0
  contact_candidates_dis: 形状 torch.Size([320, 10000]), 设备 cuda:0
  vertices: 形状 torch.Size([320, 29152, 3]), 设备 cuda:0
  faces: 形状 torch.Size([12996, 3]), 设备 cuda:0
targets['hand'] 是字典，包含以下数据项及其形状：
  surface_points: 形状 torch.Size([320, 1024, 3]), 设备 cuda:0
  contact_candidates_dis: 形状 torch.Size([320, 10000]), 设备 cuda:0
  vertices: 形状 torch.Size([320, 29152, 3]), 设备 cuda:0
  faces: 形状 torch.Size([12996, 3]), 设备 cuda:0


In [34]:
# 生成红色系和蓝色系的颜色列表
def get_color_list(base_color, num):
    # base_color: 'red' or 'blue'
    if base_color == 'red':
        cmap = matplotlib.colormaps['Reds']
    elif base_color == 'blue':
        cmap = matplotlib.colormaps['Blues']
    else:
        raise ValueError("base_color必须为'red'或'blue'")
    # 处理num为0或1的情况，避免除以0
    if num <= 0:
        return []
    elif num == 1:
        # 取中间色
        color_list = [matplotlib.colors.rgb2hex(cmap(0.6))]
    else:
        # 避免太浅或太深，取0.3~0.9区间
        color_list = [matplotlib.colors.rgb2hex(cmap(0.3 + 0.6 * i/(num-1))) for i in range(num)]
    return color_list

num_grasps = VAL_NUM_GRASPS  # 假设num_grasps为8，如需动态获取可替换为变量
red_colors = get_color_list('red', num_grasps)
blue_colors = get_color_list('blue', num_grasps)

In [35]:


# for i in range(BATCH_SIZE):
#     fig = go.Figure()
    
#     # 点云数据处理与绘制
#     scene_pc = test_ret_dict['scene_pc'][i].cpu()
#     rgb = (scene_pc[:,3:6] * 255).int() if scene_pc[:,3:6].max() <= 1.0 else scene_pc[:,3:6].int()
#     fig.add_trace(go.Scatter3d(
#         x=scene_pc[:,0], y=scene_pc[:,1], z=scene_pc[:,2], mode='markers',
#         marker=dict(size=2, color=[f'rgb({r},{g},{b})' for r,g,b in rgb.tolist()], opacity=0.8),
#         name='场景点云'
#     ))

#     obj_vertices, obj_faces = test_ret_dict['obj_verts'][i].cpu(), test_ret_dict['obj_faces'][i].cpu()
#     print(test_ret_dict['scene_id'][i])
#     print(test_ret_dict['obj_code'][i])
    
#     # 创建物体网格
#     meshes = [
#         go.Mesh3d(x=obj_vertices[:,0], y=obj_vertices[:,1], z=obj_vertices[:,2], i=obj_faces[:,0], j=obj_faces[:,1], 
#                   k=obj_faces[:,2], color=get_random_color(), opacity=1.0, name='object mesh')
#     ]
#     for mesh in meshes:
#         fig.add_trace(mesh)

#     # 可视化outputs['hand']和targets['hand']
#     hand_outputs = outputs['hand']
#     hand_targets = targets['hand']

#     # 注意：根据最新的数据结构，vertices形状为[40, 29152, 3]，faces为[12996, 3]
#     # 这里假设num_grasps * BATCH_SIZE = 40
#     # faces为所有手网格共用的索引，不再是每个抓取单独的faces

#     hand_verts_all = hand_outputs['vertices'].cpu()  # [40, 29152, 3]
#     hand_faces = hand_outputs['faces'].cpu()         # [12996, 3]
#     target_verts_all = hand_targets['vertices'].cpu()
#     target_faces = hand_targets['faces'].cpu()

#     for j in range(num_grasps):
#         idx = i * num_grasps + j

#         # 预测手部（蓝色系，每个抓取不同色）
#         hand_verts = hand_verts_all[idx]  # [29152, 3]
#         fig.add_trace(go.Mesh3d(
#             x=hand_verts[:,0], y=hand_verts[:,1], z=hand_verts[:,2],
#             i=hand_faces[:,0], j=hand_faces[:,1], k=hand_faces[:,2],
#             color=blue_colors[j], opacity=0.5, name=f'预测手部_{j+1}'
#         ))

#         # 目标手部（红色系，每个抓取不同色）
#         target_verts = target_verts_all[idx]
#         fig.add_trace(go.Mesh3d(
#             x=target_verts[:,0], y=target_verts[:,1], z=target_verts[:,2],
#             i=target_faces[:,0], j=target_faces[:,1], k=target_faces[:,2],
#             color=red_colors[j], opacity=0.3, name=f'目标手部_{j+1}'
#         ))
        
#     fig.update_layout(
#         scene=dict(aspectmode='data'), width=800, height=800,
#         title='场景点云、手部网格、目标手部网格与物体网格可视化',
#         updatemenus=[dict(type="buttons", direction="right", x=0.7, y=1.2)]
#     )
#     fig.show()

In [36]:
num_grasps = VAL_NUM_GRASPS
for i in range(BATCH_SIZE):
    fig = go.Figure()

    # 场景点云
    scene_pc = test_ret_dict['scene_pc'][i].cpu()
    rgb = (scene_pc[:,3:6] * 255).int() if scene_pc[:,3:6].max() <= 1.0 else scene_pc[:,3:6].int()
    fig.add_trace(go.Scatter3d(
        x=scene_pc[:,0], y=scene_pc[:,1], z=scene_pc[:,2], mode='markers',
        marker=dict(size=2, color=[f'rgb({r},{g},{b})' for r,g,b in rgb.tolist()], opacity=0.8),
        name='场景点云'
    ))

    # 物体网格
    obj_vertices, obj_faces = test_ret_dict['obj_verts'][i].cpu(), test_ret_dict['obj_faces'][i].cpu()
    print(test_ret_dict['scene_id'][i])
    print(test_ret_dict['obj_code'][i])
    fig.add_trace(go.Mesh3d(
        x=obj_vertices[:,0], y=obj_vertices[:,1], z=obj_vertices[:,2],
        i=obj_faces[:,0], j=obj_faces[:,1], k=obj_faces[:,2],
        color=get_random_color(), opacity=1.0, name='object mesh'
    ))

    # 只可视化 hand_model_pose 的前三维坐标
    preds = matched_preds['hand_model_pose'].cpu()    # shape [B, num_grasps, 6]
    tgts  = matched_targets['hand_model_pose'].cpu()  # shape [B, num_grasps, 6]
    for j in range(num_grasps):
        px, py, pz = preds[i, j, :3]
        tx, ty, tz = tgts[i,  j, :3]
        # 预测点（蓝色）
        fig.add_trace(go.Scatter3d(
            x=[px], y=[py], z=[pz], mode='markers',
            marker=dict(size=5, color=blue_colors[j]),
            name=f'预测手部点_{j+1}'
        ))
        # 目标点（红色）
        fig.add_trace(go.Scatter3d(
            x=[tx], y=[ty], z=[tz], mode='markers',
            marker=dict(size=5, color=red_colors[j]),
            name=f'目标手部点_{j+1}'
        ))

    fig.update_layout(
        scene=dict(aspectmode='data'),
        width=800, height=800,
        title='场景、物体和手部位姿点可视化',
        updatemenus=[dict(type="buttons", direction="right", x=0.7, y=1.2)]
    )
    fig.show()


scene_981249816_2097087825
a_gray_box_8dc1c973268e4831871544a31077704d


scene_663037648_19493478_initial_collisions
a_grey_sphere_ba7e1f48b70541dcbe19c53fff5abe59


scene_529544505_553660571
a_wooden_table_5cc921060e2244409aa2b878224f8f13


scene_182815798_4175994613
a_green_apple_e53e4c9bec39456c8fa49b5757d940e5


scene_150095602_3125852469
a_tree_trunk_b2e8047b4e254662a42be38281ddbcf4


In [37]:

# for i in range(BATCH_SIZE):
#     fig = go.Figure()
    
#     # 点云数据处理与绘制
#     scene_pc = train_ret_dict['scene_pc'][i].cpu()
#     rgb = (scene_pc[:,3:6] * 255).int() if scene_pc[:,3:6].max() <= 1.0 else scene_pc[:,3:6].int()
#     fig.add_trace(go.Scatter3d(
#         x=scene_pc[:,0], y=scene_pc[:,1], z=scene_pc[:,2], mode='markers',
#         marker=dict(size=2, color=[f'rgb({r},{g},{b})' for r,g,b in rgb.tolist()], opacity=0.8),
#         name='场景点云'
#     ))

    
#     obj_vertices, obj_faces = train_ret_dict['obj_verts'][i].cpu(), train_ret_dict['obj_faces'][i].cpu()
    
    
#     # 创建网格
#     meshes = [
#         go.Mesh3d(x=obj_vertices[:,0], y=obj_vertices[:,1], z=obj_vertices[:,2], i=obj_faces[:,0], j=obj_faces[:,1], 
#                   k=obj_faces[:,2], color=get_random_color(), opacity=1.0, name='object mesh')
#     ]
#     for mesh in meshes:
#         fig.add_trace(mesh)
        
#     fig.update_layout(
#         scene=dict(aspectmode='data'), width=800, height=800,
#         title='场景点云、手部网格、目标手部网格与物体网格可视化',
#         updatemenus=[dict(type="buttons", direction="right", x=0.7, y=1.2)]
#     )
#     fig.show()

In [38]:
test_ret_dict['hand_model_pose'].shape

torch.Size([5, 64, 23])

In [39]:
# # 可视化场景点云和手部位置
# fig = go.Figure()
# hand_positions = test_ret_dict['hand_model_pose'][0, :, :3].cpu()
# scene_pc = test_ret_dict['scene_pc'][0].cpu()

# # 添加点云和手部位置
# fig.add_trace(go.Scatter3d(
#     x=scene_pc[:,0], 
#     y=scene_pc[:,1], 
#     z=scene_pc[:,2],
#     mode='markers', 
#     marker=dict(
#         size=2, 
#         color=[f'rgb({r},{g},{b})' for r,g,b in zip(scene_pc[:,3], scene_pc[:,4], scene_pc[:,5])],
#         opacity=0.8
#     ), 
#     name='场景点云'
# ))
# fig.add_trace(go.Scatter3d(
#     x=hand_positions[:,0], 
#     y=hand_positions[:,1], 
#     z=hand_positions[:,2],
#     mode='markers', 
#     marker=dict(size=4, color='red', opacity=0.8), 
#     name='手部位置'
# ))

# fig.update_layout(scene=dict(aspectmode='data'), width=800, height=800, title='手部位置与场景点云可视化')
# fig.show()
