First, minimal, single‑purpose script that just pulls bulk binary structures from Materials Project and saves only CIFs (+ a tiny Excel index)

In [2]:
#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
from pymatgen.ext.matproj import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
from ase.io import write

# ====== CONFIG ======
API_KEY = "j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"  # put your MP API key here
OUT_ROOT = Path("/home/phanim/harshitrawat/summer/binaries_bulk")
PAIRS = [("Li","O"), ("Li","La"), ("Li","Zr"), ("La","O"), ("Zr","O"), ("La","Zr")]
E_ABOVE_HULL_MAX = 0.03
MAX_DOCS_PER_PAIR = None
SAVE_CONVENTIONAL = False
INDEX_XLSX = OUT_ROOT / "index.xlsx"
# ====================

def safe(s): return s.replace(" ", "")

def main():
    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    rows = []
    with MPRester(API_KEY) as mpr:
        for a,b in PAIRS:
            pair = f"{a}-{b}"
            pair_dir = OUT_ROOT / f"pair_{a}_{b}" / "cifs"
            pair_dir.mkdir(parents=True, exist_ok=True)

            # replace the summary_search call with this
            docs = mpr.summary_search(
                chemsys=pair,
                nelements=2,
                is_stable=True,  # optional prefilter; keeps results small
                _fields=[
                    "material_id", "formula_pretty", "chemsys", "is_stable",
                    "energy_above_hull", "structure"
                ],
            )

            # client-side filter on energy_above_hull
            docs = [d for d in docs
                    if (d.get("energy_above_hull") is not None)
                    and (float(d["energy_above_hull"]) <= E_ABOVE_HULL_MAX)]

            if MAX_DOCS_PER_PAIR:
                docs = docs[:MAX_DOCS_PER_PAIR]

            print(f"{pair}: {len(docs)} structures")
            for d in docs:
                mpid = d["material_id"]
                formula = safe(d["formula_pretty"])
                pm = d["structure"]
                if SAVE_CONVENTIONAL:
                    try:
                        pm = pm.get_conventional_standard_structure()
                    except Exception:
                        pass
                cif_path = pair_dir / f"mpid-{mpid}_{formula}__bulk.cif"
                write(cif_path, AseAtomsAdaptor.get_atoms(pm), format="cif")

                rows.append({
                    "mpid": mpid,
                    "formula": formula,
                    "chemsys": d["chemsys"],
                    "is_stable": bool(d["is_stable"]),
                    "energy_above_hull_eV_per_atom": float(d.get("energy_above_hull", 0.0) or 0.0),
                    "n_atoms": len(pm),
                    "cif": str(cif_path),
                })

    if rows:
        pd.DataFrame(rows).sort_values(["chemsys","formula","mpid"]).to_excel(INDEX_XLSX, index=False)
        print(f"\nSaved {len(rows)} CIFs to {OUT_ROOT}\nIndex: {INDEX_XLSX}")
    else:
        print("No structures found.")

if __name__ == "__main__":
    main()


Li-O: 3 structures
Li-La: 0 structures
Li-Zr: 0 structures
La-O: 1 structures
Zr-O: 3 structures
La-Zr: 0 structures

Saved 7 CIFs to /home/phanim/harshitrawat/summer/binaries_bulk
Index: /home/phanim/harshitrawat/summer/binaries_bulk/index.xlsx


In [3]:
## Now we download elemental bulk structures from MP

In [4]:
#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
from pymatgen.ext.matproj import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
from ase.io import write

# ====== CONFIG ======
API_KEY = "j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"  # Put your MP API key here
OUT_ROOT = Path("/home/phanim/harshitrawat/summer/elementals_bulk")
ELEMENTS = ["Li", "La", "Zr", "O"]
MAX_DOCS_PER_EL = 1       # Usually 1 stable phase per element is enough
SAVE_CONVENTIONAL = False
INDEX_XLSX = OUT_ROOT / "index.xlsx"
# ====================

def safe(s): return s.replace(" ", "")

def main():
    OUT_ROOT.mkdir(parents=True, exist_ok=True)
    rows = []

    with MPRester(API_KEY) as mpr:
        for el in ELEMENTS:
            el_dir = OUT_ROOT / f"element_{el}" / "cifs"
            el_dir.mkdir(parents=True, exist_ok=True)

            docs = mpr.summary_search(
                chemsys=el,
                nelements=1,
                is_stable=True,
                _fields=["material_id","formula_pretty","chemsys","is_stable","energy_above_hull","structure"],
            )

            # Sort by energy_above_hull and take top N
            docs = sorted(docs, key=lambda x: float(x.get("energy_above_hull", 0.0) or 0.0))
            if MAX_DOCS_PER_EL:
                docs = docs[:MAX_DOCS_PER_EL]

            print(f"{el}: {len(docs)} structure(s)")

            for d in docs:
                mpid = d["material_id"]
                formula = safe(d["formula_pretty"])
                pm = d["structure"]
                if SAVE_CONVENTIONAL:
                    try:
                        pm = pm.get_conventional_standard_structure()
                    except Exception:
                        pass
                cif_path = el_dir / f"mpid-{mpid}_{formula}__bulk.cif"
                write(cif_path, AseAtomsAdaptor.get_atoms(pm), format="cif")

                rows.append({
                    "element": el,
                    "mpid": mpid,
                    "formula": formula,
                    "chemsys": d["chemsys"],
                    "is_stable": bool(d["is_stable"]),
                    "energy_above_hull_eV_per_atom": float(d.get("energy_above_hull", 0.0) or 0.0),
                    "n_atoms": len(pm),
                    "cif": str(cif_path),
                })

    if rows:
        pd.DataFrame(rows).sort_values(["element","mpid"]).to_excel(INDEX_XLSX, index=False)
        print(f"\nSaved {len(rows)} elemental bulk CIFs to {OUT_ROOT}\nIndex: {INDEX_XLSX}")
    else:
        print("No structures found.")

if __name__ == "__main__":
    main()


Li: 1 structure(s)
La: 1 structure(s)
Zr: 1 structure(s)
O: 1 structure(s)

Saved 4 elemental bulk CIFs to /home/phanim/harshitrawat/summer/elementals_bulk
Index: /home/phanim/harshitrawat/summer/elementals_bulk/index.xlsx


# Now we do MD

In [6]:
#!/usr/bin/env python3
# run_md_binaries_and_elements_it2.py
import os, sys, json, time, re, random
from pathlib import Path
from datetime import datetime
import pandas as pd
import numpy as np

from ase.io import read, write, Trajectory
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase import units
from math import isfinite

# ---- Try both CHGNet calculator import paths ----
try:
    from chgnet.model.dynamics import CHGNetCalculator
except Exception:
    from chgnet.model.ase import CHGNetCalculator  # older versions

# ======= CONFIG =======
BINARIES_ROOT   = Path("/home/phanim/harshitrawat/summer/binaries_bulk")
ELEMENTALS_ROOT = Path("/home/phanim/harshitrawat/summer/elementals_bulk")

OUT_MD_ROOT   = Path("/home/phanim/harshitrawat/summer/md")
MD_TRAJ_DIR   = OUT_MD_ROOT / "mdtraj_it2"
MD_CIFS_DIR   = OUT_MD_ROOT / "mdcifs_it2"
MD_META_JSONL = OUT_MD_ROOT / "mdinfo_it2.jsonl"
MD_META_XLSX  = OUT_MD_ROOT / "mdinfo_it2.xlsx"

TEMPS_K = [360, 480]

STEPS = 1800             # total steps per temperature
DT_FS = 1.0              # timestep (fs)
WARMUP_STEPS = 200       # don't sample during warmup
SNAPSHOT_STRIDE = 10     # keep every Nth step after warmup
FRICTION_FS_INV = 0.02   # Langevin gamma (fs^-1)

TARGET_SNAPSHOTS = 3000  # total snapshots target (across temps & seeds)
RNG_SEED = 123
DEVICE = "cuda"          # "cuda" or "cpu"
# ======================

random.seed(RNG_SEED)
np.random.seed(RNG_SEED)

OUT_MD_ROOT.mkdir(parents=True, exist_ok=True)
MD_TRAJ_DIR.mkdir(parents=True, exist_ok=True)
MD_CIFS_DIR.mkdir(parents=True, exist_ok=True)

UUID_RE = re.compile(r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})", re.I)

def find_seed_cifs():
    bins = sorted(BINARIES_ROOT.glob("pair_*/*/*.cif"))
    elems = sorted(ELEMENTALS_ROOT.glob("element_*/*/*.cif"))
    seeds = bins + elems
    if not seeds:
        print("No CIFs found. Check BINARIES_ROOT/ELEMENTALS_ROOT.")
        sys.exit(2)
    return seeds

def extract_uuid(path_str):
    m = UUID_RE.search(path_str)
    if m:
        return m.group(1)
    import uuid
    return str(uuid.uuid5(uuid.NAMESPACE_URL, f"seed|{Path(path_str).resolve()}"))

def estimate_snaps_per_temp(steps, warmup, stride):
    usable = max(0, steps - warmup)
    return usable // stride

def init_calc():
    calc = CHGNetCalculator(use_device=DEVICE, model_name=None)
    try:
        dev = getattr(calc.model, "device", None)
        if dev is not None:
            print(f"[CHGNet] device = {dev}")
    except Exception:
        pass
    return calc

def safe_energy_forces(atoms):
    e = float(atoms.get_potential_energy())
    f = atoms.get_forces()
    if not isfinite(e) or np.isnan(f).any() or np.isinf(f).any():
        raise RuntimeError("Non-finite energy/forces encountered")
    return e, f

def run_md_on_seed(cif_path, calc, temps, steps, dt_fs, warmup, stride, meta_rows, remaining_quota):
    """Runs MD for this seed across temps; returns snapshots_saved."""
    uid = extract_uuid(str(cif_path))
    atoms0 = read(cif_path)
    atoms0.calc = calc

    saved = 0
    for T in temps:
        if remaining_quota <= 0:
            break

        A = atoms0.copy()
        A.calc = calc  # IMPORTANT: calculator isn't copied by ASE

        # fresh velocities per temperature (Kelvin), remove drift
        MaxwellBoltzmannDistribution(A, T)
        Stationary(A)

        # reproducible RNG per (uuid, T)
        rng_seed = abs(hash((uid, int(T), RNG_SEED))) % (1 << 32)
        rng = np.random.default_rng(rng_seed)

        # ASE Langevin signature: (atoms, timestep, temperature_K, friction, ...)
        try:
            dyn = Langevin(A, dt_fs * units.fs, T, FRICTION_FS_INV / units.fs, rng=rng)
        except TypeError:
            # older ASE without rng=
            dyn = Langevin(A, dt_fs * units.fs, T, FRICTION_FS_INV / units.fs)

        traj_path = MD_TRAJ_DIR / f"{uid}__T{int(T)}K.traj"
        steps_done = 0

        with Trajectory(traj_path, "w", A, properties=["energy", "forces"]) as traj:
            for _ in range(steps):
                dyn.run(1)
                steps_done += 1

                # throttle traj I/O
                if steps_done > warmup and ((steps_done - warmup) % stride == 0):
                    traj.write(A)
                elif steps_done == warmup:
                    traj.write(A)

                if steps_done <= warmup or (steps_done - warmup) % stride != 0:
                    continue
                if remaining_quota <= 0:
                    break

                # Save snapshot + metadata
                try:
                    e, f = safe_energy_forces(A)
                except Exception as ex:
                    print(f"    !! NaN detected @ {T}K step {steps_done}: {ex}. Skipping this temperature.")
                    break

                snap_name = f"{uid}__T{int(T)}K__step{steps_done}.cif"
                snap_path = MD_CIFS_DIR / snap_name
                write(snap_path, A, format="cif")

                fmag = np.linalg.norm(f, axis=1)
                fmax = float(fmag.max())
                frms = float(np.sqrt((f**2).sum() / f.shape[0]))

                row = {
                    "uuid": uid,
                    "source_file": str(cif_path),
                    "snapshot_file": str(snap_path),
                    "traj_file": str(traj_path),
                    "temperature_K": int(T),
                    "step": int(steps_done),
                    "dt_fs": float(dt_fs),
                    "thermostat": "Langevin",
                    "friction_fs_inv": float(FRICTION_FS_INV),
                    "energy_eV": e,
                    "force_rms_eV_per_A": frms,
                    "fmax_eV_per_A": fmax,
                    "n_atoms": len(A),
                    "created_at_iso": datetime.utcnow().isoformat() + "Z",
                    # keep full forces in JSONL (Excel will omit)
                    "forces_eV_per_A": f.tolist(),
                }
                meta_rows.append(row)
                saved += 1
                remaining_quota -= 1

        print(f"  - {Path(cif_path).name} @ {T}K: saved {saved} so far (quota left {remaining_quota})")
        if remaining_quota <= 0:
            break

    return saved

def main():
    seeds = find_seed_cifs()
    random.shuffle(seeds)

    snaps_per_temp = estimate_snaps_per_temp(STEPS, WARMUP_STEPS, SNAPSHOT_STRIDE)
    est_per_seed = snaps_per_temp * len(TEMPS_K)
    print(f"Seeds: {len(seeds)} | snaps/temp ≈ {snaps_per_temp} | ≈{est_per_seed} per seed | target={TARGET_SNAPSHOTS}")

    # resume: count existing JSONL rows if present
    existing_rows = 0
    if MD_META_JSONL.exists():
        try:
            with open(MD_META_JSONL) as f:
                existing_rows = sum(1 for _ in f)
            if existing_rows >= TARGET_SNAPSHOTS:
                print(f"Found {existing_rows} existing rows in JSONL; target met. Exiting.")
                return
        except Exception:
            pass

    calc = init_calc()
    meta_rows = []
    remaining = TARGET_SNAPSHOTS - existing_rows

    t0 = time.time()
    for i, cif in enumerate(seeds, 1):
        if remaining <= 0:
            break
        print(f"[{i}/{len(seeds)}] MD on {cif}")
        try:
            saved = run_md_on_seed(
                cif_path=cif,
                calc=calc,
                temps=TEMPS_K,
                steps=STEPS,
                dt_fs=DT_FS,
                warmup=WARMUP_STEPS,
                stride=SNAPSHOT_STRIDE,
                meta_rows=meta_rows,
                remaining_quota=remaining,
            )
            remaining -= saved
        except Exception as e:
            print(f"!! Failed {cif}: {e}")

    # Append to JSONL (resume-friendly)
    if meta_rows:
        with open(MD_META_JSONL, "a") as f:
            for r in meta_rows:
                f.write(json.dumps(r) + "\n")

        # Excel: drop giant forces column
        df = pd.DataFrame(meta_rows)
        if "forces_eV_per_A" in df.columns:
            df = df.drop(columns=["forces_eV_per_A"])
        df.sort_values(["temperature_K","uuid","step"], inplace=True)
        df.to_excel(MD_META_XLSX, index=False)

    dt = time.time() - t0
    total = existing_rows + len(meta_rows)
    print(f"\nMD finished in {dt/60:.1f} min. Saved {len(meta_rows)} new snapshots (total {total}).")
    print(f"Traj dir:   {MD_TRAJ_DIR}")
    print(f"Snapshots:  {MD_CIFS_DIR}")
    print(f"Metadata:   {MD_META_XLSX}")
    print(f"JSONL:      {MD_META_JSONL}")

if __name__ == "__main__":
    main()


Seeds: 11 | snaps/temp ≈ 160 | ≈320 per seed | target=3000
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cuda
[1/11] MD on /home/phanim/harshitrawat/summer/elementals_bulk/element_O/cifs/mpid-mp-12957_O2__bulk.cif


  "created_at_iso": datetime.utcnow().isoformat() + "Z",


  - mpid-mp-12957_O2__bulk.cif @ 360K: saved 160 so far (quota left 2840)
  - mpid-mp-12957_O2__bulk.cif @ 480K: saved 320 so far (quota left 2680)
[2/11] MD on /home/phanim/harshitrawat/summer/elementals_bulk/element_Li/cifs/mpid-mp-1018134_Li__bulk.cif
  - mpid-mp-1018134_Li__bulk.cif @ 360K: saved 160 so far (quota left 2520)
  - mpid-mp-1018134_Li__bulk.cif @ 480K: saved 320 so far (quota left 2360)
[3/11] MD on /home/phanim/harshitrawat/summer/elementals_bulk/element_La/cifs/mpid-mp-26_La__bulk.cif
  - mpid-mp-26_La__bulk.cif @ 360K: saved 160 so far (quota left 2200)
  - mpid-mp-26_La__bulk.cif @ 480K: saved 320 so far (quota left 2040)
[4/11] MD on /home/phanim/harshitrawat/summer/binaries_bulk/pair_Li_O/cifs/mpid-mp-841_Li2O2__bulk.cif
  - mpid-mp-841_Li2O2__bulk.cif @ 360K: saved 160 so far (quota left 1880)
  - mpid-mp-841_Li2O2__bulk.cif @ 480K: saved 320 so far (quota left 1720)
[5/11] MD on /home/phanim/harshitrawat/summer/binaries_bulk/pair_Zr_O/cifs/mpid-mp-14024_Zr3O__b

In [13]:
#!/usr/bin/env python3
# relabel_md_snapshots_fresh_simple.py
import re, json, time
from pathlib import Path
from datetime import datetime
from math import isfinite

import numpy as np
import pandas as pd
from ase.io import read

# CHGNet import
try:
    from chgnet.model.dynamics import CHGNetCalculator
except Exception:
    from chgnet.model.ase import CHGNetCalculator

# ==== CONFIG (edit here) ====
INPUT_DIRS = [
    "/home/phanim/harshitrawat/summer/md/mdcifs_it2"
]
DEVICE = "cuda"
BATCH_FLUSH_EVERY = 200
PRINT_EVERY = 100
OUT_ROOT = Path("/home/phanim/harshitrawat/summer/md/labels_fresh_it2")
TIMESTAMP_OUTPUT = True  # if False, overwrite "latest"
# ============================

FNAME_RE = re.compile(
    r"(?P<uuid>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})__T(?P<T>\d+)K__step(?P<step>\d+)\.cif",
    re.I,
)

def parse_from_name(p):
    m = FNAME_RE.search(p.name)
    if not m:
        return None, None, None
    return m.group("uuid"), int(m.group("T")), int(m.group("step"))

def safe_eval(atoms, name):
    e = float(atoms.get_potential_energy())
    f = atoms.get_forces()
    n = len(atoms)
    if f is None or f.shape != (n, 3):
        raise RuntimeError(f"{name}: forces.shape={None if f is None else f.shape}, expected ({n}, 3)")
    if (not isfinite(e)) or np.isnan(f).any() or np.isinf(f).any():
        raise RuntimeError(f"{name}: non-finite energy/forces")
    return e, f

def collect_inputs(dirs):
    files = []
    for d in dirs:
        d = Path(d)
        if not d.exists():
            print(f"[warn] Missing dir: {d}")
            continue
        files.extend(sorted(d.glob("*.cif")))
    return [p.resolve() for p in files]

# ==== main ====
snaps = collect_inputs(INPUT_DIRS)
if not snaps:
    raise SystemExit("No CIFs found.")

if TIMESTAMP_OUTPUT:
    stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = OUT_ROOT / f"run_{stamp}"
else:
    out_dir = OUT_ROOT / "latest"
out_dir.mkdir(parents=True, exist_ok=True)

OUT_JSONL = out_dir / "labels.jsonl"
OUT_EXCEL = out_dir / "labels.xlsx"

calc = CHGNetCalculator(use_device=DEVICE, model_name=None)
print(f"[CHGNet] device = {getattr(calc.model, 'device', 'unknown')}")

buffer = []
excel_rows = []
t0 = time.time()
total = len(snaps)
print(f"Found {total} snapshots → labeling to {out_dir}")

for i, p in enumerate(snaps, 1):
    try:
        atoms = read(p)
        atoms.calc = calc

        energy_eV, forces = safe_eval(atoms, p.name)
        n = len(atoms)
        fmag = np.linalg.norm(forces, axis=1)
        fmax = float(fmag.max())
        frms = float(np.sqrt((forces**2).sum() / n))
        uid, T, step = parse_from_name(p)

        buffer.append({
            "snapshot_file": str(p),
            "uuid": uid,
            "temperature_K": T,
            "step": step,
            "n_atoms": n,
            "energy_eV": float(energy_eV),
            "forces_per_atom_eV_per_A": forces.tolist(),
            "force_rms_eV_per_A": frms,
            "fmax_eV_per_A": fmax,
            "created_at_iso": datetime.utcnow().isoformat() + "Z",
        })

        excel_rows.append({
            "snapshot_file": str(p),
            "uuid": uid,
            "temperature_K": T,
            "step": step,
            "n_atoms": n,
            "energy_eV": float(energy_eV),
            "force_rms_eV_per_A": frms,
            "fmax_eV_per_A": fmax,
        })

        if len(buffer) >= BATCH_FLUSH_EVERY:
            with open(OUT_JSONL, "a") as f:
                for r in buffer:
                    f.write(json.dumps(r) + "\n")
            buffer.clear()

        if (i % PRINT_EVERY == 0) or (i == total):
            print(f"[{i}/{total}] {p.name} (n={n}, E={energy_eV:.6f} eV, fmax={fmax:.3f})")

    except Exception as e:
        print(f"!! Failed {p}: {e}")

# final flush
if buffer:
    with open(OUT_JSONL, "a") as f:
        for r in buffer:
            f.write(json.dumps(r) + "\n")
buffer.clear()

if excel_rows:
    df = pd.DataFrame(excel_rows)
    sort_cols = [c for c in ["temperature_K", "uuid", "step"] if c in df.columns]
    df.sort_values(sort_cols, inplace=True, na_position="last")
    df.to_excel(OUT_EXCEL, index=False)

print(f"\nDone in {(time.time()-t0)/60:.1f} min.")
print(f"JSONL: {OUT_JSONL}")
print(f"Excel: {OUT_EXCEL}")


CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cuda
[CHGNet] device = unknown
Found 3000 snapshots → labeling to /home/phanim/harshitrawat/summer/md/labels_fresh_it2/run_20250811_121909


  "created_at_iso": datetime.utcnow().isoformat() + "Z",


[100/3000] 47f7ac20-5f9d-577f-87ae-0d21207606bf__T360K__step390.cif (n=3, E=-7.346784 eV, fmax=101.292)
[200/3000] 47f7ac20-5f9d-577f-87ae-0d21207606bf__T480K__step1390.cif (n=3, E=-6.970852 eV, fmax=16.672)
[300/3000] 47f7ac20-5f9d-577f-87ae-0d21207606bf__T480K__step790.cif (n=3, E=-13.155951 eV, fmax=1.973)
[400/3000] 4bdcb7da-d4f6-58d7-9481-3d71d389469d__T360K__step1790.cif (n=8, E=5.275574 eV, fmax=86.150)
[500/3000] 4bdcb7da-d4f6-58d7-9481-3d71d389469d__T480K__step1190.cif (n=8, E=-24.523636 eV, fmax=26.406)
[600/3000] 4bdcb7da-d4f6-58d7-9481-3d71d389469d__T480K__step590.cif (n=8, E=-24.549477 eV, fmax=39.591)
[700/3000] 68d75590-319e-5978-aab1-dd84a8cea44d__T360K__step1590.cif (n=8, E=5.874779 eV, fmax=10.296)
[800/3000] 68d75590-319e-5978-aab1-dd84a8cea44d__T360K__step990.cif (n=8, E=36.176064 eV, fmax=34.752)
[900/3000] 68d75590-319e-5978-aab1-dd84a8cea44d__T480K__step390.cif (n=8, E=33.053322 eV, fmax=39.865)
[1000/3000] 849d67b4-45cb-55e3-b935-93e0e31cb2bf__T360K__step1390.ci

In [2]:
# ==== MAX-PARALLEL ZERO-SKIP EXTXYZ BUILDER (Notebook-Ready) ====
# Parallelized stages:
#  1) CIF indexing (dir walks)             -> ProcessPool
#  2) Label file parsing/ingestion         -> ProcessPool
#  3) Curation (natoms vs forces)          -> ProcessPool
#  4) Sharded EXTXYZ writing               -> ProcessPool
#
# Outputs: T1.extxyz, val.extxyz, T2.extxyz (every frame has energy & forces)

import os, json, gzip, math
from typing import List, Dict, Tuple, Optional, Any, Iterable
from collections import defaultdict, Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import lru_cache
from pathlib import Path

import numpy as np
from ase.io import read, write

# --------- CONFIG (edit here) ---------
BASE = "/home/phanim/harshitrawat/summer"

CIF_DIRS = [
    f"{BASE}/md/mdcifs",
    f"{BASE}/md/mdcifs_strained_perturbed",
    f"{BASE}/md/mdcifs_strained_perturbed_prime",
    f"{BASE}/md/mdcifs_it2",
]

LABEL_INPUTS = [
    f"{BASE}/T1_T2_T3_data",
    f"{BASE}/md/labels_fresh_it2",
]

SPLITS_DIR = f"{BASE}/md/splits_global_49_7_44/run_20250811_164008"
OUT_DIR    = f"{BASE}/final_work_extxyz"

# Workers & chunking
MAX_WORKERS   = min(48, max(1, (os.cpu_count() or 8)))
INDEX_WORKERS = max(1, min(16, MAX_WORKERS // 2))   # dir-walkers
LABEL_WORKERS = max(1, min(16, MAX_WORKERS // 2))   # json parsers
CURATE_WORKERS= max(1, MAX_WORKERS // 2)            # natoms checks
WRITE_WORKERS = max(1, MAX_WORKERS)                 # shard writers
CHUNK_SIZE    = 600                                 # frames/shard

# Source keys in label dumps
ENERGY_SRC_KEY = "energy_eV"
FORCES_SRC_KEY = "forces_per_atom_eV_per_A"

# Keep threads sane in notebooks
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")


# --------------- UTIL: robust JSON / JSONL reader ---------------
def _read_json_or_jsonl(path: Path) -> Iterable[dict]:
    opener = gzip.open if path.suffixes[-1:] == ['.gz'] else open
    try:
        with opener(path, "rt") as f:
            buf = f.read(2048)
            if not buf:
                return
            f.seek(0)
            first = next((c for c in buf if not c.isspace()), "")
            if first == "[":
                try:
                    arr = json.load(f)
                    for rec in arr:
                        if isinstance(rec, dict):
                            yield rec
                except json.JSONDecodeError as e:
                    print(f"[WARN] JSON array parse failed in {path}: {e}")
            else:
                for ln, line in enumerate(f, 1):
                    s = line.strip()
                    if not s:
                        continue
                    try:
                        rec = json.loads(s)
                        if isinstance(rec, dict):
                            yield rec
                    except json.JSONDecodeError as e:
                        # skip malformed lines, keep going
                        if ln <= 3:
                            print(f"[WARN] Skipping bad JSONL line {ln} in {path}: {e}")
                        continue
    except Exception as e:
        print(f"[WARN] Cannot open/parse {path}: {e}")


# --------------- PARALLEL CIF INDEXING ---------------
def _walk_one_dir(d: str) -> Dict[str, List[str]]:
    idx = defaultdict(list)
    if not os.path.isdir(d):
        return idx
    for root, _, files in os.walk(d):
        for fn in files:
            low = fn.lower()
            if low.endswith(".cif") or low.endswith(".cif.gz"):
                ap = os.path.abspath(os.path.join(root, fn))
                idx[fn].append(ap)
                if low.endswith(".cif.gz"):
                    idx[fn[:-3]].append(ap)      # also index ".cif"
                if low.endswith(".cif"):
                    idx[fn + ".gz"].append(ap)    # also index ".cif.gz"
    return idx

def build_cif_index_parallel(dirs: List[str], workers: int) -> Dict[str, List[str]]:
    out = defaultdict(list)
    with ProcessPoolExecutor(max_workers=workers) as ex:
        futs = {ex.submit(_walk_one_dir, d): d for d in dirs}
        for fut in as_completed(futs):
            sub = fut.result()
            for k, v in sub.items():
                out[k].extend(v)
    # dedup lists
    for k, v in out.items():
        out[k] = sorted(set(v))
    return out


# --------------- PATH RESOLUTION ---------------
def resolve_from_index(name_or_path: str, index: Dict[str, List[str]]) -> Optional[str]:
    if os.path.isfile(name_or_path):
        return os.path.abspath(name_or_path)
    base = os.path.basename(name_or_path)
    cands = index.get(base)
    if not cands:
        # cross-swap .cif <-> .cif.gz
        if base.lower().endswith(".cif"):
            cands = index.get(base + ".gz")
        elif base.lower().endswith(".cif.gz"):
            cands = index.get(base[:-3])
    if not cands:
        return None
    # prefer shorter path (usually closer) then lexicographic for determinism
    return sorted(cands, key=lambda p: (len(p), p))[0]


# --------------- PARALLEL LABEL INGEST ---------------
def _list_label_files(inputs: List[str]) -> List[Path]:
    ALLOWED = (".json", ".jsonl", ".json.gz", ".jsonl.gz")
    out = []
    for item in inputs:
        p = Path(item)
        if not p.exists(): 
            continue
        if p.is_file() and p.name.endswith(ALLOWED):
            out.append(p)
        elif p.is_dir():
            for q in p.rglob("*"):
                if q.is_file() and q.name.endswith(ALLOWED):
                    out.append(q)
    # de-dup preserving order
    seen, uniq = set(), []
    for p in out:
        if p in seen: continue
        seen.add(p); uniq.append(p)
    return uniq

def _ingest_one_label_file(args) -> Tuple[Dict[str, dict], List[dict], Dict[str,int]]:
    fp, index = args
    m = {}
    issues = []
    counts = {"CHGNet":0, "MD_JSONL":0}
    src = "MD_JSONL" if "jsonl" in fp.suffixes else "CHGNet"
    for rec in _read_json_or_jsonl(fp):
        # unified name fields to resolve
        name = rec.get("snapshot_file") or rec.get("file") or rec.get("src_path") or rec.get("path")
        if not name:
            continue
        cif_path = resolve_from_index(name, index)
        if not cif_path:
            issues.append({"file": name, "why": "cif_not_found", "source": fp.name}); continue
        e, f = rec.get(ENERGY_SRC_KEY), rec.get(FORCES_SRC_KEY)
        if e is None or f is None:
            issues.append({"file": name, "why": "missing_energy_or_forces", "source": fp.name}); continue
        m[cif_path] = {
            "energy_eV": float(e),
            "forces": np.asarray(f, dtype=np.float64),
            "meta": {"label_source": src}
        }
        for k in ("temperature_K", "step", "uuid"):
            if k in rec: m[cif_path]["meta"][k] = rec[k]
        counts[src] += 1
    return m, issues, counts

def build_label_index_parallel(label_files: List[Path], index: Dict[str, List[str]], workers: int
                              ) -> Tuple[Dict[str, dict], List[dict]]:
    label_map: Dict[str, dict] = {}
    issues_all: List[dict] = []
    totals = Counter()
    if not label_files:
        return {}, [{"why":"no_label_files_found"}]
    # fan out files to workers
    with ProcessPoolExecutor(max_workers=workers) as ex:
        futs = {ex.submit(_ingest_one_label_file, (fp, index)): fp for fp in label_files}
        for fut in as_completed(futs):
            m, iss, counts = fut.result()
            issues_all.extend(iss)
            totals.update(counts)
            # newer files should override older — we iterate in as_completed (random),
            # so enforce order by filename mtime: rebuild at the end deterministically
            label_map.update(m)
    # deterministic preference: sort label_files by mtime ascending, then re-apply to ensure newest overrides
    for fp in sorted(label_files, key=lambda p: p.stat().st_mtime):
        m, _, _ = _ingest_one_label_file((fp, index))
        label_map.update(m)
    print(f"[labels] Indexed {len(label_map)} unique CIFs "
          f"(CHGNet={totals['CHGNet']}, MD_JSONL={totals['MD_JSONL']}, issues={len(issues_all)})")
    return label_map, issues_all


# --------------- CURATION (Parallel) ---------------
@lru_cache(maxsize=100_000)
def _natoms_of(path: str) -> Optional[int]:
    try: return len(read(path))
    except Exception: return None

def _check_one(pair: Tuple[str, int]) -> Tuple[str, bool, Optional[int]]:
    p, flen = pair
    n = _natoms_of(p)
    ok = (n is not None) and (n == flen)
    return p, ok, n

def curate_split_parallel(entries: List[str], index: Dict[str, List[str]], label_map: Dict[str, dict], workers: int
                         ) -> Tuple[List[str], List[dict]]:
    resolved = []
    dropped = []
    for e in entries:
        p = resolve_from_index(e, index)
        if p is None:
            dropped.append({"file": e, "why": "cif_not_found_in_index"}); continue
        lab = label_map.get(p)
        if lab is None:
            dropped.append({"file": p, "why": "no_label_found"}); continue
        resolved.append((p, len(lab["forces"])))

    curated = []
    with ProcessPoolExecutor(max_workers=workers) as ex:
        futs = {ex.submit(_check_one, pair): pair for pair in resolved}
        for fut in as_completed(futs):
            p, flen = futs[fut]
            try:
                _p, ok, n = fut.result()
                if ok: curated.append(p)
                else: dropped.append({"file": p, "why": f"forces_shape_mismatch:{flen}vs{n}"})
            except Exception as e:
                dropped.append({"file": p, "why": f"shape_check_error:{e}"})
    return curated, dropped


# --------------- SHARDED WRITER (Parallel) ---------------
def _config_type_from_name(path: str) -> str:
    pl = path.lower()
    if "mdcifs_it2" in pl: return "md_it2"
    if "prime" in pl:      return "strain_perturb_prime"
    if "perturbed" in pl:  return "strain_perturb"
    return "base"

def _write_chunk(args) -> Tuple[str, int]:
    split_name, chunk_id, paths, label_map, shard_dir = args
    shard_path = Path(shard_dir) / f"{split_name}.chunk{chunk_id:05d}.extxyz"
    frames = []
    for cif_path in paths:
        lab = label_map[cif_path]
        a = read(cif_path)
        a.info["energy"] = float(lab["energy_eV"])                   # eV
        a.arrays["forces"] = np.asarray(lab["forces"], np.float64)   # eV/Å
        a.info["file"] = os.path.basename(cif_path)
        a.info["src_path"] = cif_path
        a.info["label_source"] = lab["meta"].get("label_source")
        for k in ("temperature_K", "step", "uuid"):
            if k in lab["meta"]: a.info[k] = lab["meta"][k]
        a.info["config_type"] = _config_type_from_name(cif_path)
        frames.append(a)
    write(str(shard_path), frames, format="extxyz")
    return str(shard_path), len(frames)

def write_split_zero_skip_parallel(split_name: str, curated_paths: List[str], label_map: Dict[str, dict],
                                   out_dir: Path, workers: int, chunk_size: int) -> int:
    shard_dir = out_dir / "_shards" / split_name
    shard_dir.mkdir(parents=True, exist_ok=True)

    chunks = [curated_paths[i:i+chunk_size] for i in range(0, len(curated_paths), chunk_size)]
    print(f"[{split_name}] writing {len(curated_paths)} frames → {len(chunks)} chunks (chunk={chunk_size}, workers={workers})")

    total, shard_paths = 0, []
    with ProcessPoolExecutor(max_workers=workers) as ex:
        futs = {ex.submit(_write_chunk, (split_name, cid, ch, label_map, shard_dir)): cid
                for cid, ch in enumerate(chunks)}
        for fut in as_completed(futs):
            cid = futs[fut]
            try:
                sp, w = fut.result()
                shard_paths.append(sp); total += w
                if (cid+1) % max(1, math.ceil(len(chunks)/10)) == 0:
                    print(f"[{split_name}] progress: chunk {cid+1}/{len(chunks)}")
            except Exception as e:
                print(f"[{split_name}] ERROR chunk {cid}: {e}")

    final_path = out_dir / f"{split_name}.extxyz"
    if final_path.exists(): final_path.unlink()
    with open(final_path, "ab") as outf:
        for sp in sorted(shard_paths):
            try:
                with open(sp, "rb") as sf:
                    outf.write(sf.read())
            except Exception as e:
                print(f"[{split_name}] WARN merge failed for {sp}: {e}")
            try:
                Path(sp).unlink()
            except Exception:
                pass
    print(f"[{split_name}] DONE: wrote {total} frames → {final_path}")
    return total


# ===================== RUN EVERYTHING =====================
out_dir = Path(OUT_DIR); out_dir.mkdir(parents=True, exist_ok=True)

print(f"[INDEX] workers={INDEX_WORKERS}")
cif_index = build_cif_index_parallel(CIF_DIRS, workers=INDEX_WORKERS)
print(f"[INDEX] basenames={len(cif_index)} total_paths={sum(len(v) for v in cif_index.values())}")

label_files = _list_label_files(LABEL_INPUTS)
print(f"[LABEL] files={len(label_files)} workers={LABEL_WORKERS}")
label_map, label_issues = build_label_index_parallel(label_files, cif_index, workers=LABEL_WORKERS)

# Load split lists
with open(Path(SPLITS_DIR)/"T1.json")  as f: T1  = json.load(f)
with open(Path(SPLITS_DIR)/"Val.json") as f: VAL = json.load(f)
with open(Path(SPLITS_DIR)/"T2.json")  as f: T2  = json.load(f)
print(f"[SPLIT] T1={len(T1)} Val={len(VAL)} T2={len(T2)}")

# Curate (zero-skip guarantee)
print(f"[CURATE] workers={CURATE_WORKERS}")
T1_cur,  T1_drop  = curate_split_parallel(T1,  cif_index, label_map, workers=CURATE_WORKERS)
VAL_cur, VAL_drop = curate_split_parallel(VAL, cif_index, label_map, workers=CURATE_WORKERS)
T2_cur,  T2_drop  = curate_split_parallel(T2,  cif_index, label_map, workers=CURATE_WORKERS)
print(f"[CURATE] keep: T1={len(T1_cur)} Val={len(VAL_cur)} T2={len(T2_cur)} | drop_total={len(T1_drop)+len(VAL_drop)+len(T2_drop)}")

# Write EXTXYZ (parallel shards)
print(f"[WRITE] workers={WRITE_WORKERS} chunk={CHUNK_SIZE}")
n1 = write_split_zero_skip_parallel("T1",  T1_cur,  label_map, out_dir, workers=WRITE_WORKERS, chunk_size=CHUNK_SIZE)
nv = write_split_zero_skip_parallel("val", VAL_cur, label_map, out_dir, workers=WRITE_WORKERS, chunk_size=CHUNK_SIZE)
n2 = write_split_zero_skip_parallel("T2",  T2_cur,  label_map, out_dir, workers=WRITE_WORKERS, chunk_size=CHUNK_SIZE)

# Manifest + curation report
manifest = {
    "outputs": {"T1": str(out_dir/"T1.extxyz"), "val": str(out_dir/"val.extxyz"), "T2": str(out_dir/"T2.extxyz")},
    "written": {"T1": n1, "val": nv, "T2": n2},
    "workers": {"index": INDEX_WORKERS, "label": LABEL_WORKERS, "curate": CURATE_WORKERS, "write": WRITE_WORKERS},
    "chunk_size": CHUNK_SIZE,
}
(Path(OUT_DIR)/"manifest_zero_skip_parallel.json").write_text(json.dumps(manifest, indent=2))
(Path(OUT_DIR)/"curation_report_parallel.json").write_text(json.dumps({
    "label_issues": label_issues,
    "drops": {
        "T1": T1_drop[:200], "Val": VAL_drop[:200], "T2": T2_drop[:200],
        "note": "Only a 200-sample of drops kept for brevity."
    }
}, indent=2))

print("\n[OK] Wrote:")
print(" -", manifest["outputs"]["T1"])
print(" -", manifest["outputs"]["val"])
print(" -", manifest["outputs"]["T2"])
print("\n[train hint] Use: --energy_key energy --forces_key forces")


[INDEX] workers=16
[INDEX] basenames=23308 total_paths=23308
[LABEL] files=9 workers=16
[WARN] Skipping bad JSONL line 1 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_per_model_dynamic.json: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)
[WARN] Skipping bad JSONL line 2 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_per_model_dynamic.json: Extra data: line 1 column 8 (char 7)
[WARN] Skipping bad JSONL line 3 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_per_model_dynamic.json: Extra data: line 1 column 10 (char 9)
[WARN] Skipping bad JSONL line 1 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_ensemble_dynamic_k2.json: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)
[WARN] Skipping bad JSONL line 2 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_ensemble_dynamic_k2.json: Extra data: line 1 column 11 (char 10)
[WARN] Skipping bad JSONL line 3 in /home/phanim/harshitrawat/summer/T1_T2_T3_data/ood_

In [9]:
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    # Uncomment if you want to force your local clone
    os.environ["PYTHONPATH"] = "/home/phanim/harshitrawat/mace/mace"

    cmd = [
        "mace_run_train",
        "--name",              "mace_T1_singlehead_sanity",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",

        "--train_file",        "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz",
        "--valid_file",        "/home/phanim/harshitrawat/summer/final_work_extxyz/val.extxyz",

        # ⚠️ Adjust these to the actual keys in your extxyz
        "--energy_key", "energy",
        "--forces_key", "forces",

        "--device",            "cuda",
        "--batch_size",        "2",
        "--valid_batch_size",  "1",
        "--valid_fraction",    "0.0",
    
        "--r_max",             "5.0",
        "--max_num_epochs",    "15",
        "--restart_latest",

        # Safer force/energy weighting
        "--forces_weight",     "30.0",
        "--energy_weight",     "1.0",

        "--lr",                "5e-3",
        "--weight_decay",      "1e-8",
        "--clip_grad",         "3",
        "--patience",          "7",

        "--E0s",               "average"
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T1_singlehead_sanity \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --train_file \
    /home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/final_work_extxyz/val.extxyz \
    --energy_key \
    energy \
    --forces_key \
    forces \
    --device \
    cuda \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --valid_fraction \
    0.0 \
    --r_max \
    5.0 \
    --max_num_epochs \
    15 \
    --restart_latest \
    --forces_weight \
    30.0 \
    --energy_weight \
    1.0 \
    --lr \
    5e-3 \
    --weight_decay \
    1e-8 \
    --clip_grad \
    3 \
    --patience \
    7 \
    --E0s \
    average
  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 

2025-08-12 20:40:44.167 INFO: MACE version: 0.3.14
2025-08-12 20:40:44.798 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-12 20:40:46.420 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-12 20:40:46.421 INFO: Using heads: ['Default']
2025-08-12 20:40:46.421 INFO: Using the key specifications to parse data:
2025-08-12 20:40:46.421 INFO: Default: KeySpecification(info_keys={'energy': 'energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'forces', 'charges': 'REF_charges'})
2025-08-12 20:41:06.579 INFO: Training set 1/1 [energy: 5986, stress: 0, virials: 0, dipole components: 0, head: 5986, elec_temp: 0, total_charge: 0, polarizability: 0, total_spin: 0, forces: 5986, charges: 0]
2025-08-12 20:41:06.597 INFO: Total Training set [energy: 5986, stress: 0, virials: 0, dipole components: 0, head: 5986, elec_temp: 0, 

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-12 20:42:31.594 INFO: Total number of parameters: 894362
2025-08-12 20:42:31.595 INFO: 
2025-08-12 20:42:31.595 INFO: Using ADAM as parameter optimizer
2025-08-12 20:42:31.595 INFO: Batch size: 2
2025-08-12 20:42:31.595 INFO: Number of gradient updates: 44895
2025-08-12 20:42:31.595 INFO: Learning rate: 0.005, weight decay: 1e-08
2025-08-12 20:42:31.595 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=30.000)
2025-08-12 20:42:31.647 INFO: Loading checkpoint: ./checkpoints/mace_T1_singlehead_sanity_run-123_epoch-4.pt
2025-08-12 20:42:31.688 INFO: Using gradient clipping with tolerance=3.000
2025-08-12 20:42:31.688 INFO: 
2025-08-12 20:42:31.688 INFO: Started training, reporting errors on validation set
2025-08-12 20:42:31.688 INFO: Loss metrics on validation set
2025-08-12 20:42:51.796 INFO: Initial: head: Default, loss=1015074.68241551, RMSE_E_per_atom=44005.39 meV, RMSE_F=145037.88 meV / A
2025-08-12 21:10:31.706 INFO: Epoch 4: head: Default, loss=6896951.4614

KeyboardInterrupt: 

In [12]:
#!/usr/bin/env python3
import os, subprocess, sys

def main():
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    os.environ["PYTHONPATH"] = "/home/phanim/harshitrawat/mace/mace"

    cmd = [
        "mace_run_train",
        "--name","mace_T1_singlehead_sanity",
        "--model","MACE",
        "--num_interactions","2",
        "--foundation_model","/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",

        "--train_file","/home/phanim/harshitrawat/summer/final_work_extxyz/T1_labeled.extxyz",
        "--valid_file","/home/phanim/harshitrawat/summer/final_work_extxyz/val_labeled.extxyz",
        "--energy_key","energy_eV",
        "--forces_key","forces_per_atom_eV_per_A",

        "--device","cuda",
        "--batch_size","2","--valid_batch_size","2","--valid_fraction","0.0",
        "--r_max","5.0","--max_num_epochs","","--restart_latest",

        "--forces_weight","30.0","--energy_weight","1.0",
        "--lr","3e-3","--weight_decay","1e-8","--clip_grad","3","--patience","3",
        "--ema_decay","0.0",

        "--E0s","average",
    ]
    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T1_singlehead_sanity \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --train_file \
    /home/phanim/harshitrawat/summer/final_work_extxyz/T1_labeled.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/final_work_extxyz/val_labeled.extxyz \
    --energy_key \
    energy_eV \
    --forces_key \
    forces_per_atom_eV_per_A \
    --device \
    cuda \
    --batch_size \
    2 \
    --valid_batch_size \
    2 \
    --valid_fraction \
    0.0 \
    --r_max \
    5.0 \
    --max_num_epochs \
    8 \
    --restart_latest \
    --forces_weight \
    30.0 \
    --energy_weight \
    1.0 \
    --lr \
    3e-3 \
    --weight_decay \
    1e-8 \
    --clip_grad \
    3 \
    --patience \
    3 \
    --ema_decay \
    0.0 \
    --E0s \
    average
  _Jd, _W3j_flat, _W

2025-08-12 23:02:18.200 INFO: MACE version: 0.3.14
2025-08-12 23:02:18.773 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-12 23:02:19.250 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-12 23:02:19.256 INFO: Using heads: ['Default']
2025-08-12 23:02:19.256 INFO: Using the key specifications to parse data:
2025-08-12 23:02:19.256 INFO: Default: KeySpecification(info_keys={'energy': 'energy_eV', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'forces_per_atom_eV_per_A', 'charges': 'REF_charges'})
2025-08-12 23:02:40.690 INFO: Training set 1/1 [energy: 5986, stress: 0, virials: 0, dipole components: 0, head: 5986, elec_temp: 0, total_charge: 0, polarizability: 0, total_spin: 0, forces: 5986, charges: 0]
2025-08-12 23:02:40.707 INFO: Total Training set [energy: 5986, stress: 0, virials: 0, dipole components: 0, head:

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-12 23:04:05.662 INFO: Total number of parameters: 894362
2025-08-12 23:04:05.663 INFO: 
2025-08-12 23:04:05.663 INFO: Using ADAM as parameter optimizer
2025-08-12 23:04:05.663 INFO: Batch size: 2
2025-08-12 23:04:05.663 INFO: Number of gradient updates: 23944
2025-08-12 23:04:05.663 INFO: Learning rate: 0.003, weight decay: 1e-08
2025-08-12 23:04:05.663 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=30.000)
2025-08-12 23:04:05.686 INFO: Loading checkpoint: ./checkpoints/mace_T1_singlehead_sanity_run-123_epoch-6.pt
2025-08-12 23:04:05.725 INFO: Using gradient clipping with tolerance=3.000
2025-08-12 23:04:05.725 INFO: 
2025-08-12 23:04:05.725 INFO: Started training, reporting errors on validation set
2025-08-12 23:04:05.725 INFO: Loss metrics on validation set
2025-08-12 23:04:15.811 INFO: Initial: head: Default, loss=367457.72067584, RMSE_E_per_atom=42656.90 meV, RMSE_F=125061.38 meV / A
2025-08-12 23:31:45.944 INFO: Epoch 6: head: Default, loss=10280532.5591

In [1]:
from pymatgen.ext.matproj import MPRester
from pymatgen.core import Element
from pymatgen.analysis.phase_diagram import PhaseDiagram

# 🔑 Insert your Materials Project API key here
API_KEY = "j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"  # <-- paste your key as a string

ELEMENTS = ["Li", "La", "Zr", "O"]

with MPRester(API_KEY) as mpr:
    print("🔍 Fetching elemental entries...")
    entries = mpr.get_entries_in_chemsys(ELEMENTS, inc_structure="final")
    el_entries = [e for e in entries if len(e.composition) == 1]

    pd = PhaseDiagram(el_entries)
    stable = {}

    for el in ELEMENTS:
        try:
            entry = pd.get_stable_entry(Element(el))
            stable[el] = entry
        except:
            print(f"⚠️ Could not find stable entry for {el}")

    for el, entry in stable.items():
        structure = entry.structure
        filename = f"{el}.cif"
        structure.to(fmt="cif", filename=filename)
        print(f"✅ Saved {el}: {filename}")


🔍 Fetching elemental entries...
⚠️ Could not find stable entry for Li
⚠️ Could not find stable entry for La
⚠️ Could not find stable entry for Zr
⚠️ Could not find stable entry for O


In [2]:
from pymatgen.ext.matproj import MPRester
from pymatgen.core import Composition
import json

# Insert your Materials Project API key here
API_KEY = "j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"  # <--- Replace this with your key

ELEMENTS = ["Li", "La", "Zr", "O"]

with MPRester(API_KEY) as mpr:
    for el in ELEMENTS:
        print(f"🔍 Searching lowest-energy structure for {el}...")
        entries = mpr.get_entries(el, inc_structure=True)
        
        # Filter only elemental ones
        elemental_entries = [e for e in entries if Composition(el).reduced_formula == e.composition.reduced_formula]

        if not elemental_entries:
            print(f"❌ No elemental structure found for {el}")
            continue

        # Sort by energy per atom
        elemental_entries.sort(key=lambda e: e.energy_per_atom)
        best_entry = elemental_entries[0]
        structure = best_entry.structure
        filename = f"{el}.cif"
        structure.to(fmt="cif", filename=filename)
        print(f"✅ Saved {el}: {filename}")


🔍 Searching lowest-energy structure for Li...


  entries = mpr.get_entries(el, inc_structure=True)


✅ Saved Li: Li.cif
🔍 Searching lowest-energy structure for La...


  entries = mpr.get_entries(el, inc_structure=True)


✅ Saved La: La.cif
🔍 Searching lowest-energy structure for Zr...


  entries = mpr.get_entries(el, inc_structure=True)


✅ Saved Zr: Zr.cif
🔍 Searching lowest-energy structure for O...


  entries = mpr.get_entries(el, inc_structure=True)


✅ Saved O: O.cif
