### 1. 设置和导入

In [None]:
import torch
import torch.nn.functional as F
import json
import os
import numpy as np
import plotly.graph_objects as go
from collections import OrderedDict

from RCT.model import RCT as MillerIndexer
from pointcept.models.utils.structure import Point


lattice_order = ['_cell_length_a', '_cell_length_b', '_cell_length_c',
                 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma']             
num_classes = 11

MODEL_PATH = 'pretrained/best_model_v123_angle_limited.pth'
DATA_FILE = 'ihep_data2.jsonl'
LATTICE_STATS_JSON = 'cell_params_statistics.json'

def collate_fn_offset(batch):
    coords, feats, labels = [], [], []

    for i, (p, l) in enumerate(batch):
        coords.append(torch.cat([torch.full((p.shape[0], 1), i), p], dim=1))
        feats.append(p[:, :])
        labels.append(l)

    coords = torch.cat(coords, dim=0)
    feats = torch.cat(feats, dim=0)
    labels = torch.cat(labels, dim=0)

    p_out = coords[:, 1:]
    offsets = torch.tensor([b[0].shape[0] for b in batch], dtype=torch.long).cumsum(0)
    return p_out, feats, labels, offsets

### 2. 加载模型和数据

In [None]:
model = MillerIndexer(in_channels=4, num_classes=num_classes)
if not os.path.isfile(MODEL_PATH):
    raise FileNotFoundError(f"Error: Model file not found at '{MODEL_PATH}'")

ckpt = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
try:
    if 'model_state_dict' in ckpt:
        model.load_state_dict(ckpt['model_state_dict'])
    elif 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
except Exception as e:
    print(f"--> Fail to load model: {e}")
    print("--> Try to load in non-strict mode")
    model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
print("--> Model loaded successfully.")

if not os.path.isfile(DATA_FILE):
    raise FileNotFoundError(f"Error: Data file not found at '{DATA_FILE}'")
with open(DATA_FILE, 'r') as f:
    line = f.readline()
    sample_data = json.loads(line)

points_raw = torch.tensor(sample_data['input_sequence'], dtype=torch.float32)
labels_raw = torch.tensor(sample_data['labels'], dtype=torch.long)
if labels_raw.ndim == 3:
    labels_raw = labels_raw[0]
coords_only = points_raw[:, :3]
miller_offset = num_classes // 2

# todo: 统一数据预处理方法
points_raw[:, 3] /= 10
points_raw[:, 1:3] = (points_raw[:, 1:3] - 0.5) * 0.99 / torch.max(points_raw[:, 1:3]) + 0.5
feats_with_hkl = points_raw
labels_final = labels_raw + miller_offset

offsets = torch.tensor([len(coords_only)], dtype=torch.long)
original_labels_tensor = labels_final
original_intensities = points_raw[:, 3] * 10

metadata = sample_data.get('metadata', {})
crystal_params = metadata.get('crystal_params', {})
if not crystal_params:
    raise ValueError("Error: Missing 'metadata.crystal_params'")
crystal_params = dict(crystal_params)
crystal_params.pop('_symmetry_space_group_name_H-M', None)

try:
    lattice_label_raw_tensor = torch.tensor([float(crystal_params[key]) for key in lattice_order], dtype=torch.float32)
    sg_label_value = int(crystal_params['_symmetry_Int_Tables_number']) - 1
except KeyError as e:
    raise KeyError(f"Error: Missing crystal parameters: {e}")

sg_label_tensor = torch.tensor([sg_label_value], dtype=torch.long)

if os.path.isfile(LATTICE_STATS_JSON):
    with open(LATTICE_STATS_JSON, 'r', encoding='utf-8') as f:
        stats_payload = json.load(f)
    stats = stats_payload.get('stats')
    order_names = ['a', 'b', 'c', 'alpha', 'beta', 'gamma']
    lattice_mean_list = [float(stats.get(name).get('mean')) for name in order_names]
    lattice_std_list = [float(stats.get(name).get('std')) for name in order_names]
    lattice_std_list = [std if abs(std) > 1e-8 else 1.0 for std in lattice_std_list]
    lattice_mean_tensor = torch.tensor(lattice_mean_list, dtype=torch.float32)
    lattice_std_tensor = torch.tensor(lattice_std_list, dtype=torch.float32)
else:
    print("Warning: Lattice stats JSON not found, using default normalization parameters")
    lattice_mean_tensor = torch.zeros(6, dtype=torch.float32)
    lattice_std_tensor = torch.tensor([10.0, 10.0, 10.0, 180.0, 180.0, 180.0], dtype=torch.float32)

lattice_label_norm_tensor = (lattice_label_raw_tensor - lattice_mean_tensor) / lattice_std_tensor

print("--> 数据加载和预处理完成。")
print(f"    坐标 shape: {coords_only.shape}")
print(f"    特征 shape: {feats_with_hkl.shape}")
print(f"    晶体参数(原始单位): {lattice_label_raw_tensor.tolist()}")
print(f"    空间群 (1-based): {sg_label_value + 1}")

--> 正在加载模型: pretrained/best_model_v123_angle_limited.pth
--> 模型加载成功。
--> 数据加载和预处理完成。
    坐标 shape: torch.Size([5289, 3])
    特征 shape: torch.Size([5289, 4])
    晶体参数(原始单位): [4.589300155639648, 4.589300155639648, 7.286399841308594, 90.0, 90.0, 120.0]
    空间群 (1-based): 164


### 3. 使用 Hooks 执行前向传播并捕获所有数据

In [None]:
# [Cell 4]

def run_and_capture_with_hooks(
    model,
    p,
    x,
    o,
    original_labels,
    original_intensities_data,
    lattice_mean,
    lattice_std,
    lattice_label_norm,
    lattice_label_raw,
    sg_label,
):
    """
    使用 PyTorch Hooks 执行完整的前向传播并捕获所有用于可视化的数据。
    同时返回晶体参数与空间群的预测结果。
    """
    captured_coords = {}
    hooks = []

    def make_hook(name):
        def hook(module, input, output):
            captured_coords[name] = output.coord.detach().cpu().numpy()
        return hook

    with torch.no_grad():
        modules_to_hook = OrderedDict([
            ('p0_embedding', model.backbone.embedding),
            ('p1_enc', model.backbone.enc.enc0),
            ('p2_enc', model.backbone.enc.enc1),
            ('p3_enc', model.backbone.enc.enc2),
            ('p4_enc', model.backbone.enc.enc3),
            ('p5_enc', model.backbone.enc.enc4),
            ('p4_dec', model.backbone.dec.dec3),
            ('p3_dec', model.backbone.dec.dec2),
            ('p2_dec', model.backbone.dec.dec1),
            ('p1_dec', model.backbone.dec.dec0),
        ])

        for name, module in modules_to_hook.items():
            hooks.append(module.register_forward_hook(make_hook(name)))

        predictions_dict = model(p, x, o)

        for h in hooks:
            h.remove()

        vis_data = {}
        vis_data.update(captured_coords)
        vis_data['p0_original'] = p.detach().cpu().numpy()

        pred_h = torch.argmax(predictions_dict['h'], dim=1)
        pred_k = torch.argmax(predictions_dict['k'], dim=1)
        pred_l = torch.argmax(predictions_dict['l'], dim=1)
        predictions = torch.stack([pred_h, pred_k, pred_l], dim=1)

        miller_offset_val = num_classes // 2
        vis_data['predictions'] = (predictions.detach().cpu().numpy() - miller_offset_val)
        vis_data['labels'] = (original_labels.detach().cpu().numpy() - miller_offset_val)
        vis_data['intensities'] = original_intensities_data.detach().cpu().numpy()

        lattice_pred_norm = predictions_dict['lattice_params'][0].detach()
        lattice_mean_dev = lattice_mean.to(lattice_pred_norm.device, dtype=lattice_pred_norm.dtype)
        lattice_std_dev = lattice_std.to(lattice_pred_norm.device, dtype=lattice_pred_norm.dtype)
        lattice_label_norm_dev = lattice_label_norm.to(lattice_pred_norm.device, dtype=lattice_pred_norm.dtype)
        lattice_label_raw_dev = lattice_label_raw.to(lattice_pred_norm.device, dtype=lattice_pred_norm.dtype)

        lattice_pred_unnorm = lattice_pred_norm * lattice_std_dev + lattice_mean_dev

        vis_data['lattice_pred_norm'] = lattice_pred_norm.cpu().numpy()
        vis_data['lattice_pred'] = lattice_pred_unnorm.cpu().numpy()
        vis_data['lattice_label_norm'] = lattice_label_norm_dev.cpu().numpy()
        vis_data['lattice_label'] = lattice_label_raw_dev.cpu().numpy()

        sg_logits = predictions_dict['space_group'][0].detach()
        sg_prob = torch.softmax(sg_logits, dim=0)
        vis_data['sg_pred'] = int(torch.argmax(sg_prob).item())
        vis_data['sg_label'] = int(sg_label.item())
        vis_data['sg_prob'] = sg_prob.cpu().numpy()

    return vis_data

print("--> 正在使用 Hooks 执行完整前向传播以捕获所有数据...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
coords_only_dev = coords_only.to(device)
feats_with_hkl_dev = feats_with_hkl.to(device)
offsets_dev = offsets.to(device)
original_labels_tensor_dev = original_labels_tensor.to(device)
original_intensities_dev = original_intensities.to(device)

vis_data = run_and_capture_with_hooks(
    model,
    coords_only_dev,
    feats_with_hkl_dev,
    offsets_dev,
    original_labels_tensor_dev,
    original_intensities_dev,
    lattice_mean=lattice_mean_tensor,
    lattice_std=lattice_std_tensor,
    lattice_label_norm=lattice_label_norm_tensor,
    lattice_label_raw=lattice_label_raw_tensor,
    sg_label=sg_label_tensor,
)
print("--> 数据捕获完成。")

# --- 精度指标 ---
print("\n--- 精度指标汇报 (全体点) ---")
all_preds = vis_data['predictions']
all_labels = vis_data['labels']
num_points = all_labels.shape[0]
if num_points > 0:
    h_correct = np.sum(all_preds[:, 0] == all_labels[:, 0])
    k_correct = np.sum(all_preds[:, 1] == all_labels[:, 1])
    l_correct = np.sum(all_preds[:, 2] == all_labels[:, 2])
    all_correct = np.sum(np.all(all_preds == all_labels, axis=1))
    print(f"总点数: {num_points}")
    print(f"H 轴准确率: {h_correct / num_points * 100:.2f}%")
    print(f"K 轴准确率: {k_correct / num_points * 100:.2f}%")
    print(f"L 轴准确率: {l_correct / num_points * 100:.2f}%")
    print(f"HKL 完全匹配准确率: {all_correct / num_points * 100:.2f}%")
else:
    print("无可用点，无法计算准确率。")
print("-" * 40)

# --- 晶体参数与空间群数值 ---
param_names = ['a (Å)', 'b (Å)', 'c (Å)', 'α (°)', 'β (°)', 'γ (°)']
lattice_pred_vals = vis_data['lattice_pred']
lattice_label_vals = vis_data['lattice_label']
print("--- 晶体参数 (预测 vs. 标签) ---")
for name, pred_v, label_v in zip(param_names, lattice_pred_vals, lattice_label_vals):
    print(f"{name}: 预测 {pred_v:.4f} | 标签 {label_v:.4f} | 误差 {abs(pred_v - label_v):.4f}")
print(f"空间群 (1-based): 预测 {vis_data['sg_pred'] + 1} | 标签 {vis_data['sg_label'] + 1}")
print("-" * 40)

--> 正在使用 Hooks 执行完整前向传播以捕获所有数据...
--> 数据捕获完成。

--- 精度指标汇报 (全体点) ---
总点数: 5289
H 轴准确率: 20.91%
K 轴准确率: 0.00%
L 轴准确率: 1.02%
HKL 完全匹配准确率: 0.00%
----------------------------------------
--- 晶体参数 (预测 vs. 标签) ---
a (Å): 预测 9.8248 | 标签 4.5893 | 误差 5.2355
b (Å): 预测 10.0047 | 标签 4.5893 | 误差 5.4154
c (Å): 预测 16.2475 | 标签 7.2864 | 误差 8.9611
α (°): 预测 92.5055 | 标签 90.0000 | 误差 2.5055
β (°): 预测 98.6958 | 标签 90.0000 | 误差 8.6958
γ (°): 预测 100.6639 | 标签 120.0000 | 误差 19.3361
空间群 (1-based): 预测 225 | 标签 164
----------------------------------------


### 4. 创建交互式3D可视化图表

In [None]:
# [Cell 5]

# --- 图 1：主交互式可视化图表 ---

# 定义映射函数
def hkl_to_rgb(hkl_array):
    max_val = 5.0
    offset = np.min(hkl_array) * -1
    normalized = (hkl_array + offset) / (max_val + offset + 1e-6)
    rgb_array = (np.clip(normalized, 0, 1) * 255).astype(int)
    return [f'rgb({r},{g},{b})' for r, g, b in rgb_array]

def intensity_to_size(intensities):
    min_val, max_val = intensities.min(), intensities.max()
    if max_val == min_val: return np.full_like(intensities, 4)
    normalized = (intensities - min_val) / (max_val - min_val)
    return 0 + (normalized + 0.5) ** 3

STAGES = {
    'gt_color': {'name': '真实标签 (颜色)', 'type': 'color'},
    'pred_color': {'name': '模型预测 (颜色)', 'type': 'color'},
    'gt_flow': {'name': '真实标签 (方向)', 'type': 'flow'},
    'pred_flow': {'name': '模型预测 (方向)', 'type': 'flow'},
    'p0_original': {'name': '原始点云', 'type': 'structure', 'color': 'royalblue'},
    'p0_embedding': {'name': 'Embedding Out', 'type': 'structure', 'color': 'cyan'},
    'p1_enc': {'name': 'Encoder 1', 'type': 'structure', 'color': 'darkorange'},
    'p2_enc': {'name': 'Encoder 2', 'type': 'structure', 'color': 'green'},
    'p3_enc': {'name': 'Encoder 3', 'type': 'structure', 'color': 'firebrick'},
    'p4_enc': {'name': 'Encoder 4', 'type': 'structure', 'color': 'purple'},
    'p5_enc': {'name': 'Encoder 5 (Bottleneck)', 'type': 'structure', 'color': 'saddlebrown'},
    'p4_dec': {'name': 'Decoder 4', 'type': 'structure', 'color': 'mediumpurple'},
    'p3_dec': {'name': 'Decoder 3', 'type': 'structure', 'color': 'lightcoral'},
    'p2_dec': {'name': 'Decoder 2', 'type': 'structure', 'color': 'lightgreen'},
    'p1_dec': {'name': 'Decoder 1 (Final)', 'type': 'structure', 'color': 'sandybrown'},
}
traces = []
for key, stage_info in STAGES.items():
    visible = (key == 'gt_color')
    if stage_info['type'] == 'color':
        hkl_data = vis_data['labels'] if 'gt' in key else vis_data['predictions']
        points = vis_data['p0_original']
        intns = vis_data['intensities'] # 修复：直接使用强度数据
        traces.append(go.Scatter3d(
            x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers',
            marker=dict(size=intensity_to_size(intns), color=hkl_to_rgb(hkl_data), opacity=0.9, line=dict(width=0)),
            name=stage_info['name'],
            customdata=np.hstack((hkl_data, intns[:, np.newaxis])),
            hovertemplate='<b>hkl:</b> (%{customdata[0]}, %{customdata[1]}, %{customdata[2]})<br><b>强度:</b> %{customdata[3]:.2f}<br><b>坐标:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<extra></extra>',
            visible=visible
        ))
    elif stage_info['type'] == 'flow':
        hkl_data = vis_data['labels'] if 'gt' in key else vis_data['predictions']
        points = vis_data['p0_original']
        intensities = vis_data['intensities'] # 修复：直接使用强度数据
        hkl_vectors = hkl_data.astype(np.float32)
        norms = np.linalg.norm(hkl_vectors, axis=1)
        non_zero_mask = norms > 1e-6
        zero_points = points[~non_zero_mask]
        traces.append(go.Scatter3d(x=zero_points[:,0], y=zero_points[:,1], z=zero_points[:,2], mode='markers', marker=dict(color='grey', size=2, opacity=0.6), hoverinfo='skip', visible=visible, name=f"{stage_info['name']} (zero hkl)"))
        if np.any(non_zero_mask):
            start_points = points[non_zero_mask]
            directions = hkl_vectors[non_zero_mask] / norms[non_zero_mask, np.newaxis]
            intensities_nz = intensities[non_zero_mask]
            lengths = 0.00 + 0.1 * (intensities_nz - intensities_nz.min()) / (intensities_nz.max() - intensities_nz.min() + 1e-6)
            end_points = start_points + directions * lengths[:, np.newaxis]
            lines_x, lines_y, lines_z = [], [], []
            for i in range(len(start_points)):
                lines_x.extend([start_points[i,0], end_points[i,0], None]); lines_y.extend([start_points[i,1], end_points[i,1], None]); lines_z.extend([start_points[i,2], end_points[i,2], None])
            intensities_repeat = np.repeat(intensities_nz, 3)
            traces.append(go.Scatter3d(x=lines_x, y=lines_y, z=lines_z, mode='lines', line=dict(width=0.4, color=intensities_repeat, colorscale='Bluered', cmin=intensities_nz.min(), cmax=intensities_nz.max()), customdata=intensities_repeat, hovertemplate='<b>强度:</b> %{customdata:.2f}<extra></extra>', visible=visible, name=f"{stage_info['name']} (vectors)"))
        else:
            traces.append(go.Scatter3d(x=[],y=[],z=[], visible=visible))
    elif stage_info['type'] == 'structure':
        if key in vis_data:
            points = vis_data[key]
            traces.append(go.Scatter3d(x=points[:, 0] * 2, y=points[:, 1], z=points[:, 2], mode='markers', marker=dict(size=1, color=stage_info['color'], opacity=0.8), name=f"{stage_info['name']} ({len(points)} points)", visible=visible))
        else:
            print(f"警告: 在vis_data中未找到键 '{key}'，跳过此阶段。"); traces.append(go.Scatter3d(x=[], y=[], z=[], visible=visible))

fig = go.Figure(data=traces)
buttons = []
trace_counter = 0
for key, stage_info in STAGES.items():
    visibility = [False] * len(traces)
    num_traces = 2 if stage_info['type'] == 'flow' else 1
    for i in range(num_traces):
        if trace_counter + i < len(traces): visibility[trace_counter + i] = True
    trace_counter += num_traces
    buttons.append(dict(label=stage_info['name'], method='update', args=[{'visible': visibility}, {'title': f"当前可视化: {stage_info['name']}"}]))
fig.update_layout(
    title_text="当前可视化: 真实标签 (颜色)",
    updatemenus=[dict(active=0, buttons=buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.01, xanchor="left", y=1.1, yanchor="top")],
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(orientation="h", yanchor="bottom", y=0.01, xanchor="right", x=1)
)
fig.show()


# --- 新增 图 2：晶体参数与空间群对比表 ---
param_labels = ['a (Å)', 'b (Å)', 'c (Å)', 'α (°)', 'β (°)', 'γ (°)', 'Space Group']
param_pred = list(np.round(vis_data['lattice_pred'], 4)) + [int(vis_data['sg_pred'] + 1)]
param_label = list(np.round(vis_data['lattice_label'], 4)) + [int(vis_data['sg_label'] + 1)]
param_error = list(np.round(np.abs(vis_data['lattice_pred'] - vis_data['lattice_label']), 4)) + [abs(int(vis_data['sg_pred'] - vis_data['sg_label']))]

table_fig = go.Figure(data=[go.Table(
    columnwidth=[140, 140, 140, 140],
    header=dict(values=['参数', '预测值', '标签值', '绝对误差'], align='center', font=dict(size=14), fill_color='#1f77b4'),
    cells=dict(
        values=[
            param_labels,
            [f"{v:.4f}" if i < 6 else str(v) for i, v in enumerate(param_pred)],
            [f"{v:.4f}" if i < 6 else str(v) for i, v in enumerate(param_label)],
            [f"{v:.4f}" if i < 6 else str(v) for i, v in enumerate(param_error)],
        ],
        align='center',
        font=dict(size=12)
    )
)])
table_fig.update_layout(title='晶体参数与空间群预测对比', margin=dict(l=0, r=0, b=0, t=40))
table_fig.show()


# --- 新增 图 3：可视化 Label 与 Output 的 hkl 差异 ---
preds = vis_data['predictions']
labels = vis_data['labels']
points = vis_data['p0_original']

# 计算 L2 差异（实际为欧氏距离）
l2_diff = np.linalg.norm(preds - labels, axis=1)
max_diff = l2_diff.max() if l2_diff.size > 0 else 0.0
color_scale_max = max_diff if max_diff > 0 else 1.0

fig_diff = go.Figure()
fig_diff.add_trace(go.Scatter3d(
    x=points[:, 0], y=points[:, 1], z=points[:, 2],
    mode='markers',
    marker=dict(
        size=1,
        color=l2_diff,
        colorscale='Reds',
        cmin=0,
        cmax=color_scale_max,
        colorbar=dict(title="HKL L2 差异"),
        line=dict(width=0)
    ),
    customdata=np.hstack((preds, labels, l2_diff[:, np.newaxis])),
    hovertemplate=(
        '<b>Pred:</b> (%{customdata[0]}, %{customdata[1]}, %{customdata[2]})<br>'
        '<b>Label:</b> (%{customdata[3]}, %{customdata[4]}, %{customdata[5]})<br>'
        '<b>L2 Diff:</b> %{customdata[6]:.2f}<br>'
        '<b>Coord:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<extra></extra>'
    ),
    name='HKL Difference'
))
fig_diff.update_layout(
    title='可视化：预测与真实标签的 hkl 差异 (L2距离)',
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, showspikes=False, showbackground=False, ticks=''),
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)
fig_diff.show()