# Flow Matching 模型可视化

本notebook用于可视化Flow Matching模型的预测结果：
1. 场景点云 + 预测抓取分布 + 目标抓取分布
2. 匹配后的抓取3D模型 + 场景点云
3. 使用ObjectCentric数据集（无RGB，无Object Mask）


In [None]:
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"  # 设置使用的GPU
import numpy as np
from torch.utils.data import DataLoader
from utils.color_utils import get_random_color
from models.fm_lightning import FlowMatchingLightning
from utils.hand_model import HandModel, HandModelType
from datasets import build_datasets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
import trimesh
from omegaconf import OmegaConf
from utils.hand_helper import norm_hand_pose_robust, denorm_hand_pose_robust
import matplotlib.pyplot as plt

print("✅ 导入完成")


## 1. 配置和模型加载


In [None]:
# 设置路径和常量
CKPT_PATH = 'experiments/fm_objcentric/checkpoints/epoch=479-val_loss=7.80.ckpt'  # FM checkpoint路径
CONFIG_PATH = 'experiments/fm_objcentric/config/whole_config.yaml'
DEVICE = 'cuda:0'
BATCH_SIZE = 1  # 单个样本便于可视化
NUM_GRASPS = 64

print(f"Checkpoint: {CKPT_PATH}")
print(f"Config: {CONFIG_PATH}")
print(f"Device: {DEVICE}")


In [None]:
# 加载配置
cfg = OmegaConf.load(CONFIG_PATH)
cfg = OmegaConf.create(cfg)

# 调整matcher的cost权重（可选）
cfg.model.criterion.cost_weights.translation = 100.0
cfg.model.criterion.cost_weights.rotation = 10.0

print("配置信息:")
print(f"  模型: {cfg.model.name}")
print(f"  数据: {cfg.data_cfg.name}")
print(f"  旋转类型: {cfg.model.rot_type}")
print(f"  坐标系: {cfg.model.mode}")
print(f"  预测模式: {cfg.model.decoder.pred_mode}")
print(f"  求解器: {cfg.model.solver.type}, NFE={cfg.model.solver.nfe}")


In [None]:
# 创建Flow Matching模型
model = FlowMatchingLightning(cfg.model)

# 强制初始化text_encoder（如果使用文本条件）
if hasattr(model.model, '_ensure_text_encoder') and cfg.model.decoder.use_text_condition:
    print("正在初始化text_encoder...")
    model.model._ensure_text_encoder()
    print("✅ text_encoder初始化完成")

# 加载checkpoint
if os.path.exists(CKPT_PATH):
    checkpoint = torch.load(CKPT_PATH, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    print(f"✅ 模型加载成功 (strict=True)")
else:
    print(f"⚠️  Checkpoint不存在: {CKPT_PATH}")
    print("   将使用未训练的模型进行可视化")

# 移动到设备并设置为评估模式
model.to(DEVICE).eval()
print(f"✅ 模型已移动到 {DEVICE} 并设置为评估模式")

# 创建hand model用于可视化
hand_model = HandModel(
    HandModelType.LEAP, 
    cfg.model.criterion.hand_model.n_surface_points, 
    cfg.model.criterion.rot_type, 
    DEVICE
)
print(f"✅ Hand Model创建完成")


## 2. 数据加载



In [None]:
# 构建数据集（使用objectcentric配置）
print("加载ObjectCentric数据集...")
train_dataset, val_dataset, test_dataset = build_datasets(cfg.data_cfg, stage='fit')

print(f"✅ 数据集加载完成")
print(f"   训练集: {len(train_dataset)}")
print(f"   验证集: {len(val_dataset)}")

# 创建dataloader
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=val_dataset.collate_fn
)

print(f"✅ DataLoader创建完成，batch_size={BATCH_SIZE}")


In [None]:
# 获取一个batch用于可视化
batch = next(iter(val_loader))

# 移动到device
for key in batch:
    if isinstance(batch[key], torch.Tensor):
        batch[key] = batch[key].to(DEVICE)

print("Batch信息:")
print(f"  场景点云: {batch['scene_pc'].shape}")
print(f"  手部姿态: {batch['hand_model_pose'].shape}")
print(f"  SE3: {batch['se3'].shape}")
if 'positive_prompt' in batch:
    print(f"  物体名称: {batch['positive_prompt']}")


## 3. Flow Matching采样生成预测


In [None]:
# 使用Flow Matching进行采样
print("开始Flow Matching采样...")
print(f"  求解器: {cfg.model.solver.type}")
print(f"  NFE: {cfg.model.solver.nfe}")
print(f"  CFG: {cfg.model.guidance.enable_cfg}")

with torch.no_grad():
    # 使用sample方法进行ODE采样
    pred_x0 = model.sample(batch, k=1)  # [B, num_grasps, D]

print(f"✅ 采样完成")
print(f"   预测姿态: {pred_x0.shape}")
print(f"   范围: [{pred_x0.min():.3f}, {pred_x0.max():.3f}]")


In [None]:
# 反归一化得到实际手部姿态
from utils.hand_helper import denorm_hand_pose_robust

pred_hand_pose = denorm_hand_pose_robust(pred_x0, cfg.model.rot_type, cfg.model.mode)
target_hand_pose = batch['hand_model_pose']

print("手部姿态:")
print(f"  预测: {pred_hand_pose.shape}")
print(f"  目标: {target_hand_pose.shape}")


## 4. 生成手部表面点云


In [None]:
# 生成预测手部的表面点
B, num_grasps, _ = pred_hand_pose.shape

print(f"生成 {B} x {num_grasps} = {B*num_grasps} 个手部模型...")

# 预测手部
pred_hand_dict = hand_model(
    pred_hand_pose.reshape(B * num_grasps, -1),
    with_surface_points=True,
    with_meshes=True
)

# Reshape回[B, num_grasps, ...]
pred_surface_points = pred_hand_dict['surface_points'].reshape(B, num_grasps, -1, 3)
pred_meshes = pred_hand_dict['meshes']  # List of meshes

print(f"✅ 预测手部表面点: {pred_surface_points.shape}")

# 目标手部
target_hand_dict = hand_model(
    target_hand_pose.reshape(B * num_grasps, -1),
    with_surface_points=True,
    with_meshes=True
)

target_surface_points = target_hand_dict['surface_points'].reshape(B, num_grasps, -1, 3)
target_meshes = target_hand_dict['meshes']

print(f"✅ 目标手部表面点: {target_surface_points.shape}")


## 5. 可视化1：场景点云 + 预测抓取分布 + 目标抓取分布


In [None]:
def visualize_grasps_distribution(scene_pc, pred_points, target_points, sample_idx=0, 
                                   num_grasps_to_show=64, title="Grasp Distribution"):
    """
    可视化场景点云和抓取分布
    
    Args:
        scene_pc: 场景点云 [B, N, 3]
        pred_points: 预测手部表面点 [B, num_grasps, N_hand, 3]
        target_points: 目标手部表面点 [B, num_grasps, N_hand, 3]
        sample_idx: 要可视化的样本索引
        num_grasps_to_show: 显示的抓取数量
    """
    fig = go.Figure()
    
    # 1. 场景点云（灰色）
    scene = scene_pc[sample_idx].cpu().numpy()
    fig.add_trace(go.Scatter3d(
        x=scene[:, 0],
        y=scene[:, 1],
        z=scene[:, 2],
        mode='markers',
        marker=dict(size=2, color='lightgray', opacity=0.3),
        name='场景点云'
    ))
    
    # 2. 预测抓取分布（蓝色）
    for i in range(min(num_grasps_to_show, pred_points.shape[1])):
        pred_hand = pred_points[sample_idx, i].cpu().numpy()
        fig.add_trace(go.Scatter3d(
            x=pred_hand[:, 0],
            y=pred_hand[:, 1],
            z=pred_hand[:, 2],
            mode='markers',
            marker=dict(size=1, color='blue', opacity=0.4),
            name=f'预测抓取{i}' if i < 3 else None,
            showlegend=(i < 3),
            legendgroup='pred'
        ))
    
    # 3. 目标抓取分布（红色）
    for i in range(min(num_grasps_to_show, target_points.shape[1])):
        target_hand = target_points[sample_idx, i].cpu().numpy()
        fig.add_trace(go.Scatter3d(
            x=target_hand[:, 0],
            y=target_hand[:, 1],
            z=target_hand[:, 2],
            mode='markers',
            marker=dict(size=1, color='red', opacity=0.4),
            name=f'目标抓取{i}' if i < 3 else None,
            showlegend=(i < 3),
            legendgroup='target'
        ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            aspectmode='data'
        ),
        width=1000,
        height=800
    )
    
    return fig


In [None]:
# 可视化抓取分布
fig = visualize_grasps_distribution(
    batch['scene_pc'],
    pred_surface_points,
    target_surface_points,
    sample_idx=0,
    num_grasps_to_show=64,
    title="Flow Matching: 场景点云 + 预测抓取(蓝) + 目标抓取(红)"
)
fig.show()


## 6. 匹配后的抓取可视化


In [None]:
# 使用matcher进行最优匹配
from models.utils.prediction import build_pred_dict_adaptive
from utils.hand_helper import process_hand_pose_test

# 准备数据
batch_processed = process_hand_pose_test(batch, rot_type=cfg.model.rot_type, mode=cfg.model.mode)

# 构建预测字典
pred_dict = build_pred_dict_adaptive(pred_x0)

# 使用criterion的matcher进行匹配
from models.utils.pose_processor import PoseProcessor

pose_processor = PoseProcessor(hand_model, cfg.model.rot_type, cfg.model.mode)

# 获取匹配
outputs = pose_processor.get_hand_model_pose_test(pred_dict)
assignments = model.criterion.matcher(outputs, batch_processed)
matched_preds, matched_targets = pose_processor.get_matched_by_assignment(
    outputs, batch_processed, assignments
)

print("匹配结果:")
print(f"  per_query_gt_inds: {assignments['per_query_gt_inds'].shape}")
print(f"  匹配的预测: {matched_preds['hand_model_pose'].shape}")
print(f"  匹配的目标: {matched_targets['hand_model_pose'].shape}")

# 查看前5个匹配索引
print("\\n前5个抓取的匹配索引:")
print(assignments['per_query_gt_inds'][0, :5].cpu().numpy())


In [None]:
def visualize_matched_grasp_pair(scene_pc, pred_hand_pose, target_hand_pose, 
                                 hand_model, grasp_idx=0, sample_idx=0):
    """
    可视化匹配后的一对抓取（预测 vs 目标）
    
    Args:
        scene_pc: 场景点云 [B, N, 3]
        pred_hand_pose: 预测手部姿态 [B, num_grasps, 23]
        target_hand_pose: 目标手部姿态 [B, num_grasps, 23]
        hand_model: HandModel实例
        grasp_idx: 要可视化的抓取索引
        sample_idx: batch中的样本索引
    """
    fig = go.Figure()
    
    # 场景点云
    scene = scene_pc[sample_idx].cpu().numpy()
    fig.add_trace(go.Scatter3d(
        x=scene[:, 0],
        y=scene[:, 1],
        z=scene[:, 2],
        mode='markers',
        marker=dict(size=2, color='lightgray', opacity=0.5),
        name='场景'
    ))
    
    # 预测手部
    pred_pose = pred_hand_pose[sample_idx, grasp_idx].unsqueeze(0)
    pred_hand = hand_model(pred_pose, with_surface_points=True, with_meshes=True)
    pred_points = pred_hand['surface_points'][0].cpu().numpy()
    
    fig.add_trace(go.Scatter3d(
        x=pred_points[:, 0],
        y=pred_points[:, 1],
        z=pred_points[:, 2],
        mode='markers',
        marker=dict(size=3, color='blue', opacity=0.8),
        name=f'预测抓取#{grasp_idx}'
    ))
    
    # 目标手部
    target_pose = target_hand_pose[sample_idx, grasp_idx].unsqueeze(0)
    target_hand = hand_model(target_pose, with_surface_points=True, with_meshes=True)
    target_points = target_hand['surface_points'][0].cpu().numpy()
    
    fig.add_trace(go.Scatter3d(
        x=target_points[:, 0],
        y=target_points[:, 1],
        z=target_points[:, 2],
        mode='markers',
        marker=dict(size=3, color='red', opacity=0.8),
        name=f'目标抓取#{grasp_idx}'
    ))
    
    # 添加手腕位置连线
    pred_wrist = pred_pose[0, :3].cpu().numpy()
    target_wrist = target_pose[0, :3].cpu().numpy()
    
    fig.add_trace(go.Scatter3d(
        x=[pred_wrist[0], target_wrist[0]],
        y=[pred_wrist[1], target_wrist[1]],
        z=[pred_wrist[2], target_wrist[2]],
        mode='lines+markers',
        line=dict(color='yellow', width=5),
        marker=dict(size=5, color='yellow'),
        name='手腕对应'
    ))
    
    fig.update_layout(
        title=f"匹配的抓取对 #{grasp_idx}",
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            aspectmode='data'
        ),
        width=1000,
        height=800
    )
    
    return fig


In [None]:
# 可视化几个匹配对
for grasp_idx in [0, 5, 10]:
    fig = visualize_matched_grasp_pair(
        batch['scene_pc'],
        matched_preds['hand_model_pose'],
        matched_targets['hand_model_pose'],
        hand_model,
        grasp_idx=grasp_idx
    )
    fig.show()
    print(f"\\n抓取 #{grasp_idx} 可视化完成")


## 7. 可视化2：带Mesh的抓取可视化


In [None]:
def visualize_hand_mesh_with_scene(scene_pc, hand_mesh, title="Hand Mesh", sample_idx=0):
    """
    使用trimesh可视化手部mesh和场景
    
    Args:
        scene_pc: 场景点云 [N, 3]
        hand_mesh: trimesh.Trimesh对象
        title: 标题
    """
    fig = go.Figure()
    
    # 场景点云
    scene = scene_pc[sample_idx].cpu().numpy()
    fig.add_trace(go.Scatter3d(
        x=scene[:, 0],
        y=scene[:, 1],
        z=scene[:, 2],
        mode='markers',
        marker=dict(size=2, color='gray', opacity=0.4),
        name='场景'
    ))
    
    # 手部mesh
    vertices = hand_mesh.vertices
    faces = hand_mesh.faces
    
    fig.add_trace(go.Mesh3d(
        x=vertices[:, 0],
        y=vertices[:, 1],
        z=vertices[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        color='lightblue',
        opacity=0.8,
        name='手部Mesh'
    ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            aspectmode='data'
        ),
        width=1000,
        height=800
    )
    
    return fig


In [None]:
# 可视化几个预测的手部mesh
for i in [0, 10, 20]:
    mesh_idx = i  # pred_meshes是扁平的列表
    fig = visualize_hand_mesh_with_scene(
        batch['scene_pc'],
        pred_meshes[mesh_idx],
        title=f"预测抓取 #{i} - Mesh可视化"
    )
    fig.show()


## 8. 对比可视化：并排显示预测和目标


In [None]:
def visualize_pred_vs_target_sidebyside(scene_pc, pred_mesh, target_mesh, 
                                         grasp_idx=0, sample_idx=0):
    """
    并排对比预测和目标抓取
    """
    from plotly.subplots import make_subplots
    
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        subplot_titles=('预测抓取', '目标抓取')
    )
    
    scene = scene_pc[sample_idx].cpu().numpy()
    
    # 左图：预测
    fig.add_trace(go.Scatter3d(
        x=scene[:, 0], y=scene[:, 1], z=scene[:, 2],
        mode='markers',
        marker=dict(size=2, color='gray', opacity=0.3),
        name='场景',
        showlegend=True
    ), row=1, col=1)
    
    pred_verts = pred_mesh.vertices
    pred_faces = pred_mesh.faces
    fig.add_trace(go.Mesh3d(
        x=pred_verts[:, 0], y=pred_verts[:, 1], z=pred_verts[:, 2],
        i=pred_faces[:, 0], j=pred_faces[:, 1], k=pred_faces[:, 2],
        color='blue', opacity=0.7,
        name='预测手部'
    ), row=1, col=1)
    
    # 右图：目标
    fig.add_trace(go.Scatter3d(
        x=scene[:, 0], y=scene[:, 1], z=scene[:, 2],
        mode='markers',
        marker=dict(size=2, color='gray', opacity=0.3),
        name='场景',
        showlegend=False
    ), row=1, col=2)
    
    target_verts = target_mesh.vertices
    target_faces = target_mesh.faces
    fig.add_trace(go.Mesh3d(
        x=target_verts[:, 0], y=target_verts[:, 1], z=target_verts[:, 2],
        i=target_faces[:, 0], j=target_faces[:, 1], k=target_faces[:, 2],
        color='red', opacity=0.7,
        name='目标手部'
    ), row=1, col=2)
    
    fig.update_layout(
        title=f"匹配的抓取对 #{grasp_idx} 对比",
        scene=dict(aspectmode='data'),
        scene2=dict(aspectmode='data'),
        width=1600,
        height=700
    )
    
    return fig


In [None]:
# 并排对比几个匹配对
for i in [0, 5, 10]:
    fig = visualize_pred_vs_target_sidebyside(
        batch['scene_pc'],
        pred_meshes[i],
        target_meshes[i],
        grasp_idx=i
    )
    fig.show()


## 9. 统计分析：Chamfer距离


In [None]:
# 计算匹配后的误差统计
from pytorch3d.loss import chamfer_distance

# 生成匹配后的手部表面点
matched_pred_hand = hand_model(
    matched_preds['hand_model_pose'].reshape(-1, 23),
    with_surface_points=True
)
matched_target_hand = hand_model(
    matched_targets['hand_model_pose'].reshape(-1, 23),
    with_surface_points=True
)

matched_pred_points = matched_pred_hand['surface_points'].reshape(B, num_grasps, -1, 3)
matched_target_points = matched_target_hand['surface_points'].reshape(B, num_grasps, -1, 3)

# 计算每对的Chamfer距离
chamfer_dists = []
for i in range(num_grasps):
    dist = chamfer_distance(
        matched_pred_points[:, i:i+1],
        matched_target_points[:, i:i+1],
        point_reduction="mean"
    )[0]
    chamfer_dists.append(dist.item())

chamfer_dists = np.array(chamfer_dists)

print("Chamfer距离统计 (匹配后):")
print(f"  平均: {chamfer_dists.mean():.4f}")
print(f"  最小: {chamfer_dists.min():.4f}")
print(f"  最大: {chamfer_dists.max():.4f}")
print(f"  中位数: {np.median(chamfer_dists):.4f}")
print(f"  标准差: {chamfer_dists.std():.4f}")


In [None]:
# 绘制Chamfer距离分布直方图
fig = go.Figure()
fig.add_trace(go.Histogram(
    x=chamfer_dists,
    nbinsx=30,
    name='Chamfer Distance',
    marker_color='blue'
))

fig.update_layout(
    title="匹配后Chamfer距离分布",
    xaxis_title="Chamfer Distance",
    yaxis_title="Count",
    width=800,
    height=400
)

fig.show()

# 显示最好和最差的抓取索引
best_idx = chamfer_dists.argmin()
worst_idx = chamfer_dists.argmax()

print(f"\\n最佳抓取: #{best_idx}, Chamfer={chamfer_dists[best_idx]:.4f}")
print(f"最差抓取: #{worst_idx}, Chamfer={chamfer_dists[worst_idx]:.4f}")


In [None]:
# 可视化最佳和最差的匹配对
print("\\n=== 最佳匹配抓取 ===")
fig_best = visualize_matched_grasp_pair(
    batch['scene_pc'],
    matched_preds['hand_model_pose'],
    matched_targets['hand_model_pose'],
    hand_model,
    grasp_idx=best_idx
)
fig_best.show()

print("\\n=== 最差匹配抓取 ===")
fig_worst = visualize_matched_grasp_pair(
    batch['scene_pc'],
    matched_preds['hand_model_pose'],
    matched_targets['hand_model_pose'],
    hand_model,
    grasp_idx=worst_idx
)
fig_worst.show()


## 10. 姿态误差分析


In [None]:
# 分解姿态并计算各部分误差
matched_pred_pose = matched_preds['hand_model_pose'][0]  # [num_grasps, 23]
matched_target_pose = matched_targets['hand_model_pose'][0]  # [num_grasps, 23]

# 提取各部分
pred_trans = matched_pred_pose[:, :3]
pred_qpos = matched_pred_pose[:, 3:19]
pred_rot = matched_pred_pose[:, 19:23]

target_trans = matched_target_pose[:, :3]
target_qpos = matched_target_pose[:, 3:19]
target_rot = matched_target_pose[:, 19:23]

# 计算误差
trans_error = torch.norm(pred_trans - target_trans, dim=-1).cpu().numpy()
qpos_error = torch.norm(pred_qpos - target_qpos, dim=-1).cpu().numpy()

# 旋转误差（四元数）
from pytorch3d.transforms import quaternion_to_matrix
pred_rot_mat = quaternion_to_matrix(pred_rot)
target_rot_mat = quaternion_to_matrix(target_rot)
R_diff = torch.bmm(pred_rot_mat.transpose(-2, -1), target_rot_mat)
trace = R_diff.diagonal(dim1=-2, dim2=-1).sum(-1)
rot_error = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)).cpu().numpy() * 180 / np.pi

print("姿态误差统计:")
print(f"\\n平移误差 (m):")
print(f"  平均: {trans_error.mean():.4f}")
print(f"  中位数: {np.median(trans_error):.4f}")
print(f"  最大: {trans_error.max():.4f}")

print(f"\\n旋转误差 (度):")
print(f"  平均: {rot_error.mean():.4f}")
print(f"  中位数: {np.median(rot_error):.4f}")
print(f"  最大: {rot_error.max():.4f}")

print(f"\\n关节角误差:")
print(f"  平均: {qpos_error.mean():.4f}")
print(f"  中位数: {np.median(qpos_error):.4f}")
print(f"  最大: {qpos_error.max():.4f}")


In [None]:
# 绘制误差分布
fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=('平移误差分布', '旋转误差分布', '关节角误差分布')
)

fig.add_trace(go.Histogram(x=trans_error, nbinsx=20, name='平移', marker_color='blue'), row=1, col=1)
fig.add_trace(go.Histogram(x=rot_error, nbinsx=20, name='旋转', marker_color='green'), row=1, col=2)
fig.add_trace(go.Histogram(x=qpos_error, nbinsx=20, name='关节角', marker_color='red'), row=1, col=3)

fig.update_layout(
    title="匹配后姿态误差分布",
    width=1400,
    height=400,
    showlegend=False
)

fig.show()


## 11. 批量可视化（多个样本）


In [None]:
# 可视化多个样本的抓取质量
num_samples_to_vis = 5
all_chamfer_dists = []

for batch_idx, batch_sample in enumerate(val_loader):
    if batch_idx >= num_samples_to_vis:
        break
    
    # 移动到device
    for key in batch_sample:
        if isinstance(batch_sample[key], torch.Tensor):
            batch_sample[key] = batch_sample[key].to(DEVICE)
    
    with torch.no_grad():
        # 采样
        pred_x0_sample = model.sample(batch_sample, k=1)
        
        # 生成手部点云
        pred_pose_sample = denorm_hand_pose_robust(pred_x0_sample, cfg.model.rot_type, cfg.model.mode)
        pred_hand = hand_model(
            pred_pose_sample.reshape(-1, 23),
            with_surface_points=True
        )
        
        target_hand = hand_model(
            batch_sample['hand_model_pose'].reshape(-1, 23),
            with_surface_points=True
        )
        
        # 计算Chamfer
        pred_pts = pred_hand['surface_points'].reshape(BATCH_SIZE, NUM_GRASPS, -1, 3)
        target_pts = target_hand['surface_points'].reshape(BATCH_SIZE, NUM_GRASPS, -1, 3)
        
        chamfer = chamfer_distance(
            pred_pts,
            target_pts,
            point_reduction="sum",
            batch_reduction="mean"
        )[0]
        
        all_chamfer_dists.append(chamfer.item())
    
    print(f"Sample {batch_idx}: Chamfer={chamfer.item():.2f}")

print(f"\\n{num_samples_to_vis}个样本的平均Chamfer距离: {np.mean(all_chamfer_dists):.2f}")


## 12. 保存可视化结果
