In [1]:
#!/usr/bin/env python3
import ijson, random, sys
from ase import Atoms
from ase.io import write

# --- config ---
INPATH   = "/home/phanim/harshitrawat/MPtrj_2022.9_full.json"
OUTPATH  = "/home/phanim/harshitrawat/summer/replay_LiLaZrO_5k_le200.extxyz"
WANTED   = {"Li", "La", "Zr", "O"}
MAX_AT   = 200
TARGET_N = 5000
SEED     = 42

random.seed(SEED)

def valid_entry(entry):
    try:
        syms = entry["symbols"]
        uniq = set(syms)
        if not uniq.issubset(WANTED): return False
        if not (1 <= len(uniq) <= 4): return False
        if len(syms) > MAX_AT: return False
        if "positions" not in entry or "cell" not in entry or "energy" not in entry: return False
        return True
    except Exception:
        return False

def make_atoms(entry):
    ats = Atoms(
        symbols=entry["symbols"],
        positions=entry["positions"],
        cell=entry["cell"],
        pbc=True
    )
    ats.info["energy"] = float(entry["energy"])
    return ats

def main():
    reservoir = []
    kept = 0
    seen_valid = 0

    with open(INPATH, "rb") as f:
        # stream top-level dict: material_id → dict of entries
        for mid, entries in ijson.kvitems(f, "", multiple_values=True):
            for eid, entry in entries.items():
                if not valid_entry(entry):
                    continue

                seen_valid += 1
                if kept < TARGET_N:
                    try:
                        reservoir.append(make_atoms(entry))
                        kept += 1
                    except Exception:
                        seen_valid -= 1
                        continue
                else:
                    # reservoir sampling replacement
                    j = random.randrange(seen_valid)
                    if j < TARGET_N:
                        try:
                            reservoir[j] = make_atoms(entry)
                        except Exception:
                            seen_valid -= 1
                            continue

    write(OUTPATH, reservoir)
    print(f"[done] wrote {len(reservoir)} structures to {OUTPATH}")
    print(f"[stats] seen_valid={seen_valid}, kept={kept}")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit(130)


IncompleteJSONError: lexical error: invalid char in json text.
                                                            (right here) ------^


In [2]:
import ijson

f = "/home/phanim/harshitrawat/MPtrj_2022.9_full.json"
with open(f, "rb") as fd:
    for i, (k,v) in enumerate(ijson.kvitems(fd, "", multiple_values=True)):
        print(i, k, list(v.keys())[:5])
        if i==3: break


0 mp-1005792 ['mp-1012897-0-0', 'mp-1005792-0-1', 'mp-1005792-0-0', 'mp-1005792-1-1', 'mp-1005792-1-0']
1 mp-1006278 ['mp-1006287-0-0', 'mp-1006278-0-4', 'mp-1006278-0-3', 'mp-1006278-0-2', 'mp-1006278-0-1']
2 mp-10068 ['mp-910115-0-0', 'mp-10068-0-2', 'mp-10068-1-4', 'mp-10068-1-2', 'mp-10068-1-0']
3 mp-1007758 ['mp-1007758-0-0', 'mp-1007758-1-10', 'mp-1007758-1-9', 'mp-1007758-1-8', 'mp-1007758-1-6']


In [12]:
import json, time
from concurrent.futures import ProcessPoolExecutor, as_completed

# --- robust extractor for CHGNet MPtrj ---
def elements_from_structure(record):
    s = record.get("structure")
    if not isinstance(s, dict):
        return set()

    # 1) composition dict (common in pymatgen as_dict)
    comp = s.get("composition")
    if isinstance(comp, dict):
        # comp might be {"Li":7, "La":3, "Zr":2, "O":12} or {"@class":"Composition", ...}
        # Try dict of element->count first:
        keys = [k for k in comp.keys() if isinstance(k, str) and len(k) <= 3]
        if keys:
            return set(keys)
        # Sometimes nested like {"reduced_cell_composition": {"Li":7,...}}; try to find a flat dict of elements
        for v in comp.values():
            if isinstance(v, dict):
                el_keys = [k for k in v.keys() if isinstance(k, str) and len(k) <= 3]
                if el_keys:
                    return set(el_keys)

    # 2) sites[*].species[*].element (CHGNet docs)
    sites = s.get("sites")
    if isinstance(sites, list):
        out = set()
        for site in sites:
            species = site.get("species", [])
            # species can be [{"element":"Li","occu":1}] or [{"element":{"element":"Li"},...}]
            for sp in species:
                el = sp.get("element")
                if isinstance(el, dict) and "element" in el:
                    el = el["element"]
                if isinstance(el, str) and el:
                    out.add(el)
            # fallback: label field sometimes present ("Li", "O", ...)
            lab = site.get("label")
            if isinstance(lab, str) and len(lab) <= 3:
                out.add(lab)
        if out:
            return out

    # 3) last resort: try top-level shortcut if present
    els = record.get("elements")
    if isinstance(els, list) and els:
        return set(map(str, els))

    return set()

TARGET = {"Li","La","Zr","O"}
ALLOWED = {2,3,4}

def filter_chunk(in_path, out_path):
    kept = seen = 0
    with open(in_path, "r", encoding="utf-8", errors="ignore") as fi, \
         open(out_path, "w", encoding="utf-8") as fo:
        for line in fi:
            s = line.strip()
            if not s:
                continue
            try:
                d = json.loads(s)
            except Exception:
                continue
            seen += 1
            elems = elements_from_structure(d)
            if elems and (len(elems) in ALLOWED) and elems.issubset(TARGET):
                if not d.get("elements"):
                    d = dict(d); d["elements"] = sorted(elems)
                fo.write(json.dumps(d, separators=(",",":")) + "\n")
                kept += 1
    return in_path, seen, kept


In [15]:
# --- CONFIG ---
SRC = "/home/phanim/harshitrawat/MPtrj_2022.9_full.json"   # or .json.gz
WORKDIR = "/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb"
NDJSON = f"{WORKDIR}/MPtrj_base.ndjson"

import os, gzip, io, ijson, json, time
os.makedirs(WORKDIR, exist_ok=True)

def open_bin(path):
    return gzip.open(path, "rb") if path.endswith(".gz") else open(path, "rb")

def first_non_ws_byte(f):
    pos = f.tell()
    while True:
        chunk = f.read(65536)
        if not chunk:
            f.seek(pos); return None
        for b in chunk:
            if chr(b) not in " \t\r\n":
                f.seek(pos); return b

t0 = time.time()
with open_bin(SRC) as fb:
    b = first_non_ws_byte(fb)
    if b is None:
        raise RuntimeError("Source file is empty.")
    if chr(b) == '[':
        # Big JSON array -> stream to NDJSON
        print("Detected big JSON array. Streaming -> NDJSON ...")
        it = ijson.items(fb, "item")
        count = 0
        with open(NDJSON, "w", encoding="utf-8") as out:
            for obj in it:
                out.write(json.dumps(obj, separators=(",", ":")) + "\n")
                count += 1
                if count % 100000 == 0:
                    print(f"  wrote {count:,} objects...", flush=True)
        print(f"Done: wrote {count:,} objects to {NDJSON} in {time.time()-t0:.1f}s")
    else:
        # NDJSON already -> just copy as-is (strip blanks)
        print("Detected NDJSON. Copying -> normalized NDJSON ...")
        count = 0
        with io.TextIOWrapper(fb, encoding="utf-8", errors="ignore") as fi, \
             open(NDJSON, "w", encoding="utf-8") as out:
            for line in fi:
                s = line.strip()
                if not s: continue
                out.write(s + "\n")
                count += 1
        print(f"Done: copied {count:,} lines to {NDJSON} in {time.time()-t0:.1f}s")


Detected NDJSON. Copying -> normalized NDJSON ...
Done: copied 1 lines to /home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/MPtrj_base.ndjson in 48.2s


In [16]:
NUM_CHUNKS = 30  # set to number of CPU cores you want

chunk_paths = [f"{WORKDIR}/chunk_{i:02d}.ndjson" for i in range(NUM_CHUNKS)]
files = [open(p, "w", encoding="utf-8") for p in chunk_paths]

import time
t0 = time.time()
line_count = 0
with open(NDJSON, "r", encoding="utf-8", errors="ignore") as f:
    for i, line in enumerate(f):
        if not line.strip(): continue
        files[i % NUM_CHUNKS].write(line)
        line_count += 1

for fh in files: fh.close()
print(f"Split {line_count:,} lines into {NUM_CHUNKS} chunks in {time.time()-t0:.1f}s")
chunk_paths


Split 1 lines into 30 chunks in 52.7s


['/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_00.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_01.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_02.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_03.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_04.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_05.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_06.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_07.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_08.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_09.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_10.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_11.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_12.ndjson',
 '/home/phanim/harshitrawat/summer/mptrj_li_lazr_o_nb/chunk_13.n

In [17]:
import ijson, gzip, json, time

SRC = "/home/phanim/harshitrawat/MPtrj_2022.9_full.json"   # or .gz
NDJSON = f"{WORKDIR}/MPtrj_base.ndjson"

t0 = time.time()
count = 0
with open(SRC, "rb") as f, open(NDJSON, "w", encoding="utf-8") as out:
    for obj in ijson.items(f, "item"):
        out.write(json.dumps(obj, separators=(",",":")) + "\n")
        count += 1
        if count % 100000 == 0:
            print(f"wrote {count:,} objects...", flush=True)

print(f"Done: wrote {count:,} objects to {NDJSON} in {time.time()-t0:.1f}s")


IncompleteJSONError: lexical error: invalid char in json text.
                                                            (right here) ------^


In [None]:
#!/usr/bin/env python3
from pymatgen.core import Structure
import numpy as np, os, json

SRC = "/home/phanim/harshitrawat/summer/md/mdcifs_strained_perturbed_prime/cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed.cif"
s = Structure.from_file(SRC)
lengths = np.array([s.lattice.a, s.lattice.b, s.lattice.c])
frac = s.frac_coords.copy()

occupied, vac = [], []
for ax in range(3):
    c = frac[:, ax]
    span_frac = max(1e-9, c.max() - c.min())
    span_cart = span_frac * lengths[ax]
    occupied.append(span_cart)
    vac.append(max(0.0, lengths[ax] - span_cart))
vac_axis = int(np.argmax(vac))

# center along vacuum axis
minf, maxf = frac[:, vac_axis].min(), frac[:, vac_axis].max()
shift = 0.5 - 0.5*(minf + maxf)
frac[:, vac_axis] = (frac[:, vac_axis] + shift) % 1.0

s_centered = s.copy()
s_centered.remove_sites(range(len(s_centered)))
for i, sp in enumerate(s.species):
    s_centered.append(sp, frac[i], coords_are_cartesian=False)

OUT = os.path.splitext(SRC)[0] + "_centered.cif"
s_centered.to(filename=OUT)
print(f"vac_axis={vac_axis} (0=a, 1=b, 2=c); wrote {OUT}")


vac_axis=2 (0=a, 1=b, 2=c); wrote /home/phanim/harshitrawat/summer/md/mdcifs_strained_perturbed_prime/cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed_centered.cif


  struct = parser.parse_structures(primitive=primitive)[0]


In [6]:
# Robust centering for slabs split across periodic boundaries
# - Finds the largest fractional gap along each axis (vacuum)
# - Rotates fractional coords so slab is contiguous and centered at 0.5
# - Optional: add vacuum padding and disable PBC along vacuum axis
# Requires: pip install ase

from ase.io import read, write
import numpy as np, os
from typing import Tuple

AXMAP = {"a":0,"b":1,"c":2}

def largest_gap_shift(fracs_1d: np.ndarray) -> Tuple[float, float]:
    """
    Given fractional coords in [0,1), find the circular largest gap.
    Returns (gap_center, gap_size), where shifting all fracs by
    (-gap_center + 0.5) moves the gap center to 0.5, i.e., centers the slab.
    """
    f = np.sort(fracs_1d % 1.0)
    if len(f) == 0:
        return 0.5, 1.0
    # Gaps between consecutive points on the circle:
    diffs = np.diff(f, append=f[0] + 1.0)
    i = int(np.argmax(diffs))
    gap_size = float(diffs[i])
    # Gap runs from f[i] -> f[i]+gap_size. Its center:
    gap_center = (f[i] + gap_size/2.0) % 1.0
    return gap_center, gap_size

def detect_vac_axis_by_gap(atoms) -> Tuple[int, list]:
    """
    Vacuum thickness (Å) = largest fractional gap * axis length.
    Returns (vac_axis, [vacA_a, vacA_b, vacA_c]).
    """
    cell = atoms.get_cell()
    lengths = cell.lengths()
    frac = atoms.get_scaled_positions() % 1.0
    vacA = []
    for ax in range(3):
        _, gap = largest_gap_shift(frac[:, ax])
        vacA.append(float(gap * lengths[ax]))
    return int(np.argmax(vacA)), vacA

def rotate_and_center(atoms, ax: int):
    """
    Rotate fractional coords so the largest gap is centered at 0.5,
    making the slab contiguous and centered along axis ax.
    """
    frac = atoms.get_scaled_positions() % 1.0
    gap_center, _ = largest_gap_shift(frac[:, ax])
    # We want the gap center at 0.5 -> slab center at 0.0 and symmetric.
    # Shift = 0.5 - gap_center
    shift = (0.5 - gap_center) % 1.0
    frac[:, ax] = (frac[:, ax] + shift) % 1.0

    # Now center the slab block exactly at 0.5 (continuous segment midpoint)
    # Compute occupied segment [min,max] after rotation (no wrap now)
    lo, hi = float(frac[:, ax].min()), float(frac[:, ax].max())
    mid = 0.5 * (lo + hi)
    frac[:, ax] = (frac[:, ax] + (0.5 - mid)) % 1.0

    atoms.set_scaled_positions(frac)

def add_vacuum_padding(atoms, ax: int, pad_A: float):
    """Increase lattice length along axis ax by pad_A while preserving tilt."""
    if pad_A <= 0: 
        return
    cell = atoms.get_cell().array.copy()
    vec = cell[ax]
    L = np.linalg.norm(vec)
    if L > 1e-12:
        cell[ax] = vec * ((L + pad_A) / L)
        atoms.set_cell(cell, scale_atoms=False)

def center_slab_gap_method(
    infile: str,
    axis: str = "auto",      # "auto" or "a"/"b"/"c"
    pad_A: float = 0.0,      # extra vacuum to add (Å)
    nonperiodic: bool = True,
    basename: str | None = None,
    write_xyz: bool = True,
    write_lammps: bool = True,
):
    atoms = read(infile)
    stem = os.path.splitext(os.path.basename(infile))[0]
    base = basename or f"{stem}_centered"

    if axis == "auto":
        vac_ax, vacA = detect_vac_axis_by_gap(atoms)
    else:
        vac_ax = AXMAP[axis.lower()]
        _, vacA = detect_vac_axis_by_gap(atoms)  # still report

    # First rotate+center into a contiguous block, then optionally add padding.
    # (Order doesn’t matter for fractional rotation, but this keeps intuition clean.)
    rotate_and_center(atoms, vac_ax)
    add_vacuum_padding(atoms, vac_ax, pad_A)
    # Re-center again after padding to be super safe numerically:
    rotate_and_center(atoms, vac_ax)

    if nonperiodic:
        pbc = list(atoms.get_pbc())
        pbc[vac_ax] = False
        atoms.set_pbc(pbc)

    out_cif = f"{base}.cif"
    write(out_cif, atoms)
    out_xyz = out_lmp = None

    if write_xyz:
        out_xyz = f"{base}.xyz"
        write(out_xyz, atoms)   # XYZ has no PBC → Avogadro shows a single slab

    if write_lammps:
        out_lmp = f"{base}.data"
        write(out_lmp, atoms, format="lammps-data")  # respects PBC → gives p p f if nonperiodic

    return {
        "vacuum_axis_index": vac_ax,                 # 0=a,1=b,2=c
        "vacuum_estimates_A": {"a":vacA[0], "b":vacA[1], "c":vacA[2]},
        "outputs": {"cif": out_cif, "xyz": out_xyz, "lammps_data": out_lmp},
        "nonperiodic_applied": bool(nonperiodic),
        "extra_padding_A": float(pad_A),
    }


In [7]:
# === EDIT THIS ===
infile = "/home/phanim/harshitrawat/summer/md/mdcifs_strained_perturbed_prime/cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed.cif"

rep = center_slab(
    infile,
    axis="auto",        # or "c" if you know vacuum is along c
    pad_A=10.0,         # add 10 Å more vacuum (set 0.0 to skip)
    nonperiodic=True,   # off-PBC along vacuum (XYZ, LAMMPS respect it)
    basename=None,      # custom base name if you want
)
rep


{'vacuum_axis_index': 2,
 'vacuum_estimates_A': {'a': 0.028919990786413052,
  'b': 0.03174803773705648,
  'c': 0.12013604529882116},
 'outputs': {'cif': 'cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed_centered.cif',
  'xyz': 'cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed_centered.xyz',
  'lammps_data': 'cellrelaxed_LLZO_011_La_code71_sto__Li_100_slab_heavy_T300_0138_strain+0.015_perturbed_centered.data'},
 'nonperiodic_applied': True,
 'extra_padding_A': 10.0}