# HTGP模型接入

## 1. 导入checkpoint

In [2]:
from lmy.src.models import HTGPModel
import torch
import sys
sys.path.append("./lmy/")
HTGP_PATH = "../lmy_checkpoints/Checkpoints_break_2/model_epoch_47.pt"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dtypes = torch.float32 

def load_HTGP_model(path: str, device: torch.device) -> HTGPModel:
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    config = checkpoint['model_config']
    model = HTGPModel(config)

    state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v 
        else:
            new_state_dict[k] = v
    try:
        model.load_state_dict(new_state_dict, strict=False)
        print("✅ 模型参数加载成功！")
    except RuntimeError as e:
        print(f"❌ 加载依然失败，请检查 Config 是否与训练一致。\n详细错误: {e}")

    # 打印模型信息
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型总参数量: {total_params}")
    print(model.cfg)
    dtypes = {p.dtype for p in model.parameters()}
    print(f"模型参数数据类型: {dtypes}")

    model.to(device)

    return model

lmy_model = load_HTGP_model(HTGP_PATH, device)



✅ 模型参数加载成功！
模型总参数量: 630440
HTGPConfig(num_atom_types=100, hidden_dim=96, num_layers=2, cutoff=6.0, num_rbf=10, atom_types_map=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], use_L0=True, use_L1=True, use_L2=True, use_gating=True, avg_neighborhood=61.481112820672955, use_long_range=False, use_charge=False, use_vdw=False, use_dipole=False, FINETUNE_MODE=True, PRETRAINED_CKPT='Checkpoints_Old/model_epoch_50.pt', steps_per_epoch=None, long_range_scale=1, active_paths={(0, 0, 0, 'prod'): True, (0, 1, 1, 'prod'): True, (0, 2, 2, 'prod'): True, (1, 0, 1, 'prod'): True, (1, 1, 0, 'dot'): True, (1, 1, 1, 'cross'): True, (1, 1, 2, 'outer'): True, (2, 0, 2,

## 2. 架构内数据流模拟

In [3]:
# 1. 设置基础参数
B, N, C = 2, 4, 60  # Batch Size, Max Atoms, Element Types
pos = torch.rand((B, N, 3), device=device) * 1.0  # 随机生成位置，范围在 [0, 5)
species_indices = torch.randint(1, C, (B, N), device=device)
species = torch.nn.functional.one_hot(species_indices, num_classes=C).float()
cell = torch.eye(3, device=device).unsqueeze(0).repeat(B, 1, 1) * 10.0  # 简单的立方体晶胞
node_mask = torch.tensor([
    [[1.0], [1.0], [1.0], [1.0]], 
    [[1.0], [1.0], [1.0], [0.0]]
], device=device)
# 打印结果检查
print(f"Pos shape:      {pos.shape}")         # [2, 4, 3]
print(f"Species shape:  {species.shape}")     # [2, 4, 5]
print(f"Cell shape:     {cell.shape}")        # [2, 3, 3]
print(f"Node Mask shape: {node_mask.shape}")      # [2, 4, 1]

Pos shape:      torch.Size([2, 4, 3])
Species shape:  torch.Size([2, 4, 60])
Cell shape:     torch.Size([2, 3, 3])
Node Mask shape: torch.Size([2, 4, 1])


## 3. 数据转换成 HTGP 输入形式

In [4]:
import torch
import numpy as np
from ase import Atoms
from ase.neighborlist import neighbor_list
from torch_geometric.data import Data, Batch


def dense_to_atoms_list(pos, species, cell, node_mask) -> list[Atoms]:
    """
    将 Dense Batch 转换为 ASE Atoms 列表
    
    参数:
    pos: [B, N, 3] - 坐标
    species: [B, N, C] - 原子序数
    cell: [B, 3, 3] - 晶胞
    node_mask: [B, N, 1] - 掩码
    """
    B, N, _ = pos.shape
    atoms_list = []
    
    # 将 Tensor 转到 CPU 方便 ASE 处理
    pos_cpu = pos.detach().cpu().numpy()
    cell_cpu = cell.detach().cpu().numpy()
    species = torch.argmax(species, dim=-1)
    z_cpu = species.detach().cpu().numpy()
    mask_cpu = node_mask.detach().cpu().numpy().squeeze(-1) > 0

    for i in range(B):
        # 提取有效原子
        m = mask_cpu[i]
        curr_pos = pos_cpu[i][m]
        curr_z = z_cpu[i][m]
        curr_cell = cell_cpu[i]

        # 创建 ASE Atoms 对象
        temp_atoms = Atoms(numbers=curr_z, positions=curr_pos, cell=curr_cell, pbc=True)
        atoms_list.append(temp_atoms)

    return atoms_list

def dense_to_pyg_batch(pos, species, cell, node_mask, cutoff, device):
    """
    参考 _atoms_to_pyg_data 的逻辑，将 Dense Batch 转换为 PyG Batch
    
    参数:
    pos: [B, N, 3] - 坐标
    species: [B, N, C] - 原子序数
    cell: [B, 3, 3] - 晶胞
    node_mask: [B, N, 1] - 掩码
    """
    B, N, _ = pos.shape
    data_list = []
    
    # 将 Tensor 转到 CPU 方便 ASE 处理邻居表 (ASE 不支持 GPU 邻居表计算)
    pos_cpu = pos.detach().cpu().numpy()
    cell_cpu = cell.detach().cpu().numpy()
    species = torch.argmax(species, dim=-1)
    z_cpu = species.detach().cpu().numpy()
    mask_cpu = node_mask.detach().cpu().numpy().squeeze(-1) > 0

    for i in range(B):
        # 1. 提取有效原子
        m = mask_cpu[i]
        curr_pos = pos_cpu[i][m]
        curr_z = z_cpu[i][m]
        curr_cell = cell_cpu[i]

        # 2. 创建临时 ASE Atoms 对象 (为了复用 neighbor_list 逻辑)
        # 注意: 这里假设是周期性的，如果是分子则 pbc=[False, False, False]
        temp_atoms = Atoms(numbers=curr_z, positions=curr_pos, cell=curr_cell, pbc=True)
        num_atoms = len(temp_atoms)
        batch = torch.zeros(num_atoms, dtype=torch.long).to(device)

        # 3. 计算邻居表 - 严格复刻你提供的 'ijdS' 逻辑
        i_idx, j_idx, _, S_integers = neighbor_list('ijdS', temp_atoms, cutoff)

        # 4. 组装为 PyG Data 对象
        # 将结果转回指定的 device
        data = Data(
            z=torch.from_numpy(curr_z).to(torch.long).to(device),
            pos=pos[i][mask_cpu[i]], # 保持梯度追踪
            cell=cell[i].unsqueeze(0), # 形状 [1, 3, 3]
            edge_index=torch.tensor(np.vstack((i_idx, j_idx)), dtype=torch.long).to(device),
            shifts_int=torch.from_numpy(S_integers).to(torch.float32).to(device),
            batch=batch
        )
        data.num_nodes = len(curr_z)
        data.num_graphs = 1
        data_list.append(data)

    # 5. 合并为 PyG Batch
    # Batch.from_data_list 会自动处理 batch 属性和索引偏移
    pyg_batch = Batch.from_data_list(data_list)
    
    return pyg_batch.to(device), data_list

## 4. 模型 forward 计算各种性质

In [5]:
from lmy.src.utils import HTGP_Calculator

# convert batch
pyg_data, data_list = dense_to_pyg_batch(pos, species, cell, node_mask, cutoff=6.0, device=device)
atoms_list = dense_to_atoms_list(pos, species, cell, node_mask)

# HTGP calculator
lmy_model.eval()
lmy_calculator = HTGP_Calculator(lmy_model, cutoff=6.0, device=device)

for atoms in atoms_list:
    atoms.calc = lmy_calculator
    energy = lmy_calculator.get_potential_energy(atoms)
    print(f"Predicted Energy: {energy} eV")
    force = lmy_calculator.get_forces(atoms)
    print(f"positions:, {atoms.get_positions()}, \nPredicted Forces:\n{force}")
    stress = lmy_calculator.get_stress(atoms)
    print(f"Predicted Stress:\n{stress}")

Predicted Energy: 775.1927490234375 eV
positions:, [[0.83899093 0.4275907  0.77417934]
 [0.12738533 0.98398918 0.37296629]
 [0.02028096 0.18894708 0.65469432]
 [0.33082721 0.44240227 0.49594921]], 
Predicted Forces:
[[ 241.72588   203.49521    21.66552 ]
 [-215.67813   444.99152     3.932743]
 [-643.8226    -96.32204    60.000954]
 [ 617.7749   -552.16473   -85.59921 ]]
Predicted Stress:
[-0.36665094 -0.26240054 -0.01506919  0.01339851  0.00842356 -0.04279144]
Predicted Energy: 289.4280700683594 eV
positions:, [[0.59786505 0.17394595 0.52950513]
 [0.91123998 0.02964883 0.98903912]
 [0.85767126 0.32990023 0.30413806]], 
Predicted Forces:
[[-456.71783  -251.79047   346.9703  ]
 [   9.063694   57.66429  -122.765656]
 [ 447.65414   194.12616  -224.20465 ]]
Predicted Stress:
[-0.11914366 -0.02195402  0.00588663  0.01725095  0.09672148 -0.06850573]


## 5. 通过本地数据计算 energy_above_hull

### 5.1 导入本地数据库

In [6]:
from monty.serialization import loadfn
database_path =  "../mp_stable_reference_84el.json.gz"
local_database = loadfn(database_path)

  syms: list[str] = sorted(sym_amt, key=lambda x: [get_el_sp(x).X, x])
  syms: list[str] = sorted(sym_amt, key=lambda x: [get_el_sp(x).X, x])
  syms: list[str] = sorted(sym_amt, key=lambda x: [get_el_sp(x).X, x])


### 5.2 本地计算器

In [10]:
import numpy as np
from pymatgen.core import Composition, Structure
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry
from sklearn.linear_model import LinearRegression 

class LocalHullCalculator:
    def __init__(self, all_entries):
        """
        全本地计算器。
        
        Args:
            all_entries (list[PDEntry]): 
                你从本地文件加载进来的所有参考相。
                (e.g. pd_entries = loadfn('my_local_mp_database.json'))
        """
        self.all_entries = all_entries
        self.model_offsets = {} 
        self.is_calibrated = False

    def calibrate(self, validation_data):
        """
        算出 MLP 和 MP 之间的系统误差。
        
        Args:
            validation_data (list[dict]): 你的本地验证集数据
            格式: [{"composition": "Li2O", "e_mlp": -14.2, "e_mp": -15.1}, ...]
            注意: 能量必须是 Total Energy (eV)，不是 per atom。
        """
        print(">>> [本地] 正在计算能量校准参数...")
        
        # 1. 提取元素列表
        elements = set()
        for d in validation_data:
            comp = Composition(d["composition"])
            elements.update([str(e) for e in comp.elements])
        
        sorted_elems = sorted(list(elements))
        el_map = {el: i for i, el in enumerate(sorted_elems)}
        
        # 2. 构建方程 Ax = b
        # A = 原子数矩阵, x = offset, b = (E_mlp - E_mp)
        A = []
        b = []
        
        for d in validation_data:
            comp = Composition(d["composition"])
            row = [0.0] * len(sorted_elems)
            for el, amt in comp.items():
                row[el_map[str(el)]] = amt
            A.append(row)
            b.append(d["e_mlp"] - d["e_mp"]) # 差值

        # 3. 求解线性方程 (最小二乘法)
        reg = LinearRegression(fit_intercept=False) # 必须无截距
        reg.fit(A, b)
        
        # 4. 存入字典
        self.model_offsets = dict(zip(sorted_elems, reg.coef_))
        self.is_calibrated = True
        
        print(f">>> 校准完成。共校准 {len(sorted_elems)} 种元素。")
        # print(self.model_offsets) # 调试用


    def set_model_offsets(self, offsets_dict):
        """
        手动设置模型偏差。
        
        Args:
            offsets_dict (dict): {元素: eV/atom}
        """
        self.model_offsets = offsets_dict
        self.is_calibrated = True
        print(f">>> 手动设置模型偏差，共设置 {len(offsets_dict)} 种元素。")


    def get_ehull(self, composition_dict, total_energy):
        """
        计算 Energy Above Hull (本地版)
        
        Args:
            composition_dict: 字典，如 {"Li": 2, "Fe": 1, "O": 4}
            total_energy: 你的模型预测的【结构总能量】(Total eV)，不要传 eV/atom
        Returns:
            float: Energy Above Hull (eV/atom)
            None: 如果本地库缺数据算不出来
        """
        # 1. 解析成分
        comp = Composition(composition_dict)
        chemsys = set(str(el) for el in comp.elements)

        # 2. 能量校准
        # 确保 self.model_offsets 是预先算好的 {元素: eV/atom}
        correction = sum(comp[el] * self.model_offsets.get(str(el), 0) for el in comp.elements)
        corrected_energy = total_energy - correction
        
        # 3. 从本地库筛选参考相
        # 替代 mpr.get_entries_in_chemsys
        # 遍历 self.all_entries，只保留元素集合属于当前 chemsys 子集的条目
        relevant_entries = [
            e for e in self.all_entries 
            if set(str(el) for el in e.composition.elements).issubset(chemsys)
        ]

        # --- 安全检查 (防崩溃) ---
        if not relevant_entries:
            print(f"Error: 本地库中找不到体系 {chemsys} 的任何参考条目！")
            return None 

        # 4. 构建当前结构的 PDEntry
        target_entry = PDEntry(comp, corrected_energy)

        # 5. 构建相图并计算
        try:
            # 这里 Pymatgen 会自动用 relevant_entries 构建凸包
            pd = PhaseDiagram(relevant_entries)
            
            # 计算 target_entry 距离凸包的垂直距离
            e_above_hull = pd.get_e_above_hull(target_entry)
            
            return e_above_hull
            
        except Exception as e:
            # 常见错误：本地库虽然有数据，但缺端点元素(比如缺纯 Li 或纯 O)，无法画出封闭的相图
            print(f"相图构建失败 (可能缺失端点元素): {chemsys}, 错误信息: {e}")
            return None


# 初始化计算器
HullCalculator = LocalHullCalculator(all_entries=local_database)

# 用验证集做校准, 让 MLP 能量和 MP 能量“对齐”
my_val_data = [
    {"composition": "Li", "e_mlp": -1.8, "e_mp": -1.9},     # MLP比MP高0.1
    {"composition": "O2", "e_mlp": -5.8, "e_mp": -9.88},    # MLP比MP高4.08 (缺气体校正)
]
HullCalculator.calibrate(my_val_data)

# 开始批量计算 (完全离线)
# 假设这是你 MLP 预测的新结构
new_comp = {"Li": 2, "O": 1} 
new_energy = -12.1 # MLP 预测值

ehull = HullCalculator.get_ehull(new_comp, new_energy)

print(f"Structure: {new_comp}")
print(f"Energy Above Hull: {ehull:.4f} eV/atom")

>>> [本地] 正在计算能量校准参数...
>>> 校准完成。共校准 2 种元素。
Structure: {'Li': 2, 'O': 1}
Energy Above Hull: 0.8289 eV/atom


  warn("Using UFloat objects with std_dev==0 may give unexpected results.")


### 5.2 批量计算 offset

In [11]:
import json
import random
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from sklearn.linear_model import LinearRegression
from pymatgen.core import Composition
from pymatgen.io.ase import AseAtomsAdaptor

def generate_offsets_for_htgp(all_entries, lmy_model, device, save_path, num_samples=1500):
    """
    针对 HTGP 模型接口的大批量能量校准 (保底抽样版)
    输入输出维持原样，确保 84 元素全覆盖。
    """
    # 1. 初始化模型和转换器
    adaptor = AseAtomsAdaptor()
    calculator = HTGP_Calculator(lmy_model, cutoff=6.0, device=device)
    lmy_model.eval()

    # 2. 元素保底抽样逻辑 (核心改进)
    print(">>> 正在进行元素分布分析与保底抽样...")
    el_to_entries = defaultdict(list)
    for entry in all_entries:
        for el in entry.composition.elements:
            el_to_entries[str(el)].append(entry)

    selected_ids = set()
    samples_per_el = 10 # 每种元素最少抓 10 个结构（如果有的话）

    # 遍历所有存在的元素进行保底
    for el, entries in el_to_entries.items():
        draws = random.sample(entries, min(len(entries), samples_per_el))
        for d in draws:
            selected_ids.add(d.entry_id)

    # 如果还没达到 num_samples，随机补齐
    remaining_count = num_samples - len(selected_ids)
    if remaining_count > 0:
        all_ids = [e for e in all_entries if e.entry_id not in selected_ids]
        additional_draws = random.sample(all_ids, min(len(all_ids), remaining_count))
        for d in additional_draws:
            selected_ids.add(d.entry_id)

    # 提取最终抽样列表
    sampled_entries = [e for e in all_entries if e.entry_id in selected_ids]
    
    # 3. 批量推理生成数据
    val_data = []
    print(f">>> 开始模型推理，生成校准数据 (有效样本数: {len(sampled_entries)})...")

    for entry in tqdm(sampled_entries):
        try:
            # Pymatgen -> ASE
            atoms = adaptor.get_atoms(entry.structure)
            atoms.calc = calculator
            # 获取 MLP 总能
            e_mlp = calculator.get_potential_energy(atoms)
            
            val_data.append({
                "composition": entry.composition,
                "e_mlp": float(e_mlp),
                "e_mp": float(entry.energy)
            })
        except Exception:
            continue

    # 4. 构建线性回归 Ax = b
    elements = set()
    for d in val_data:
        elements.update([str(el) for el in d["composition"].elements])
    
    sorted_elems = sorted(list(elements))
    el_map = {el: i for i, el in enumerate(sorted_elems)}
    
    A, b = [], []
    for d in val_data:
        comp = d["composition"]
        row = [0.0] * len(sorted_elems)
        for el, amt in comp.items():
            row[el_map[str(el)]] = amt
        A.append(row)
        b.append(d["e_mlp"] - d["e_mp"])

    # 求解 Offset
    reg = LinearRegression(fit_intercept=False)
    reg.fit(A, b)
    offsets = dict(zip(sorted_elems, reg.coef_))
    
    # 保存结果
    with open(save_path, "w") as f:
        json.dump(offsets, f, indent=4)
    
    print(f">>> 校准完成！Offset 文件已保存至: {save_path}")
    print(f">>> 共覆盖元素数量: {len(offsets)}")
    
    # 检查是否覆盖了全部 84 元素（可选提示）
    if len(offsets) < 84:
        print(f"提示：当前数据库及抽样仅覆盖了 {len(offsets)} 种元素，请确保这已包含你研究的所有体系。")

    return offsets

OFFSET_PATH = "../htgp_to_mp_offsets.json"
generate_offsets_for_htgp(local_database, lmy_model, device, OFFSET_PATH, num_samples=1500)
with open(OFFSET_PATH, "r") as f:
    model_offsets = json.load(f)
print("succuessfully generated offsets !")

>>> 正在进行元素分布分析与保底抽样...
>>> 开始模型推理，生成校准数据 (有效样本数: 2086)...


  warn("Using UFloat objects with std_dev==0 may give unexpected results.")
100%|██████████| 2086/2086 [03:31<00:00,  9.86it/s]

>>> 校准完成！Offset 文件已保存至: ../htgp_to_mp_offsets.json
>>> 共覆盖元素数量: 84
succuessfully generated offsets !





### 5.4 最终脚本

In [14]:
from collections import Counter

HullCalculator.set_model_offsets(model_offsets)

new_comp = {"Li": 2, "O": 1} 
new_energy = -12.1 # MLP 预测值
ehull = HullCalculator.get_ehull(new_comp, new_energy)
print(f"Structure: {new_comp}")
print(f"Energy Above Hull: {ehull:.4f} eV/atom")

for atoms in atoms_list:
    pos = atoms.get_positions()
    comp = dict(Counter(atoms.get_chemical_symbols())) # ex: {'Br': 1, 'P': 1, 'C': 1}

    # MLP energy calculation
    atoms.calc = lmy_calculator
    energy = lmy_calculator.get_potential_energy(atoms)
    force = lmy_calculator.get_forces(atoms)
    stress = lmy_calculator.get_stress(atoms)
    
    # energy_hull calculation
    ehull = HullCalculator.get_ehull(comp, energy)
    print(f"Energy Above Hull: {ehull:.4f} eV/atom")

>>> 手动设置模型偏差，共设置 84 种元素。
Structure: {'Li': 2, 'O': 1}
Energy Above Hull: 1.8633 eV/atom
Energy Above Hull: 210.2184 eV/atom
Energy Above Hull: 106.1852 eV/atom
