In [2]:
import numpy as np
from pathlib import Path
import dpdata


def _resolve_deepmd_npy_root(path: str | Path) -> Path:
    """
    Resolve a user-provided path to the DeepMD-kit 'deepmd/npy' root directory.
    The root is expected to contain one or more 'set.XXX' directories (e.g., set.000).
    """
    p = Path(path).expanduser().resolve()
    if not p.exists():
        raise FileNotFoundError(f"Path does not exist: {p}")

    # Case 1: user points directly to deepmd/npy root (contains set.*)
    if any(child.is_dir() and child.name.startswith("set.") for child in p.iterdir()):
        return p

    # Case 2: user points to a parent; search for the first set.000 within a limited depth
    # (keeps it safe even if you point to a larger folder)
    max_depth = 4
    candidates = []
    for depth in range(1, max_depth + 1):
        candidates = list(p.glob("*/" * (depth - 1) + "set.000"))
        if candidates:
            break

    if not candidates:
        raise ValueError(
            f"Could not find a DeepMD 'set.000' under: {p}\n"
            "Please point to a directory that contains set.000 (deepmd/npy root),\n"
            "e.g. .../dpmd_npy_test_eq/ or .../test/T1000eq (if it is a symlink to that)."
        )

    return candidates[0].parent  # deepmd/npy root is the parent of set.000


def load_test_dataset(test_path: str | Path) -> dpdata.system.LabeledSystem:
    """
    Load a DeepMD-kit NumPy dataset exported by dpdata (format: 'deepmd/npy').
    This matches how you wrote out datasets via:
        labeled_sys.to("deepmd/npy", "./dpmd_npy_test")
    :contentReference[oaicite:1]{index=1}
    """
    deepmd_root = _resolve_deepmd_npy_root(test_path)
    ds = dpdata.LabeledSystem(str(deepmd_root), fmt="deepmd/npy")
    return ds


def summarize_labeled_system(
    ds: dpdata.system.LabeledSystem,
    frame_interval: int = 10,
    print_head: bool = True,
    max_frames_for_table: int | None = None,
):
    """
    Print dataset statistics + a per-frame table similar to your print_labeled_sys()
    (energy and force-norm min/max/mean/std). :contentReference[oaicite:2]{index=2}
    """
    energies = np.asarray(ds["energies"])
    forces = np.asarray(ds["forces"])  # (nframes, natoms, 3)

    nframes = energies.size
    atom_names = list(ds["atom_names"])
    atom_numbs = list(map(int, ds.get_atom_numbs()))
    natoms = int(np.sum(atom_numbs))

    if print_head:
        comp = {name: num for name, num in zip(atom_names, atom_numbs)}
        print("=== Dataset header ===")
        print(f"Frames: {nframes}")
        print(f"Atom names: {atom_names}")
        print(f"Atom counts: {comp}  (total atoms = {natoms})")
        print()

    # Dataset-level energy stats
    print("=== Dataset-level statistics ===")
    print(f"Energy (eV): min={energies.min():.6f}, max={energies.max():.6f}, mean={energies.mean():.6f}, std={energies.std():.6f}")

    # Dataset-level force-norm stats (compute norms vectorized)
    force_norm = np.linalg.norm(forces, axis=2)  # (nframes, natoms)
    print(
        "Force |F| (eV/Å): "
        f"min={force_norm.min():.6f}, max={force_norm.max():.6f}, "
        f"mean={force_norm.mean():.6f}, std={force_norm.std():.6f}"
    )

    # Optional: cell volume stats, if present
    try:
        cells = np.asarray(ds["cells"])  # (nframes, 3, 3)
        vols = np.abs(np.linalg.det(cells))
        print(f"Cell volume (Å^3): min={vols.min():.6f}, max={vols.max():.6f}, mean={vols.mean():.6f}, std={vols.std():.6f}")
    except Exception:
        pass

    # Identify the single largest force in the dataset
    flat_idx = np.argmax(force_norm)
    frame_idx, atom_idx = np.unravel_index(flat_idx, force_norm.shape)
    print(f"Max |F| occurs at frame={frame_idx}, atom={atom_idx}, |F|={force_norm[frame_idx, atom_idx]:.6f} (eV/Å)")
    print()

    # Per-frame table (similar to your print_labeled_sys table) :contentReference[oaicite:3]{index=3}
    print("=== Per-frame table (sampled) ===")
    print("FrameNo\t\tNRG(eV)\t\tF_Min\t\tF_Max\t\tF_Mean\t\tF_Std")

    shown = 0
    for i in range(0, nframes, max(1, frame_interval)):
        fn = force_norm[i]
        print(
            f"{i:d}\t\t"
            f"{energies[i]:7.4f}\t"
            f"{fn.min():7.4f}\t\t"
            f"{fn.max():7.4f}\t\t"
            f"{fn.mean():7.4f}\t\t"
            f"{fn.std():7.4f}"
        )
        shown += 1
        if max_frames_for_table is not None and shown >= max_frames_for_table:
            break


# -----------------------------
# Example usage (edit this path)
# -----------------------------
# Pick any one of your test folders, e.g.:
# diffusion_fexny_vdw11/01_fe_mp13/test/T1000eq
# diffusion_fexny_vdw11/03_fe_mp150/test/T1500eq
# diffusion_fexny_vdw11/05_fe4n_mp535/test/T1000_Np2
test_dir = "diffusion_fexny_vdw11/01_fe_mp13/test/T1000eq"

ds_test = load_test_dataset(test_dir)
summarize_labeled_system(ds_test, frame_interval=1, print_head=True, max_frames_for_table=30)


=== Dataset header ===
Frames: 40
Atom names: ['Fe']
Atom counts: {'Fe': 250}  (total atoms = 250)

=== Dataset-level statistics ===
Energy (eV): min=-2118.311837, max=-2089.113442, mean=-2103.262478, std=7.455450
Force |F| (eV/Å): min=0.041786, max=5.049533, mean=1.589553, std=0.714884
Cell volume (Å^3): min=2743.445206, max=2743.445206, mean=2743.445206, std=0.000000
Max |F| occurs at frame=11, atom=24, |F|=5.049533 (eV/Å)

=== Per-frame table (sampled) ===
FrameNo		NRG(eV)		F_Min		F_Max		F_Mean		F_Std
0		-2113.1658	 0.0993		 3.2927		 1.3589		 0.6233
1		-2107.2275	 0.2890		 4.0528		 1.5574		 0.6706
2		-2110.9208	 0.2153		 3.1439		 1.3347		 0.5702
3		-2104.5239	 0.2070		 3.9481		 1.5566		 0.7401
4		-2105.9487	 0.1070		 3.9792		 1.4025		 0.5878
5		-2101.1673	 0.1908		 3.4593		 1.5093		 0.6590
6		-2093.7641	 0.1597		 4.7662		 1.7487		 0.8188
7		-2094.0089	 0.3656		 4.2710		 1.7266		 0.7841
8		-2095.6990	 0.2723		 4.0793		 1.6605		 0.7417
9		-2096.9382	 0.1295		 3.9258		 1.6426		 0.7057
