In [4]:
# %% Notebook cell: count atoms from a .inp file

from collections import Counter
from typing import Iterable, Tuple, Dict, List, Union
import pandas as pd

# Periodic table: index == atomic number
_PT = [
    None,
    "H","He","Li","Be","B","C","N","O","F","Ne",
    "Na","Mg","Al","Si","P","S","Cl","Ar",
    "K","Ca","Sc","Ti","V","Cr","Mn","Fe","Co","Ni","Cu","Zn",
    "Ga","Ge","As","Se","Br","Kr",
    "Rb","Sr","Y","Zr","Nb","Mo","Tc","Ru","Rh","Pd","Ag","Cd",
    "In","Sn","Sb","Te","I","Xe",
    "Cs","Ba","La","Ce","Pr","Nd","Pm","Sm","Eu","Gd","Tb","Dy","Ho","Er","Tm","Yb","Lu",
    "Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg",
    "Tl","Pb","Bi","Po","At","Rn",
    "Fr","Ra","Ac","Th","Pa","U","Np","Pu","Am","Cm","Bk","Cf","Es","Fm","Md","No","Lr",
    "Rf","Db","Sg","Bh","Hs","Mt","Ds","Rg","Cn","Nh","Fl","Mc","Lv","Ts","Og"
]

def _z_to_symbol(z: int) -> str:
    return _PT[z] if 0 < z < len(_PT) else f"Z{z}"

def _parse_species(token: str) -> str:
    """Normalize species token to a symbol; accepts atomic number or symbol."""
    t = token.strip()
    try:
        return _z_to_symbol(int(t))
    except ValueError:
        return t

def _is_comment_or_blank(line: str, comment_chars: str = "#;!") -> bool:
    s = line.strip()
    return (not s) or any(s.startswith(c) for c in comment_chars)

def count_atoms_from_lines(
    lines: Iterable[str],
    species_col: int = 0,
    comment_chars: str = "#;!"
) -> Tuple[int, Counter]:
    """
    Count total atoms and per-species counts from an iterable of lines.
    - species_col: 0-based index of the column containing species (default: first column)
    - comment_chars: lines starting with any of these are ignored
    """
    counts = Counter()
    total = 0
    for line in lines:
        if _is_comment_or_blank(line, comment_chars):
            continue
        parts = line.split()
        if not parts or len(parts) <= species_col:
            continue
        sp = _parse_species(parts[species_col])
        counts[sp] += 1
        total += 1
    return total, counts

def count_atoms_from_file(
    path: str,
    species_col: int = 0,
    comment_chars: str = "#;!"
) -> Tuple[int, Counter]:
    """Convenience wrapper to read from a file path."""
    with open(path, "r") as f:
        return count_atoms_from_lines(f, species_col=species_col, comment_chars=comment_chars)

def counts_to_dataframe(counts: Counter) -> pd.DataFrame:
    """Convert counts to a sorted DataFrame."""
    items = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))
    df = pd.DataFrame(items, columns=["Species", "Count"])
    df["Fraction"] = df["Count"] / df["Count"].sum()
    return df

def summarize_counts(
    counts: Counter,
    title: str = "Per-species counts"
) -> pd.DataFrame:
    """
    Return a DataFrame with Species, Count, Fraction, total atoms and a 'Cumulative' column.
    """
    df = counts_to_dataframe(counts)
    df["Cumulative"] = df["Fraction"].cumsum()
    df.attrs["title"] = title
    df.attrs["total_atoms"] = df["Count"].sum()
    return df
    # to return total as well, use: return df, df["Count"].sum()
    # here is the code for that
# --------------------------
# Example usage (uncomment and set your file path):
# path = "your_file.inp"
# total, counts = count_atoms_from_file(path, species_col=0)  # change species_col if needed
# print(f"Total atoms: {total}")
# display(summarize_counts(counts))
# --------------------------


In [16]:
# %% Formula unit from a .inp file (extension of the count code)

from collections import Counter
from typing import Iterable, Tuple, Dict, List, Union, Optional
import math
import pandas as pd

# ---------- Periodic Table + parser (same as before) ----------
_PT = [
    None,
    "H","He","Li","Be","B","C","N","O","F","Ne",
    "Na","Mg","Al","Si","P","S","Cl","Ar",
    "K","Ca","Sc","Ti","V","Cr","Mn","Fe","Co","Ni","Cu","Zn",
    "Ga","Ge","As","Se","Br","Kr",
    "Rb","Sr","Y","Zr","Nb","Mo","Tc","Ru","Rh","Pd","Ag","Cd",
    "In","Sn","Sb","Te","I","Xe",
    "Cs","Ba","La","Ce","Pr","Nd","Pm","Sm","Eu","Gd","Tb","Dy","Ho","Er","Tm","Yb","Lu",
    "Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg",
    "Tl","Pb","Bi","Po","At","Rn",
    "Fr","Ra","Ac","Th","Pa","U","Np","Pu","Am","Cm","Bk","Cf","Es","Fm","Md","No","Lr",
    "Rf","Db","Sg","Bh","Hs","Mt","Ds","Rg","Cn","Nh","Fl","Mc","Lv","Ts","Og"
]

def _z_to_symbol(z: int) -> str:
    return _PT[z] if 0 < z < len(_PT) else f"Z{z}"

def _parse_species(token: str) -> str:
    t = token.strip()
    try:
        return _z_to_symbol(int(t))
    except ValueError:
        return t

def _is_comment_or_blank(line: str, comment_chars: str = "#;!") -> bool:
    s = line.strip()
    return (not s) or any(s.startswith(c) for c in comment_chars)

# ---------- Counting helpers (as before) ----------
def count_atoms_from_lines(
    lines: Iterable[str],
    species_col: int = 0,
    comment_chars: str = "#;!"
) -> Tuple[int, Counter]:
    counts = Counter()
    total = 0
    for line in lines:
        if _is_comment_or_blank(line, comment_chars):
            continue
        parts = line.split()
        if not parts or len(parts) <= species_col:
            continue
        sp = _parse_species(parts[species_col])
        counts[sp] += 1
        total += 1
    return total, counts

def count_atoms_from_file(
    path: str,
    species_col: int = 0,
    comment_chars: str = "#;!"
) -> Tuple[int, Counter]:
    with open(path, "r") as f:
        return count_atoms_from_lines(f, species_col=species_col, comment_chars=comment_chars)

def counts_to_dataframe(counts: Counter) -> pd.DataFrame:
    items = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))
    df = pd.DataFrame(items, columns=["Species", "Count"])
    df["Fraction"] = df["Count"] / df["Count"].sum()
    return df

def summarize_counts(counts: Counter, title: str = "Per-species counts") -> pd.DataFrame:
    df = counts_to_dataframe(counts)
    df["Cumulative"] = df["Fraction"].cumsum()
    df.attrs["title"] = title
    return df

# ---------- NEW: formula unit computation ----------
def _gcd_list(ints: List[int]) -> int:
    g = 0
    for x in ints:
        g = math.gcd(g, int(x))
    return max(g, 1)

def _order_species(species: List[str], hill: bool = True, custom_order: Optional[List[str]] = None) -> List[str]:
    """
    Ordering for pretty formula strings:
      - If custom_order given, use that order first; remaining species follow after alphabetically.
      - Else if hill=True: C, H first (if present), then others alphabetical.
      - Else alphabetical.
    """
    sp_set = set(species)
    if custom_order:
        used = [s for s in custom_order if s in sp_set]
        rest = sorted([s for s in species if s not in used])
        return used + rest
    if hill:
        head = [s for s in ["C", "H"] if s in sp_set]
        tail = sorted([s for s in species if s not in {"C","H"}])
        return head + tail
    return sorted(species)

def empirical_formula_from_counts(
    counts: Counter,
    hill: bool = True,
    custom_order: Optional[List[str]] = None
) -> Tuple[Dict[str, int], str]:
    """
    Reduce per-species counts to the smallest integer ratio (empirical formula unit).
    Returns (dict_of_counts, pretty_string).
    """
    if not counts:
        return {}, ""

    # Reduce by GCD
    nums = [int(counts[s]) for s in counts]
    g = _gcd_list(nums)
    fu = {s: counts[s] // g for s in counts}

    # Pretty string
    order = _order_species(list(fu.keys()), hill=hill, custom_order=custom_order)
    def _fmt(spec, n): return f"{spec}{'' if n==1 else int(n)}"
    formula_str = "".join(_fmt(s, fu[s]) for s in order)

    return fu, formula_str

def number_of_formula_units(total_counts: Counter, fu_counts: Dict[str, int]) -> float:
    """
    Compute how many formula units are present in the supercell.
    Returns a float; if the supercell is an exact multiple of the formula unit,
    this will be an integer (within numerical tolerance).
    """
    # Use only species present in the formula unit (fu_counts should include all species)
    ratios = []
    for s, n_fu in fu_counts.items():
        if n_fu == 0:
            continue
        ratios.append(total_counts[s] / n_fu)
    if not ratios:
        return 0.0
    # Check consistency (all ratios equal)
    r0 = ratios[0]
    if any(abs(r - r0) > 1e-8 for r in ratios[1:]):
        # Not an exact multiple; still return average to indicate scale
        return sum(ratios) / len(ratios)
    return r0

def formula_unit_report(
    counts: Counter,
    hill: bool = True,
    custom_order: Optional[List[str]] = None
) -> pd.DataFrame:
    """
    Build a compact report with total counts, empirical formula counts, and per-FU normalization.
    """
    fu_counts, formula_str = empirical_formula_from_counts(counts, hill=hill, custom_order=custom_order)
    n_fu = number_of_formula_units(counts, fu_counts) if fu_counts else 0.0

    df = pd.DataFrame({
        "Species": list(counts.keys()),
        "TotalCount": [int(counts[s]) for s in counts],
        "PerFU": [fu_counts.get(s, 0) for s in counts],
    })
    df["Fraction"] = df["TotalCount"] / df["TotalCount"].sum()
    df.attrs["formula_str"] = formula_str
    df.attrs["n_formula_units"] = n_fu
    return df

# --------------------------
# Example usage (uncomment and set your file path):
# path = "your_file.inp"
# total, counts = count_atoms_from_file(path, species_col=1)  # e.g., second column holds Z
# print(f"Total atoms: {total}")
# display(summarize_counts(counts))
# rep = formula_unit_report(counts, hill=False, custom_order=["Li","La","Zr","O"])  # nice order for LLZO
# print("Empirical formula:", rep.attrs["formula_str"])
# print("Number of formula units in supercell:", rep.attrs["n_formula_units"])
# display(rep.sort_values("Species"))
# --------------------------


In [5]:
# %% Script to make a template model from given hyperparameters and save it
# Write E0s JSON

!mace_run_train \
  --name mace_T1_w1_template \
  --model MACE \
  --num_interactions 2 \
  --foundation_model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
  --foundation_model_readout \
  --multiheads_finetuning True \
  --heads "{"target_head": {"train_file": "/home/phanim/harshitrawat/summer/dummy.extxyz", "E0s": "{"3": -1.882, "8": -4.913, "40": -8.509, "57": -4.894}"},"pt_head": {"train_file": "/home/phanim/harshitrawat/summer/dummy.extxyz", "E0s": "{"3": -1.882, "8": -4.913, "40": -8.509, "57": -4.894}"}" \
  --atomic_numbers "[3,8,40,57]" \
  --valid_file /home/phanim/harshitrawat/summer/dummy.extxyz \
  --batch_size 2 \
  --valid_batch_size 1 \
  --device cpu \
  --forces_weight 10 \
  --energy_weight 50 \
  --stress_weight 0 \
  --lr 0.0002 \
  --scheduler_patience 4 \
  --clip_grad 1 \
  --weight_decay 1e-8 \
  --r_max 5.0 \
  --max_num_epochs 1 \
  --seed 10 \
  --patience 8 \
  --restart_latest




  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
Error while loading libcue_ops.so: libcuda.so.1: cannot open shared object file: No such file or directory
usage: mace_run_train [-h] [--config CONFIG] --name NAME [--seed SEED]
                      [--work_dir WORK_DIR] [--log_dir LOG_DIR]
                      [--model_dir MODEL_DIR]
                      [--checkpoints_dir CHECKPOINTS_DIR]
                      [--results_dir RESULTS_DIR]
                      [--downloads_dir DOWNLOADS_DIR]
                      [--device {cpu,cuda,mps,xpu}]
                      [--default_dtype {float32,float64}] [--distributed]
                      [--launcher {slurm,torchrun,mpi,none}]
                      [--log_level LOG_LEVEL] [--plot PLOT]
                      [--plot_frequency PLOT_FREQUENCY]
                      [--error_table {PerAtomRMSE,TotalRMSE,PerAtomRMSEstressvirials,PerAtomMAEstressvirials,PerAtomMAE,TotalMAE,DipoleRMSE,Dipole

In [1]:
# === Cell 1: utils ============================================================
from pathlib import Path
import numpy as np

from pymatgen.core import Lattice, Structure, Element
from pymatgen.io.cif import CifWriter

BOHR_TO_ANG = 0.529177210903  # Å per Bohr (CODATA 2018)

def read_lattice_bohr(path) -> np.ndarray:
    """
    Read a 3×3 lattice (rows = a, b, c) from domainVectors.inp in Bohr,
    return a 3×3 array in Å.
    """
    path = Path(path)
    rows = []
    with path.open("r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) != 3:
                raise ValueError(f"{path.name}: expected 3 numbers/line, got {len(parts)} → {line}")
            rows.append([float(x) for x in parts])
    if len(rows) != 3:
        raise ValueError(f"{path.name}: expected exactly 3 lines, got {len(rows)}.")

    lattice_bohr = np.array(rows, dtype=float)
    lattice_ang = lattice_bohr * BOHR_TO_ANG
    return lattice_ang


def read_fractional_coords_with_Z(path):
    """
    Read coordinates.inp lines of:  Z  tag  fx  fy  fz
    - Z   : atomic number (int)
    - tag : ignored (e.g., species index/pseudopotential id)
    - fx,fy,fz : fractional coordinates in [0,1) (we keep values as-is)

    Returns:
      species : list[str]  (element symbols)
      frac    : list[list[float]] (fractional coords)
    """
    path = Path(path)
    species, frac = [], []
    with path.open("r") as f:
        for i, line in enumerate(f, start=1):
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            parts = line.split()
            if len(parts) < 5:
                raise ValueError(f"{path.name} line {i}: need ≥5 columns, got {len(parts)} → {line}")
            try:
                Z = int(parts[0])
                fx, fy, fz = map(float, parts[-3:])
            except Exception as e:
                raise ValueError(f"{path.name} line {i}: parse error → {e}\n  line: {line}")
            elem = Element.from_Z(Z).symbol
            species.append(elem)
            frac.append([fx, fy, fz])

    if not species:
        raise ValueError(f"{path.name}: found 0 atoms.")
    return species, frac


def build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=False) -> Structure:
    """
    Construct a pymatgen Structure from domainVectors.inp and coordinates.inp.
    """
    lat_ang = read_lattice_bohr(lattice_inp)
    species, coords = read_fractional_coords_with_Z(coords_inp)
    lattice = Lattice(lat_ang)  # rows are a,b,c

    struct = Structure(
        lattice=lattice,
        species=species,
        coords=coords,
        coords_are_cartesian=coords_are_cartesian,  # default False since your coords look fractional
        to_unit_cell=True,
        validate_proximity=False,
    )
    return struct


def write_cif(struct: Structure, out_path):
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    CifWriter(struct, symprec=None).write_file(str(out_path))
    return out_path


In [5]:
# === Cell 2: run conversion ====================================================
# Set your paths (relative or absolute)
lattice_inp = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/Li_222/R2SCAN/domainVectors.inp"     # 3×3 in Bohr
coords_inp  = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/Li_222/R2SCAN/coordinates.inp"       # Z tag fx fy fz (fractional)
out_cif     = "/home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/Li_Srinibas.cif"

# If your coordinates were Cartesian (rare for this format), flip to True.
coords_are_cartesian = False

# Build and write
struct = build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=coords_are_cartesian)
out = write_cif(struct, out_cif)

# Quick summary
a, b, c = struct.lattice.abc
alpha, beta, gamma = struct.lattice.angles
print(f"Wrote: {out}")
print(f"Formula: {struct.composition.formula}")
print(f"a,b,c (Å): {a:.6f}, {b:.6f}, {c:.6f}")
print(f"α,β,γ (°): {alpha:.6f}, {beta:.6f}, {gamma:.6f}")
print(f"Atoms: {len(struct)}")


Wrote: /home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/Li_Srinibas.cif
Formula: Li16
a,b,c (Å): 6.987679, 6.987679, 6.987679
α,β,γ (°): 90.000000, 90.000000, 90.000000
Atoms: 16


In [6]:
# === Cell 2: run conversion ====================================================
# Set your paths (relative or absolute)
lattice_inp = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/La_222/R2SCAN/domainVectors.inp"     # 3×3 in Bohr
coords_inp  = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/La_222/R2SCAN/coordinates.inp"       # Z tag fx fy fz (fractional)
out_cif     = "/home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/La_Srinibas.cif"

# If your coordinates were Cartesian (rare for this format), flip to True.
coords_are_cartesian = False

# Build and write
struct = build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=coords_are_cartesian)
out = write_cif(struct, out_cif)

# Quick summary
a, b, c = struct.lattice.abc
alpha, beta, gamma = struct.lattice.angles
print(f"Wrote: {out}")
print(f"Formula: {struct.composition.formula}")
print(f"a,b,c (Å): {a:.6f}, {b:.6f}, {c:.6f}")
print(f"α,β,γ (°): {alpha:.6f}, {beta:.6f}, {gamma:.6f}")
print(f"Atoms: {len(struct)}")


Wrote: /home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/La_Srinibas.cif
Formula: La32
a,b,c (Å): 7.635224, 7.635224, 24.467712
α,β,γ (°): 90.000000, 90.000000, 120.001290
Atoms: 32


In [7]:
# === Cell 2: run conversion ====================================================
# Set your paths (relative or absolute)
lattice_inp = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/Zr_222/R2SCAN/domainVectors.inp"     # 3×3 in Bohr
coords_inp  = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/Zr_222/R2SCAN/coordinates.inp"       # Z tag fx fy fz (fractional)
out_cif     = "/home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/Zr_Srinibas.cif"

# If your coordinates were Cartesian (rare for this format), flip to True.
coords_are_cartesian = False

# Build and write
struct = build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=coords_are_cartesian)
out = write_cif(struct, out_cif)

# Quick summary
a, b, c = struct.lattice.abc
alpha, beta, gamma = struct.lattice.angles
print(f"Wrote: {out}")
print(f"Formula: {struct.composition.formula}")
print(f"a,b,c (Å): {a:.6f}, {b:.6f}, {c:.6f}")
print(f"α,β,γ (°): {alpha:.6f}, {beta:.6f}, {gamma:.6f}")
print(f"Atoms: {len(struct)}")


Wrote: /home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/Zr_Srinibas.cif
Formula: Zr16
a,b,c (Å): 6.494704, 6.495736, 10.305029
α,β,γ (°): 90.000000, 90.000000, 119.994744
Atoms: 16


In [8]:
# === Cell 2: run conversion ====================================================
# Set your paths (relative or absolute)
lattice_inp = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/O2/R2SCAN/domainVectors.inp"     # 3×3 in Bohr
coords_inp  = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/O2/R2SCAN/coordinates.inp"       # Z tag fx fy fz (fractional)
out_cif     = "/home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/O2_Srinibas.cif"

# If your coordinates were Cartesian (rare for this format), flip to True.
coords_are_cartesian = False

# Build and write
struct = build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=coords_are_cartesian)
out = write_cif(struct, out_cif)

# Quick summary
a, b, c = struct.lattice.abc
alpha, beta, gamma = struct.lattice.angles
print(f"Wrote: {out}")
print(f"Formula: {struct.composition.formula}")
print(f"a,b,c (Å): {a:.6f}, {b:.6f}, {c:.6f}")
print(f"α,β,γ (°): {alpha:.6f}, {beta:.6f}, {gamma:.6f}")
print(f"Atoms: {len(struct)}")


Wrote: /home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/O2_Srinibas.cif
Formula: O2
a,b,c (Å): 26.458861, 26.458861, 26.458861
α,β,γ (°): 90.000000, 90.000000, 90.000000
Atoms: 2


In [9]:
# === Cell 2: run conversion ====================================================
# Set your paths (relative or absolute)
lattice_inp = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/LLZO/R2SCAN/domainVectors.inp"     # 3×3 in Bohr
coords_inp  = "/home/phanim/harshitrawat/summer/llzo_data_srinibas/LLZO/R2SCAN/coordinates.inp"       # Z tag fx fy fz (fractional)
out_cif     = "/home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/LLZO_Srinibas.cif"

# If your coordinates were Cartesian (rare for this format), flip to True.
coords_are_cartesian = False

# Build and write
struct = build_structure_from_inp(lattice_inp, coords_inp, coords_are_cartesian=coords_are_cartesian)
out = write_cif(struct, out_cif)

# Quick summary
a, b, c = struct.lattice.abc
alpha, beta, gamma = struct.lattice.angles
print(f"Wrote: {out}")
print(f"Formula: {struct.composition.formula}")
print(f"a,b,c (Å): {a:.6f}, {b:.6f}, {c:.6f}")
print(f"α,β,γ (°): {alpha:.6f}, {beta:.6f}, {gamma:.6f}")
print(f"Atoms: {len(struct)}")


Wrote: /home/phanim/harshitrawat/summer/llzo_sanity_check_srinibas/LLZO_Srinibas.cif
Formula: Li56 La24 Zr16 O96
a,b,c (Å): 13.133798, 13.133815, 12.596967
α,β,γ (°): 90.000070, 90.000005, 90.000005
Atoms: 192


In [2]:
import os
import numpy as np
# Paths
DATA_DIR = "universal_embeddings_results"
FILES = {
    "T1": os.path.join(DATA_DIR, "Universal_on_T1.xyz"),
    "T2": os.path.join(DATA_DIR, "Universal_on_T2.xyz"),
    "T3": os.path.join(DATA_DIR, "Universal_on_T3.xyz")
}

# Checkpoint Configuration
LOAD_FROM_CHECKPOINT = True
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
CP_DATA_LOADED = os.path.join(CHECKPOINT_DIR, "checkpoint_1_data_loaded.pkl")
CP_OOD_RESULTS = os.path.join(CHECKPOINT_DIR, "checkpoint_2_ood_results.pkl")

import pandas as pd
if LOAD_FROM_CHECKPOINT and os.path.exists(CP_DATA_LOADED):
    print(f"Loading data from checkpoint: {CP_DATA_LOADED}")
    df = pd.read_pickle(CP_DATA_LOADED)
    df.head()
    # Re-create X if needed
    X = np.stack(df['latent'].values)
    print(f"Loaded {len(df)} structures.")
    print(f"Feature Matrix Shape: {X.shape}")
else:
    def parse_filename(filename):
        # Regex based on user convention:
        # cellrelaxed_LLZO_{cleavingdir}_{termination}_{order}_{sto|offsto}__Li_{facet}_slab_heavy_T{Temp}_{Index}.cif
        # Also handling strain: ..._strain{+/-}{val}_perturbed.cif
        
        meta = {}
        meta['filename'] = filename
        
        # Strain
        strain_match = re.search(r"strain([+-]?[\d\.]+)_perturbed", filename)
        if strain_match:
            meta['strain'] = float(strain_match.group(1))
            meta['is_perturbed'] = True
        else:
            meta['strain'] = 0.0
            meta['is_perturbed'] = False
    
        # Temperature
        temp_match = re.search(r"_T(\d+)_", filename)
        if temp_match:
            meta['temp'] = int(temp_match.group(1))
        else:
            meta['temp'] = None
    
        # Facet (e.g., Li_100_slab)
        facet_match = re.search(r"Li_(\d+)_slab", filename)
        if facet_match:
            meta['facet'] = facet_match.group(1)
        else:
            meta['facet'] = "Unknown"
    
        # Termination (e.g., LLZO_010_La_order0)
        # This is tricky, let's try to capture the block between LLZO_ and __Li
        term_match = re.search(r"LLZO_(.*?)__Li", filename)
        if term_match:
            parts = term_match.group(1).split('_')
            # Heuristic: usually {cleaving}_{termination}_{order}_{sto}
            if len(parts) >= 2:
                meta['termination'] = parts[1] # e.g. La
            else:
                meta['termination'] = "Unknown"
        else:
            meta['termination'] = "Unknown"
            
        return meta
    
    data_list = []
    
    for dataset_name, filepath in FILES.items():
        print(f"Loading {dataset_name} from {filepath}...")
        atoms_list = ase.io.read(filepath, index=":")
        
        for atoms in atoms_list:
            info = atoms.info
            arrays = atoms.arrays
            
            # Extract Latent (256D)
            # Note: 'mace_latent' is per-atom. We need a global descriptor.
            # Strategy: MEAN of atomic latents.
            if 'mace_latent' in arrays:
                latent = np.mean(arrays['mace_latent'], axis=0) # Shape (256,)
            else:
                continue
                
            # Extract Energy
            energy = info.get('mace_energy', np.nan)
            
            # Parse Filename (stored in info or we assume order?)
            # MACE usually preserves info. Let's assume 'filename' or 'comment' holds it.
            # If not, we might need to rely on index if filenames weren't saved.
            # CHECK: The user's extraction script likely saved filenames in info if they were in the input.
            # If input was .extxyz, it might have 'config_type' or similar.
            # Let's assume there is a way to identify. For now, we use a placeholder if missing.
            fname = info.get('filename', info.get('comment', f"unknown_{dataset_name}"))
            
            entry = parse_filename(fname)
            entry['dataset'] = dataset_name
            entry['energy'] = energy
            entry['latent'] = latent
            
            data_list.append(entry)
    
    df = pd.DataFrame(data_list)
    print(f"Loaded {len(df)} structures.")
    
    # Create Feature Matrix X
    X = np.stack(df['latent'].values)
    print(f"Feature Matrix Shape: {X.shape}")
            

    # Save Checkpoint
    print(f"Saving data to checkpoint: {CP_DATA_LOADED}")
    df.to_pickle(CP_DATA_LOADED)


Loading data from checkpoint: checkpoints/checkpoint_1_data_loaded.pkl
Loaded 8654 structures.
Feature Matrix Shape: (8654, 256)


In [3]:
    df.head()


Unnamed: 0,filename,strain,is_perturbed,temp,facet,termination,dataset,energy,latent
0,unknown_T1,0.0,False,,Unknown,Unknown,T1,-2985.323995,"[-0.022429647077777785, 0.37030010755555526, 0..."
1,unknown_T1,0.0,False,,Unknown,Unknown,T1,-2918.476146,"[-0.030376922312206605, 0.32244097245305237, 0..."
2,unknown_T1,0.0,False,,Unknown,Unknown,T1,-2853.348134,"[-0.043781398527315905, 0.28577672785035624, -..."
3,unknown_T1,0.0,False,,Unknown,Unknown,T1,-2677.91998,"[-0.0532804312654321, 0.24710180930555542, -0...."
4,unknown_T1,0.0,False,,Unknown,Unknown,T1,-2382.882676,"[-0.047781491575342436, 0.2694620743493152, -0..."


In [None]:
from ase.io import read

atoms = read("home/phanim/harshitrawat/summer/universal_embeddings_results/Universal_on_T2.xyz")
type(atoms)


In [None]:
atoms.info
