In [None]:
# view atom in `checkpoints/modal_results/single_modal.xyz` using the plot atoms func in pete_data_gen.py


In [None]:
import sys
import importlib.util
from pathlib import Path

import torch
import numpy as np
from ase.io import read
from omg.datamodule.dataloader import OMGData

PROJECT_ROOT = Path("/Users/kosta/Documents/Research/martiniani/omatg_TomEgg")
DATA_GEN_DIR = PROJECT_ROOT / "code" / "data-gen"
PETE_DATA_GEN_PATH = DATA_GEN_DIR / "pete_data_gen.py"

spec = importlib.util.spec_from_file_location("pete_data_gen", PETE_DATA_GEN_PATH)
pete_data_gen = importlib.util.module_from_spec(spec)
sys.modules["pete_data_gen"] = pete_data_gen
spec.loader.exec_module(pete_data_gen)
plot_crystal_with_points = pete_data_gen.plot_crystal_with_points


def ase_atoms_to_omg_data(atoms):
    species_np = atoms.numbers.astype("int64")
    ghost_mask = np.zeros_like(species_np, dtype=bool)

    if "is_ghost" in atoms.arrays:
        ghost_mask |= atoms.get_array("is_ghost").astype(bool)

    if "ghost_atomic_number" in atoms.arrays:
        ghost_values = np.unique(atoms.get_array("ghost_atomic_number")).astype("int64")
        ghost_mask |= np.isin(species_np, ghost_values)

    if not ghost_mask.any():
        ghost_mask |= species_np <= 0

    species = torch.from_numpy(species_np)
    if ghost_mask.any():
        species = species.masked_fill(torch.from_numpy(ghost_mask), -1)

    data = OMGData()
    data.pos = torch.from_numpy(atoms.positions).to(dtype=torch.float64)
    data.species = species.long()
    data.cell = torch.from_numpy(atoms.cell.array).to(dtype=torch.float64).unsqueeze(0)
    data.n_atoms = torch.tensor(len(atoms), dtype=torch.long)
    data.batch = torch.zeros(len(atoms), dtype=torch.long)
    return data


XYZ_PATH = PROJECT_ROOT / "checkpoints" / "modal_results" / "single_modal.xyz"
atoms_list = read(XYZ_PATH, index=":")
print(f"Loaded {len(atoms_list)} structures from {XYZ_PATH}")

sample_idx = 0  # change this to inspect a different structure
sample_data = ase_atoms_to_omg_data(atoms_list[sample_idx])
plot_crystal_with_points(sample_data, sample_idx, title=f"single_modal idx {sample_idx}")


Loaded 44 structures from /Users/kosta/Documents/Research/martiniani/omatg_TomEgg/checkpoints/modal_results/single_modal.xyz


In [4]:
# XYZ_PATH = PROJECT_ROOT / "checkpoints" / "modal_results" / "20251122-122521" / "generated_modal.xyz"
XYZ_PATH = PROJECT_ROOT / "checkpoints" / "inference_results" / "20251122-144212" / "generated_modal.xyz"
atoms_list = read(XYZ_PATH, index=":")   # load every structure
print(f"Loaded {len(atoms_list)} structures from {XYZ_PATH}")

for sample_idx, atoms in enumerate(atoms_list[:5]):  # adjust slice as needed
    sample_data = ase_atoms_to_omg_data(atoms)
    plot_crystal_with_points(sample_data, sample_idx, title=f"generated_modal idx {sample_idx}")

Loaded 44888 structures from /Users/kosta/Documents/Research/martiniani/omatg_TomEgg/checkpoints/inference_results/20251122-144212/generated_modal.xyz
