In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.stats import pearsonr
from pathlib import Path
import h5py
import ipywidgets as widgets
from IPython.display import display

In [None]:
MAP_DIR = Path("/Users/christiandewey/Code/gp17-ant_sediments/maps")
PCA_DIR = Path("/Users/christiandewey/Code/gp17-ant_sediments/pca_results")

ELEMENTS = ["Fe Ka", "Ca Ka", "K Ka", "Ti Ka", "Mn Ka", "S Ka", "Si Ka", "P Ka"]
SHORT = {e: e.replace(" Ka", "") for e in ELEMENTS}

CLUSTER_COLORS = {1: '#1f77b4', 2: '#ff7f0e', 3: '#2ca02c', 4: '#d62728', 5: '#9467bd'}

# Map labels: filename stem → display label (consistent across all figures)
MAP_LABELS = {
    "1x1_10um_flaky_dark_gt15_001": "Map 1",
    "1x1_10um_flaky_gray_mix_gt15_001": "Map 2",
    "1x1_10um_rectangles_flakes_gt15_2_001": "Map 3",
    "2x2_10um_concentric_gray_1_001": "Map 4",
    "2x2_10um_concentric_gray_3_001": "Map 5",
    "2x2_10um_flaky_1_001": "Map 6",
    "2x2_10um_flaky_2_001": "Map 7",
    "2x2_10um_flaky_nodule_001": "Map 8",
    "2x2_10um_flaky_smooth_2_001": "Map 9",
    "2x2_10um_rectangles_gt15_1_001": "Map 10",
    "2x2_10um_striated_gt15_2_001": "Map 11",
    "2x2_10um_super_dark_gt15_4_001": "Map 12",
    "2x2_10um_white_band_001": "Map 13",
}


def get_roi_map(f, roi_name):
    names = [n.decode() if isinstance(n, bytes) else n for n in f["xrmmap/roimap/sum_name"][:]]
    if roi_name in names:
        idx = names.index(roi_name)
        return f["xrmmap/roimap/sum_cor"][:, 1:-1, idx].astype(float)
    return None


# Discover all maps
all_h5 = sorted([p for p in MAP_DIR.glob("*.h5")
                 if "test_map" not in p.name and "elongated_particle" not in p.name])
map_names = [p.stem for p in all_h5]
map_lookup = {p.stem: p for p in all_h5}

# Load cluster assignments
clusters_df = pd.read_csv(PCA_DIR / "cluster_assignments.csv")
cluster_lookup = dict(zip(clusters_df["spectrum"], clusters_df["cluster"]))

print(f"Found {len(all_h5)} maps:")
for name in map_names:
    print(f"  {MAP_LABELS.get(name, name)}: {name}")

In [7]:
def load_map_data(h5_path):
    """Load element maps from an HDF5 file."""
    with h5py.File(h5_path, "r") as f:
        maps = {}
        for elem in ELEMENTS:
            m = get_roi_map(f, elem)
            if m is not None:
                maps[SHORT[elem]] = m
    return maps


# Pre-scan which elements are available per map
map_elements = {}
for name, path in map_lookup.items():
    with h5py.File(path, "r") as f:
        roi_names = [n.decode() if isinstance(n, bytes) else n for n in f["xrmmap/roimap/sum_name"][:]]
    avail = [SHORT[e] for e in ELEMENTS if e in roi_names]
    map_elements[name] = avail
    print(f"  {name}: {', '.join(avail)}")

  1x1_10um_flaky_dark_gt15_001: Fe, Ca, K, Ti, Mn, S, Si, P
  1x1_10um_flaky_gray_mix_gt15_001: Fe, Ca, K, Ti, Mn, S, Si, P
  1x1_10um_rectangles_flakes_gt15_2_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_concentric_gray_1_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_concentric_gray_3_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_elongated_particle_gt15_1_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_flaky_1_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_flaky_2_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_flaky_nodule_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_flaky_smooth_2_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_rectangles_gt15_1_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_striated_gt15_2_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_striated_gt15_2_002: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_super_dark_gt15_4_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_white_band_001: Fe, Ca, K, Ti, Mn, S, Si, P
  2x2_10um_white_gray_particles_002: Fe, Ca, K, Ti, Mn, S, Si, P


In [None]:
import io
from IPython.display import Image as IPImage

# Pre-load all maps
cached_maps = {}
for name in map_names:
    cached_maps[name] = load_map_data(map_lookup[name])
print(f"Loaded pixel data for {len(cached_maps)} maps")

all_map_elems = sorted(set().union(*(map_elements[n] for n in map_names)))
n_maps = len(map_names)

def show_fig(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    display(IPImage(data=buf.read()))

def density_scatter_all(x_elem="Fe", y_elem="Ca"):
    """Plot pixel-level density scatter for all maps in a 4-column grid."""
    if x_elem == y_elem:
        print("Select two different elements.")
        return
    ncols = 4
    nrows = int(np.ceil(n_maps / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3.5 * nrows))
    axes = axes.flatten()

    for i, map_label in enumerate(map_names):
        ax = axes[i]
        maps = cached_maps[map_label]
        label = MAP_LABELS.get(map_label, map_label)
        if x_elem not in maps or y_elem not in maps:
            ax.set_title(f"{label}\n(missing data)", fontsize=7)
            ax.axis("off")
            continue

        xd = maps[x_elem].ravel()
        yd = maps[y_elem].ravel()
        mask = (xd > 0) & (yd > 0)
        if mask.sum() < 10:
            ax.set_title(f"{label}\n(too few pixels)", fontsize=7)
            ax.axis("off")
            continue

        r, _ = pearsonr(xd[mask], yd[mask])
        ax.hist2d(xd[mask], yd[mask], bins=80, cmap="inferno", norm=LogNorm(), rasterized=True)
        ax.set_title(f"{label}\nr={r:.2f}", fontsize=7)
        ax.tick_params(labelsize=5)
        ax.ticklabel_format(style="scientific", axis="both", scilimits=(0, 0))
        ax.xaxis.get_offset_text().set_fontsize(5)
        ax.yaxis.get_offset_text().set_fontsize(5)
        if i % ncols == 0:
            ax.set_ylabel(f"{y_elem} K\u03b1", fontsize=7)
        if i >= (nrows - 1) * ncols:
            ax.set_xlabel(f"{x_elem} K\u03b1", fontsize=7)

    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"{x_elem} vs {y_elem} \u2014 All Maps", fontsize=12)
    fig.tight_layout()
    show_fig(fig)

In [None]:
_dd_x = widgets.Dropdown(options=all_map_elems, value="Fe", description="X:")
_dd_y = widgets.Dropdown(options=all_map_elems, value="Ca", description="Y:")
_dd_btn = widgets.Button(description="Plot", button_style="primary")
_dd_out = widgets.Output()

def _on_click(_):
    _dd_out.clear_output(wait=True)
    with _dd_out:
        density_scatter_all(_dd_x.value, _dd_y.value)

_dd_btn.on_click(_on_click)
display(widgets.VBox([widgets.HBox([_dd_x, _dd_y, _dd_btn]), _dd_out]))

VBox(children=(HBox(children=(Dropdown(description='X:', index=1, options=('Ca', 'Fe', 'K', 'Mn', 'P', 'S', 'S…