In [1]:
lmdb_path ="data/is2re/all/train/data.lmdb"
mapping_path = "oc20_data_mapping.pkl"
index = 1
min_ab = 8.0
min_slab_size = 7.0
min_vacuum_size = 20.0

In [2]:
import lmdb
import pickle

db = lmdb.open(
        str(lmdb_path),
        subdir=False,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False,
    )

with db.begin() as txn:
    # Get data for index
    key = f"{index}".encode("ascii")
    value = txn.get(key)
    
    if value is None:
        raise ValueError(f"Index {index} not found in LMDB")
    
    # Deserialize with pickle
    data = pickle.loads(value)    

In [3]:
from trash_codes import extract_true_system_from_lmdb
true_system_atoms = extract_true_system_from_lmdb(lmdb_path, index)
slab_mask = (true_system_atoms.get_tags() == 0) | (true_system_atoms.get_tags() == 1)
true_slab_atoms = true_system_atoms[slab_mask].copy()


In [4]:
from trash_codes import get_sid_from_lmdb, get_slab_params_from_mapping
sid = get_sid_from_lmdb(lmdb_path, index)
slab_params = get_slab_params_from_mapping(mapping_path, sid)
bulk_src_id = slab_params['bulk_mpid']
specific_miller = slab_params['miller_index']
shift = slab_params['shift']
top = slab_params['top']

In [5]:
from fairchem.data.oc.core import Bulk
bulk = Bulk(bulk_src_id_from_db=bulk_src_id)

  import pkg_resources


In [6]:
from pymatgen.core.surface import SlabGenerator
from fairchem.data.oc.core.slab import standardize_bulk
import math

initial_structure = standardize_bulk(bulk.atoms)
slab_gen = SlabGenerator(
        initial_structure=initial_structure,
        miller_index=specific_miller,
        min_slab_size=7.0,
        min_vacuum_size=20.0,
        lll_reduce=False,
        center_slab=False,
        primitive=True,
        max_normal_search=1,
    )

height = slab_gen._proj_height
height_per_layer = round(height / slab_gen.parent.lattice.d_hkl(slab_gen.miller_index), 8)

if slab_gen.in_unit_planes:
    n_layers_slab = math.ceil(slab_gen.min_slab_size / height_per_layer)
    n_layers_vac = math.ceil(slab_gen.min_vac_size / height_per_layer)
else:
    n_layers_slab = math.ceil(slab_gen.min_slab_size / height)
    n_layers_vac = math.ceil(slab_gen.min_vac_size / height)

n_layers = n_layers_slab + n_layers_vac

In [7]:
print(n_layers)
print(n_layers_slab)
print(n_layers_vac)
print(height)


8
2
6
3.646657815772546


In [8]:
from ase.geometry import get_duplicate_atoms

slab_atoms = true_slab_atoms.copy()

new_cell = slab_atoms.get_cell()
new_cell[2] = new_cell[2] * n_layers_slab / n_layers

slab_atoms.set_cell(new_cell)

slab_atoms_wrap = slab_atoms.copy()
slab_atoms_center = slab_atoms.copy()
slab_atoms_wrap.wrap()
slab_atoms_center.center()

from helpers import calculate_rmsd_pymatgen
rmsd = calculate_rmsd_pymatgen(
        struct1=slab_atoms_wrap,
        struct2=slab_atoms_center,
        ltol=0.2, stol=0.3, angle_tol=5,
        primitive_cell = False,
        )
print(rmsd)

# slab_atoms.wrap()
get_duplicate_atoms(slab_atoms, cutoff=0.05, delete=True)


(np.float64(1.2996750573077628e-15), np.float64(4.017145845252586e-15))


array([], shape=(0, 2), dtype=int64)

In [9]:
from ase.visualize import view
# slab_atoms.center()
# view(true_slab_atoms, viewer='ngl')
view(slab_atoms, viewer='ngl')



HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'Ta', 'S'), value='All…

In [9]:
import numpy as np
from pymatgen.core import Structure, Lattice
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.ase import AseAtomsAdaptor

# 1. Setup: Load Supercell
adaptor = AseAtomsAdaptor()
supercell = adaptor.get_structure(slab_atoms)

In [10]:
# 2. Get the transformation matrix M (Standard Prim -> Supercell)
# We still need this to know "how much" to shrink the lattice.
sga = SpacegroupAnalyzer(supercell)
prim_std = sga.get_primitive_standard_structure()

view(adaptor.get_atoms(prim_std), viewer='ngl')

NameError: name 'view' is not defined

In [11]:
dataset = sga.get_symmetry_dataset()

In [12]:
M = dataset["transformation_matrix"] # Scaling matrix
origin_shift = dataset["origin_shift"]
# 3. Define the Aligned Primitive Lattice
# Shrink the supercell lattice by applying the inverse of M
# L_prim = L_super * inv(M)
M_inv = np.linalg.inv(M)

# Note: We apply the dot product to the lattice matrix
# (Pymatgen Lattice rows are vectors, so we multiply by M_inv)
# The math here: new_matrix = old_matrix @ M_inv
prim_matrix = np.dot(supercell.lattice.matrix.T, M_inv).T
aligned_lattice = Lattice(prim_matrix)

# 4. "Wrap": Create a structure with ALL atoms in the small lattice
# We dump every single atom from the supercell into this tiny box.
# Because the box is small, the fractional coordinates will be > 1.0
wrapped_struct = Structure(
    aligned_lattice,
    supercell.species,
    supercell.cart_coords,
    coords_are_cartesian=True
)

print(f"Sites before deduplication: {wrapped_struct.num_sites}")

# 5. "Deduplicate": Merge overlapping sites
# merge_sites handles the periodic boundaries. If an atom at 0.1 and 
# an atom at 1.1 (wrapped) overlap, it removes one.
wrapped_struct.merge_sites(tol=0.01, mode='delete')

print(f"Sites after deduplication: {wrapped_struct.num_sites}")

# Result is your perfectly aligned primitive cell
aligned_primitive = wrapped_struct
print("\nAligned Primitive Structure:")
print(aligned_primitive)

Sites before deduplication: 80
Sites after deduplication: 5

Aligned Primitive Structure:
Full Formula (Zr2 Se1 N2)
Reduced Formula: Zr2SeN2
abc   :   3.680750   3.680750   6.713781
angles:  90.000000  90.000000 119.999999
pbc   :       True       True       True
Sites (5)
  #  SP           a         b        c
---  ----  --------  --------  -------
  0  Se    -1.975    -3.975    1
  1  Zr    -2.30833  -4.64167  1.30435
  2  N     -2.64167  -4.30833  1.3682
  3  N     -2.30833  -4.64167  1.6318
  4  Zr    -2.64167  -4.30833  1.69565


  M = dataset["transformation_matrix"] # Scaling matrix


In [14]:
print(M)

[[ 3.00000000e+00  2.85764329e-16 -2.64942038e-16]
 [ 0.00000000e+00  2.00000000e+00  0.00000000e+00]
 [ 3.88578059e-16  4.16333634e-16  1.00000000e+00]]


In [15]:
view(adaptor.get_atoms(aligned_primitive), viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'Ta', 'S'), value='All…

In [16]:
print(supercell.lattice.abc)
print(supercell.lattice.angles)

(10.073967710071361, 14.238894448016747, 8.694892501831054)
(83.74094666579693, 86.67897510122955, 76.35931950510978)


In [17]:
M_int = np.round(M).astype(int)

recon_slab = aligned_primitive.copy()
recon_slab.make_supercell(M_int)

print(len(supercell))
print(len(recon_slab))

54
54


In [13]:
from pymatgen.analysis.structure_matcher import StructureMatcher

matcher = StructureMatcher(primitive_cell=False, ltol=0.2, stol=0.3, angle_tol=5)

# 표준화 없이 직접 primitive 추출
prim_direct = supercell.get_primitive_structure(tolerance=0.1)
print(f"Direct primitive atoms: {len(prim_direct)}")
print(f"Direct primitive composition: {prim_direct.composition}")

# Supercell matrix 계산
S_direct = np.dot(supercell.lattice.matrix, np.linalg.inv(prim_direct.lattice.matrix))
S_direct_int = np.round(S_direct).astype(int)
print(f"\nDirect S matrix:\n{S_direct_int}")
print(f"det(S_direct) = {np.linalg.det(S_direct_int)}")

# 재구성 테스트
recon_direct = prim_std.copy()
recon_direct.make_supercell(S_direct_int)
print(f"\nReconstructed atoms: {len(recon_direct)}")
print(f"Original supercell atoms: {len(supercell)}")

# 일치 확인
is_match = matcher.fit(recon_direct, supercell)
print(f"Does reconstructed match supercell? {is_match}")

Direct primitive atoms: 5
Direct primitive composition: N2 Se1 Zr2

Direct S matrix:
[[-2  2  0]
 [ 0  1 -2]
 [-2 -2  0]]
det(S_direct) = 15.999999999999998

Reconstructed atoms: 80
Original supercell atoms: 80
Does reconstructed match supercell? True


In [19]:
# view(adaptor.get_atoms(recon_slab), viewer='ngl')
view(adaptor.get_atoms(prim_direct), viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'Ta', 'S'), value='All…

In [20]:
from helpers import calculate_rmsd_pymatgen
rmsd = calculate_rmsd_pymatgen(
        struct1=recon_direct,
        struct2=supercell,
        ltol=0.2, stol=0.3, angle_tol=5,
        primitive_cell=False,
        )
print(rmsd)

(np.float64(1.7957223534309045e-07), np.float64(3.465470150891716e-07))


In [21]:
print(prim_std.lattice.abc)
print(prim_std.lattice.angles)
print(prim_direct.lattice.abc)
print(prim_direct.lattice.angles)


(3.3579892366904533, 7.119447224008374, 8.694892501831054)
(83.74094666579693, 86.67897510122955, 76.3593195051098)
(3.357989236690454, 7.119447224008374, 8.694892501831054)
(83.74094666579693, 86.67897510122955, 76.35931950510978)
