# Point Transformer V3 编解码与结果可视化

这个 Notebook 用于加载一个基于 `PointTransformerV3` 训练好的 `MillerIndexerV3` 模型和一个衍射数据样本，然后在一个可交互的3D图表中进行可视化。

**核心改动:**
本脚本使用 **PyTorch Hooks** 来捕获模型中间层的输出，以适应 `PointTransformerV3` 的模块化设计，无需修改模型源代码即可实现对编码器和解码器各个阶段点云变化的追踪。

**可视化内容:**
1.  **编解码过程**: 查看在Encoder和Decoder各个阶段的点云下采样和上采样过程。
2.  **米勒指数 (颜色)**: 以点的颜色表示hkl值，点的大小表示衍射强度。
3.  **米勒指数 (方向)**: 以从点出发的矢量线段表示hkl值的方向，线段的颜色表示衍射强度。hkl为(0,0,0)的点将显示为灰色圆点。

**步骤:**
1.  **安装依赖**: 确保 `pointcept` 库已安装 (`pip install pointcept`).
2.  **配置路径**: 在第二个代码单元中，设置好模型（`.pth`）和数据（`.jsonl`）的正确路径。
3.  **运行所有单元**: 点击菜单栏的 `Cell` > `Run All`。
4.  **交互查看**: 在最后的图表输出中，使用下拉菜单选择不同的模式进行查看。

### 1. 设置和导入

In [None]:
import torch
import json
import os
import numpy as np
import plotly.graph_objects as go
from collections import OrderedDict

# 确保这些模块可以被导入
# 注意：模型已更新为 MillerIndexerV3
from modelv3 import MillerIndexerV3 as MillerIndexer
# 确保 pointcept 库已安装: pip install pointcept
try:
    from pointcept.models.utils.structure import Point
except ImportError:
    print("错误: 未找到 'pointcept' 库。请运行 'pip install pointcept' 进行安装。")
    raise

# MODEL_PATH = '/root/autodl-tmp/tf-logs/xrdbert_pt/pt_v3_20250728_121347/best_model.pth' # 示例路径，请修改
MODEL_PATH = 'pretrain/best_model.pth' # bigger
num_classes = 11
DATA_FILE = '/media/max/Data/datasets/mp_random_150k_v3_canonical/test/test_000004.jsonl' # 示例路径，请修改 # 13

VIS_MASKING_RATIO = 1.0

# -------------------

def collate_fn_offset(batch):
    """一个简单的 collate_fn 示例，用于处理数据"""
    coords, feats, labels = [], [], []
    for i, (p, l) in enumerate(batch):
        coords.append(torch.cat([torch.full((p.shape[0], 1), i), p], dim=1))
        feats.append(p[:, :]) # 假设强度是第4列及之后
        labels.append(l)
    coords = torch.cat(coords, dim=0)
    feats = torch.cat(feats, dim=0)
    labels = torch.cat(labels, dim=0)
    # 从 coords 中提取 p 和 o
    p_out = coords[:, 1:] # 坐标 (x,y,z)
    offsets = torch.tensor([b[0].shape[0] for b in batch], dtype=torch.long).cumsum(0)
    return p_out, feats, labels, offsets

### 2. 加载模型和数据

In [None]:
print(f"--> 正在加载模型: {MODEL_PATH}")
in_channels = 4
model = MillerIndexer(in_channels=in_channels, num_classes=num_classes)
if not os.path.isfile(MODEL_PATH):
    raise FileNotFoundError(f"错误: 模型文件未找到 at '{MODEL_PATH}'")

ckpt = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
try:
    # 尝试加载可能存在的不同键名
    if 'model_state_dict' in ckpt:
        model.load_state_dict(ckpt['model_state_dict'])
    elif 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
except Exception as e:
    print(f"--> 模型加载失败: {e}")
    print("--> 将尝试以非严格模式加载")
    model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
print("--> 模型加载成功。")
if not os.path.isfile(DATA_FILE):
    raise FileNotFoundError(f"错误: 数据文件未找到 at '{DATA_FILE}'")
with open(DATA_FILE, 'r') as f:
    line = f.readline()
    sample_data = json.loads(line)
print(f"--> 使用可视化遮蔽比例: {VIS_MASKING_RATIO}")
points_raw = torch.tensor(sample_data['input_sequence'], dtype=torch.float32)
labels_raw = torch.tensor(sample_data['labels'], dtype=torch.long)
if labels_raw.ndim == 3:
    labels_raw = labels_raw[0]
coords_only = points_raw[:, :3]  # (N, 3)
is_abs = num_classes < 11
miller_offset = 0 if is_abs else 5
if is_abs:
    hkl_features_unmasked = torch.abs(labels_raw).clone().float() / 5.0
else:
    hkl_features_unmasked = labels_raw.clone().float() / 5.0
num_points = points_raw.shape[0]
num_to_mask = int(num_points * VIS_MASKING_RATIO)
perm = torch.randperm(num_points)
masked_indices = perm[:num_to_mask] # 保存 masked_indices 以供后续使用
hkl_features_masked = hkl_features_unmasked.clone()
if num_to_mask > 0:
    hkl_features_masked[masked_indices, :] = 0.0 # 将h,k,l特征置为0
points_raw[:, 3] /= 10 # 强度归一化
points_raw[:, 1:3] = (points_raw[:, 1:3] - 0.5) * 0.99 / torch.max(points_raw[:, 1:3]) + 0.5
if in_channels == 4:
    feats_with_hkl = points_raw
else:   
    feats_with_hkl = torch.cat([points_raw, hkl_features_masked], dim=1) # (N, 7)
if is_abs:
    labels_final = torch.abs(labels_raw)
else:
    labels_final = labels_raw + miller_offset
offsets = torch.tensor([len(coords_only)], dtype=torch.long)
original_labels_tensor = labels_final
original_intensities = points_raw[:, 3] * 10 # 恢复原始强度
print("--> 数据加载和预处理完成。")
print(f"    坐标 shape: {coords_only.shape}")
print(f"    特征 shape: {feats_with_hkl.shape}")

### 3. 使用 Hooks 执行前向传播并捕获所有数据

In [None]:
# [Cell 4]

def run_and_capture_with_hooks(model, p, x, o, original_labels, original_intensities_data, masked_indices_data, is_abs_label):
    """
    使用 PyTorch Hooks 执行完整的前向传播并捕获所有用于可视化的数据。
    """
    captured_coords = {}
    hooks = []

    def make_hook(name):
        def hook(module, input, output):
            captured_coords[name] = output.coord.cpu().numpy()
        return hook

    with torch.no_grad():
        modules_to_hook = OrderedDict([
            ('p0_embedding', model.backbone.embedding),
            ('p1_enc', model.backbone.enc.enc0),
            ('p2_enc', model.backbone.enc.enc1),
            ('p3_enc', model.backbone.enc.enc2),
            ('p4_enc', model.backbone.enc.enc3),
            ('p5_enc', model.backbone.enc.enc4),
            ('p4_dec', model.backbone.dec.dec3),
            ('p3_dec', model.backbone.dec.dec2),
            ('p2_dec', model.backbone.dec.dec1),
            ('p1_dec', model.backbone.dec.dec0),
        ])
        
        for name, module in modules_to_hook.items():
            hooks.append(module.register_forward_hook(make_hook(name)))
        
        predictions_dict = model(p, x, o)
        
        for h in hooks:
            h.remove()
            
        vis_data = {}
        vis_data.update(captured_coords)
        vis_data['p0_original'] = p.cpu().numpy()
        
        pred_h = torch.argmax(predictions_dict['h'], dim=1)
        pred_k = torch.argmax(predictions_dict['k'], dim=1)
        pred_l = torch.argmax(predictions_dict['l'], dim=1)
        predictions = torch.stack([pred_h, pred_k, pred_l], dim=1)
        
        miller_offset_val = 0 if is_abs_label else 5
        # 注意：这里的 'labels' 是未经处理的原始真实值 (例如 h in [-5, 5])
        vis_data['predictions'] = (predictions.cpu().numpy() - miller_offset_val)
        vis_data['labels'] = (original_labels.cpu().numpy() - miller_offset_val)
        vis_data['intensities'] = original_intensities_data.cpu().numpy()
        vis_data['masked_indices'] = masked_indices_data.cpu().numpy()
        
    return vis_data

print("--> 正在使用 Hooks 执行完整前向传播以捕获所有数据...")
model.to('cuda')
coords_only_gpu = coords_only.to('cuda')
feats_with_hkl_gpu = feats_with_hkl.to('cuda') 
offsets_gpu = offsets.to('cuda')
original_labels_tensor_gpu = original_labels_tensor.to('cuda')
original_intensities_gpu = original_intensities.to('cuda')
# 注意：masked_indices不需要上GPU，因为它只用于索引
masked_indices_cpu = masked_indices

vis_data = run_and_capture_with_hooks(
    model, 
    coords_only_gpu, 
    feats_with_hkl_gpu, 
    offsets_gpu, 
    original_labels_tensor_gpu, 
    original_intensities_gpu,
    masked_indices_cpu, # 传递masked_indices
    is_abs_label=is_abs
)
print("--> 数据捕获完成。")

# --- 新增：计算并汇报精度指标 ---
print("\n--- 精度指标汇报 (仅计算被遮蔽的点) ---")
masked_preds = vis_data['predictions'][vis_data['masked_indices']]
masked_labels = vis_data['labels'][vis_data['masked_indices']]
num_masked_points = len(masked_labels)

if num_masked_points > 0:
    h_correct = (masked_preds[:, 0] == masked_labels[:, 0]).sum()
    k_correct = (masked_preds[:, 1] == masked_labels[:, 1]).sum()
    l_correct = (masked_preds[:, 2] == masked_labels[:, 2]).sum()
    all_correct = ((masked_preds == masked_labels).all(axis=1)).sum()

    print(f"被遮蔽点数量: {num_masked_points}")
    print(f"H  轴准确率: {h_correct / num_masked_points * 100:.2f}%")
    print(f"K  轴准确率: {k_correct / num_masked_points * 100:.2f}%")
    print(f"L  轴准确率: {l_correct / num_masked_points * 100:.2f}%")
    print(f"HKL完全匹配准确率: {all_correct / num_masked_points * 100:.2f}%")
else:
    print("没有被遮蔽的点，无法计算精度。")
print("-" * 40)

### 4. 创建交互式3D可视化图表

In [None]:
# [Cell 5]

# --- 图 1：主交互式可视化图表 ---

# 定义映射函数
def hkl_to_rgb(hkl_array, is_abs=False):
    hkl_array = np.abs(hkl_array) if is_abs else hkl_array
    max_val = 5.0
    # offset = 0 if is_abs else 5
    offset = np.min(hkl_array) * -1
    normalized = (hkl_array + offset) / (max_val + offset + 1e-6)
    rgb_array = (np.clip(normalized, 0, 1) * 255).astype(int)
    return [f'rgb({r},{g},{b})' for r, g, b in rgb_array]

def intensity_to_size(intensities):
    min_val, max_val = intensities.min(), intensities.max()
    if max_val == min_val: return np.full_like(intensities, 4)
    normalized = (intensities - min_val) / (max_val - min_val)
    return 0 + (normalized + 0.5) ** 3

STAGES = {
    'gt_color': {'name': '真实标签 (颜色)', 'type': 'color'},
    'pred_color': {'name': '模型预测 (颜色)', 'type': 'color'},
    'gt_flow': {'name': '真实标签 (方向)', 'type': 'flow'},
    'pred_flow': {'name': '模型预测 (方向)', 'type': 'flow'},
    'p0_original': {'name': '原始点云', 'type': 'structure', 'color': 'royalblue'},
    'p0_embedding': {'name': 'Embedding Out', 'type': 'structure', 'color': 'cyan'},
    'p1_enc': {'name': 'Encoder 1', 'type': 'structure', 'color': 'darkorange'},
    'p2_enc': {'name': 'Encoder 2', 'type': 'structure', 'color': 'green'},
    'p3_enc': {'name': 'Encoder 3', 'type': 'structure', 'color': 'firebrick'},
    'p4_enc': {'name': 'Encoder 4', 'type': 'structure', 'color': 'purple'},
    'p5_enc': {'name': 'Encoder 5 (Bottleneck)', 'type': 'structure', 'color': 'saddlebrown'},
    'p4_dec': {'name': 'Decoder 4', 'type': 'structure', 'color': 'mediumpurple'},
    'p3_dec': {'name': 'Decoder 3', 'type': 'structure', 'color': 'lightcoral'},
    'p2_dec': {'name': 'Decoder 2', 'type': 'structure', 'color': 'lightgreen'},
    'p1_dec': {'name': 'Decoder 1 (Final)', 'type': 'structure', 'color': 'sandybrown'},
}
traces = []
for key, stage_info in STAGES.items():
    visible = (key == 'gt_color')
    if stage_info['type'] == 'color':
        hkl_data = vis_data['labels'] if 'gt' in key else vis_data['predictions']
        points = vis_data['p0_original']
        intns = vis_data['intensities'] # 修复：直接使用强度数据
        traces.append(go.Scatter3d(
            x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers',
            marker=dict(size=intensity_to_size(intns), color=hkl_to_rgb(hkl_data, is_abs=is_abs), opacity=0.9, line=dict(width=0)),
            name=stage_info['name'],
            customdata=np.hstack((hkl_data, intns[:, np.newaxis])),
            hovertemplate='<b>hkl:</b> (%{customdata[0]}, %{customdata[1]}, %{customdata[2]})<br><b>强度:</b> %{customdata[3]:.2f}<br><b>坐标:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<extra></extra>',
            visible=visible
        ))
    elif stage_info['type'] == 'flow':
        hkl_data = vis_data['labels'] if 'gt' in key else vis_data['predictions']
        points = vis_data['p0_original']
        intensities = vis_data['intensities'] # 修复：直接使用强度数据
        hkl_vectors = hkl_data.astype(np.float32)
        norms = np.linalg.norm(hkl_vectors, axis=1)
        non_zero_mask = norms > 1e-6
        zero_points = points[~non_zero_mask]
        traces.append(go.Scatter3d(x=zero_points[:,0], y=zero_points[:,1], z=zero_points[:,2], mode='markers', marker=dict(color='grey', size=2, opacity=0.6), hoverinfo='skip', visible=visible, name=f"{stage_info['name']} (zero hkl)"))
        if np.any(non_zero_mask):
            start_points = points[non_zero_mask]
            directions = hkl_vectors[non_zero_mask] / norms[non_zero_mask, np.newaxis]
            intensities_nz = intensities[non_zero_mask]
            lengths = 0.00 + 0.1 * (intensities_nz - intensities_nz.min()) / (intensities_nz.max() - intensities_nz.min() + 1e-6)
            end_points = start_points + directions * lengths[:, np.newaxis]
            lines_x, lines_y, lines_z = [], [], []
            for i in range(len(start_points)):
                lines_x.extend([start_points[i,0], end_points[i,0], None]); lines_y.extend([start_points[i,1], end_points[i,1], None]); lines_z.extend([start_points[i,2], end_points[i,2], None])
            intensities_repeat = np.repeat(intensities_nz, 3)
            traces.append(go.Scatter3d(x=lines_x, y=lines_y, z=lines_z, mode='lines', line=dict(width=0.4, color=intensities_repeat, colorscale='Bluered', cmin=intensities_nz.min(), cmax=intensities_nz.max()), customdata=intensities_repeat, hovertemplate='<b>强度:</b> %{customdata:.2f}<extra></extra>', visible=visible, name=f"{stage_info['name']} (vectors)"))
        else:
            traces.append(go.Scatter3d(x=[],y=[],z=[], visible=visible))
    elif stage_info['type'] == 'structure':
        if key in vis_data:
            points = vis_data[key]
            traces.append(go.Scatter3d(x=points[:, 0] * 2, y=points[:, 1], z=points[:, 2], mode='markers', marker=dict(size=1, color=stage_info['color'], opacity=0.8), name=f"{stage_info['name']} ({len(points)} points)", visible=visible))
        else:
            print(f"警告: 在vis_data中未找到键 '{key}'，跳过此阶段。"); traces.append(go.Scatter3d(x=[], y=[], z=[], visible=visible))

fig = go.Figure(data=traces)
buttons = []
trace_counter = 0
for key, stage_info in STAGES.items():
    visibility = [False] * len(traces)
    num_traces = 2 if stage_info['type'] == 'flow' else 1
    for i in range(num_traces):
        if trace_counter + i < len(traces): visibility[trace_counter + i] = True
    trace_counter += num_traces
    buttons.append(dict(label=stage_info['name'], method='update', args=[{'visible': visibility}, {'title': f"当前可视化: {stage_info['name']}"}]))
fig.update_layout(
    title_text="当前可视化: 真实标签 (颜色)",
    updatemenus=[dict(active=0, buttons=buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.01, xanchor="left", y=1.1, yanchor="top")],
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(orientation="h", yanchor="bottom", y=0.01, xanchor="right", x=1)
)
fig.show()


# --- 新增 图 2：可视化 Mask 点云分布 ---
all_points = vis_data['p0_original']
masked_indices = vis_data['masked_indices']
unmasked_indices = np.setdiff1d(np.arange(len(all_points)), masked_indices)
masked_points = all_points[masked_indices]
unmasked_points = all_points[unmasked_indices]
print(len(unmasked_points), len(masked_points))

fig_mask = go.Figure()
# 未被遮蔽的点
fig_mask.add_trace(go.Scatter3d(
    x=unmasked_points[:, 0], y=unmasked_points[:, 1], z=unmasked_points[:, 2],
    mode='markers',
    marker=dict(size=1, color='black', opacity=1),
    name=f'Unmasked Points ({len(unmasked_points)})'
))
# 被遮蔽的点
fig_mask.add_trace(go.Scatter3d(
    x=masked_points[:, 0], y=masked_points[:, 1], z=masked_points[:, 2],
    mode='markers',
    marker=dict(size=0.3, color='white'),
    name=f'Masked Points ({len(masked_points)})'
))
fig_mask.update_layout(
    title='可视化：被遮蔽点(Masked)的分布',
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)
fig_mask.show()


# --- 新增 图 3：可视化 Label 与 Output 的 hkl 差异 ---
preds = vis_data['predictions']
labels = vis_data['labels']
points = vis_data['p0_original']

# 计算 L2 差异
l2_diff = np.sum(np.abs(preds - labels), axis=1) * -1

fig_diff = go.Figure()
fig_diff.add_trace(go.Scatter3d(
    x=points[:, 0], y=points[:, 1], z=points[:, 2],
    mode='markers',
    marker=dict(
        size=1,
        color=l2_diff, # 使用差异作为颜色
        colorscale='Reds_r',
        cmin=0,
        cmax=l2_diff.max(), # 避免除以0
        colorbar=dict(title="HKL L2 差异"), 
        line=dict(width=0)
    ),
    customdata=np.hstack((preds, labels, l2_diff[:, np.newaxis])),
    hovertemplate=(
        '<b>Pred:</b> (%{customdata[0]}, %{customdata[1]}, %{customdata[2]})<br>'
        '<b>Label:</b> (%{customdata[3]}, %{customdata[4]}, %{customdata[5]})<br>'
        '<b>L2 Diff:</b> %{customdata[6]:.2f}<br>'
        '<b>Coord:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<extra></extra>'
    ),
    name='HKL Difference'
))
fig_diff.update_layout(
    title='可视化：预测与真实标签的 hkl 差异 (L2距离)',
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)
fig_diff.show()