In [5]:
from ase.io import iread
import numpy as np, json, os
from collections import Counter

T1 = "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz"
E0S_JSON = "/home/phanim/harshitrawat/summer/final_work_extxyz/e0s_from_T1.json"

def _get_energy(at):
    # Try common info keys; return None if not found
    info = at.info or {}
    for k in ("energy", "total_energy", "E", "dft_energy", "energy_eV"):
        if k in info:
            try:
                return float(info[k])
            except Exception:
                pass
    # last resort: if a calculator is attached (unlikely here)
    try:
        return float(at.get_potential_energy())
    except Exception:
        return None

# Pass 1: discover elements
Zs = set()
n_frames = 0
for at in iread(T1):
    n_frames += 1
    Zs.update(at.get_atomic_numbers().tolist())

Zs = sorted(Zs)  # e.g., [3, 8, 40, 57]
print(f"Found {len(Zs)} elements across {n_frames} frames: {Zs}")

# Pass 2: build counts and energies (skip frames with missing energy)
X_rows, y_vals = [], []
kept, skipped = 0, 0
for at in iread(T1):
    E = _get_energy(at)
    if E is None or not np.isfinite(E):
        skipped += 1
        continue
    nums = at.get_atomic_numbers()
    counts = [int((nums == Z).sum()) for Z in Zs]
    X_rows.append(counts)
    y_vals.append(E)
    kept += 1

X = np.asarray(X_rows, dtype=float)
y = np.asarray(y_vals, dtype=float)
print(f"Using {kept} frames (skipped {skipped} without a usable energy)")

if kept < len(Zs):
    raise RuntimeError("Not enough frames with energies to solve for all E0s—need at least one per element.")

# Least squares solve: X * E0s ≈ y
E0s_vec, *_ = np.linalg.lstsq(X, y, rcond=None)
E0s = {int(Z): float(v) for Z, v in zip(Zs, E0s_vec)}

# Save JSON (best for CLI): use string keys for safety
E0s_json = {str(k): v for k, v in E0s.items()}
with open(E0S_JSON, "w") as f:
    json.dump(E0s_json, f, indent=2)
print("Saved E0s to:", E0S_JSON)
print("E0s:", E0s)

# Also show brace format if you still want inline (be careful with quoting)
cli_inline = "{" + ",".join(f"{Z}:{E0s[Z]:.6f}" for Z in Zs) + "}"
print("Inline CLI (brace) form:", cli_inline)


Found 4 elements across 5986 frames: [3, 8, 40, 57]
Using 5986 frames (skipped 0 without a usable energy)
Saved E0s to: /home/phanim/harshitrawat/summer/final_work_extxyz/e0s_from_T1.json
E0s: {3: -1.1190404292132379, 8: 0.5759887109337836, 40: 5.377267508500145, 57: -69.06937595870446}
Inline CLI (brace) form: {3:-1.119040,8:0.575989,40:5.377268,57:-69.069376}


In [7]:
from ase.io import read
import numpy as np

def fit_e0s(files):
    Zs = sorted({Z for f in files for at in read(f, ":") for Z in at.get_atomic_numbers()})
    X = []
    y = []
    for f in files:
        for at in read(f, ":"):
            nums = at.get_atomic_numbers()
            counts = [np.sum(nums == Z) for Z in Zs]
            X.append(counts)
            y.append(at.info.get("energy", at.get_potential_energy()))
    X = np.atleast_2d(np.asarray(X, float))  # force 2D
    y = np.asarray(y, float)
    E0s_vec, *_ = np.linalg.lstsq(X, y, rcond=None)
    return {int(Z): float(v) for Z, v in zip(Zs, E0s_vec)}

T1   = "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz"
VAL  = "/home/phanim/harshitrawat/summer/final_work_extxyz/val.extxyz"
T2   = "/home/phanim/harshitrawat/summer/final_work_extxyz/T2.extxyz"

e0_T1   = fit_e0s([T1])
e0_T1T2 = fit_e0s([T1, T2])
e0_ALL  = fit_e0s([T1, VAL, T2])

print("E0s from T1 only:   ", e0_T1)
print("E0s from T1 + T2:   ", e0_T1T2)
print("E0s from ALL splits:", e0_ALL)


E0s from T1 only:    {3: -1.1190404292132379, 8: 0.5759887109337836, 40: 5.377267508500145, 57: -69.06937595870446}
E0s from T1 + T2:    {3: -1.2467890359561093, 8: -15.762976385313458, 40: 7.038908378067, 57: -3.0029005518499665}
E0s from ALL splits: {3: -1.2501344927589009, 8: -14.54274319010543, 40: 5.24387544506227, 57: -6.629770984694096}


In [8]:
from ase.io import read
a = read("/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz", index=0)
print("INFO:", a.info.keys())      # should contain 'energy'
print("ARRAYS:", a.arrays.keys())  # should contain 'forces'


INFO: dict_keys(['spacegroup', 'unit_cell', 'occupancy', 'file', 'src_path', 'label_source', 'config_type'])
ARRAYS: dict_keys(['numbers', 'positions', 'spacegroup_kinds'])


In [9]:
# Quick peek: confirm what the first frame actually has in the file
path = "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz"
with open(path, "r") as f:
    for i in range(2):  # first two lines of first frame
        print(f.readline().rstrip())


648
Lattice="13.36304565 0.0 0.0 -0.024529949824237295 13.766710625906558 0.0 -0.15653770325558072 0.04068489902763997 78.56882197629467" Properties=species:S:1:pos:R:3:spacegroup_kinds:I:1:forces:R:3 spacegroup="P 1" unit_cell=conventional occupancy="_JSON {\"0\": {\"Li\": 1.0}, \"1\": {\"Li\": 1.0}, \"2\": {\"Li\": 1.0}, \"3\": {\"Li\": 1.0}, \"4\": {\"Li\": 1.0}, \"5\": {\"Li\": 1.0}, \"6\": {\"Li\": 1.0}, \"7\": {\"Li\": 1.0}, \"8\": {\"Li\": 1.0}, \"9\": {\"Li\": 1.0}, \"10\": {\"Li\": 1.0}, \"11\": {\"Li\": 1.0}, \"12\": {\"Li\": 1.0}, \"13\": {\"Li\": 1.0}, \"14\": {\"Li\": 1.0}, \"15\": {\"Li\": 1.0}, \"16\": {\"Li\": 1.0}, \"17\": {\"Li\": 1.0}, \"18\": {\"Li\": 1.0}, \"19\": {\"Li\": 1.0}, \"20\": {\"Li\": 1.0}, \"21\": {\"Li\": 1.0}, \"22\": {\"Li\": 1.0}, \"23\": {\"Li\": 1.0}, \"24\": {\"Li\": 1.0}, \"25\": {\"Li\": 1.0}, \"26\": {\"Li\": 1.0}, \"27\": {\"Li\": 1.0}, \"28\": {\"Li\": 1.0}, \"29\": {\"Li\": 1.0}, \"30\": {\"Li\": 1.0}, \"31\": {\"Li\": 1.0}, \"32\": {\"Li\"

In [11]:
from ase.io import read

path = "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz"
atoms_list = list(read(path, index=":"))

n = len(atoms_list)
has_E  = [i for i,a in enumerate(atoms_list) if "energy" in a.info]
has_F  = [i for i,a in enumerate(atoms_list) if "forces" in a.arrays]
print(f"Frames: {n}")
print(f"with energy: {len(has_E)}")
print(f"with forces: {len(has_F)}")
print("example frame keys:")
for k in [0, n//2, n-1]:
    a = atoms_list[k]
    print(k, "INFO:", a.info.keys(), "ARRAYS:", a.arrays.keys())


Frames: 5986
with energy: 0
with forces: 0
example frame keys:
0 INFO: dict_keys(['spacegroup', 'unit_cell', 'occupancy', 'file', 'src_path', 'label_source', 'config_type']) ARRAYS: dict_keys(['numbers', 'positions', 'spacegroup_kinds'])
2993 INFO: dict_keys(['spacegroup', 'unit_cell', 'occupancy', 'file', 'src_path', 'label_source', 'config_type']) ARRAYS: dict_keys(['numbers', 'positions', 'spacegroup_kinds'])
5985 INFO: dict_keys(['spacegroup', 'unit_cell', 'occupancy', 'file', 'src_path', 'label_source', 'config_type']) ARRAYS: dict_keys(['numbers', 'positions', 'spacegroup_kinds'])


In [3]:
from ase.io import iread
import numpy as np

path = "/home/phanim/harshitrawat/summer/final_work_extxyz/T1.extxyz"

def show(a, i):
    info, arrs = a.info, a.arrays
    print(f"\n--- Frame {i} ---")
    print("info keys:", sorted(info.keys()))
    print("array keys:", sorted(arrs.keys()))
    # try any key containing 'energy' (case-insensitive)
    ekeys = [k for k in info if 'energy' in k.lower()]
    fkeys = [k for k in arrs if 'force' in k.lower()]
    print("candidate energy keys:", ekeys)
    print("candidate force keys:", fkeys)
    if ekeys:
        E = float(info[ekeys[0]])
        print("E(raw) =", E, " per-atom=", E/len(a))
    if fkeys:
        F = arrs[fkeys[0]]
        print("F shape:", F.shape, " |F|_mean:", np.linalg.norm(F,axis=1).mean())

for i, a in enumerate(iread(path, index=":")):
    show(a, i)
    if i == 2: break



--- Frame 0 ---
info keys: ['config_type', 'file', 'label_source', 'occupancy', 'spacegroup', 'src_path', 'unit_cell']
array keys: ['numbers', 'positions', 'spacegroup_kinds']
candidate energy keys: []
candidate force keys: []

--- Frame 1 ---
info keys: ['config_type', 'file', 'label_source', 'occupancy', 'spacegroup', 'src_path', 'unit_cell']
array keys: ['numbers', 'positions', 'spacegroup_kinds']
candidate energy keys: []
candidate force keys: []

--- Frame 2 ---
info keys: ['config_type', 'file', 'label_source', 'occupancy', 'spacegroup', 'src_path', 'unit_cell']
array keys: ['numbers', 'positions', 'spacegroup_kinds']
candidate energy keys: []
candidate force keys: []


In [10]:
import json, gzip
from pathlib import Path

# === Pick any one label file to inspect ===
label_file = Path("/home/phanim/harshitrawat/summer/T1_T2_T3_data/mdinfo_chgnet_predictions_forces.json")  

# Auto-handle .gz or plain JSON/JSONL
opener = gzip.open if label_file.suffix == ".gz" else open

with opener(label_file, "rt") as f:
    first_lines = []
    for _ in range(5):  # just preview 5 entries
        line = f.readline()
        if not line.strip():
            continue
        try:
            rec = json.loads(line)  # JSONL style
        except json.JSONDecodeError:
            f.seek(0)
            recs = json.load(f)  # JSON array style
            rec = recs[0]
        first_lines.append(rec)

# Print the keys in the first few entries
for i, rec in enumerate(first_lines):
    print(f"\n--- Record {i} ---")
    for k, v in rec.items():
        if isinstance(v, (list, dict)):
            print(f"{k}: type={type(v)}, len={len(v) if hasattr(v,'__len__') else '?'}")
        else:
            print(f"{k}: {v}")



--- Record 0 ---
file: cellrelaxed_LLZO_001_Zr_code93_sto__Li_100_slab_heavy_T300_0000.cif
energy_eV: -2787.9942054748535
forces_per_atom_eV_per_A: type=<class 'list'>, len=648
stress_tensor: type=<class 'list'>, len=3
magmom_total: None


In [1]:
#!/usr/bin/env python3
import os, json, gzip, math, sys, random
from pathlib import Path
from typing import Dict, List, Tuple, Iterable
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
from ase.io import read, iread, write

# ===== CONFIG =====
BASE = "/home/phanim/harshitrawat/summer"
IN_TRAIN = f"{BASE}/final_work_extxyz/T1.extxyz"      # unlabeled or mislabeled extxyz to take frame order/meta from
IN_VALID = f"{BASE}/final_work_extxyz/val.extxyz"
OUT_TRAIN = f"{BASE}/final_work_extxyz/T1_labeled.extxyz"
OUT_VALID = f"{BASE}/final_work_extxyz/val_labeled.extxyz"

LABEL_DIRS = [
    f"{BASE}/T1_T2_T3_data",
    f"{BASE}/md/labels_fresh_it2",
]

# shards
OUT_DIR   = Path(f"{BASE}/final_work_extxyz")
SHARD_DIR = OUT_DIR / "_shards_fast"
SHARD_SIZE = 2000  # frames per shard
MAX_WORKERS = max(1, (os.cpu_count() or 8))
JSON_WORKERS = min(16, MAX_WORKERS)
WRITE_WORKERS = max(1, MAX_WORKERS)

random.seed(0)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

# ===== Helpers =====
def _read_json_like(path: Path) -> Iterable[dict]:
    opener = gzip.open if path.suffix == ".gz" else open
    try:
        with opener(path, "rt") as f:
            head = f.read(2048)
            if not head.strip():
                return
            f.seek(0)
            if head.lstrip().startswith("["):
                arr = json.load(f)
                for r in arr:
                    if isinstance(r, dict): yield r
            else:
                for line in f:
                    s = line.strip()
                    if not s: continue
                    try:
                        r = json.loads(s)
                        if isinstance(r, dict): yield r
                    except json.JSONDecodeError:
                        continue
    except Exception:
        return

def _list_label_files() -> List[Path]:
    ALLOWED = (".json", ".jsonl", ".json.gz", ".jsonl.gz")
    files = []
    for d in LABEL_DIRS:
        p = Path(d)
        if not p.exists(): continue
        if p.is_file() and p.name.endswith(ALLOWED):
            files.append(p)
        elif p.is_dir():
            files += [q for q in p.rglob("*") if q.is_file() and q.name.endswith(ALLOWED)]
    return sorted(set(files))

def _ingest_one(fp: Path) -> Dict[str, Tuple[float, np.ndarray]]:
    m: Dict[str, Tuple[float, np.ndarray]] = {}
    for rec in _read_json_like(fp) or []:
        nm = rec.get("snapshot_file") or rec.get("file") or rec.get("src_path") or rec.get("path")
        if not nm: continue
        base = os.path.basename(nm)
        if "energy_eV" not in rec: continue
        F = rec.get("forces_per_atom_eV_per_A", rec.get("forces"))
        if F is None: continue
        F = np.asarray(F, float)
        if F.ndim == 1:  # flat 3N
            if F.size % 3 != 0: continue
            F = F.reshape(-1, 3)
        elif F.ndim != 2 or F.shape[1] != 3:
            continue
        m[base] = (float(rec["energy_eV"]), F)
    return m

def build_label_map_parallel() -> Dict[str, Tuple[float, np.ndarray]]:
    files = _list_label_files()
    if not files:
        print("[fatal] no label files found", file=sys.stderr); sys.exit(2)
    out: Dict[str, Tuple[float, np.ndarray]] = {}
    with ProcessPoolExecutor(max_workers=JSON_WORKERS) as ex:
        futs = {ex.submit(_ingest_one, fp): fp for fp in files}
        for fut in as_completed(futs):
            sub = fut.result()
            # later files override earlier deterministically due to sorted(files)
            out.update(sub)
    print(f"[labels] loaded {len(out)} unique basenames from {len(files)} files")
    return out

def _chunk(lst, n):
    for i in range(0, len(lst), n):
        yield i//n, lst[i:i+n]

# ===== Relabel core =====
def _read_extxyz_meta(path: str):
    """Return list of (info, arrays minimal) to extract 'file'/'src_path' and natoms per frame."""
    metas = []
    for a in iread(path, index=":"):
        info = dict(a.info)
        key = os.path.basename(info.get("file", info.get("src_path","")))
        metas.append((key, len(a), a))
    return metas

def _write_shard(args):
    (split_name, shard_id, metas, label_map, shard_dir) = args
    shard_path = Path(shard_dir) / f"{split_name}.shard{shard_id:05d}.extxyz"
    frames = []
    kept = 0
    for key, nat, a in metas:
        lab = label_map.get(key)
        if lab is None: continue
        E, F = lab
        # shape check
        if F.shape != (nat, 3):
            continue
        # write canonical keys
        info = dict(a.info)
        info["energy_eV"] = float(E)
        a.info = info
        arrs = dict(a.arrays)
        arrs["forces_per_atom_eV_per_A"] = F
        arrs.pop("forces", None)
        a.arrays = arrs
        frames.append(a)
        kept += 1
    if not frames:
        return str(shard_path), 0
    write(str(shard_path), frames, format="extxyz")
    # verify by reading one frame back
    try:
        b0 = next(iread(str(shard_path), index="0"))
        assert "energy_eV" in b0.info
        assert "forces_per_atom_eV_per_A" in b0.arrays
        assert b0.arrays["forces_per_atom_eV_per_A"].shape[1] == 3
    except Exception as e:
        return str(shard_path), -1  # signal verification failure
    return str(shard_path), kept

def relabel_one_split(split_name: str, in_path: str, out_path: str, label_map: Dict[str, Tuple[float, np.ndarray]]):
    metas = _read_extxyz_meta(in_path)
    if not metas:
        raise RuntimeError(f"No frames in {in_path}")
    print(f"[{split_name}] frames_in={len(metas)}; shard_size={SHARD_SIZE}; workers={WRITE_WORKERS}")
    SHARD_DIR.mkdir(parents=True, exist_ok=True)

    total_written = 0
    shard_paths = []
    with ProcessPoolExecutor(max_workers=WRITE_WORKERS) as ex:
        futs = {ex.submit(_write_shard, (split_name, sid, chunk, label_map, SHARD_DIR)): sid
                for sid, chunk in _chunk(metas, SHARD_SIZE)}
        for fut in as_completed(futs):
            sid = futs[fut]
            path, kept = fut.result()
            if kept == -1:
                print(f"[{split_name}] verify FAIL on shard {sid}: {path}")
                continue
            if kept > 0:
                shard_paths.append(path); total_written += kept

    if not shard_paths:
        raise RuntimeError(f"[{split_name}] no valid shards written (check labels/keys)")

    # merge shards via ASE to avoid header mismatches
    images = []
    for sp in sorted(shard_paths):
        images.extend(list(iread(sp, index=":")))
    write(out_path, images, format="extxyz")
    # final verify
    b0 = next(iread(out_path, index="0"))
    assert "energy_eV" in b0.info and "forces_per_atom_eV_per_A" in b0.arrays
    assert b0.arrays["forces_per_atom_eV_per_A"].shape[1] == 3
    print(f"[{split_name}] DONE: wrote {total_written} frames → {out_path}")

def main():
    label_map = build_label_map_parallel()
    relabel_one_split("T1",  IN_TRAIN, OUT_TRAIN, label_map)
    relabel_one_split("val", IN_VALID, OUT_VALID, label_map)

if __name__ == "__main__":
    main()


[labels] loaded 11654 unique basenames from 9 files
[T1] frames_in=5986; shard_size=2000; workers=224
[T1] DONE: wrote 5986 frames → /home/phanim/harshitrawat/summer/final_work_extxyz/T1_labeled.extxyz
[val] frames_in=640; shard_size=2000; workers=224
[val] DONE: wrote 640 frames → /home/phanim/harshitrawat/summer/final_work_extxyz/val_labeled.extxyz
