In [6]:
import os
import sys
import pandas as pd
from mp_api.client import MPRester
from pymatgen.core import Structure


def enrich_dataframe(df: pd.DataFrame, mp_api_key: str | None = None, preview_rows: int = 3) -> pd.DataFrame:
    """
    Enrich df with MP-grounded symmetry/global metadata + per-site oxidation states.
    Prints a small preview of the enriched dataframe (in-memory). Does not write files.

    Expected df column: 'material_id' (e.g., 'mp-1519998')
    """
    mp_api_key = 'RmYSluuvUC3TJADjNCdvrN1AyifPkfog'
    if not mp_api_key:
        raise ValueError("MP_API_KEY not set. Export MP_API_KEY or pass mp_api_key=...")

    if "material_id" not in df.columns:
        raise ValueError("Input df must contain a 'material_id' column.")

    enriched_rows: list[dict] = []

    # helpful display for printing
    pd.set_option("display.max_columns", 200)
    pd.set_option("display.width", 200)

    with MPRester(mp_api_key) as mpr:
        for i, row in df.iterrows():
            mp_id = str(row["material_id"])

            try:
                docs = mpr.materials.summary.search(
                    material_ids=[mp_id],
                    fields=["structure", "symmetry", "nsites", "volume", "density"],
                )
                if not docs:
                    print(f"[warn] No summary doc for {mp_id} (row={i}). Skipping enrichment for this row.")
                    enriched_rows.append(dict(row))
                    continue

                doc = docs[0]
                structure: Structure = doc.structure

                # Oxidation states: prefer MP endpoint, fallback to pymatgen guess
                ox_structure = None
                try:
                    ox_docs = mpr.materials.oxidation_states.search(material_ids=[mp_id])
                    if ox_docs and getattr(ox_docs[0], "structure", None) is not None:
                        ox_structure = ox_docs[0].structure
                except Exception as e:
                    # don't fail row—just fallback
                    ox_structure = None

                if ox_structure is None:
                    ox_structure = structure.copy()
                    try:
                        ox_structure.add_oxidation_state_by_guess()
                    except Exception:
                        # keep undecorated if guess fails
                        ox_structure = structure

                ox_states = [getattr(site.specie, "oxi_state", None) for site in ox_structure.sites]

                enriched_row = dict(row)

                # GLOBAL FEATURES (guard against None)
                sym = getattr(doc, "symmetry", None)
                enriched_row["spacegroup_number"] = getattr(sym, "number", None)
                enriched_row["spacegroup_symbol"] = getattr(sym, "symbol", None)
                enriched_row["crystal_system"] = getattr(sym, "crystal_system", None)

                nsites = getattr(doc, "nsites", None)
                vol = getattr(doc, "volume", None)

                enriched_row["nsites_mp"] = nsites
                enriched_row["volume_mp"] = vol
                enriched_row["density_mp"] = getattr(doc, "density", None)
                enriched_row["volume_per_atom"] = (vol / nsites) if (vol is not None and nsites) else None

                # NODE FEATURES
                enriched_row["site_oxidation_states"] = ox_states
                enriched_row["n_ox_states"] = len(ox_states)  # quick sanity column

                enriched_rows.append(enriched_row)

                # lightweight progress prints so you know it’s doing something
                if (len(enriched_rows) <= preview_rows) or ((len(enriched_rows) % 50) == 0):
                    print(
                        f"[ok] {mp_id} sg={enriched_row['spacegroup_symbol']} "
                        f"#{enriched_row['spacegroup_number']} nsites={enriched_row['nsites_mp']}"
                    )
                    sys.stdout.flush()

            except Exception as e:
                print(f"[error] Failed on {mp_id} (row={i}): {repr(e)}")
                # keep original row so output df stays aligned with input
                enriched_rows.append(dict(row))
                sys.stdout.flush()

    df_enriched = pd.DataFrame(enriched_rows)

    # Print proof of enrichment (new columns + preview)
    new_cols = [c for c in df_enriched.columns if c not in df.columns]
    print("\n=== Enrichment complete ===")
    print("New columns:", new_cols)
    print("\n=== Enriched preview ===")
    cols_to_show = ["material_id"] + [c for c in new_cols if c in df_enriched.columns]
    print(df_enriched[cols_to_show].head(preview_rows).to_string(index=False))
    sys.stdout.flush()

    return df_enriched


In [7]:
MP_API_KEY = "RmYSluuvUC3TJADjNCdvrN1AyifPkfog"

df = pd.read_csv("./data/mp/test.csv")  # must have 'material_id'
df_enriched = enrich_dataframe(df.head(5))  # test small first


Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving OxidationStateDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

[ok] mp-1519998 sg=F-43m #216 nsites=10


Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving OxidationStateDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

[ok] mp-1038035 sg=P4/mmm #123 nsites=64


Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving OxidationStateDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

[ok] mp-1206636 sg=Pm-3m #221 nsites=4


Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving OxidationStateDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving OxidationStateDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]


=== Enrichment complete ===
New columns: ['spacegroup_number', 'spacegroup_symbol', 'crystal_system', 'nsites_mp', 'volume_mp', 'density_mp', 'volume_per_atom', 'site_oxidation_states', 'n_ox_states']

=== Enriched preview ===
material_id  spacegroup_number spacegroup_symbol crystal_system  nsites_mp  volume_mp  density_mp  volume_per_atom                                                                                                                                                                                                                                                                                                                                                                            site_oxidation_states  n_ox_states
 mp-1519998                216             F-43m          Cubic         10 128.624052    5.471870        12.862405                                                                                                                                                 

In [1]:
import os
import sys
import ast
import numpy as np
import pandas as pd

from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.analysis.local_env import CrystalNN

# Optional: MP API (will be used only if installed + key present)
try:
    from mp_api.client import MPRester
    HAS_MPAPI = True
except Exception:
    HAS_MPAPI = False

# Optional: JARVIS tools (graceful)
try:
    from jarvis.core.atoms import Atoms
    from jarvis.analysis.structure.spacegroup import Spacegroup3D
    HAS_JARVIS = True
except Exception:
    HAS_JARVIS = False


def _parse_structure(obj) -> Structure:
    """
    Accepts:
      - pymatgen Structure
      - dict (pymatgen as_dict)
      - stringified dict
    Returns:
      - pymatgen Structure
    """
    if isinstance(obj, Structure):
        return obj
    if isinstance(obj, dict):
        return Structure.from_dict(obj)
    if isinstance(obj, str):
        s = obj.strip()
        # Handle strings that look like dicts (your example)
        try:
            d = ast.literal_eval(s)
            if isinstance(d, dict):
                return Structure.from_dict(d)
        except Exception:
            pass
    raise ValueError("Could not parse structure into pymatgen Structure.")


def _mp_fetch_summary_and_ox(mp_id: str, mp_api_key: str | None):
    """
    Returns (doc, ox_structure_or_none). If MP not available, returns (None, None).
    """
    if not (HAS_MPAPI and mp_api_key):
        return None, None

    with MPRester(mp_api_key) as mpr:
        # Summary metadata (symmetry etc.)
        docs = mpr.materials.summary.search(
            material_ids=[mp_id],
            fields=["structure", "symmetry", "nsites", "volume", "density"]
        )
        doc = docs[0] if docs else None

        # Oxidation states endpoint (best-effort)
        ox_struct = None
        try:
            ox_docs = mpr.materials.oxidation_states.search(material_ids=[mp_id])
            if ox_docs and getattr(ox_docs[0], "structure", None) is not None:
                ox_struct = ox_docs[0].structure
        except Exception:
            ox_struct = None

        return doc, ox_struct


def _coordination_and_bonds(struct: Structure, cutoff: float = 6.0):
    """
    Uses CrystalNN for coordination. Also computes neighbor bond length stats
    using a simple cutoff-based neighbor search (robust + fast).
    """
    cnn = CrystalNN()

    # Coordination numbers per site
    cn_list = []
    for i in range(len(struct)):
        try:
            nn = cnn.get_nn_info(struct, i)
            cn_list.append(len(nn))
        except Exception:
            cn_list.append(None)

    # Bond length stats from a cutoff neighbor list
    # get_neighbors returns neighbors within radius for each site
    all_dists = []
    edge_count = 0
    for i, site in enumerate(struct.sites):
        neigh = struct.get_neighbors(site, r=cutoff)
        edge_count += len(neigh)
        for n in neigh:
            all_dists.append(float(n.nn_distance))

    # edge_count is directed; undirected edges ~ edge_count/2 (approx)
    all_dists = np.array(all_dists, dtype=float) if len(all_dists) else np.array([], dtype=float)

    bond_stats = {
        "nn_cutoff": cutoff,
        "edge_count_dir": int(edge_count),
        "edge_count_undir_approx": float(edge_count) / 2.0,
        "bond_len_mean": float(np.mean(all_dists)) if len(all_dists) else None,
        "bond_len_std": float(np.std(all_dists)) if len(all_dists) else None,
        "bond_len_min": float(np.min(all_dists)) if len(all_dists) else None,
        "bond_len_max": float(np.max(all_dists)) if len(all_dists) else None,
    }

    # CN stats (ignore None)
    cn_vals = np.array([c for c in cn_list if c is not None], dtype=float) if any(c is not None for c in cn_list) else np.array([], dtype=float)
    cn_stats = {
        "site_coord_nums": cn_list,  # keep per-site list for node features
        "cn_mean": float(np.mean(cn_vals)) if len(cn_vals) else None,
        "cn_std": float(np.std(cn_vals)) if len(cn_vals) else None,
        "cn_min": float(np.min(cn_vals)) if len(cn_vals) else None,
        "cn_max": float(np.max(cn_vals)) if len(cn_vals) else None,
    }

    return cn_stats, bond_stats


def _symmetry_from_pymatgen(struct: Structure):
    sga = SpacegroupAnalyzer(struct, symprec=1e-3)
    return {
        "pmg_spacegroup_symbol": sga.get_space_group_symbol(),
        "pmg_spacegroup_number": sga.get_space_group_number(),
        "pmg_crystal_system": sga.get_crystal_system(),
        "pmg_point_group": sga.get_point_group_symbol(),
    }


def _oxidation_states(struct: Structure, mp_ox_struct: Structure | None):
    """
    Prefer MP oxidation structure if provided, else guess via pymatgen.
    Returns:
      - ox_states list aligned with sites
      - flag describing source
    """
    if mp_ox_struct is not None:
        ox_list = [getattr(s.specie, "oxi_state", None) for s in mp_ox_struct.sites]
        if any(v is not None for v in ox_list):
            return ox_list, "mp_api"

    # fallback guess
    s2 = struct.copy()
    try:
        s2.add_oxidation_state_by_guess()
        ox_list = [getattr(s.specie, "oxi_state", None) for s in s2.sites]
        return ox_list, "pymatgen_guess"
    except Exception:
        return [None] * len(struct), "none"


def _composition_features(struct: Structure):
    comp = struct.composition.fractional_composition
    # Light, stable composition descriptors (no target leakage)
    el_amt = comp.get_el_amt_dict()
    n_elements = len(el_amt)
    # fractions sum to 1
    fracs = np.array(list(el_amt.values()), dtype=float)
    return {
        "n_elements": int(n_elements),
        "elem_frac_dict": el_amt,  # keep as dict for downstream featurization
        "elem_frac_max": float(fracs.max()) if len(fracs) else None,
        "elem_frac_entropy": float(-(fracs * np.log(fracs + 1e-12)).sum()) if len(fracs) else None,
    }


def _jarvis_features_from_structure(struct: Structure):
    """
    Optional: JARVIS-derived checks (only if jarvis-tools installed).
    No property values; only symmetry and basic geometry.
    """
    if not HAS_JARVIS:
        return {"jarvis_available": False}

    try:
        atoms = Atoms.from_pymatgen(struct)
        spg = Spacegroup3D(atoms)
        return {
            "jarvis_available": True,
            "jarvis_spacegroup_symbol": spg.space_group_symbol,
            "jarvis_spacegroup_number": int(spg.space_group_number),
            "jarvis_crystal_system": spg.crystal_system,
            "jarvis_volume": float(atoms.volume),
            "jarvis_density": float(atoms.density),
        }
    except Exception:
        return {"jarvis_available": True, "jarvis_error": True}


def enrich_dataframe(df: pd.DataFrame, mp_api_key: str | None = None, nn_cutoff: float = 6.0, preview_rows: int = 3) -> pd.DataFrame:
    """
    Enrich df with:
      - MP (optional): summary symmetry + oxidation endpoint
      - Pymatgen-derived symmetry
      - Oxidation states (MP if possible, else guess)
      - Coordination numbers (CrystalNN)
      - Bond-length stats + simple graph complexity
      - Lattice/volume/density + composition descriptors
      - Optional JARVIS structural symmetry (if jarvis-tools installed)
    """
    if "material_id" not in df.columns:
        raise ValueError("df must contain 'material_id' column.")
    if "structure" not in df.columns:
        raise ValueError("df must contain 'structure' column (dict/str/pymatgen Structure).")

    mp_api_key = mp_api_key or os.getenv("MP_API_KEY", None)

    out_rows = []

    pd.set_option("display.max_columns", 250)
    pd.set_option("display.width", 220)

    for idx, row in df.iterrows():
        mp_id = str(row["material_id"])
        enriched = dict(row)

        # Parse structure from dict/string
        try:
            struct = _parse_structure(row["structure"])
        except Exception as e:
            print(f"[error] row={idx} {mp_id} structure parse failed: {e}")
            out_rows.append(enriched)
            continue

        # Optionally fetch MP doc + MP oxidation structure
        mp_doc, mp_ox_struct = _mp_fetch_summary_and_ox(mp_id, mp_api_key)

        # --- Symmetry: (A) MP if available + (B) Pymatgen always
        if mp_doc is not None and getattr(mp_doc, "symmetry", None) is not None:
            sym = mp_doc.symmetry
            enriched["mp_spacegroup_symbol"] = getattr(sym, "symbol", None)
            enriched["mp_spacegroup_number"] = getattr(sym, "number", None)
            enriched["mp_crystal_system"] = getattr(sym, "crystal_system", None)
        else:
            enriched["mp_spacegroup_symbol"] = None
            enriched["mp_spacegroup_number"] = None
            enriched["mp_crystal_system"] = None

        enriched.update(_symmetry_from_pymatgen(struct))

        # --- Oxidation states
        ox_list, ox_src = _oxidation_states(struct, mp_ox_struct)
        enriched["site_oxidation_states"] = ox_list
        enriched["oxidation_state_source"] = ox_src
        enriched["n_ox_known"] = int(sum(v is not None for v in ox_list))

        # --- Lattice / geometry (from structure; reliable)
        lat = struct.lattice
        enriched.update({
            "a": float(lat.a), "b": float(lat.b), "c": float(lat.c),
            "alpha": float(lat.alpha), "beta": float(lat.beta), "gamma": float(lat.gamma),
            "pmg_volume": float(struct.volume),
            "pmg_nsites": int(len(struct)),
            "pmg_volume_per_atom": float(struct.volume) / float(len(struct)),
            "pmg_density": float(struct.density),
        })

        # --- Coordination + bonds
        cn_stats, bond_stats = _coordination_and_bonds(struct, cutoff=nn_cutoff)
        enriched.update(cn_stats)
        enriched.update(bond_stats)

        # --- Composition descriptors
        enriched.update(_composition_features(struct))

        # --- Optional JARVIS structural features (no property values)
        enriched.update(_jarvis_features_from_structure(struct))

        out_rows.append(enriched)

        if (len(out_rows) <= preview_rows) or (len(out_rows) % 50 == 0):
            print(f"[ok] {mp_id} ox_src={ox_src} sg(pmg)={enriched['pmg_spacegroup_symbol']} cn_mean={enriched['cn_mean']}")
            sys.stdout.flush()

    df_enriched = pd.DataFrame(out_rows)

    # Print proof of enrichment
    new_cols = [c for c in df_enriched.columns if c not in df.columns]
    print("\n=== Enrichment complete ===")
    print(f"Added {len(new_cols)} columns.")
    print("New columns (sample):", new_cols[:25], "..." if len(new_cols) > 25 else "")
    cols_show = ["material_id"] + [c for c in ["oxidation_state_source", "n_ox_known",
                                              "pmg_spacegroup_symbol", "pmg_spacegroup_number",
                                              "cn_mean", "bond_len_mean", "pmg_density",
                                              "jarvis_available"] if c in df_enriched.columns]
    print("\n=== Preview ===")
    print(df_enriched[cols_show].head(preview_rows).to_string(index=False))
    sys.stdout.flush()

    return df_enriched


In [3]:
MP_API_KEY = "RmYSluuvUC3TJADjNCdvrN1AyifPkfog"

df = pd.read_csv("./data/mp/test.csv")  # must have 'material_id'
df_enriched = enrich_dataframe(df.head(5))  # test small first


[ok] mp-1519998 ox_src=pymatgen_guess sg(pmg)=F-43m cn_mean=4.8


  r1 = _get_radius(structure[n])
  r2 = _get_radius(entry["site"])
  nn_data = self.get_nn_data(structure, n)


[ok] mp-1038035 ox_src=pymatgen_guess sg(pmg)=P4/mmm cn_mean=6.0
[ok] mp-1206636 ox_src=pymatgen_guess sg(pmg)=P4/mmm cn_mean=6.0

=== Enrichment complete ===
Added 37 columns.
New columns (sample): ['mp_spacegroup_symbol', 'mp_spacegroup_number', 'mp_crystal_system', 'pmg_spacegroup_symbol', 'pmg_spacegroup_number', 'pmg_crystal_system', 'pmg_point_group', 'site_oxidation_states', 'oxidation_state_source', 'n_ox_known', 'a', 'b', 'c', 'alpha', 'beta', 'gamma', 'pmg_volume', 'pmg_nsites', 'pmg_volume_per_atom', 'pmg_density', 'site_coord_nums', 'cn_mean', 'cn_std', 'cn_min', 'cn_max'] ...

=== Preview ===
material_id oxidation_state_source  n_ox_known pmg_spacegroup_symbol  pmg_spacegroup_number  cn_mean  bond_len_mean  pmg_density  jarvis_available
 mp-1519998         pymatgen_guess          10                 F-43m                    216      4.8       4.392817     5.471870             False
 mp-1038035         pymatgen_guess          64                P4/mmm                    123  

In [1]:
#!/usr/bin/env python3
"""
Compute a strong, practical feature set for materials property prediction
from CIF strings using pymatgen + matminer, then append to a CSV.

Features included (high-signal, commonly used):
  - Composition: stoichiometry, element fractions, element property stats (Magpie),
                 valence orbital fractions, ion-property (optional)
  - Structure: density, volume/atom, packing fraction (if available),
               symmetry (space group, point group, crystal system)
  - Local environment: CrystalNN coordination number stats (mean/std/min/max)

Input CSV must contain a CIF string column (default: 'cif_structure').

Usage:
  python featurize_cif.py --input ./data/mp/test.csv --output ./data/mp/test_featurized.csv
  python featurize_cif.py --input ./data/mp/test.csv --cif-col cif_structure --id-col material_id --print-new-only

Dependencies:
  pip install pymatgen matminer pandas tqdm
"""

import argparse
import warnings
from typing import Any, Dict, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

from pymatgen.core import Structure, Composition
from pymatgen.io.cif import CifParser
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from matminer.featurizers.base import MultipleFeaturizer
from matminer.featurizers.composition import (
    Stoichiometry,
    ElementFraction,
    ElementProperty,
    ValenceOrbital,
    IonProperty,
)
from matminer.featurizers.structure import DensityFeatures, GlobalSymmetryFeatures
from matminer.featurizers.site import CoordinationNumber
from matminer.featurizers.structure import SiteStatsFingerprint


# -------------------------
# Parsing helpers
# -------------------------
def structure_from_cif_string(cif_str: str) -> Optional[Structure]:
    """Parse a pymatgen Structure from a CIF string. Returns None if parsing fails."""
    if not isinstance(cif_str, str) or not cif_str.strip():
        return None
    try:
        # CifParser.from_string works on recent pymatgen; fallback: CifParser(io.StringIO(...))
        parser = CifParser.from_string(cif_str)
        structs = parser.get_structures(primitive=False)
        return structs[0] if structs else None
    except Exception:
        return None


def try_add_oxidation_states(struct: Structure) -> Tuple[Optional[Structure], str, int]:
    """
    Attempt oxidation state decoration. Returns:
      (decorated_structure_or_None, source_label, n_ox_known)
    """
    if struct is None:
        return None, "none", 0

    # Strategy: use pymatgen's internal guessing (can fail).
    # Keep it conservative: if it fails, we fall back to None and mark source.
    try:
        s = struct.copy()
        s.add_oxidation_state_by_guess()
        # Count how many sites have oxidation states assigned (should be all if successful)
        n_ox = 0
        for site in s.sites:
            sp = site.specie
            # Specie has oxidation_state attribute when decorated
            if hasattr(sp, "oxi_state"):
                n_ox += 1
        return s, "pymatgen_guess", n_ox
    except Exception:
        return None, "failed", 0


# -------------------------
# Featurizers
# -------------------------
def build_composition_featurizer() -> MultipleFeaturizer:
    """
    Composition featurizer set.
    - Magpie element property statistics are very strong baselines.
    """
    return MultipleFeaturizer(
        [
            Stoichiometry(),  # e.g., num atoms, entropy-like measures
            ElementFraction(),  # fractional amounts of elements (sparse-ish but powerful)
            ElementProperty.from_preset("magpie"),  # mean/min/max/range/std for many elem props
            ValenceOrbital(props=["frac"]),  # valence orbital fractions (s,p,d,f)
            # IonProperty can be helpful but requires oxidation states; we compute it conditionally
            # via a separate pass using decorated composition.
        ]
    )


def build_structure_featurizer() -> MultipleFeaturizer:
    """
    Structure featurizer set:
    - Density/volume/packing
    - Symmetry features
    """
    return MultipleFeaturizer(
        [
            DensityFeatures(),  # density, vpa, packing fraction (if can be computed)
            GlobalSymmetryFeatures(),  # spacegroup, crystal system, point group, etc.
        ]
    )


def build_local_env_featurizer() -> MultipleFeaturizer:
    """
    Local environment:
    - CoordinationNumber via CrystalNN then aggregate across sites with SiteStatsFingerprint
    """
    cn = CoordinationNumber(nn="CrystalNN")
    return MultipleFeaturizer(
        [
            SiteStatsFingerprint(cn, stats=("mean", "std_dev", "minimum", "maximum")),
        ]
    )


def safe_featurize(featurizer: MultipleFeaturizer, obj: Any) -> Dict[str, Any]:
    """
    Robustly featurize an object and return a {feature_name: value} dict.
    If featurization fails, returns NaNs for those features.
    """
    out: Dict[str, Any] = {}
    try:
        labels = featurizer.feature_labels()
        vals = featurizer.featurize(obj)
        out = {k: v for k, v in zip(labels, vals)}
        return out
    except Exception:
        # Fill with NaN for all expected labels
        try:
            labels = featurizer.feature_labels()
            return {k: np.nan for k in labels}
        except Exception:
            return out


def ion_property_features_from_oxidized_composition(comp: Composition) -> Dict[str, Any]:
    """
    IonProperty features require oxidation states in the Composition.
    We'll compute them if possible.
    """
    ip = IonProperty()
    try:
        labels = ip.feature_labels()
        vals = ip.featurize(comp)
        return {k: v for k, v in zip(labels, vals)}
    except Exception:
        try:
            return {k: np.nan for k in ip.feature_labels()}
        except Exception:
            return {}


# -------------------------
# Main pipeline
# -------------------------
def featurize_dataframe(
    df: pd.DataFrame,
    cif_col: str = "cif_structure",
    id_col: Optional[str] = None,
    verbose: bool = True,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns:
      - full_df: original df with appended features
      - new_features_df: only the newly created feature columns (plus id_col if provided)
    """
    comp_feat = build_composition_featurizer()
    struct_feat = build_structure_featurizer()
    local_feat = build_local_env_featurizer()

    records: list[Dict[str, Any]] = []

    it = tqdm(df.itertuples(index=False), total=len(df), disable=not verbose)
    for row in it:
        row_dict = row._asdict()
        cif_str = row_dict.get(cif_col, None)

        rec: Dict[str, Any] = {}
        if id_col and id_col in row_dict:
            rec[id_col] = row_dict[id_col]

        s = structure_from_cif_string(cif_str)

        # Basic parse flags
        rec["pmg_parsed_ok"] = bool(s is not None)
        if s is None:
            # Fill all features as NaN for consistency
            rec.update({k: np.nan for k in comp_feat.feature_labels()})
            rec.update({k: np.nan for k in struct_feat.feature_labels()})
            rec.update({k: np.nan for k in local_feat.feature_labels()})
            # Oxidation state meta
            rec["oxidation_state_source"] = "none"
            rec["n_ox_known_sites"] = 0
            # IonProperty placeholders (unknown labels at runtime)
            rec.update(ion_property_features_from_oxidized_composition(Composition("H")))  # to get keys
            # Set placeholders to NaN (override)
            for k in list(rec.keys()):
                if k.startswith("IonProperty"):
                    rec[k] = np.nan
            records.append(rec)
            continue

        # Composition features
        comp = s.composition
        rec.update(safe_featurize(comp_feat, comp))

        # Structure features
        # Some symmetry features can throw if lattice is weird; handle safely
        rec.update(safe_featurize(struct_feat, s))

        # Local env features
        rec.update(safe_featurize(local_feat, s))

        # Oxidation states + ion-property features (conditional)
        s_ox, ox_src, n_ox = try_add_oxidation_states(s)
        rec["oxidation_state_source"] = ox_src
        rec["n_ox_known_sites"] = int(n_ox)

        # IonProperty features require oxidized composition; if oxidation failed, NaN them
        if s_ox is not None and ox_src != "failed":
            rec.update(ion_property_features_from_oxidized_composition(s_ox.composition))
        else:
            # Create keys once, then fill with NaN
            tmp = ion_property_features_from_oxidized_composition(Composition("Fe2 O3"))
            for k in tmp.keys():
                rec[k] = np.nan

        # Extra hand-crafted features (often useful)
        try:
            rec["volume_per_atom"] = float(s.volume) / float(len(s))
        except Exception:
            rec["volume_per_atom"] = np.nan

        # Space group number/symbol (explicit, very common)
        try:
            sga = SpacegroupAnalyzer(s, symprec=1e-3)
            rec["pmg_spacegroup_number"] = int(sga.get_space_group_number())
            rec["pmg_spacegroup_symbol"] = str(sga.get_space_group_symbol())
        except Exception:
            rec["pmg_spacegroup_number"] = np.nan
            rec["pmg_spacegroup_symbol"] = np.nan

        records.append(rec)

    feats = pd.DataFrame(records)

    # Split out "new features" columns (everything except original)
    if id_col and id_col in df.columns:
        base_cols = [id_col]
    else:
        base_cols = []

    new_feature_cols = [c for c in feats.columns if c not in base_cols]
    new_features_df = feats[base_cols + new_feature_cols].copy()

    # Merge back with original df
    if id_col and id_col in df.columns and id_col in feats.columns:
        full_df = df.merge(feats, on=id_col, how="left", suffixes=("", ""))
    else:
        # row-aligned concat
        full_df = pd.concat([df.reset_index(drop=True), feats.reset_index(drop=True)], axis=1)

    return full_df, new_features_df


def main():
    warnings.filterwarnings("ignore")  # pymatgen/matminer can be chatty

    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Path to input CSV")
    ap.add_argument("--output", default=None, help="Path to output CSV (optional)")
    ap.add_argument("--cif-col", default="cif_structure", help="Column name containing CIF string")
    ap.add_argument("--id-col", default=None, help="Optional ID column (e.g., material_id)")
    ap.add_argument("--print-new-only", action="store_true", help="Print only the new feature columns")
    ap.add_argument("--no-progress", action="store_true", help="Disable progress bar")
    args = ap.parse_args()

    df = pd.read_csv(args.input)
    if args.cif_col not in df.columns:
        raise ValueError(f"Missing CIF column '{args.cif_col}'. Available columns: {list(df.columns)}")

    full_df, new_features_df = featurize_dataframe(
        df,
        cif_col=args.cif_col,
        id_col=args.id_col,
        verbose=not args.no_progress,
    )

    # Print
    pd.set_option("display.max_columns", 200)
    pd.set_option("display.width", 220)
    print("\n=== New Features Preview ===")
    print(new_features_df.head(10) if args.print_new_only else full_df.head(10))

    # Save if requested
    if args.output:
        full_df.to_csv(args.output, index=False)
        print(f"\nSaved featurized CSV to: {args.output}")

    # Basic quality report
    print("\n=== Feature Coverage Report ===")
    parsed_ok = full_df["pmg_parsed_ok"].mean() if "pmg_parsed_ok" in full_df.columns else np.nan
    ox_ok = (full_df["oxidation_state_source"] == "pymatgen_guess").mean() if "oxidation_state_source" in full_df.columns else np.nan
    print(f"Parsed CIF OK: {parsed_ok:.3f}")
    print(f"Oxidation states guessed OK: {ox_ok:.3f}")


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] --input INPUT [--output OUTPUT]
                             [--cif-col CIF_COL] [--id-col ID_COL]
                             [--print-new-only] [--no-progress]
ipykernel_launcher.py: error: the following arguments are required: --input


SystemExit: 2