In [13]:
import numpy as np

def build_lattice_matrix(a, b, c, alpha, beta, gamma):
    # 角度转弧度
    alpha = np.radians(alpha)
    beta = np.radians(beta)
    gamma = np.radians(gamma)

    # 三角函数
    cos_alpha = np.cos(alpha)
    cos_beta = np.cos(beta)
    cos_gamma = np.cos(gamma)
    sin_gamma = np.sin(gamma)

    # 基矢构造
    a1 = [a, 0, 0]
    a2 = [b * cos_gamma, b * sin_gamma, 0]

    cx = c * cos_beta
    cy = c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma
    cz = c * np.sqrt(
        1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 +
        2 * cos_alpha * cos_beta * cos_gamma
    ) / sin_gamma

    a3 = [cx, cy, cz]

    return np.array([a1, a2, a3])

def frac_to_cart(frac_coords, lattice):
    # lattice: 3x3 matrix, row-wise
    return np.dot(frac_coords, lattice)

def wrap_structure(s):
    if hasattr(s, "get_wrapped_structure"):
        return s.get_wrapped_structure()
    elif hasattr(s, "wrap_sites"):
        s = s.copy()
        s.wrap_sites(in_place=True)
        return s
    else:
        raise AttributeError("Your pymatgen version is too old; please upgrade.")


dir = "../sampled_50_batchsize/"

from pymatgen.core import Lattice, Structure
import os
# files = os.listdir(dir)
# # 导入files中的每一个结构文件
# for file in files[0:1]:
#     if file.endswith(".cif"):
#         filepath = os.path.join(dir, file)
#         structure = Structure.from_file(filepath)
#         print(f"Loaded structure from {file}:")
#         print(structure)

#         lattice_matrix = structure.lattice.matrix
#         print("Lattice matrix:")
#         print(lattice_matrix)
#         a = structure.lattice.a
#         b = structure.lattice.b
#         c = structure.lattice.c
#         alpha = structure.lattice.alpha
#         beta = structure.lattice.beta
#         gamma = structure.lattice.gamma
#         lattice_another = build_lattice_matrix(a, b, c, alpha, beta, gamma)
#         print("Reconstructed lattice matrix:")
#         print(lattice_another)

#         fractional_coords = structure.frac_coords
#         print("Fractional coordinates:")
#         print(fractional_coords)
#         # cartesian_coords = frac_to_cart(fractional_coords, lattice_matrix)
#         # print("Converted Cartesian coordinates:")
#         # print(cartesian_coords)
        


# 重新定义一个结构
lattice = Lattice.from_parameters(a=3, b=3, c=3, alpha=90, beta=90, gamma=90)
coords = [[100, 0, 0]]  # Cartesian
species = ["Si"]

# 创建结构对象时指定坐标类型
s = Structure(lattice, species, coords, coords_are_cartesian=True)
print("Original Cartesian coordinates:")
print(s.cart_coords)

# # 自动将原子“包裹”进主晶胞：
# s = wrap_structure(s)
fractional_coords = s.frac_coords % 1.0  # ensure fractional coordinates are within [0, 1)
print("Fractional coordinates:")
print(fractional_coords)


Original Cartesian coordinates:
[[100.   0.   0.]]
Fractional coordinates:
[[0.33333333 0.         1.        ]]


In [5]:
import numpy as np
from pymatgen.core import Lattice, Structure, Element

# ===== 你的输入 =====
# 举例：
a, b, c = 5.0, 6.0, 7.0
alpha, beta, gamma = 90, 100, 120

# N 个原子的笛卡尔坐标（单位：Å，可能超出晶胞）
cart_coords = np.array([
    [1.0, 2.0, 3.0],
    [8.0, 0.0, 0.0],   # 超出晶胞
])

# 原子种类 one-hot (假设有 3 种元素: H, C, O)
atom_onehot = np.array([
    [1, 0, 0],  # H
    [0, 1, 0],  # C
])

# 对应的原子种类列表（顺序必须与 one-hot 的列对应）
atom_classes = ["H", "C", "O"]

# ===== 1️⃣ 解析原子种类 =====
atom_indices = np.argmax(atom_onehot, axis=1)
species = [atom_classes[i] for i in atom_indices]

# ===== 2️⃣ 构造晶格 =====
lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)

# ===== 3️⃣ 构造结构（坐标为笛卡尔坐标）=====
structure = Structure(lattice, species, cart_coords, coords_are_cartesian=True)

# ===== 5️⃣ 导出 CIF 文件 =====
# structure.to(fmt="cif", filename="generated_structure.cif")

# print("✅ CIF 文件已保存：generated_structure.cif")
print(structure)


Full Formula (H1 C1)
Reduced Formula: HC
abc   :   5.000000   6.000000   7.000000
angles:  90.000000 100.000000 120.000000
pbc   :       True       True       True
Sites (2)
  #  SP           a        b         c
---  ----  --------  -------  --------
  0  H     0.442449  0.38691  0.48345
  1  C     1.62468   0        0.201517


In [None]:
import torch
current_space_group_ops = [
                    {
                        "R": -torch.eye(3),
                        "t": torch.zeros(3),
                    },
                    {
                        "R": torch.eye(3),
                        "t": torch.zeros(3),
                    }
                ]
print(len(current_space_group_ops))

2
