In [3]:
from ase.io import write,read
from ase.visualize import view
from src.models.loss.validation import _structural_validity
import torch
from pymatgen.io.ase import AseAtomsAdaptor

import json
import sys
import warnings
from pathlib import Path
from typing import Dict, Tuple
from multiprocessing import cpu_count, get_context
import argparse
import contextlib
import io

from ase import Atoms
from ase.io import read
from ase.optimize import LBFGS
from fairchem.core import pretrained_mlip, FAIRChemCalculator
from tqdm import tqdm
from ase.constraints import FixAtoms

# Suppress warnings
warnings.filterwarnings('ignore')

OC20_GAS_PHASE_ENERGIES = {
    'H': -3.48483361833793,
    'O': -7.185616160375758,
    'C': -7.232295041080779,
    'N': -8.09079187764214,
} # Computed by UMA
# Reference molecules : N2, H2O, H2, CO
# E(C) = E(CO) - E(H2O) + E(H2)
# E(H) = 0.5 * E(H2)
# E(N) = 0.5 * E(N2)
# E(O) = E(H2O) - E(H2)

def get_adsorbate_energy_from_table(atoms_obj):
    """Calculate adsorbate energy from lookup table."""
    total_energy = 0.0
    symbols = atoms_obj.get_chemical_symbols()
    
    for atom_symbol in symbols:
        try:
            total_energy += OC20_GAS_PHASE_ENERGIES[atom_symbol]
        except KeyError:
            raise ValueError(
                f"Energy table does not contain '{atom_symbol}' atom. "
                f"Currently supported atoms are {list(OC20_GAS_PHASE_ENERGIES.keys())}."
            )
            
    return total_energy


def get_uma_calculator(model_name: str = "uma-s-1p1", device: str = "cuda") -> FAIRChemCalculator:
    """Get UMA calculator."""
    predictor = pretrained_mlip.get_predict_unit(model_name, device=device)
    return FAIRChemCalculator(predictor, task_name="oc20")


def relaxation_and_compute_adsorption_energy(
    calc: FAIRChemCalculator, 
    system: Atoms, 
    slab: Atoms, 
    adsorbate: Atoms
) -> Tuple[float, float, float, float, bool, bool]:
    try:
        system.calc = calc
        slab.calc = calc
        
        opt = LBFGS(system, logfile=None)
        with contextlib.redirect_stdout(io.StringIO()):
            converged_system = opt.run(0.05)
        steps_system = opt.get_number_of_steps()
        
        opt = LBFGS(slab, logfile=None)
        with contextlib.redirect_stdout(io.StringIO()):
            converged_slab = opt.run(0.05)
        steps_slab = opt.get_number_of_steps()
        
        e_sys = system.get_potential_energy()
        e_slab = slab.get_potential_energy()
        e_adsorbate = get_adsorbate_energy_from_table(adsorbate)
        e_ads = e_sys - (e_slab + e_adsorbate)
        
        print(f"System converged in {steps_system} steps.")
        print(f"Slab converged in {steps_slab} steps.")

        return e_ads, e_sys, e_slab, e_adsorbate, converged_system, converged_slab

    except Exception as e:
        print(f"WARNING: UMA calculation failed for a sample. Error: {e}", file=sys.stderr)
        e_adsorbate = get_adsorbate_energy_from_table(adsorbate)
        return 999.0, float('nan'), float('nan'), e_adsorbate, False, False

adaptor = AseAtomsAdaptor()

idx=3
struct_path = f"/home/jovyan/MinCatFlow/unrelaxed_samples/sp_all/24685.traj"
# details = torch.load(f"/home/jovyan/MinCatFlow/unrelaxed_samples/de_novo_generation/C2H2O/mask_test/dist_invalid/{idx}.pt")
gen_system = read(struct_path)

gen_system.center()

# c = FixAtoms(indices=[atom.index for atom in gen_system if atom.tag == 0])
# gen_system.set_constraint(c)

# calc = get_uma_calculator()
# slab = gen_system.copy()[gen_system.get_tags() != 2]
# adsorbate = gen_system.copy()[gen_system.get_tags() == 2]
# e_ads, e_sys, e_slab, e_adsorbate, converged_system, converged_slab = relaxation_and_compute_adsorption_energy(
#     calc, 
#     gen_system, 
#     slab, 
#     adsorbate
# )
# print("E_ads:", e_ads)

view(gen_system, viewer='ngl')

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'H', 'Se', 'Ga'), valu…

In [None]:
import os
from ase import io

base_traj_path = "/home/jovyan/MinCatFlow/unrelaxed_samples/L1_test/O/"
output_dir = os.path.join(base_traj_path, "cif_output")
os.makedirs(output_dir, exist_ok=True)

for filename in os.listdir(base_traj_path):
    if filename.endswith(".traj"):
        
        full_input_path = os.path.join(base_traj_path, filename)
        
        try:
            # index=-1 을 넣어주는 것이 안전합니다 (가장 마지막 구조만 읽기)
            # 넣지 않아도 기본값은 마지막이지만, 명시적으로 적는 것이 좋습니다.
            atoms = io.read(full_input_path, index=-1)
            
            # [수정된 부분 1] 파일명을 먼저 변수로 만듭니다.
            output_filename = filename.replace(".traj", ".cif")
            
            # [수정된 부분 2] 위에서 만든 파일명을 이용해 전체 경로를 만듭니다.
            full_output_path = os.path.join(output_dir, output_filename)
            
            io.write(full_output_path, atoms)
            
            # 이제 output_filename 변수가 존재하므로 에러가 나지 않습니다.
            print(f"Converted: {filename} -> {output_filename}")
            
        except Exception as e:
            print(f"Error converting {filename}: {e}")

print("모든 변환 작업이 완료되었습니다.")

In [None]:
print("prim_slab_coords :", details['prim_slab_coords'].count_nonzero()/3)
# print("\n")
print("prim_slab_atom_types", details['prim_slab_atom_types'].count_nonzero())
print("prim_slab_atom_mask", details['prim_slab_atom_mask'].sum())
print("ads_coords", details['ads_coords'])
print("ads_atom_mask", details['ads_atom_mask'])