# Target Transformation

Test reference mineral spectra against the PCA model to determine which phases
are consistent with the dataset. References that reconstruct well from the PCA
components are likely present in the sample; poor reconstruction indicates that
phase is not represented.

**Requires:** Run `01_pca_clustering.ipynb` first to generate PCA results.

**Inputs:**
- `flattened-spectra/*.csv` — sample spectra
- `FeK-standards/fluorescence/flattened/*.csv` — reference mineral spectra

**Outputs:** Ranked table of reference fit quality (χ², R-factor)

## Imports and configuration

In [1]:
"""
XAS Spectra PCA + Clustering Pipeline
======================================
Reads pre-normalized XANES/EXAFS spectra, performs PCA to identify
the number of distinct spectral components, then clusters spectra
in PC-score space to group similar grains.

Assumes spectra are already normalized/flattened (e.g., via Athena or Larch).
Uses the flattened (post-edge-corrected) XANES by default to avoid
post-edge slope artifacts in PCA.

Requirements:
    pip install xraylarch scikit-learn scipy matplotlib numpy pandas

Usage:
    1. Update the CONFIGURATION section below with your paths and parameters.
    2. Run: python xas_pca_clustering.py
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import silhouette_score
from pathlib import Path
import pandas as pd

# Larch imports
from larch import Group
from larch.math import pca_train

In [2]:
# ============================================================
# CONFIGURATION
# ============================================================

# Directory containing your normalized spectra files
SPECTRA_DIR = Path("./flattened-spectra")

# Energy range to use for PCA (eV, relative to E0 or absolute)
# Set to None to use the full overlapping range
E_MIN = 7100  # e.g., -20 (relative to E0) or 7100 (absolute)
E_MAX = 7180  # e.g., 80 (relative to E0) or 7200 (absolute)
ENERGY_IS_RELATIVE = False  # True if E_MIN/E_MAX are relative to E0

# PCA region: 'xanes' uses normalized mu(E), 'exafs' uses chi(k)
PCA_REGION = "xanes"

# Common energy/k grid spacing for interpolation
E_STEP = 0.2  # eV step for XANES

# Max number of clusters to evaluate
MAX_CLUSTERS = 15

# Set to None to use silhouette-optimal k, or an integer to override
FORCE_K = 5

# Number of PCA components to use for clustering and target transformation.
# Set to None to use the IND minimum (automatic), or override with an integer
# if IND gives an unreasonable result (common with noisy microprobe data).
N_COMPONENTS = 5

# Reference spectra for target transformation (optional)
# List of file paths to reference spectra
REFERENCE_DIR = Path("./FeK-standards/fluorescence/flattened")
REFERENCE_FILES = ["2L-Fhy on sand.csv",
                    "2L-Fhy.csv",
                    "6L-Fhy.csv",
                    "Augite.csv",
                    "Biotite.csv",
                    "FeS.csv",
                    "Ferrosmectite.csv",
                    "Goethite on sand.csv",
                    "Goethite.csv",
                    "Green Rust - Carbonate.csv",
                    "Green Rust - Chloride.csv",
                    "Green Rust - Sulfate.csv",
                    "Hematite on sand.csv",
                    "Hematite.csv",
                    "Hornblende.csv",
                    "Ilmenite.csv",
                    "Jarosite.csv",
                    "Lepidocrocite.csv",
                    "Mackinawite (aged).csv",
                    "Mackinawite.csv",
                    "Maghemite.csv",
                    "Nontronite.csv",
                    "Pyrite.csv",
                    "Pyrrhotite.csv",
                    "Schwertmannite.csv",
                    "Siderite-n.csv",
                    "Siderite-s.csv",
                    "Vivianite.csv"] 
                    
REFERENCE_PATHS = [REFERENCE_DIR / f for f in REFERENCE_FILES]
# Output directory

# Bulk directory
BULK_DIR = Path("./bulk")
BULK_PATTERN = '*.csv'

OUTPUT_DIR = Path("./pca_results")

## Load and prepare data

Re-run the data loading and PCA steps to get the PCA model in memory.

In [3]:
def load_ascii_spectra(spectra_dir, pattern="*.csv"):
    """
    Load pre-normalized/flattened spectra from individual CSV files.
    Expected format: comment lines starting with #, then two columns
    (energy, flat) comma-separated.
    """
    groups = []
    files = sorted(Path(spectra_dir).glob(pattern))
    for f in files:
        try:
            data = np.loadtxt(str(f), delimiter=",", comments="#")
            g = Group(
                energy=data[:, 0],
                flat=data[:, 1],
                filename=f.stem,
                _name=f.stem,
            )
            groups.append(g)
        except Exception as e:
            print(f"  Skipping {f.name}: {e}")
    print(f"Loaded {len(groups)} spectra from {spectra_dir}")
    return groups

In [4]:
def build_xanes_matrix(groups, e_min=None, e_max=None):
    """
    Interpolate normalized XANES spectra onto a common energy grid.
    Returns: energy_grid (1D), matrix (n_spectra x n_energy), names list
    """
    # Find the common energy range
    all_emin = max(g.energy.min() for g in groups)
    all_emax = min(g.energy.max() for g in groups)

    if e_min is not None:
        if ENERGY_IS_RELATIVE:
            # Use median E0 as reference
            e0_median = np.median([g.e0 for g in groups])
            all_emin = max(all_emin, e0_median + e_min)
        else:
            all_emin = max(all_emin, e_min)

    if e_max is not None:
        if ENERGY_IS_RELATIVE:
            e0_median = np.median([g.e0 for g in groups])
            all_emax = min(all_emax, e0_median + e_max)
        else:
            all_emax = min(all_emax, e_max)

    energy_grid = np.arange(all_emin, all_emax, E_STEP)
    matrix = np.zeros((len(groups), len(energy_grid)))
    names = []

    for i, g in enumerate(groups):
        matrix[i, :] = np.interp(energy_grid, g.energy, g.flat)
        names.append(g._name)

    print(f"Spectral matrix: {matrix.shape[0]} spectra × {matrix.shape[1]} energy points")
    print(f"Energy range: {energy_grid[0]:.1f} – {energy_grid[-1]:.1f} eV")
    return energy_grid, matrix, names

In [5]:
# ============================================================
# QUALITY SCREENING
# ============================================================

def screen_spectra(matrix, names, sigma_threshold=3.0):
    """
    Remove outlier spectra based on their distance from the mean spectrum.
    Returns filtered matrix and names.
    """
    mean_spec = matrix.mean(axis=0)
    distances = np.sqrt(np.sum((matrix - mean_spec) ** 2, axis=1))
    threshold = distances.mean() + sigma_threshold * distances.std()

    mask = distances < threshold
    n_removed = (~mask).sum()
    if n_removed > 0:
        print(f"Quality screen: removed {n_removed} spectra beyond {sigma_threshold}σ")
        removed_names = [names[i] for i in range(len(names)) if not mask[i]]
        for rn in removed_names:
            print(f"    Removed: {rn}")
    else:
        print("Quality screen: all spectra passed")

    filtered_names = [names[i] for i in range(len(names)) if mask[i]]
    return matrix[mask], filtered_names, mask

In [6]:
# ============================================================
# PCA
# ============================================================

def run_pca(x_grid, matrix, names):
    """
    Run PCA using Larch's pca_train.
    Returns the PCA result group.
    """
    # Build Larch groups for pca_train
    groups_for_pca = []
    for i in range(matrix.shape[0]):
        g = Group()
        if PCA_REGION == "xanes":
            g.energy = x_grid
            g.flat = matrix[i, :]
        else:
            g.k = x_grid
            g.chi = matrix[i, :] / (x_grid ** K_WEIGHT)  # undo k-weight
        g._name = names[i]
        groups_for_pca.append(g)

    if PCA_REGION == "xanes":
        pca_result = pca_train(groups_for_pca, arrayname="flat")
    else:
        pca_result = pca_train(groups_for_pca, arrayname="chi")

    return pca_result


def plot_pca_diagnostics(pca_result, output_dir):
    """Plot scree plot, IND, and component spectra."""
    fig = plt.figure(figsize=(16, 12))
    gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)

    n_show = min(20, len(pca_result.variances))

    # --- Scree plot ---
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.semilogy(range(1, n_show + 1), pca_result.variances[:n_show], "ko-")
    ax1.set_xlabel("Component number")
    ax1.set_ylabel("Eigenvalue (variance)")
    ax1.set_title("Scree Plot")
    ax1.set_xticks(range(1, n_show + 1))
    ax1.grid(True, alpha=0.3)

    # --- IND function ---
    ax2 = fig.add_subplot(gs[0, 1])
    ind = pca_result.ind
    ax2.semilogy(range(1, len(ind) + 1), ind, "rs-")
    ind_min = np.argmin(ind) + 1
    ax2.axvline(ind_min, color="blue", linestyle="--", label=f"IND min = {ind_min}")
    ax2.set_xlabel("Component number")
    ax2.set_ylabel("IND")
    ax2.set_title(f"Indicator Function (minimum at {ind_min} components)")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # --- First few component spectra ---
    ax3 = fig.add_subplot(gs[1, 0])
    n_comp_show = min(5, ind_min + 2)
    for i in range(n_comp_show):
        offset = i * 0.5
        ax3.plot(pca_result.components[i] + offset, label=f"PC{i+1}")
    ax3.set_xlabel("Point index")
    ax3.set_ylabel("Component loading (offset)")
    ax3.set_title("Principal Component Spectra")
    ax3.legend(fontsize=8)

    # --- Cumulative variance ---
    ax4 = fig.add_subplot(gs[1, 1])
    cumvar = np.cumsum(pca_result.variances) / np.sum(pca_result.variances) * 100
    ax4.plot(range(1, n_show + 1), cumvar[:n_show], "go-")
    ax4.axhline(95, color="red", linestyle="--", alpha=0.5, label="95%")
    ax4.axhline(99, color="red", linestyle=":", alpha=0.5, label="99%")
    ax4.set_xlabel("Number of components")
    ax4.set_ylabel("Cumulative variance (%)")
    ax4.set_title("Cumulative Variance Explained")
    ax4.legend()
    ax4.set_xticks(range(1, n_show + 1))
    ax4.grid(True, alpha=0.3)

    plt.savefig(output_dir / "pca_diagnostics.png", dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved: pca_diagnostics.png")
    return ind_min

## Target transformation function

Project each reference spectrum onto the PCA components, reconstruct it, and
measure the residual. Lower χ²/R-factor = the reference is well-represented
in the dataset.

In [7]:
# ============================================================
# TARGET TRANSFORMATION (optional)
# ============================================================

def run_target_transform(pca_result, reference_files, n_components):
    """
    Test reference spectra against PCA model by manual projection.
    Interpolates each reference onto the PCA energy grid, projects onto
    the first n components, reconstructs, and computes residual.
    """
    if not reference_files:
        print("\nNo reference files provided — skipping target transformation.")
        return None

    energy_grid = pca_result.x
    components = pca_result.components[:n_components]
    mean_spec = pca_result.mean

    print(f"\nTarget transformation results (using {n_components} components):")
    print(f"{'Reference':<30s} {'Chi-square':>12s} {'R-factor':>10s}")
    print("-" * 54)

    results = []
    for ref_path in reference_files:
        try:
            ref_dat = np.loadtxt(str(ref_path), delimiter=",", comments="#")
            ref_energy = ref_dat[:, 0]
            ref_flat = ref_dat[:, 1]

            # Interpolate reference onto PCA energy grid
            ref_interp = np.interp(energy_grid, ref_energy, ref_flat)

            # Mean-center and project onto components
            centered = ref_interp - mean_spec
            weights = centered @ components.T
            reconstructed = weights @ components + mean_spec

            # Compute fit quality
            residual = ref_interp - reconstructed
            chi_sq = np.sum(residual ** 2) / len(residual)
            r_factor = np.sum(np.abs(residual)) / np.sum(np.abs(ref_interp))

            print(f"{Path(ref_path).stem:<30s} {chi_sq:>12.6f} {r_factor:>10.4f}")
            results.append({
                "reference": Path(ref_path).stem,
                "chi_square": chi_sq,
                "r_factor": r_factor,
            })
        except Exception as e:
            print(f"{Path(ref_path).stem:<30s} FAILED: {e}")

    # Sort by chi-square
    if results:
        results.sort(key=lambda x: x["chi_square"])
        print(f"\nRanked by fit quality (best first):")
        for i, r in enumerate(results, 1):
            print(f"  {i:2d}. {r['reference']:<30s} χ²={r['chi_square']:.6f}  R={r['r_factor']:.4f}")

    print("\nInterpretation:")
    print("  Lower chi-square / R-factor = better reconstruction from PCA components.")
    print("  Well-reconstructed references are consistent with species in the dataset.")
    print("  Poor reconstruction suggests that species is not represented.")

    return results

---
## Run

In [8]:
# Load and prepare data
groups = load_ascii_spectra(SPECTRA_DIR)
x_grid, matrix, names = build_xanes_matrix(groups, E_MIN, E_MAX)
matrix, names, quality_mask = screen_spectra(matrix, names)
pca_result = run_pca(x_grid, matrix, names)

# Determine number of components
if N_COMPONENTS is not None:
    n_components = N_COMPONENTS
else:
    ind_min = np.argmin(pca_result.ind) + 1
    cumvar = np.cumsum(pca_result.variances) / np.sum(pca_result.variances)
    n_components = ind_min if ind_min <= len(names) // 2 else int(np.argmax(cumvar >= 0.95)) + 1
print(f'Using {n_components} components')

Loaded 172 spectra from flattened-spectra
Spectral matrix: 172 spectra × 400 energy points
Energy range: 7100.0 – 7179.8 eV
Quality screen: all spectra passed
Using 5 components


## Run target transformation

In [9]:
# Step 5: Target transformation
tt_results = run_target_transform(pca_result, REFERENCE_PATHS, n_components)



Target transformation results (using 5 components):
Reference                        Chi-square   R-factor
------------------------------------------------------
2L-Fhy on sand                     0.003484     0.0715
2L-Fhy                             0.002906     0.0662
6L-Fhy                             0.001864     0.0523
Augite                             0.000587     0.0245
Biotite                            0.001663     0.0358
FeS                                0.004224     0.0594
Ferrosmectite                      0.004359     0.0848
Goethite on sand                   0.006053     0.0964
Goethite                           0.002336     0.0553
Green Rust - Carbonate             0.001086     0.0310
Green Rust - Chloride              0.000571     0.0259
Green Rust - Sulfate               0.000521     0.0236
Hematite on sand                   0.004819     0.0818
Hematite                           0.001716     0.0465
Hornblende                         0.000718     0.0251
Ilmenite    