# XFM Tricolor Maps with Cluster Overlay

Interactive RGB composite maps from µ-XRF HDF5 data, with XANES point locations
colored/shaped by cluster assignment. Includes sub-cluster 3 split into 3a/3b.

**Requires:** Run `01_pca_clustering.ipynb` and `03_lcf_microprobe.ipynb` first
(needs cluster assignments, LCF results, and sub-cluster labels).

**Inputs:**
- `maps/*.h5` — HDF5 µ-XRF map files
- `pca_results/cluster_assignments.csv` — cluster assignments

**Outputs:** `pca_results/xfm_*.png` — tricolor maps with cluster overlays

## Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.gridspec as gridspec
from pathlib import Path
import h5py, io
import ipywidgets as widgets
from IPython.display import display, clear_output, Image as IPImage
import warnings
warnings.filterwarnings('ignore')

def show_fig(fig):
    """Render figure to PNG and display as static image."""
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=200, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    display(IPImage(data=buf.read()))

## Load cluster assignments

Load cluster assignments from the PCA/clustering notebook output.

In [None]:
MAP_DIR = Path('maps')
CLUSTER_CSV = Path('pca_results/cluster_assignments.csv')
OUT_DIR = Path('pca_results')

# Map labels: filename stem → display label (consistent across all figures)
MAP_LABELS = {
    "1x1_10um_flaky_dark_gt15_001": "Map 1",
    "1x1_10um_flaky_gray_mix_gt15_001": "Map 2",
    "1x1_10um_rectangles_flakes_gt15_2_001": "Map 3",
    "2x2_10um_concentric_gray_1_001": "Map 4",
    "2x2_10um_concentric_gray_3_001": "Map 5",
    "2x2_10um_flaky_1_001": "Map 6",
    "2x2_10um_flaky_2_001": "Map 7",
    "2x2_10um_flaky_nodule_001": "Map 8",
    "2x2_10um_flaky_smooth_2_001": "Map 9",
    "2x2_10um_rectangles_gt15_1_001": "Map 10",
    "2x2_10um_striated_gt15_2_001": "Map 11",
    "2x2_10um_super_dark_gt15_4_001": "Map 12",
    "2x2_10um_white_band_001": "Map 13",
}

cluster_df = pd.read_csv(CLUSTER_CSV)
cluster_lookup = dict(zip(cluster_df['spectrum'], cluster_df['cluster']))
print(f'Loaded {len(cluster_df)} cluster assignments')

## Sub-cluster 3 setup

Split cluster 3 into sub-clusters 3a and 3b using k-means on PC scores.
This requires re-running a lightweight clustering on the cluster 3 scores.

In [3]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

pc_cols = [c for c in cluster_df.columns if c.startswith('PC')]
c3_df = cluster_df[cluster_df['cluster'] == 3].copy()

if len(c3_df) > 1 and len(pc_cols) > 0:
    km = KMeans(n_clusters=2, random_state=42, n_init=10)
    c3_labels = km.fit_predict(c3_df[pc_cols].values)
    
    # Determine which sub-cluster is 3a vs 3b using LCF results if available
    lcf_path = Path('pca_results/lcf_individual.csv')
    if lcf_path.exists():
        lcf_df = pd.read_csv(lcf_path)
        c3_lcf = lcf_df[lcf_df['cluster'] == 3].copy()
        c3_lcf['sub_label'] = c3_labels
        if 'Pyrrhotite' in c3_lcf.columns:
            sub0_pyrr = c3_lcf[c3_lcf['sub_label'] == 0]['Pyrrhotite'].mean()
            sub1_pyrr = c3_lcf[c3_lcf['sub_label'] == 1]['Pyrrhotite'].mean()
            label_map = {0: 0, 1: 1} if sub0_pyrr >= sub1_pyrr else {1: 0, 0: 1}
        else:
            label_map = {0: 0, 1: 1}
    else:
        label_map = {0: 0, 1: 1}
    
    sub3_names = c3_df['spectrum'].tolist()
    sub3_labels = [label_map[l] for l in c3_labels]
    sub3_lookup = dict(zip(sub3_names, sub3_labels))
    print(f'Cluster 3 split: 3a={sub3_labels.count(0)}, 3b={sub3_labels.count(1)}')
else:
    sub3_lookup = {}
    print('Could not split cluster 3')

Cluster 3 split: 3a=17, 3b=9


## Cluster style definitions

In [4]:
CLUSTER_STYLE = {
    1:    {'marker': 'o', 'label': 'Group 1'},
    2:    {'marker': 's', 'label': 'Group 2'},
    '3a': {'marker': '^', 'label': 'Group 3a'},
    '3b': {'marker': 'v', 'label': 'Group 3b'},
    4:    {'marker': 'D', 'label': 'Group 4'},
    5:    {'marker': 'p', 'label': 'Group 5'},
}

def get_style_key(spec_name, cluster_id):
    """Map a spectrum to its style key, splitting cluster 3 into sub-clusters."""
    if cluster_id == 3:
        sub = sub3_lookup.get(spec_name)
        if sub == 0:
            return '3a'
        elif sub == 1:
            return '3b'
        return '3a'
    return cluster_id

## HDF5 map utilities

In [None]:
# Discover available ROIs
h5_files = sorted([p for p in MAP_DIR.glob('*.h5')
                   if 'test_map' not in p.name and not p.stem.endswith('_002')
                   and 'elongated_particle' not in p.name])
with h5py.File(h5_files[0], 'r') as f:
    _roi_names = [n.decode() if isinstance(n, bytes) else n
                  for n in f['xrmmap/roimap/sum_name'][:]]
ELEMENT_ROIS = [n for n in _roi_names if any(c in n for c in ['Ka', 'Kb', 'La', 'Ma'])]
print(f'Found {len(h5_files)} map files, {len(ELEMENT_ROIS)} element ROIs')

def area_to_spectrum(area_name):
    return f'FeXANES_{area_name}.001'

def get_roi_map(f, roi_name):
    for path in ['xrmmap/roimap/sum_cor', 'xrmmap/roimap/sum_raw']:
        if path in f:
            names_path = 'xrmmap/roimap/sum_name'
            if names_path in f:
                roi_names = [n.decode() if isinstance(n, bytes) else n
                             for n in f[names_path][:]]
                if roi_name in roi_names:
                    idx = roi_names.index(roi_name)
                    return f[path][:, 1:-1, idx].astype(float)
    return None

def make_rgb(f, r_name, g_name, b_name):
    channels = []
    for name in [r_name, g_name, b_name]:
        ch = get_roi_map(f, name)
        if ch is None:
            ch = np.zeros((1, 1))
        vmin = np.percentile(ch, 1)
        vmax = np.percentile(ch, 99.5)
        if vmax > vmin:
            ch = np.clip((ch - vmin) / (vmax - vmin), 0, 1)
        else:
            ch = np.zeros_like(ch)
        channels.append(ch)
    return np.stack(channels, axis=-1)

def get_area_centroids(f):
    centroids = {}
    areas_grp = f.get('xrmmap/areas')
    if areas_grp is None:
        return centroids
    for area_name in areas_grp:
        mask = areas_grp[area_name][:]
        if mask.any():
            rows, cols = np.where(mask)
            centroids[area_name] = (rows.mean(), cols.mean())
    return centroids

## RGB triangle and scale bar

In [6]:
def make_rgb_triangle(size=100):
    img = np.zeros((size, size, 4), dtype=np.float32)
    v_r = np.array([0.1 * size, 0.05 * size])
    v_b = np.array([0.9 * size, 0.05 * size])
    v_g = np.array([0.5 * size, 0.95 * size])
    for y in range(size):
        for x in range(size):
            p = np.array([x, y])
            denom = (v_g[1] - v_b[1]) * (v_r[0] - v_b[0]) + (v_b[0] - v_g[0]) * (v_r[1] - v_b[1])
            if abs(denom) < 1e-10:
                continue
            w_r = ((v_g[1] - v_b[1]) * (p[0] - v_b[0]) + (v_b[0] - v_g[0]) * (p[1] - v_b[1])) / denom
            w_g = ((v_b[1] - v_r[1]) * (p[0] - v_b[0]) + (v_r[0] - v_b[0]) * (p[1] - v_b[1])) / denom
            w_b = 1 - w_r - w_g
            if w_r >= -0.01 and w_g >= -0.01 and w_b >= -0.01:
                w_r, w_g, w_b = max(w_r, 0), max(w_g, 0), max(w_b, 0)
                s = w_r + w_g + w_b
                if s > 0:
                    w_r, w_g, w_b = w_r / s, w_g / s, w_b / s
                img[y, x] = [w_r, w_g, w_b, 1.0]
    return img, v_r, v_g, v_b

_tri_img, _tri_vr, _tri_vg, _tri_vb = make_rgb_triangle(80)

def add_scale_bar(ax, x_range, bar_length_mm=0.5):
    """Add a scale bar just outside the plot area, bottom right."""
    bar_frac = bar_length_mm / x_range
    x_end = 0.98
    x_start = x_end - bar_frac
    y_pos = -0.04
    ax.plot([x_start, x_end], [y_pos, y_pos], color='black',
            linewidth=3, solid_capstyle='butt', zorder=10,
            clip_on=False, transform=ax.transAxes)
    label = f'{bar_length_mm:.1f} mm' if bar_length_mm < 1 else f'{bar_length_mm:.0f} mm'
    ax.text((x_start + x_end) / 2, y_pos - 0.02, label,
            color='black', fontsize=9, ha='center', va='top',
            fontweight='bold', zorder=10, clip_on=False,
            transform=ax.transAxes)

## Legend and map rendering

In [None]:
def render_maps(r_name, g_name, b_name, show_xanes=True):
    """Render all maps as a single gridded figure with RGB tricolor."""
    # Collect map data
    map_data = []
    for h5_path in h5_files:
        with h5py.File(h5_path, 'r') as f:
            rgb = make_rgb(f, r_name, g_name, b_name)
            centroids = get_area_centroids(f)
            pos = f['xrmmap/positions/pos']
            ny, nx_full = pos.shape[:2]
            x_pos = pos[:, 1:-1, 0][:]
            y_pos = pos[:, 1:-1, 1][:]
            nx = nx_full - 2
            extent = [x_pos.min(), x_pos.max(), y_pos.min(), y_pos.max()]
        map_data.append({
            'path': h5_path, 'rgb': rgb, 'centroids': centroids,
            'extent': extent, 'ny': ny, 'nx': nx,
        })

    n_maps = len(map_data)
    ncols = 4
    nrows = int(np.ceil(n_maps / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 4 * nrows))
    axes = axes.flatten()

    def short(name):
        return name.replace(' Ka', '').replace(' Kb', '').replace(' La', '').replace(' Ma', '')

    for idx, md in enumerate(map_data):
        ax = axes[idx]
        ax.imshow(md['rgb'], extent=md['extent'], aspect='equal',
                  interpolation='nearest', origin='lower')

        if show_xanes:
            style_points = {k: ([], []) for k in CLUSTER_STYLE}
            for area_name, (row_c, col_c) in md['centroids'].items():
                spec_name = area_to_spectrum(area_name)
                cluster_id = cluster_lookup.get(spec_name)
                if cluster_id is None:
                    continue
                sk = get_style_key(spec_name, cluster_id)
                if sk not in style_points:
                    continue
                col_adj = col_c - 1
                if col_adj < 0 or col_adj >= md['nx']:
                    continue
                x_disp = np.interp(col_adj, [0, md['nx'] - 1],
                                   [md['extent'][0], md['extent'][1]])
                y_disp = np.interp(row_c, [0, md['ny'] - 1],
                                   [md['extent'][2], md['extent'][3]])
                style_points[sk][0].append(x_disp)
                style_points[sk][1].append(y_disp)

            for sk, style in CLUSTER_STYLE.items():
                xs, ys = style_points[sk]
                if xs:
                    ax.scatter(xs, ys, marker=style['marker'], facecolors='none',
                               edgecolors='white', s=60, linewidths=0.8, zorder=5)

        title = MAP_LABELS.get(md['path'].stem, md['path'].stem)
        ax.set_title(title, fontsize=7)
        ax.set_xticks([])
        ax.set_yticks([])

        x_range = md['extent'][1] - md['extent'][0]
        add_scale_bar(ax, x_range, bar_length_mm=0.5 if x_range > 1.2 else 0.2)

    for j in range(n_maps, len(axes)):
        axes[j].axis('off')

    # RGB triangle legend in last empty panel
    if n_maps < len(axes):
        leg_ax = axes[n_maps]
        leg_ax.axis('on')
        leg_ax.set_xticks([])
        leg_ax.set_yticks([])
        for spine in leg_ax.spines.values():
            spine.set_visible(False)

        tri_ax = leg_ax.inset_axes([0.05, 0.5, 0.9, 0.45])
        tri_ax.imshow(_tri_img, origin='upper', interpolation='bilinear')
        tri_ax.set_xlim(0, _tri_img.shape[1])
        tri_ax.set_ylim(_tri_img.shape[0], 0)
        tri_ax.axis('off')
        s = _tri_img.shape[0]
        tri_ax.text(_tri_vr[0], _tri_vr[1] - s * 0.08, short(r_name), color='red',
                    fontsize=9, fontweight='bold', ha='center', va='bottom')
        tri_ax.text(_tri_vb[0], _tri_vb[1] - s * 0.08, short(b_name), color='blue',
                    fontsize=9, fontweight='bold', ha='center', va='bottom')
        tri_ax.text(_tri_vg[0], _tri_vg[1] + s * 0.05, short(g_name), color='lime',
                    fontsize=9, fontweight='bold', ha='center', va='top')

        if show_xanes:
            y_cur = 0.45
            for sk, style in CLUSTER_STYLE.items():
                leg_ax.scatter([0.15], [y_cur], marker=style['marker'], s=50,
                               facecolors='none', edgecolors='black', linewidths=0.8,
                               transform=leg_ax.transAxes, clip_on=False)
                leg_ax.text(0.25, y_cur, style['label'], fontsize=8,
                            va='center', transform=leg_ax.transAxes)
                y_cur -= 0.07

    fig.suptitle(f'Tricolor: {short(r_name)} (R) / {short(g_name)} (G) / {short(b_name)} (B)',
                 fontsize=12, y=1.01)
    fig.tight_layout()
    show_fig(fig)

---
## Interactive map viewer

Select elements for the RGB channels and click **Render Maps** to generate
tricolor composites with cluster overlays for all map files.

In [8]:
_dd_r = widgets.Dropdown(options=ELEMENT_ROIS, value='Fe Ka', description='Red:')
_dd_g = widgets.Dropdown(options=ELEMENT_ROIS, value='Ca Ka', description='Green:')
_dd_b = widgets.Dropdown(options=ELEMENT_ROIS, value='K Ka',  description='Blue:')
_cb_xanes = widgets.Checkbox(value=True, description='Show XANES spots', indent=False)
_btn = widgets.Button(description='Render Maps', button_style='primary')
_out = widgets.Output()

def _on_render(_):
    _out.clear_output(wait=True)
    with _out:
        render_maps(_dd_r.value, _dd_g.value, _dd_b.value, show_xanes=_cb_xanes.value)

_btn.on_click(_on_render)
display(widgets.VBox([
    widgets.HBox([_dd_r, _dd_g, _dd_b]),
    widgets.HBox([_cb_xanes, _btn]),
    _out,
]))

VBox(children=(HBox(children=(Dropdown(description='Red:', index=14, options=('Mg Ka', 'Al Ka', 'Si Ka', 'P Ka…