In [1]:
#!/usr/bin/env python3
"""
compare_to_cod_pymatgen_rmsd_all.py

Rewritten validator that guarantees an RMSD value for every generated structure.

- Compares generated CIFs in GEN_FOLDER against COD CIFs in COD_FOLDER
- RMSD is computed preferentially against COD structures with same reduced formula
- Keeps multi-tier matching (exact, near, similar, unique)
- Suppresses spglib noise
- Outputs CSV / JSON / human-readable report
"""

from pathlib import Path
import json
import math
import sys
from tqdm import tqdm
import numpy as np
import pandas as pd
import contextlib, io

from pymatgen.core import Structure, Composition
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

# =========================================================
#                  USER SETTINGS
# =========================================================

GEN_FOLDER = Path("/Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed")
COD_FOLDER = Path("/Users/becca/Documents/All_CIF_Outputs/Project 580/CaOSiTi")

# tolerances
LATTICE_TOL = 0.50
ANGLE_TOL   = 12.0
VOLUME_TOL  = 0.40

#SM_NEAR  = StructureMatcher(ltol=0.10, stol=0.12,  attempt_supercell=False, scale=True)
#SM_EXACT = StructureMatcher(ltol=0.03, stol=0.30, attempt_supercell=False, scale=True)
SM_NEAR  = StructureMatcher(ltol=0.2, stol=0.5, attempt_supercell=True, scale=True)
SM_EXACT = StructureMatcher(ltol=0.05, stol=0.3, attempt_supercell=True, scale=True)


# output files
OUT_CSV        = GEN_FOLDER / "comparison_results.csv"
OUT_MATCHES    = GEN_FOLDER / "matches_onlynew.csv"
OUT_UNIQUE     = GEN_FOLDER / "unique_only.csv"
OUT_JSON       = GEN_FOLDER / "matches.json"
OUT_TXT        = GEN_FOLDER / "human_readable_reportnew.txt"

# =========================================================
#                  HELPER FUNCTIONS
# =========================================================

def integer_ratio(comp: Composition):
    """Return normalized integer composition ratio for matching."""
    el_amt = {}
    for el, v in comp.get_el_amt_dict().items():
        el_amt[el] = int(round(v))

    vals = [v for v in el_amt.values() if v > 0]
    if len(vals) == 0:
        return {}

    def gcd(a,b):
        while b:
            a, b = b, a % b
        return a

    g = abs(vals[0])
    for v in vals[1:]:
        g = gcd(g, abs(v))
    if g == 0:
        g = 1

    return {k: int(v // g) for k, v in el_amt.items() if v > 0}


def crystal_system_quiet(struct: Structure):
    try:
        with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
            sga = SpacegroupAnalyzer(struct, symprec=1e-2)
            return sga.get_crystal_system()
    except Exception:
        return None


def lattice_relative_diff(lat1, lat2):
    """Return (max relative length diff, max angle diff)."""
    a1, b1, c1, al1, be1, ga1 = lat1
    a2, b2, c2, al2, be2, ga2 = lat2

    L1 = np.sort([a1,b1,c1])
    L2 = np.sort([a2,b2,c2])
    rels = np.abs(L1 - L2) / np.maximum(L1, 1e-6)

    angs = np.abs(np.array([al1, be1, ga1]) - np.array([al2, be2, ga2]))
    return float(rels.max()), float(angs.max())


def volume_relative_diff(s1, s2):
    return float(abs(s1.lattice.volume - s2.lattice.volume) /
                 max(s1.lattice.volume, 1e-6))


def try_read_structure(path: Path):
    try:
        return Structure.from_file(str(path))
    except Exception:
        return None


def compute_best_rmsd(struct, cod_structures, sm=None):
    """Compute the smallest RMSD to a list of reference structures."""
    if sm is None:
        sm = StructureMatcher(primitive_cell=False, scale=True)

    best = None
    for ref in cod_structures:
        try:
            rmsd = sm.get_rms_dist(struct, ref)
            if rmsd is None:
                continue
            if isinstance(rmsd, tuple):
                v = float(rmsd[0])
            else:
                v = float(rmsd)
            if best is None or v < best:
                best = v
        except Exception:
            continue
    return best

#get rid of try and accept and put print statements 
# point group/space group

# =========================================================
#                LOAD COD DATABASE
# =========================================================

print("Loading COD structures...")

cod_list = []
cod_by_formula = {}

for f in COD_FOLDER.glob("*.cif"):
    st = try_read_structure(f)
    if st is None:
        continue

    comp = st.composition
    reduced = comp.reduced_formula

    entry = {
        "path": f,
        "structure": st,
        "composition": comp,
        "ratio": integer_ratio(comp),
        "lat": (st.lattice.a, st.lattice.b, st.lattice.c,
                st.lattice.alpha, st.lattice.beta, st.lattice.gamma),
        "crystalsys": crystal_system_quiet(st),
        "volume": st.lattice.volume
    }

    cod_list.append(entry)
    cod_by_formula.setdefault(reduced, []).append(entry)

print(f"Loaded {len(cod_list)} COD structures.")


# =========================================================
#            COMPARE EACH GENERATED STRUCTURE
# =========================================================

# =========================================================
#            COMPARE EACH GENERATED STRUCTURE
# =========================================================

rows = []
matches_json = {}

print("Comparing generated structures...")

for gen_path in tqdm(sorted(GEN_FOLDER.glob("*.cif"))):
    g = try_read_structure(gen_path)

    if g is None:
        rows.append({
            "gen_file": gen_path.name,
            "gen_formula": None,
            "best_cod": None,
            "match_level": "unreadable",
            "best_rmsd": None,
            "best_rmsd_cod": None,
            "lattice_diff_max_rel": None,
            "angle_diff_max_deg": None,
            "volume_diff_rel": None,
            "gen_crystalsys": None,
            "cod_crystalsys": None,
            "candidates": []
        })
        matches_json[gen_path.name] = {"status": "unreadable", "matches": []}
        continue

    # ---------------- BASIC GENERATED INFO ----------------
    gen_comp = g.composition
    gen_formula = gen_comp.reduced_formula
    gen_ratio = integer_ratio(gen_comp)
    gen_lat = (g.lattice.a, g.lattice.b, g.lattice.c,
               g.lattice.alpha, g.lattice.beta, g.lattice.gamma)
    gen_crystalsys = crystal_system_quiet(g)

    best_candidate = None
    best_metrics = None
    best_score = float("inf")
    candidate_list = []

    # ---------------- FAST FILTER OVER COD ----------------
    for cod in cod_list:
        # strict stoichiometry; if you want only same elements, replace with:
        # if set(gen_ratio.keys()) != set(cod["ratio"].keys()): continue
        if gen_ratio != cod["ratio"]:
            continue

        # geometry diffs
        rel_len_diff, ang_diff = lattice_relative_diff(gen_lat, cod["lat"])
        vol_rel = volume_relative_diff(g, cod["structure"])

        if rel_len_diff > (LATTICE_TOL + 0.05):
            continue
        if ang_diff > (ANGLE_TOL + 1.0):
            continue
        if vol_rel > (VOLUME_TOL + 0.1):
            continue

        cod_crys = cod["crystalsys"]
        same_crys = (gen_crystalsys == cod_crys) or (gen_crystalsys is None or cod_crys is None)

        # structurematcher tests
        try:
            near_fit  = SM_NEAR.fit(g, cod["structure"])
            exact_fit = SM_EXACT.fit(g, cod["structure"])
            if near_fit:
                try:
                    rms = SM_NEAR.get_rms_dist(g, cod["structure"])
                except Exception:
                    rms = None
            else:
                rms = None
        except Exception:
            near_fit = False
            exact_fit = False
            rms = None

        if rms is not None:
            rms = float(rms[0] if isinstance(rms, tuple) else rms)

        metrics = {
            "cod_file": cod["path"].name,
            "rel_len_diff": float(rel_len_diff),
            "max_angle_diff": float(ang_diff),
            "vol_rel": float(vol_rel),
            "same_crystalsys": bool(same_crys),
            "near_fit": bool(near_fit),
            "exact_fit": bool(exact_fit),
            "rmsd": rms
        }

        candidate_list.append(metrics)

        # scoring
        if exact_fit:
            score = 0.0 if rms is None else rms
        elif near_fit:
            score = 10.0 if rms is None else 10.0 + rms
        else:
            score = rel_len_diff + vol_rel

        if score < best_score:
            best_score = score
            best_candidate = metrics
            best_metrics = metrics

    # =========================================================
    #  ALWAYS COMPUTE BEST RMSD, EVEN IF NO fast-filter MATCHES
    # =========================================================

    best_rmsd = None
    best_rmsd_cod = None

    cod_same_formula = cod_by_formula.get(gen_formula, [])
    if cod_same_formula:
        cod_structs = [e["structure"] for e in cod_same_formula]
        search_pool = cod_same_formula
    else:
        # fallback: use full COD db
        cod_structs = [e["structure"] for e in cod_list]
        search_pool = cod_list

    best_rmsd = compute_best_rmsd(
        g,
        cod_structs,
        sm=StructureMatcher(
            ltol=0.3,
            stol=0.7,
            angle_tol=10,
            attempt_supercell=True,
            scale=True
        )
    )

    if best_rmsd is not None:
        target = best_rmsd
        for e in search_pool:
            try:
                raw = SM_NEAR.get_rms_dist(g, e["structure"])
                if raw is None:
                    continue
                rv = float(raw[0] if isinstance(raw, tuple) else raw)
                if abs(rv - target) < 1e-9:
                    best_rmsd_cod = e["path"].name
                    break
            except Exception:
                continue

    # =========================================================
    #  GEOMETRY FOR THE RMSD-MATCHED COD STRUCTURE (FALLBACK)
    # =========================================================

    fallback_geom = None
    if best_rmsd_cod is not None:
        for entry in cod_list:
            if entry["path"].name == best_rmsd_cod:
                rel_len, ang = lattice_relative_diff(gen_lat, entry["lat"])
                vol_rel = volume_relative_diff(g, entry["structure"])
                fallback_geom = {
                    "rel_len_diff": rel_len,
                    "max_angle_diff": ang,
                    "vol_rel": vol_rel,
                    "cod_crystalsys": entry["crystalsys"]
                }
                break

    # =========================================================
    # CLASSIFY MATCH LEVEL (exact / near / similar / unique)
    # =========================================================

    match_level = "unique"
    chosen_cod = None
    chosen_crystalsys = None

    if best_candidate:
        chosen_cod = best_candidate["cod_file"]
        chosen_crystalsys = None  # could be filled from cod_list if desired

        if (best_candidate["exact_fit"] or
            (best_candidate["rmsd"] is not None and
             best_candidate["rmsd"] < 0.05 and
             best_candidate["rel_len_diff"] < 0.01 and
             best_candidate["vol_rel"] < 0.01)):
            match_level = "exact"
        elif (best_candidate["near_fit"] or
              (best_candidate["rel_len_diff"] <= LATTICE_TOL and
               best_candidate["max_angle_diff"] <= ANGLE_TOL and
               best_candidate["vol_rel"] <= VOLUME_TOL)):
            match_level = "near"
        else:
            match_level = "similar"

    # =========================================================
    # RECORD RESULTS
    # =========================================================

    rows.append({
        "gen_file": gen_path.name,
        "gen_formula": gen_formula,
        "best_cod": chosen_cod if chosen_cod else best_rmsd_cod,
        "match_level": match_level,
        "best_rmsd": float(best_rmsd) if best_rmsd is not None else None,
        "best_rmsd_cod": best_rmsd_cod,

        "lattice_diff_max_rel": (
            best_metrics["rel_len_diff"] if best_metrics else
            (fallback_geom["rel_len_diff"] if fallback_geom else None)
        ),

        "angle_diff_max_deg": (
            best_metrics["max_angle_diff"] if best_metrics else
            (fallback_geom["max_angle_diff"] if fallback_geom else None)
        ),

        "volume_diff_rel": (
            best_metrics["vol_rel"] if best_metrics else
            (fallback_geom["vol_rel"] if fallback_geom else None)
        ),

        "gen_crystalsys": gen_crystalsys,

        "cod_crystalsys": (
            chosen_crystalsys if chosen_crystalsys else
            (fallback_geom["cod_crystalsys"] if fallback_geom else None)
        ),

        "candidates": candidate_list
    })

    matches_json[gen_path.name] = {
        "gen_formula": gen_formula,
        "gen_crystalsys": gen_crystalsys,
        "candidates": candidate_list,
        "best": best_candidate,
        "best_rmsd": float(best_rmsd) if best_rmsd is not None else None,
        "best_rmsd_cod": best_rmsd_cod,
        "match_level": match_level
    }



# =========================================================
#                   WRITE OUTPUTS
# =========================================================

df = pd.DataFrame(rows)
df.to_csv(OUT_CSV, index=False)

df[df["match_level"].isin(["exact", "near", "similar"])].to_csv(OUT_MATCHES, index=False)
df[df["match_level"] == "unique"].to_csv(OUT_UNIQUE, index=False)

with open(OUT_JSON, "w") as f:
    json.dump(matches_json, f, indent=2)

with open(OUT_TXT, "w") as fh:
    fh.write("COD comparison report\n")
    fh.write("=====================\n\n")
    fh.write(f"Generated folder: {GEN_FOLDER}\n")
    fh.write(f"COD folder: {COD_FOLDER}\n\n")
    fh.write(f"Parameters: LATTICE_TOL={LATTICE_TOL}, ANGLE_TOL={ANGLE_TOL}, VOLUME_TOL={VOLUME_TOL}\n\n")

    counts = df["match_level"].value_counts(dropna=False)
    fh.write("Summary counts:\n")
    for k, v in counts.items():
        fh.write(f"  {k}: {v}\n")

    fh.write("\nTop matches (first 20):\n")
    for _, r in df.head(20).iterrows():
        fh.write(f" - {r.gen_file} -> {r.best_cod} (level={r.match_level}, rmsd={r.best_rmsd})\n")

print("\nDone.")
print(f"Wrote CSV:  {OUT_CSV}")
print(f"Wrote CSV:  {OUT_MATCHES}")
print(f"Wrote CSV:  {OUT_UNIQUE}")
print(f"Wrote JSON: {OUT_JSON}")
print(f"Wrote TXT:  {OUT_TXT}")


Loading COD structures...
Loaded 81 COD structures.
Comparing generated structures...


  struct = parser.parse_structures(primitive=primitive)[0]
  struct = parser.parse_structures(primitive=primitive)[0]
100%|██████████| 600/600 [02:33<00:00,  3.90it/s]


Done.
Wrote CSV:  /Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed/comparison_results.csv
Wrote CSV:  /Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed/matches_onlynew.csv
Wrote CSV:  /Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed/unique_only.csv
Wrote JSON: /Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed/matches.json
Wrote TXT:  /Users/becca/Documents/All_CIF_Outputs/Project 580/content/mattersim_results_relaxed/human_readable_reportnew.txt



