# HTGP模型接入

## 1. 导入checkpoint

In [10]:
from lmy.src.models import HTGPModel
import torch

HTGP_PATH = "../lmy_checkpoints/Checkpoints_break_2/model_epoch_47.pt"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dtypes = torch.float32 

def load_HTGP_model(path: str, device: torch.device) -> HTGPModel:
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    config = checkpoint['model_config']
    model = HTGPModel(config)

    state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v 
        else:
            new_state_dict[k] = v
    try:
        model.load_state_dict(new_state_dict, strict=False)
        print("✅ 模型参数加载成功！")
    except RuntimeError as e:
        print(f"❌ 加载依然失败，请检查 Config 是否与训练一致。\n详细错误: {e}")

    # 打印模型信息
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型总参数量: {total_params}")
    print(model.cfg)
    dtypes = {p.dtype for p in model.parameters()}
    print(f"模型参数数据类型: {dtypes}")

    model.to(device)

    return model

lmy_model = load_HTGP_model(HTGP_PATH, device)



✅ 模型参数加载成功！
模型总参数量: 630440
HTGPConfig(num_atom_types=100, hidden_dim=96, num_layers=2, cutoff=6.0, num_rbf=10, atom_types_map=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], use_L0=True, use_L1=True, use_L2=True, use_gating=True, avg_neighborhood=61.481112820672955, use_long_range=False, use_charge=False, use_vdw=False, use_dipole=False, FINETUNE_MODE=True, PRETRAINED_CKPT='Checkpoints_Old/model_epoch_50.pt', steps_per_epoch=None, long_range_scale=1, active_paths={(0, 0, 0, 'prod'): True, (0, 1, 1, 'prod'): True, (0, 2, 2, 'prod'): True, (1, 0, 1, 'prod'): True, (1, 1, 0, 'dot'): True, (1, 1, 1, 'cross'): True, (1, 1, 2, 'outer'): True, (2, 0, 2,

## 2. 架构内数据流模拟

In [11]:
# 1. 设置基础参数
B, N, C = 2, 4, 5  # Batch Size, Max Atoms, Element Types
pos = torch.rand((B, N, 3), device=device) * 5.0  # 随机生成位置，范围在 [0, 5)
species_indices = torch.randint(0, C, (B, N), device=device)
species = torch.nn.functional.one_hot(species_indices, num_classes=C).float()
cell = torch.eye(3, device=device).unsqueeze(0).repeat(B, 1, 1) * 10.0  # 简单的立方体晶胞
node_mask = torch.tensor([
    [[1.0], [1.0], [1.0], [1.0]], 
    [[1.0], [1.0], [1.0], [0.0]]
], device=device)
# 打印结果检查
print(f"Pos shape:      {pos.shape}")         # [2, 4, 3]
print(f"Species shape:  {species.shape}")     # [2, 4, 5]
print(f"Cell shape:     {cell.shape}")        # [2, 3, 3]
print(f"Node Mask shape: {node_mask.shape}")      # [2, 4, 1]

Pos shape:      torch.Size([2, 4, 3])
Species shape:  torch.Size([2, 4, 5])
Cell shape:     torch.Size([2, 3, 3])
Node Mask shape: torch.Size([2, 4, 1])


## 3. 数据转换成 HTGP 输入形式

In [14]:
import torch
import numpy as np
from ase import Atoms
from ase.neighborlist import neighbor_list
from torch_geometric.data import Data, Batch


def dense_to_atoms_list(pos, species, cell, node_mask) -> list[Atoms]:
    """
    将 Dense Batch 转换为 ASE Atoms 列表
    
    参数:
    pos: [B, N, 3] - 坐标
    species: [B, N] - 原子序数 (如果之前是 one-hot，请先 argmax)
    cell: [B, 3, 3] - 晶胞
    node_mask: [B, N, 1] - 掩码
    """
    B, N, _ = pos.shape
    atoms_list = []
    
    # 将 Tensor 转到 CPU 方便 ASE 处理
    pos_cpu = pos.detach().cpu().numpy()
    cell_cpu = cell.detach().cpu().numpy()
    z_cpu = species.detach().cpu().numpy()
    mask_cpu = node_mask.detach().cpu().numpy().squeeze(-1) > 0

    for i in range(B):
        # 提取有效原子
        m = mask_cpu[i]
        curr_pos = pos_cpu[i][m]
        curr_z = z_cpu[i][m]
        curr_cell = cell_cpu[i]

        # 创建 ASE Atoms 对象
        temp_atoms = Atoms(numbers=curr_z, positions=curr_pos, cell=curr_cell, pbc=True)
        atoms_list.append(temp_atoms)

    return atoms_list

def dense_to_pyg_batch(pos, species, cell, node_mask, cutoff, device):
    """
    参考 _atoms_to_pyg_data 的逻辑，将 Dense Batch 转换为 PyG Batch
    
    参数:
    pos: [B, N, 3] - 坐标
    species: [B, N] - 原子序数 (如果之前是 one-hot，请先 argmax)
    cell: [B, 3, 3] - 晶胞
    node_mask: [B, N, 1] - 掩码
    """
    B, N, _ = pos.shape
    data_list = []
    
    # 将 Tensor 转到 CPU 方便 ASE 处理邻居表 (ASE 不支持 GPU 邻居表计算)
    pos_cpu = pos.detach().cpu().numpy()
    cell_cpu = cell.detach().cpu().numpy()
    z_cpu = species.detach().cpu().numpy()
    mask_cpu = node_mask.detach().cpu().numpy().squeeze(-1) > 0

    for i in range(B):
        # 1. 提取有效原子
        m = mask_cpu[i]
        curr_pos = pos_cpu[i][m]
        curr_z = z_cpu[i][m]
        curr_cell = cell_cpu[i]

        # 2. 创建临时 ASE Atoms 对象 (为了复用 neighbor_list 逻辑)
        # 注意: 这里假设是周期性的，如果是分子则 pbc=[False, False, False]
        temp_atoms = Atoms(numbers=curr_z, positions=curr_pos, cell=curr_cell, pbc=True)
        num_atoms = len(temp_atoms)
        batch = torch.zeros(num_atoms, dtype=torch.long).to(device)

        # 3. 计算邻居表 - 严格复刻你提供的 'ijdS' 逻辑
        i_idx, j_idx, _, S_integers = neighbor_list('ijdS', temp_atoms, cutoff)

        # 4. 组装为 PyG Data 对象
        # 将结果转回指定的 device
        data = Data(
            z=torch.from_numpy(curr_z).to(torch.long).to(device),
            pos=pos[i][mask_cpu[i]], # 保持梯度追踪
            cell=cell[i].unsqueeze(0), # 形状 [1, 3, 3]
            edge_index=torch.tensor(np.vstack((i_idx, j_idx)), dtype=torch.long).to(device),
            shifts_int=torch.from_numpy(S_integers).to(torch.float32).to(device),
            batch=batch
        )
        data.num_nodes = len(curr_z)
        data.num_graphs = 1
        data_list.append(data)

    # 5. 合并为 PyG Batch
    # Batch.from_data_list 会自动处理 batch 属性和索引偏移
    pyg_batch = Batch.from_data_list(data_list)
    
    return pyg_batch.to(device), data_list

## 4. 模型 forward 计算各种性质

In [None]:
from lmy.src.utils import HTGP_Calculator

# convert batch
pyg_data, data_list = dense_to_pyg_batch(pos, species_indices, cell, node_mask, cutoff=6.0, device=device)
atoms_list = dense_to_atoms_list(pos, species_indices, cell, node_mask)

# HTGP calculator
lmy_model.eval()
calculator = HTGP_Calculator(lmy_model, cutoff=6.0, device=device)

# for data in data_list:
#     data.pos.requires_grad = True
#     calc_stress = True
#     displacement = torch.zeros((1, 3, 3), dtype=data.pos.dtype, device=device)
#     displacement.requires_grad = True
#     symmetric_strain = 0.5 * (displacement + displacement.transpose(-1, -2))
#     strain_on_graph = symmetric_strain[0] # Assuming single graph for simplicity
#     pos_deformed = pos + torch.matmul(pos, strain_on_graph.T)
#     cell_deformed = cell + torch.bmm(cell, symmetric_strain)
#     data.pos = pos_deformed
#     data.cell = cell_deformed
#     energy = lmy_model(data, capture_weights=self.capture_weights, capture_descriptors=self.capture_descriptors)  

for atoms in atoms_list:
    atoms.calc = calculator
    energy = calculator.get_potential_energy(atoms)
    print(f"Predicted Energy: {energy} eV")
    force = calculator.get_forces(atoms)
    print(f"positions:, {atoms.get_positions()}, Predicted Forces:\n{force}")
    stress = calculator.get_stress(atoms)