# Publication Figures: XFM Maps & Representative XANES

This notebook generates the two main manuscript figures:

- **Figure 1** — 2×2 panel µ-XRF tricolor maps (Fe=red, Ca=green, K=blue) with
  cluster-assigned point markers and cross-references to Figure 2 spectra.
- **Figure 2** — 6-panel representative Fe K-edge µ-XANES spectra with LCF fits,
  individual reference components, and residuals.

**Requires:** Run `01_pca_clustering.ipynb` and `03_lcf_microprobe.ipynb` first.

**Inputs:**
- `maps/*.h5` — HDF5 µ-XRF map files
- `flattened-spectra/*.csv` — normalized sample spectra
- `FeK-standards/fluorescence/flattened/*.csv` — reference mineral spectra
- `pca_results/cluster_assignments.csv` — cluster assignments
- `pca_results/lcf_individual.csv` — per-spectrum LCF results

**Outputs:**
- `figure_xfm_maps.png` / `.pdf` — Figure 1
- `figure_xanes_lcf.png` / `.pdf` — Figure 2

## Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from scipy.optimize import nnls
from sklearn.cluster import KMeans
from pathlib import Path
import h5py, os, glob

## Configuration

In [None]:
# ---------- Shared config ----------
DPI = 300
TARGET_WIDTH_MM = 180
TARGET_WIDTH_IN = TARGET_WIDTH_MM / 25.4

MAP_DIR = Path("maps")
PCA_DIR = Path("pca_results")
SPEC_DIR = Path("flattened-spectra")
REF_DIR = Path("FeK-standards/fluorescence/flattened")
OUT_DIR = Path(".")

# Energy grid for XANES
E_MIN, E_MAX, E_STEP = 7100, 7180, 0.2
E_GRID = np.arange(E_MIN, E_MAX + E_STEP / 2, E_STEP)
PLOT_E_MIN, PLOT_E_MAX = 7105, 7170

# RGB channel ROI names (Figure 1)
R_ROI, G_ROI, B_ROI = "Fe Ka", "Ca Ka", "K Ka"

# Figure 1 panel definitions: key -> (h5 filename, scale bar mm)
PANEL_FILES = {
    "a": ("2x2_10um_striated_gt15_2_001.h5",   0.5),
    "b": ("2x2_10um_rectangles_gt15_1_001.h5",  0.5),
    "c": ("2x2_10um_flaky_nodule_001.h5",       0.5),
    "d": ("1x1_10um_flaky_dark_gt15_001.h5",    0.2),
}

# Cluster marker styles
CLUSTER_STYLE = {
    1:    {"marker": "o", "label": "Group 1"},
    2:    {"marker": "s", "label": "Group 2"},
    "3a": {"marker": "^", "label": "Group 3a"},
    "3b": {"marker": "v", "label": "Group 3b"},
    4:    {"marker": "D", "label": "Group 4"},
    5:    {"marker": "p", "label": "Group 5"},
}

# Map panel spectrum prefixes (spectra on the 4 Figure 1 maps)
MAP_PREFIXES = [
    "FeXANES_GT15_FeTiXRD_striated2_",
    "FeXANES_GT15_Fe_striated2_",
    "FeXANES_GT15_rectangles_Fe_",
    "FeXANES_GT5_flaky_nodule_Fe",
    "FeXANES_GT15_flakydark_Fe_",
]

# Reference standard names
REF_NAMES = [
    "2L-Fhy on sand", "2L-Fhy", "6L-Fhy", "Augite", "Biotite",
    "FeS", "Ferrosmectite", "Goethite on sand", "Goethite",
    "Green Rust - Carbonate", "Green Rust - Chloride",
    "Green Rust - Sulfate", "Hematite on sand", "Hematite",
    "Hornblende", "Ilmenite", "Jarosite", "Lepidocrocite",
    "Mackinawite (aged)", "Mackinawite", "Maghemite", "Nontronite",
    "Pyrite", "Pyrrhotite", "Schwertmannite", "Siderite-n",
    "Siderite-s", "Vivianite",
]

## Phase grouping and colors

Map mineral reference names to geochemical phase groups. Used by both figures:
Figure 1 legend labels and Figure 2 component colors.

In [None]:
PHASE_GROUPS = {
    "6L-Fhy": "Fe(III) oxyhydroxide", "2L-Fhy": "Fe(III) oxyhydroxide",
    "2L-Fhy on sand": "Fe(III) oxyhydroxide", "Goethite": "Fe(III) oxyhydroxide",
    "Goethite on sand": "Fe(III) oxyhydroxide", "Lepidocrocite": "Fe(III) oxyhydroxide",
    "Schwertmannite": "Fe(III) oxyhydroxide",
    "Hematite": "Fe(III) oxide", "Hematite on sand": "Fe(III) oxide",
    "Maghemite": "Fe(III) oxide",
    "Ferrosmectite": "Fe(III) phyllosilicate", "Nontronite": "Fe(III) phyllosilicate",
    "Biotite": "Fe(II) phyllosilicate",
    "Hornblende": "Fe(II) silicate", "Augite": "Fe(II) silicate",
    "Mackinawite (aged)": "Fe sulfide", "Mackinawite": "Fe sulfide",
    "Pyrrhotite": "Fe sulfide", "Pyrite": "Fe sulfide", "FeS": "Fe sulfide",
    "Siderite-s": "Fe(II) carbonate", "Siderite-n": "Fe(II) carbonate",
    "Ilmenite": "Fe-Ti oxide",
    "Vivianite": "Fe(II) phosphate",
    "Jarosite": "Fe(III) sulfate",
    "Green Rust - Carbonate": "Green rust", "Green Rust - Chloride": "Green rust",
    "Green Rust - Sulfate": "Green rust",
}

def phase_label(ref_name):
    return PHASE_GROUPS.get(ref_name, ref_name)

# Colors for phase groups (Figure 2 component curves)
PHASE_COLORS = {
    "Fe(III) oxyhydroxide":   '#ff7f0e',
    "Fe(III) oxide":          '#d62728',
    "Fe(III) phyllosilicate": '#2ca02c',
    "Fe(II) phyllosilicate":  '#1f77b4',
    "Fe(II) silicate":        '#17becf',
    "Fe sulfide":             '#9467bd',
    "Fe(II) carbonate":       '#8c564b',
    "Fe-Ti oxide":            '#e377c2',
    "Fe(II) phosphate":       '#bcbd22',
    "Fe(III) sulfate":        '#7f7f7f',
    "Green rust":             '#aec7e8',
}

COMP_COLORS = [
    '#1f77b4', '#2ca02c', '#9467bd', '#8c564b', '#e377c2',
    '#7f7f7f', '#bcbd22', '#17becf', '#ff7f0e', '#aec7e8',
]

## Load cluster assignments, LCF results, and sub-cluster labels

This shared data loading step provides the cluster and LCF information
needed by both figures.

In [None]:
print("Loading cluster assignments and LCF results...")
cluster_df = pd.read_csv(PCA_DIR / "cluster_assignments.csv")
cluster_lookup = dict(zip(cluster_df["spectrum"], cluster_df["cluster"]))

lcf_df = pd.read_csv(PCA_DIR / "lcf_individual.csv")
df = cluster_df.merge(lcf_df, on=["spectrum", "cluster"], how="inner")

# Sub-cluster cluster 3 into 3a and 3b using k-means on PC scores
c3_mask = df["cluster"] == 3
c3_df = df[c3_mask].copy()
pc_cols = [c for c in df.columns if c.startswith("PC")]
km = KMeans(n_clusters=2, random_state=42, n_init=10)
c3_labels = km.fit_predict(c3_df[pc_cols].values)

# Determine 3a (higher Pyrrhotite) vs 3b
sub0_pyrr = c3_df.iloc[c3_labels == 0]["Pyrrhotite"].mean() if (c3_labels == 0).any() else 0
sub1_pyrr = c3_df.iloc[c3_labels == 1]["Pyrrhotite"].mean() if (c3_labels == 1).any() else 0
label_map = {0: "3a", 1: "3b"} if sub0_pyrr >= sub1_pyrr else {1: "3a", 0: "3b"}

sub3_lookup = {}
for i, (_, row) in enumerate(c3_df.iterrows()):
    sub3_lookup[row["spectrum"]] = label_map[c3_labels[i]]

# Build group column for the merged dataframe
c3_df["sub_cluster"] = c3_labels
c3_df["group"] = c3_df["sub_cluster"].map(label_map)
df["group"] = df["cluster"].astype(str)
df.loc[c3_mask, "group"] = c3_df["group"].values

def get_style_key(spec_name, cluster_id):
    if cluster_id == 3:
        return sub3_lookup.get(spec_name, "3a")
    return cluster_id

# Reverse lookup: spectrum -> group key
spec_to_group = {}
for _, row in cluster_df.iterrows():
    spec_to_group[row["spectrum"]] = get_style_key(row["spectrum"], row["cluster"])

# Reference mineral columns from LCF results
ref_columns = [c for c in lcf_df.columns if c not in
               ["spectrum", "cluster", "r_factor", "chi_sq", "weight_sum", "n_refs"]]

print(f"Loaded {len(cluster_df)} spectra, {len(lcf_df)} LCF results")
print(f"Group counts:\n{df['group'].value_counts().sort_index()}")

## Select representative spectra

For each group, select a representative spectrum that:
1. Comes from one of the 4 Figure 1 maps (preferred)
2. Has a 3+ component LCF fit
3. Has good fit quality (R ≤ 0.025)
4. Is closest to the median R-factor among candidates

The same selections drive both the annotation arrows in Figure 1 and the
spectra plotted in Figure 2.

In [None]:
def is_map_spectrum(name):
    for prefix in MAP_PREFIXES:
        if name.startswith(prefix):
            return True
    return False

print("Selecting representative spectra...")
selected = {}
for grp in [1, 2, "3a", "3b", 4, 5]:
    grp_str = str(grp)
    grp_df = df[df["group"] == grp_str].copy()

    map_mask = grp_df["spectrum"].apply(is_map_spectrum)
    candidates = grp_df[map_mask] if map_mask.any() else grp_df

    multi_comp = candidates[candidates["n_refs"] >= 3]
    if len(multi_comp) >= 1:
        candidates = multi_comp

    good_fits = candidates[candidates["r_factor"] <= 0.025]
    if len(good_fits) >= 1:
        candidates = good_fits

    median_r = candidates["r_factor"].median()
    idx = (candidates["r_factor"] - median_r).abs().idxmin()
    row = candidates.loc[idx]
    selected[grp] = row
    print(f"  Group {grp}: {row['spectrum']} (R={row['r_factor']:.4f}, n_refs={int(row['n_refs'])})")

# Build cross-reference lookups used by both figures
SELECTED_SPECTRA = {grp: row["spectrum"] for grp, row in selected.items()}

# Figure 2 panel labels for Figure 1 annotation arrows
FIG2_PANEL = {1: "2a", 2: "2b", "3a": "2c", "3b": "2d", 4: "2e", 5: "2f"}

# Selected area names for quick lookup (Figure 1)
selected_areas = {}
for grp, spec_name in SELECTED_SPECTRA.items():
    area_name = spec_name.replace("FeXANES_", "").replace(".001", "")
    selected_areas[area_name] = grp

## Build legend labels from LCF

Aggregate individual mineral weights into phase groups to create compact
composition labels for the Figure 1 legend strip.

In [None]:
GROUPS = [
    {"key": "1",  "marker": "o"},
    {"key": "2",  "marker": "s"},
    {"key": "3a", "marker": "^"},
    {"key": "3b", "marker": "v"},
    {"key": "4",  "marker": "D"},
    {"key": "5",  "marker": "p"},
]

for grp_info in GROUPS:
    grp_key = grp_info["key"]
    if grp_key in SELECTED_SPECTRA:
        sk = grp_key
    elif grp_key.isdigit() and int(grp_key) in SELECTED_SPECTRA:
        sk = int(grp_key)
    else:
        grp_info["label"] = ""
        continue

    spec_name = SELECTED_SPECTRA[sk]
    row = lcf_df[lcf_df["spectrum"] == spec_name]
    if row.empty:
        grp_info["label"] = ""
        continue
    row = row.iloc[0]
    weight_sum = row["weight_sum"]

    phase_pcts = {}
    for ref_name in ref_columns:
        w = row[ref_name]
        if w > 0.005:
            pg = phase_label(ref_name)
            pct = w / weight_sum * 100 if weight_sum > 0 else 0
            phase_pcts[pg] = phase_pcts.get(pg, 0) + pct

    sorted_phases = sorted(phase_pcts.items(), key=lambda x: x[1], reverse=True)
    parts = [f"{name} {pct:.0f}%" for name, pct in sorted_phases if pct >= 1]
    grp_info["label"] = ", ".join(parts)

print("Legend labels:")
for g in GROUPS:
    print(f"  Group {g['key']}: {g['label']}")

## Load reference spectra

Used by Figure 2 to reconstruct LCF fits and plot individual components.

In [None]:
def load_csv_spectrum(filepath):
    data = np.loadtxt(str(filepath), comments='#', delimiter=',')
    return data[:, 0], data[:, 1]

def interp_to_grid(energy, mu, grid=E_GRID):
    return np.interp(grid, energy, mu)

print("Loading reference standards...")
ref_spectra = {}
for name in REF_NAMES:
    fp = REF_DIR / (name + ".csv")
    if fp.exists():
        e, mu = load_csv_spectrum(fp)
        ref_spectra[name] = interp_to_grid(e, mu)
    else:
        print(f"  WARNING: {fp} not found")
print(f"Loaded {len(ref_spectra)} reference spectra")

---
# Figure 1: XFM Tricolor Maps

## HDF5 map helpers

In [None]:
def area_to_spectrum(area_name):
    return f"FeXANES_{area_name}.001"

def get_roi_map(f, roi_name):
    for path in ["xrmmap/roimap/sum_cor", "xrmmap/roimap/sum_raw"]:
        if path in f:
            names = [n.decode() if isinstance(n, bytes) else n
                     for n in f["xrmmap/roimap/sum_name"][:]]
            if roi_name in names:
                idx = names.index(roi_name)
                return f[path][:, 1:-1, idx].astype(float)
    return None

def make_rgb(f):
    channels = []
    for name in [R_ROI, G_ROI, B_ROI]:
        ch = get_roi_map(f, name)
        if ch is None:
            ch = np.zeros((1, 1))
        vmin = np.percentile(ch, 1)
        vmax = np.percentile(ch, 99.5)
        if vmax > vmin:
            ch = np.clip((ch - vmin) / (vmax - vmin), 0, 1)
        else:
            ch = np.zeros_like(ch)
        channels.append(ch)
    return np.stack(channels, axis=-1)

def get_area_centroids(f):
    centroids = {}
    areas_grp = f.get("xrmmap/areas")
    if areas_grp is None:
        return centroids
    for area_name in areas_grp:
        mask = areas_grp[area_name][:]
        if mask.any():
            rows, cols = np.where(mask)
            centroids[area_name] = (rows.mean(), cols.mean())
    return centroids

## RGB triangle helper

In [None]:
def make_rgb_triangle(ax):
    """Draw Fe(red)-K(blue)-Ca(green) tricolor triangle."""
    top_left = np.array([0.2, 0.95])
    top_right = np.array([0.8, 0.95])
    bottom = np.array([0.5, 0.95 - 0.6 * np.sqrt(3) / 2])

    n = 200
    tri_img = np.ones((n, n, 3))
    for iy in range(n):
        for ix in range(n):
            px = ix / (n - 1)
            py = 1.0 - iy / (n - 1)
            p = np.array([px, py])
            v0 = top_right - top_left
            v1 = bottom - top_left
            v2 = p - top_left
            d00 = np.dot(v0, v0)
            d01 = np.dot(v0, v1)
            d11 = np.dot(v1, v1)
            d20 = np.dot(v2, v0)
            d21 = np.dot(v2, v1)
            denom = d00 * d11 - d01 * d01
            if abs(denom) < 1e-10:
                continue
            v = (d11 * d20 - d01 * d21) / denom
            w = (d00 * d21 - d01 * d20) / denom
            u = 1.0 - v - w
            if u >= -0.01 and v >= -0.01 and w >= -0.01:
                u = max(u, 0); v = max(v, 0); w = max(w, 0)
                s = u + v + w
                if s > 0:
                    u /= s; v /= s; w /= s
                tri_img[iy, ix] = [u, w, v]
            else:
                tri_img[iy, ix] = [1, 1, 1]

    ax.imshow(tri_img, extent=[0, 1, 0, 1], aspect='auto')
    tri_coords = [top_left, top_right, bottom, top_left]
    ax.plot([c[0] for c in tri_coords], [c[1] for c in tri_coords], 'k-', lw=1.0)
    ax.text(top_left[0] - 0.03, top_left[1] + 0.03, "Fe", fontsize=8,
            fontweight='bold', color='red', ha='center', va='bottom')
    ax.text(top_right[0] + 0.03, top_right[1] + 0.03, "K", fontsize=8,
            fontweight='bold', color='blue', ha='center', va='bottom')
    ax.text(bottom[0], bottom[1] - 0.05, "Ca", fontsize=8,
            fontweight='bold', color='green', ha='center', va='top')
    ax.set_xlim(0, 1); ax.set_ylim(0, 1); ax.axis('off')

## Render map panels

Extract RGB composites and cluster point locations from each HDF5 file.

In [None]:
print("Rendering maps from HDF5...")

panel_data = {}

for panel_key, (h5_name, bar_mm) in PANEL_FILES.items():
    h5_path = MAP_DIR / h5_name
    with h5py.File(h5_path, "r") as f:
        rgb = make_rgb(f)
        centroids = get_area_centroids(f)
        pos = f["xrmmap/positions/pos"]
        ny, nx_full = pos.shape[:2]
        x_pos = pos[:, 1:-1, 0][:]
        y_pos = pos[:, 1:-1, 1][:]
        nx = nx_full - 2
        extent = [float(x_pos.min()), float(x_pos.max()),
                  float(y_pos.min()), float(y_pos.max())]

    style_points = {k: ([], []) for k in CLUSTER_STYLE}
    sel_points = []

    for area_name, (row_c, col_c) in centroids.items():
        spec_name = area_to_spectrum(area_name)
        cluster_id = cluster_lookup.get(spec_name)
        if cluster_id is None:
            continue
        sk = get_style_key(spec_name, cluster_id)
        if sk not in style_points:
            continue
        col_adj = col_c - 1
        if col_adj < 0 or col_adj >= nx:
            continue
        x_disp = np.interp(col_adj, [0, nx - 1], [extent[0], extent[1]])
        y_disp = np.interp(row_c, [0, ny - 1], [extent[2], extent[3]])
        style_points[sk][0].append(x_disp)
        style_points[sk][1].append(y_disp)

        if area_name in selected_areas:
            sel_points.append((x_disp, y_disp, selected_areas[area_name]))

    panel_data[panel_key] = (rgb, extent, style_points, sel_points, bar_mm)
    print(f"  Panel ({panel_key}): {rgb.shape[1]}x{rgb.shape[0]}, "
          f"{sum(len(v[0]) for v in style_points.values())} points, "
          f"{len(sel_points)} selected")

## Assemble Figure 1

2×2 map panels with a shared legend strip below containing the RGB triangle
and group/phase labels. Selected spectra are annotated with arrows pointing
to their Figure 2 panel labels.

In [None]:
print("Assembling Figure 1...")

fig_w = TARGET_WIDTH_IN
panel_w_in = fig_w / 2.06
aspect = panel_data["a"][0].shape[0] / panel_data["a"][0].shape[1]
panel_h_in = panel_w_in * aspect
gap = 0.04
legend_h_in = 1.8
total_h_in = 2 * panel_h_in + gap + legend_h_in + 0.15

fig1 = plt.figure(figsize=(fig_w, total_h_in), dpi=DPI)

def to_fig(x, y, w, h):
    return [x / fig_w, y / total_h_in, w / fig_w, h / total_h_in]

x_start = (fig_w - 2 * panel_w_in - gap) / 2
y_grid_bottom = legend_h_in + 0.15

panel_layout = [["a", "b"], ["c", "d"]]
panel_labels = {"a": "(a)", "b": "(b)", "c": "(c)", "d": "(d)"}

for row_idx, row_keys in enumerate(panel_layout):
    for col_idx, key in enumerate(row_keys):
        x = x_start + col_idx * (panel_w_in + gap)
        y = y_grid_bottom + (1 - row_idx) * (panel_h_in + gap)
        rect = to_fig(x, y, panel_w_in, panel_h_in)
        ax = fig1.add_axes(rect)

        rgb, extent, style_points, sel_points, bar_mm = panel_data[key]

        ax.imshow(rgb, extent=extent, aspect="equal", interpolation="nearest",
                  origin="lower")

        for sk, style in CLUSTER_STYLE.items():
            xs, ys = style_points[sk]
            if xs:
                ax.scatter(xs, ys, marker=style["marker"], facecolors="none",
                           edgecolors="white", s=80, linewidths=1.0, zorder=5)

        for x_disp, y_disp, grp in sel_points:
            fig2_label = FIG2_PANEL[grp]
            ax.annotate(
                fig2_label,
                xy=(x_disp, y_disp),
                xytext=(14, 14), textcoords="offset points",
                fontsize=8, fontweight="bold", color="yellow",
                bbox=dict(boxstyle="round,pad=0.2", facecolor="black",
                          alpha=0.8, edgecolor="yellow", lw=0.8),
                arrowprops=dict(arrowstyle="-|>", color="yellow",
                                lw=1.2, mutation_scale=8,
                                shrinkA=0, shrinkB=3),
                zorder=10,
            )

        ax.text(0.04, 0.96, panel_labels[key], transform=ax.transAxes,
                fontsize=11, fontweight='bold', color='white', va='top', ha='left',
                bbox=dict(boxstyle='round,pad=0.15', facecolor='black',
                          alpha=0.5, edgecolor='none'))

        x_range = extent[1] - extent[0]
        bar_frac = bar_mm / x_range
        x_sb = 0.03; x_end_sb = x_sb + bar_frac; y_sb = 0.04
        ax.plot([x_sb, x_end_sb], [y_sb, y_sb], color="white", linewidth=3,
                solid_capstyle="butt", zorder=10, clip_on=False, transform=ax.transAxes)
        label = f"{bar_mm:.1f} mm" if bar_mm < 1 else f"{bar_mm:.0f} mm"
        ax.text((x_sb + x_end_sb) / 2, y_sb + 0.03, label, color="white",
                fontsize=7, ha="center", va="bottom", fontweight="bold",
                zorder=10, clip_on=False, transform=ax.transAxes)

        ax.set_xticks([]); ax.set_yticks([])

# Legend strip
legend_y = 0.0; legend_x = x_start
tri_w_in = 1.0; tri_h_in = 1.0
tri_y = legend_y + (legend_h_in - tri_h_in) / 2
ax_tri = fig1.add_axes(to_fig(legend_x, tri_y, tri_w_in, tri_h_in))
make_rgb_triangle(ax_tri)

legend_start_x = legend_x + tri_w_in + 0.08
entry_w = fig_w - legend_start_x - 0.05
entry_h = legend_h_in / 6

for i, grp in enumerate(GROUPS):
    x = legend_start_x
    y = legend_y + legend_h_in - (i + 1) * entry_h
    rect = to_fig(x, y, entry_w, entry_h)
    ax_entry = fig1.add_axes(rect)
    ax_entry.axis('off')
    ax_entry.scatter([0.02], [0.5], marker=grp["marker"], s=50,
                     facecolors='none', edgecolors='black', linewidths=1.0,
                     transform=ax_entry.transAxes, clip_on=False, zorder=5)
    ax_entry.text(0.05, 0.5, f"Group {grp['key']}:  {grp['label']}",
                  transform=ax_entry.transAxes, fontsize=5.5, va='center', ha='left')

plt.show()

fig1.savefig(OUT_DIR / "figure_xfm_maps.png", dpi=DPI,
             bbox_inches='tight', facecolor='white', pad_inches=0.05)
fig1.savefig(OUT_DIR / "figure_xfm_maps.pdf", dpi=DPI,
             bbox_inches='tight', facecolor='white', pad_inches=0.05)
print("Figure 1 saved: figure_xfm_maps.png and .pdf")

---
# Figure 2: Representative XANES with LCF Fits

## Figure 2 helpers

In [None]:
# Map spectrum name -> Figure 1 panel cross-reference
SPEC_TO_PANEL = [
    ("FeXANES_GT15_FeTiXRD_striated2_", "(a)"),
    ("FeXANES_GT15_Fe_striated2_",      "(a)"),
    ("FeXANES_GT15_rectangles_Fe_",     "(b)"),
    ("FeXANES_GT5_flaky_nodule_Fe",     "(c)"),
    ("FeXANES_GT15_flakydark_Fe_",      "(d)"),
]

def map_panel_label(spectrum_name):
    """Return 'Fig. 1X, pt. N' string for a spectrum from the Figure 1 maps."""
    for prefix, panel in SPEC_TO_PANEL:
        if spectrum_name.startswith(prefix):
            rest = spectrum_name[len(prefix):]
            pt = rest.replace(".001", "").strip("_. ")
            if not pt:
                pt = "1"
            return f"Fig. 1{panel[1]}, pt. {pt}"
    return ""

def short_name(spectrum_name):
    s = spectrum_name.replace("FeXANES_", "").replace(".001", "")
    s = s.replace("GT15_FeTiXRD_", "").replace("GT15_", "").replace("GT5_", "")
    s = s.replace("Fe_", "").replace("_Fe", "")
    return s

## Build Figure 2

Each panel shows:
- **Black line**: measured XANES spectrum
- **Red dashed**: LCF best fit (sum of weighted references)
- **Colored lines**: individual reference component contributions
- **Gray line**: fit residual (offset below data)
- **Text box**: R-factor and phase composition (%)

In [None]:
print("Building Figure 2...")

fig2, axes = plt.subplots(3, 2, figsize=(TARGET_WIDTH_IN, 8.5), dpi=DPI)
fig2.subplots_adjust(hspace=0.35, wspace=0.25, left=0.08, right=0.97, top=0.97, bottom=0.06)

group_order = [1, 2, "3a", "3b", 4, 5]
panel_labels_list = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]

for idx, grp in enumerate(group_order):
    row_idx = idx // 2
    col_idx = idx % 2
    ax = axes[row_idx, col_idx]

    info = selected[grp]
    spec_name = info["spectrum"]

    # Load the measured spectrum
    spec_fp = SPEC_DIR / (spec_name + ".csv")
    if not spec_fp.exists():
        alt = spec_name.replace(".001", "") + ".001.csv"
        spec_fp = SPEC_DIR / alt
    if not spec_fp.exists():
        matches = list(SPEC_DIR.glob(spec_name + "*"))
        if matches:
            spec_fp = matches[0]

    e_data, mu_data = load_csv_spectrum(spec_fp)
    mu_grid = interp_to_grid(e_data, mu_data)

    # Get active references and weights
    active_refs = []
    for ref_name in ref_columns:
        w = info[ref_name]
        if w > 0.005:
            active_refs.append((ref_name, w))
    active_refs.sort(key=lambda x: x[1], reverse=True)

    # Reconstruct fit
    weight_sum = info["weight_sum"]
    fit = np.zeros_like(E_GRID)
    components = []
    for ref_name, w in active_refs:
        if ref_name in ref_spectra:
            comp = w * ref_spectra[ref_name]
            fit += comp
            pct = w / weight_sum * 100 if weight_sum > 0 else 0
            components.append((ref_name, w, pct, comp))

    residual = mu_grid - fit
    r_factor = info["r_factor"]

    # Plot range
    e_mask = (E_GRID >= PLOT_E_MIN) & (E_GRID <= PLOT_E_MAX)
    e_plot = E_GRID[e_mask]

    # Residual offset
    data_min = mu_grid[e_mask].min()
    resid_max = np.abs(residual[e_mask]).max()
    resid_offset = data_min - resid_max - 0.08

    # Data + fit
    ax.plot(e_plot, mu_grid[e_mask], 'k-', lw=1.2, label='Data', zorder=5)
    ax.plot(e_plot, fit[e_mask], 'r--', lw=1.0, label='LCF fit', zorder=4)

    # Individual components
    for ci, (ref_name, w, pct, comp) in enumerate(components):
        pg = phase_label(ref_name)
        color = PHASE_COLORS.get(pg, COMP_COLORS[ci % len(COMP_COLORS)])
        ax.plot(e_plot, comp[e_mask], '-', color=color, lw=0.7, alpha=0.8, zorder=3)

    # Residual
    ax.plot(e_plot, residual[e_mask] + resid_offset, '-', color='gray', lw=0.7, zorder=2)
    ax.axhline(y=resid_offset, color='gray', lw=0.3, ls=':', zorder=1)

    # Composition text
    comp_text_parts = []
    for ref_name, w, pct, comp in components:
        pg = phase_label(ref_name)
        comp_text_parts.append(f"{pg} {pct:.0f}%")
    comp_text = "\n".join(comp_text_parts)

    ax.text(0.98, 0.97, f"R = {r_factor:.4f}\n{comp_text}",
            transform=ax.transAxes, fontsize=5.5, va='top', ha='right',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.85,
                      edgecolor='gray', lw=0.3))

    # Panel label and title with Figure 1 cross-reference
    panel_ref = map_panel_label(spec_name)
    title_str = f"Group {grp}"
    if panel_ref:
        title_str += f"  ({panel_ref})"
    ax.text(0.02, 0.97, f"{panel_labels_list[idx]}",
            transform=ax.transAxes, fontsize=9, fontweight='bold', va='top', ha='left')
    ax.set_title(title_str, fontsize=7.5, pad=4)

    # Axes
    if row_idx == 2:
        ax.set_xlabel("Energy (eV)", fontsize=8)
    else:
        ax.set_xticklabels([])
    if col_idx == 0:
        ax.set_ylabel("Flattened µ(E)", fontsize=8)

    ax.tick_params(labelsize=7, direction='in', top=True, right=True)
    ax.set_xlim(PLOT_E_MIN, PLOT_E_MAX)

plt.show()

fig2.savefig("figure_xanes_lcf.png", dpi=DPI, bbox_inches='tight',
             facecolor='white', pad_inches=0.05)
fig2.savefig("figure_xanes_lcf.pdf", dpi=DPI, bbox_inches='tight',
             facecolor='white', pad_inches=0.05)
print("\nFigure 2 saved: figure_xanes_lcf.png and .pdf")