## TEM-EELS

In [None]:
import sys
import hyperspy.api as hs
from pathlib import Path as path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from matplotlib import ticker
from matplotlib.colors import LinearSegmentedColormap, Normalize
import xarray as xr

In [None]:
# 画图的初始设置
plt.style.use(r'C:\Users\chengliu\OneDrive - UAB\ICMAB-python\Figure\liuchzzyy.mplstyle')
# display(plt.style.available)

# 颜色设定
sys.path.append(r'C:\Users\chengliu\OneDrive - UAB\ICMAB-Python\Figure')
from colors import tol_cmap, tol_cset # type: ignore
colors = list(tol_cset('vibrant'))
if r'sunset' not in plt.colormaps():
    plt.colormaps.register(tol_cmap('sunset'))
if r'rainbow_PuRd' not in plt.colormaps():
    plt.colormaps.register(tol_cmap('rainbow_PuRd')) # 备用 plasma

# 输出的文件夹
path_out = path(r"C:\Users\chengliu\Desktop\Figure")

In [None]:
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
def add_sizebar(ax, size, bardata, color, barunits=None):
    if isinstance(bardata, float):
        if  not isinstance(barunits, str):
            raise ValueError("barunits must be provided if bardata is a float.")
        asb = AnchoredSizeBar(ax.transData,
                        size / bardata,
                        '{} {}'.format(size, barunits),
                        loc='lower left',
                        pad=0.1, borderpad=0.5, sep=1,
                        frameon=False,
                        color=color,
                        label_top=True,
                        fontproperties={'size':9})
        ax.add_artist(asb)
    else:
        if len(bardata.axes_manager.navigation_shape) == 2:
            barsize = bardata.axes_manager.navigation_axes[0].scale
            unit = bardata.axes_manager.navigation_axes[0].units
        elif len(bardata.axes_manager.signal_shape) == 2:
            barsize = bardata.axes_manager.signal_axes[0].scale
            unit = bardata.axes_manager.signal_axes[0].units
        asb = AnchoredSizeBar(ax.transData,
                            size / barsize,  # type: ignore
                            '{} {}'.format(size, unit),  # type: ignore
                            loc='lower left',
                            pad=0.1, borderpad=0.5, sep=1,
                            frameon=False,
                            color=color,
                            label_top=True,
                            fontproperties={'size':9})
        ax.add_artist(asb)
    return asb

### 读取数据

In [None]:
path_file = path(r'C:\Users\chengliu\OneDrive - UAB\ICMAB-Data\Zn-Mn\Uno\Result\TEM\ExSitu\αMnO2\Charge\1st0.9V\αMnO2 + PVDF + SP\1M ZnSO4 + 1M MnSO4\2024-EMCA\EELS\Data')
file = path.joinpath(path_file, r'SI1_80pA_10ms', r'STEM SI.dm4')
eels_list = hs.load(file) # type: ignore
eels_list

In [None]:
for file in eels_list:
    if len(file.axes_manager.shape) >= 2:
        if len(file.axes_manager.navigation_shape) == 2:
            if file.axes_manager.navigation_axes[0].units == r'µm':
                file.axes_manager.convert_units(axes="navigation", units='nm', same_units=True, factor=1000)
        elif len(file.axes_manager.signal_shape) == 2:
            if file.axes_manager.signal_axes[0].units == r'µm':
                file.axes_manager.convert_units(axes="signal", units='nm', same_units=True, factor=1000)

    if len(file.axes_manager.shape) ==3:
        if len(file.axes_manager.navigation_shape) == 2:
            for axis in file.axes_manager.navigation_axes:
                axis.offset = 0
        elif len(file.axes_manager.signal_shape) == 2:
            for axis in file.axes_manager.signal_axes:
                axis.offset = 0

eels_list[-1].axes_manager

In [None]:
# HADDF 图
%matplotlib inline

plt.close('all')
fig = plt.figure(figsize=(3.3, 2.5))
gs = gridspec.GridSpec(1, 1, width_ratios=None, height_ratios=None, wspace=0, hspace=0, figure=fig)

subfig = fig.add_subfigure(gs[0, 0], zorder=0)
ax = subfig.add_subplot()
ax.set_position((0, 0, 1.0, 1.0))

ax.imshow(eels_list[1].data, cmap='gray', aspect=1.0)
add_sizebar(ax, 20, eels_list[1], 'w')
ax.set_axis_off()

plt.savefig(path.joinpath(path_out, r'1_TEM_EELS_HAADF_300.tif'), pad_inches=0.01, bbox_inches='tight', dpi=300, transparent=False, pil_kwargs={"compression": "tiff_lzw"})
plt.show()

### ZLP 校准

In [None]:
eels_list[-2].align_zero_loss_peak(subpixel=True, also_align=[eels_list[-1]])

### 确定样品厚度

In [None]:
%matplotlib ipympl
th = eels_list[-2].estimate_elastic_scattering_threshold(window=30)
# th.T.plot()
s_thickness = eels_list[-2].estimate_thickness(threshold=th,)  # MnO2 5.03
s_thickness.plot()

# 保存数据
s_thickness.save(path.joinpath(path_out, r'1-sample_thickness.hspy'), overwrite=True)

In [None]:
# thickness 图 横向图
%matplotlib inline
vmax = s_thickness.nanmax().data[0].round(2)-0.01  # 需要做调整
plt.close('all')
fig = plt.figure(figsize=(3.3, 2.5))
gs = gridspec.GridSpec(1, 1, width_ratios=None, height_ratios=None, wspace=0, hspace=0, figure=fig)

subfig = fig.add_subfigure(gs[0, 0], zorder=0)
ax = subfig.add_subplot()
ax.set_position((0, 0, 1.0, 1.0))
im = ax.imshow(s_thickness.data, cmap='gray', aspect=1.0, vmin=0.0, vmax=vmax)
add_sizebar(ax, 20, s_thickness, 'w')
ax.set_axis_off()

cax = subfig.add_subplot()
cax.set_position((0.0, 0.1, 0.4, 0.05))
subfig.colorbar(mappable=im, cax=cax, ticks=np.linspace(0, vmax, 5), format='{x:.1f}', location='bottom', orientation='horizontal')
cax.tick_params(axis='x', direction='out')

cax.text(1.1, 0.35, r'Tickness ($\frac{t}{\lambda}$)', horizontalalignment='left', verticalalignment='center', transform=cax.transAxes, fontsize=11, c='k')

plt.savefig(path.joinpath(path_out, r'1-sample_thickness_300.tif'), pad_inches=0.1, bbox_inches='tight', dpi=300, transparent=False, pil_kwargs={"compression": "tiff_lzw"})
plt.show()

### PCA 降噪

In [None]:
# PCA 
ps = eels_list[-1].deepcopy()
ps.decomposition(algorithm="SVD", navigation_mask=None, centre='signal')

In [None]:
%matplotlib ipympl
# ps.plot_explained_variance_ratio(n=20, threshold=3, vline=True)
num_components = ps.estimate_elbow_position()  # component number =  num_components + 1 
# ps.plot_decomposition_results()
# ps.plot_decomposition_loadings(comp_ids=num_components+1, axes_decor="off", with_factors=True, per_row=(num_components+1)//2,)
# # reconstruct data
ps_recon = ps.get_decomposition_model(components=2*(num_components+1))
ps_recon.save(path.joinpath(path_out, r'1-data_pca.hspy'), overwrite=True)
ps_recon.plot()

#### 除去 offset 背景

In [None]:
# 定义元素特征峰能量范围（可扩展）
element_lines = {
    'O':   (480.0, 600.0),   # O-K
    'Mn':  (600.0, 700.0),   # Mn-L
    'Zn':  (980.0, 1180.0), # Zn-L
    'S':   (2430.0, 2550.0), # S-K
}

fit_ranges = {
    'O':   (480.0, 520.0),   # O-K
    'Mn':  (600.0, 623.0),   # Mn-L
    'Zn':  (980.0, 1010.0), # Zn-L
    'S':   (2430.0, 2450.0), # S-K
}

data_ranges = {
    'O':   (480.0, 600.0),   # O-K
    'Mn':  (600.0, 700.0),   # Mn-L
    'Zn':  (980.0, 1280.0), # Zn-L
    'S':   (2430.0, 2550.0), # S-K
}

def get_elements(data, element_lines: dict[str, tuple[float, float]]) -> tuple[str, ...]:
    """根据谱图能量范围，确定包含的元素。"""
    elements = tuple(
        element for element, (lowenergy, highenergy) in element_lines.items()
        if data.axes_manager['Energy loss'].high_value >= highenergy -20 and data.axes_manager['Energy loss'].low_value <= lowenergy + 20
    )
    return elements

import tqdm.notebook as tqdm
def remove_bkg(
    data,
    lowloss =None,
    element_lines: dict = element_lines,
    data_ranges: dict | None = data_ranges,
    fit_ranges: dict = fit_ranges,
    mask=None,
    component_name: str = 'PowerLaw', # type: ignore
    plot_fig: bool = False,
    save_data: bool = True,
    path_out: path = path_out,
) -> dict:
    """
    自动分段处理谱图并移除背景。
    返回每个元素对应的去背景数据段组成的字典。
    """
    hs.set_log_level('ERROR')

    elements = get_elements(data, element_lines)
    if not elements:
        raise ValueError("No recognizable elements found in the data's energy range.")

    # 添加元素（避免重复）
    data.add_elements(set(elements))
    if len(elements) >=2:
            elements += ('all',)

    result: dict = {}

    for element in tqdm.tqdm(elements):
        
        # 取出能量段与拟合范围
        if element == 'all':
            result[element] = data
        else:
            if data_ranges is not None:
                start, end = data_ranges[element]
                # 提取该段信号
                sig = data.isig[start:end]
            else:
                sig = data

            fit_start, fit_end = fit_ranges[element]

            component_list: dict = {
                'PowerLaw': hs.model.components1D.PowerLaw(),
                'Offset': hs.model.components1D.Offset(),
            }

            # 创建模型并拟合背景
            model = sig.create_model(low_loss=lowloss, auto_add_edges=False, auto_background=False)
            model.append(component_list[component_name])
            model.fit_component(
                component = component_list[component_name],
                signal_range=(fit_start, fit_end),
                mask=mask,
                fit_independent=True,
                only_current=False
            )

            if plot_fig:
                print(f"Plotting background fit for {element}")
                model.plot(plot_components=True)

            # 减去背景并保存
            result[element] = (sig - model.as_signal(component_list=[component_name])).deepcopy()
            if save_data:
                result[element].save(path.joinpath(path_out, f'2-ps_{element}_rebgk.hspy'), overwrite=True)
            
    return result

results_rebgk = remove_bkg(
    ps_recon,
    lowloss=eels_list[-2],
    element_lines=element_lines,
    data_ranges=data_ranges,
    fit_ranges=fit_ranges,
    mask=None,
    plot_fig=False,
    save_data=True,
    path_out=path_out,
)

In [None]:
%matplotlib ipympl
results_rebgk['Mn'].plot()
results_rebgk

### Mask 的选择，EELS Mappings 的方法

In [None]:
%matplotlib widget
roi_definitions = {
    'Mn': (651.0, 655.0),
    'O':  (520.0, 550.0),
    'Zn': (1025.0, 1165.0),
    'S':  (2467.0, 2478.0)
}

mask_ranges: dict[str, tuple[float, float] | None] = {
    'Mn': None, # (6.0, 123.0),
    'O':   (50.0, 574.0), 
    'Zn': None, # (80.0, 148.0), 
    'S':   None, # (2.0, 20.0),
}

# 存放映射结果
ps_mappings = {}
ps_masks = {}
for element, (start, end) in roi_definitions.items():
    if element in results_rebgk.keys():
        if element == 'all':
            ps_mappings[element] = None
        else:
            signal = results_rebgk[element]
            roi = hs.roi.SpanROI(start, end)
            _, map_img = hs.plot.plot_roi_map(signal, rois=[roi])
            if mask_ranges[element] is not None:
                ps_mappings[element] = map_img[0]
                mask = (ps_mappings[element] < mask_ranges[element][0]) | (ps_mappings[element] > mask_ranges[element][1])  # type: ignore
                ps_masks[element] = mask
                # ps_mappings[element].data[mask.data] = np.nan

            if mask_ranges[element] is None:
                q_min, q_max = np.nanpercentile(map_img[0].data, [0, 100])
                print(f"Element: {element}, Q_min: {q_min}, Q_max: {q_max}")
                ps_mappings[element] = map_img[0]
                mask = (ps_mappings[element] < q_min) | (ps_mappings[element] > q_max)
                ps_masks[element] = mask
                # ps_mappings[element].data[mask.data] = np.nan

            # 保存映射结果
            ps_mappings[element].save(path.joinpath(path_out, f'3-ps_{element}_mapping.hspy'), overwrite=True)
            ps_masks[element].save(path.joinpath(path_out, f'3-ps_{element}_mask.hspy'), overwrite=True)
    else:
        print(f"Element '{element}' not found. Skipping.")


# 关闭所有图窗口
plt.close('all')
element = 'O'
# 创建一个横向的子图布局
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3), gridspec_kw={'width_ratios': [2, 1]}) # type: ignore
element_mapping = ps_mappings.get(element).data # type: ignore
element_mask = ps_masks.get(element).data # type: ignore
file = np.where(element_mask, np.nan, element_mapping) # type: ignore
# 画直方图
ax1.hist(file.flatten(), bins=30, color='steelblue', edgecolor='black')
# 去除上右边框
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
# 画灰度图像
im = ax2.imshow(file, cmap='Spectral', aspect=1.0)
ax2.axis('off')
ax2.set_title(f'{element} Mapping')
add_sizebar(ax2, 20, ps_mappings[element], 'k') 
# 自动紧凑布局
plt.tight_layout()
plt.show()

In [None]:
from typing import Optional
%matplotlib widget
def EELSMappings(
    data: dict,
    colors_map: Optional[dict[str, np.ndarray]] = None,
    selected_elements: tuple[str, ...] = (r'Mn', r'Zn'),
    mask_mapping: dict | None = None,
    fig_plot: bool = False,
    path_out: Optional[path] = None,
    save_data: bool = True
) -> dict:
    """
    读取 EDS 中的元素（如 Mn, Zn）数据，生成 RGB 合成图，保留原始 NaN 区域。

    参数:
        data (dict): 包含元素映射的输入字典。
        colors_map (dict, optional): 每个元素对应的 RGB 颜色。
        selected_elements (tuple[str]): 要可视化的元素。
        fig_plot (bool): 是否绘图。
        path_out (Path, optional): 图像/数据保存路径。
        save_data (bool): 是否保存 npz 文件。
    """
    images: dict[str, np.ndarray] = {}
    eelsmappings_out: dict[str, np.ndarray] = {}

    if mask_mapping is not None:
        if len(mask_mapping) == 1:
            mask: np.ndarray = mask_mapping.data # type: ignore
        else:
            mask: np.ndarray = np.any([mask_mapping[key].data for key in mask_mapping], axis=0)
    else:
        print("No mask provided, provide the mask.")
        
    for key, file in data.items():
        if key in selected_elements and getattr(file.data, "ndim", 0) == 2:
            img = file.data
            img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
            img = (img - img.min()) / (img.max() - img.min() + 1e-8)
            images[key] = img

    if colors_map is None:
        colors_map = {
            'Mn': np.array([236, 236, 0]) / 255,
            'Zn': np.array([252, 0, 252]) / 255,
            'S':  np.array([0, 252, 252]) / 255,
            'O':  np.array([0, 0, 252]) / 255,
        }

    shape = next(iter(images.values())).shape
    eels_mapping = np.zeros((*shape, 3), dtype=np.float32)

    for element in selected_elements:
        img = images[element]
        eels_mapping += img[..., None] * colors_map[element]

    eels_mapping = np.clip(eels_mapping, 0, 1)
    eels_mapping[mask] = np.nan # type: ignore

    if fig_plot:
        plt.close('all')
        fig, ax = plt.subplots(figsize=(3.3, 2.5))
        im = ax.imshow(eels_mapping)
        ax.axis('off')
        add_sizebar(ax, 20, data[selected_elements[0]], 'k')

        for idx, element in enumerate(selected_elements):
            if element not in colors_map:
                continue
            ax.text(
                0.02 + 0.09 * idx, 0.98, element, color='k' if element == 'Mn' else 'w',
                bbox=dict(ec=None, fc=colors_map[element], alpha=1.0, boxstyle='Square, pad=0.3'),
                transform=ax.transAxes, fontsize=9, va='top', ha='left',
                fontfamily='Arial', fontweight='bold'
            )

        fig.tight_layout()

        if path_out:
            tif_path = path_out.joinpath('3-ps_eels_mapping_300.tif')
            npz_path = path_out.joinpath('3-ps_eels_mapping.npz')
            fig.savefig(
                tif_path,
                pad_inches=0.05, bbox_inches='tight', dpi=300, transparent=False,
                pil_kwargs={"compression": "tiff_lzw"}
            )
            if save_data:
                np.savez(npz_path, eels_mapping=eels_mapping)

        plt.show()
        final_mask = ps_mappings['Mn'].deepcopy()
        final_mask.data = mask
        final_mask.save(path_out.joinpath('3-ps_Zn_Mn_mask.hspy'), overwrite=True)
        eelsmappings_out['mask'] = final_mask
        eelsmappings_out['eels_mapping'] = eels_mapping
    return eelsmappings_out

eelsmappings = EELSMappings(ps_mappings, selected_elements=(r'Mn', r'Zn'), mask_mapping=ps_masks['O'], fig_plot=True, path_out=path_out)
ps_masks['Zn_Mn'] = eelsmappings['mask']

#### 计算 Mn 和 Zn 的相关性

In [None]:
%matplotlib ipympl
from PIL import Image
plt.close('all')
def generate_distribution_phase_mapping(
    mapping_Mn: np.ndarray,
    mapping_Zn: np.ndarray,
    mask: np.ndarray,
    k_min: float | None = None,
    k_max: float | None = None,
    k_step: float = 0.2,
    gif_path: path = path_out,
    figsize: tuple = (12, 2.5),
    dpi: int = 100,
    bins: int = 100,
) -> None:
    assert mapping_Mn.shape == mapping_Zn.shape == mask.shape, "All input arrays must have the same shape."

    with np.errstate(divide='ignore', invalid='ignore'):
        K_map = np.where((mapping_Zn != 0) & (~mask), mapping_Mn / mapping_Zn, np.nan)

    mapping_Mn_A = mapping_Mn[~mask]
    mapping_Zn_A = mapping_Zn[~mask]
    K_flat = K_map[~mask]
    valid = np.isfinite(K_flat)
    mapping_Mn_A_K = mapping_Mn_A[valid]
    mapping_Zn_A_K = mapping_Zn_A[valid]
    K_flat = K_flat[valid]

    if k_min is None:
        k_min = float(np.nanmin(K_flat))
    if k_max is None:
        k_max = float(np.nanmax(K_flat))
    k_ranges = np.arange(k_min, k_max, k_step)

    cmap_base = plt.cm.hot_r(np.linspace(0.2, 0.8, 256))
    cmap = LinearSegmentedColormap.from_list('cmap', cmap_base)
    cmap.set_over('white')
    cmap.set_under('black')

    fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'wspace': 0.05}, constrained_layout=True)
    ax_overlay, ax_map = axes

    ax_overlay.set_position([0.1, 0.18, 0.3, 0.7])
    ax_overlay.set_box_aspect(1.0)
    ax_overlay.hist2d(mapping_Mn_A_K, mapping_Zn_A_K, cmap=cmap, cmin=1, bins=bins)
    sc2 = ax_overlay.scatter([], [], s=2, c='magenta', alpha=0.1)
    legend_text2 = ax_overlay.text(0.95, 0.95, '', ha='right', va='top', transform=ax_overlay.transAxes, fontsize=8)
    ax_overlay.set_title("After Masking", fontsize=12, loc='left')
    ax_overlay.set_xlabel("Mn")
    ax_overlay.set_ylabel("Zn")

    ax_map.set_position([0.4, 0.17, 0.3, 0.7])
    masked_map = np.where(mask, np.nan, mapping_Mn)
    img_map = ax_map.imshow(masked_map, cmap='gray')
    ax_map.set_title("Distribution of Mn and Zn", fontsize=12, loc='left')
    ax_map.axis('off')
    mask_overlay = ax_map.imshow(np.zeros_like(mapping_Mn), cmap='hot', alpha=0.7, vmin=0, vmax=1)

    tiff_stack = []
    for frame_index in range(len(k_ranges)):
        k_start = k_ranges[frame_index]
        k_end = k_start + k_step

        range_mask_flat = (K_flat > k_start) & (K_flat <= k_end)
        Mn_filtered = mapping_Mn_A_K[range_mask_flat]
        Zn_filtered = mapping_Zn_A_K[range_mask_flat]
        sc2.set_offsets(np.column_stack((Mn_filtered, Zn_filtered)))
        legend_text2.set_text(f'{k_start:.2f} < K ≤ {k_end:.2f}')

        range_mask_2d = (K_map > k_start) & (K_map <= k_end)
        mask_overlay.set_data(range_mask_2d.astype(float))

        fig.canvas.draw()
        img_array = np.asarray(fig.canvas.buffer_rgba())[:, :, :3]
        pil_img = Image.fromarray(img_array)
        tiff_stack.append(pil_img.convert("RGB"))

    # 保存 TIFF stack
    tiff_output_path = gif_path.joinpath('3-distribution_phase_mapping.tif')
    tiff_stack[0].save(
        tiff_output_path,
        save_all=True,
        append_images=tiff_stack[1:],
        compression="tiff_deflate"
    )
    plt.close(fig)
    print(f"[✔] TIFF stack saved to: {tiff_output_path.resolve()}")

generate_distribution_phase_mapping(
    ps_mappings['Mn'].data,
    ps_mappings['Zn'].data,
    ps_masks['Zn_Mn'].data,
    k_min=-0.2,
    k_max=6.0,
    k_step=0.2,
    gif_path=path_out,
    figsize=(7.0, 2.5),
    dpi=80,
    bins=60
)

### Onset energy of L3 of Mn

In [None]:
# 对 Energy Axis 进行插值修改

hs.set_log_level('ERROR')
from hyperspy.axes import UniformDataAxis
def interpolate_axis(data, new_scale:float = 0.1):
    if r'Energy loss' in [axes.name for axes in data.axes_manager.signal_axes]:
        axis_new = UniformDataAxis(offset=data.axes_manager['Energy loss'].offset, scale=new_scale, size=((data.axes_manager['Energy loss'].high_value - data.axes_manager['Energy loss'].low_value)//new_scale+2), name="Energy loss", units='eV', navigate=False, is_binned=True)
        data = data.interpolate_on_axis(axis_new, 2, inplace=False)
    return data

In [None]:
ps_Mn_L = results_rebgk['Mn'].deepcopy()
ps_Mn_L = interpolate_axis(ps_Mn_L, new_scale=0.1)
ps_onset_intensity_L3 = np.multiply(0.1, ps_Mn_L.max('Energy loss'))
ps_onset_energy_Mn_L3 = ps_onset_intensity_L3.deepcopy()
ps_onset_energy_Mn_L3.data = np.where(ps_masks['Zn_Mn'].data, np.nan, np.abs(ps_Mn_L.isig[:644.0] - ps_onset_intensity_L3).valuemin('Energy loss')) 
%matplotlib ipympl
ps_onset_energy_Mn_L3.plot()

#### 优化 onset_energy

In [None]:
plt.close('all')
def optimize_onset_energy(onset_energy_data, energy_range: tuple = (636.0, 638.0), save_path: path = path_out, fig_plot: bool = True) -> dict:
    """
    绘制 Onset Energy 分布图，展示 mask 之前与之后的对比。
    
    参数：
    - onset_energy_data: Hyperspy 信号对象, Onset Energy 图像
    - energy_range: tuple, 限定的 (min, max) 范围
    - save_path: Path 对象或字符串，图像保存路径（不保存则为 None）
    """
    from matplotlib import ticker
    onset_energy:dict = {}
    # 创建 mask 和被 mask 后的新图像
    mask = (onset_energy_data < energy_range[0]) & (onset_energy_data > energy_range[1])
    onset_energy_filtered = onset_energy_data.deepcopy()
    onset_energy_filtered.data = np.where(mask.data, np.nan, onset_energy_data.data)

    # 扁平化数据用于排序可视化
    data_all = onset_energy_data.data.flatten()
    data_masked = onset_energy_filtered.data.flatten()
    data_all = data_all[~np.isnan(data_all)]
    data_masked = data_masked[~np.isnan(data_masked)]

    # 画图
    if fig_plot:
        fig, axes = plt.subplots(1, 2, figsize=(7.0, 2.5), gridspec_kw=dict(wspace=0.05))
    
        ax0 = axes[0] # type: ignore
        ax0.plot(np.sort(data_all), label='Raw', color=colors[1])
        ax0.plot(np.sort(data_masked), label='Masked', color=colors[4], ls='--')
        ax0.set_xlabel(r'Number (N)', fontsize=11)
        ax0.set_ylabel(r'Onset Energy (eV)', fontsize=11, labelpad=7)
        ax0.set_ylim(energy_range[0]-2, energy_range[1]+2)
        ax0.yaxis.set_major_locator(ticker.MultipleLocator(base=1, offset=0))
        ax0.yaxis.set_minor_locator(ticker.MultipleLocator(base=0.5, offset=0))
        ax0.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        ax0.axhline(y=energy_range[0], ls='--', c='k')
        ax0.axhline(y=energy_range[1], ls='--', c='k')
        ax0.legend(loc='lower right', fontsize=8, frameon=False)

        ax1 = axes[1]  # type: ignore
        vmin_L3 = energy_range[0]
        vmax_L3 = energy_range[1]
        N_color = 5
        # 创建 colormap 并设置 over/under 颜色
        cmap = plt.cm.hot_r(np.linspace(0.2, 0.8, 256))
        cmap = LinearSegmentedColormap.from_list('cmap', cmap)
        cmap.set_over('white')        # 超出 vmax 显示白色
        cmap.set_under('black')       # 低于 vmin 显示黑色

        onset_energy_filtered_a = np.nan_to_num(onset_energy_filtered.data, copy=True, nan=0)
        # 设置 norm
        norm = Normalize(vmin=vmin_L3, vmax=vmax_L3)
        im = ax1.imshow(onset_energy_filtered_a, cmap=cmap, norm=norm)
        add_sizebar(ax1, 20, onset_energy_filtered, 'w')
        ax1.axis('off')

        cax = fig.add_subplot()
        cax.set_position((1.03, 0.1, 0.03, 0.8))
        fig.colorbar(mappable=im, cax=cax, ticks=np.linspace(vmin_L3, vmax_L3, N_color), format='{x:.1f}', location='right', orientation='vertical')
        cax.tick_params(axis='x', direction='out')
        cax.text(4.0, 0.5, r'Onset Energy (eV)', rotation=90, horizontalalignment='left', verticalalignment='center', transform=cax.transAxes, fontsize=11, c='k')
        plt.show()
    else:
        plt.close('all')
            
    # 保存
    if save_path:
        plt.savefig(save_path.joinpath(r'4-TEM_sTXM_L3_OnsetEnergy_300.tif'), pad_inches=0.05, bbox_inches='tight', dpi=300, transparent=False,
                        pil_kwargs={"compression": "tiff_lzw"})
        onset_energy_filtered.save(save_path.joinpath(r'4-TEM_sTXM_L3_OnsetEnergy.hspy'), overwrite=True)

    onset_energy['onset_energy'] = onset_energy_filtered
    onset_energy['mask'] = np.isnan(onset_energy_filtered)
    return onset_energy

%matplotlib inline
onset_energy = optimize_onset_energy(ps_onset_energy_Mn_L3, energy_range=(634.0, 636.0), save_path=path_out, fig_plot=True)
ps_masks['onsetenergy'] = onset_energy['mask']

### 不同区域的信号，以及全局

In [None]:
%matplotlib ipympl
eels_list[1].plot()
rectangular_roi = hs.roi.RectangularROI(left=0.0, right=10.0, top=0.0, bottom=10.0)
roi2D = rectangular_roi.interactive(eels_list[1], color="yellow")

In [None]:
rois_definition = {
   # 'Bulk': (50.0,85.0,12.0,65.0),
   # 'Surface': (12.0,30.0,23.0,68.0),
   'all': (eels_list[-1].axes_manager[0].low_value, eels_list[-1].axes_manager[0].high_value, eels_list[-1].axes_manager[1].low_value, eels_list[-1].axes_manager[1].high_value),
}
rois: dict = {}
for key, value in rois_definition.items():
        rois[key] = hs.roi.RectangularROI(left=value[0], right=value[1], top=value[2], bottom=value[3]) # 需要保证是 float

In [None]:
import matplotlib.patches as patches
def draw_roi_rectangles(ax, roi_dict, color_dict, scale_x, scale_y, zorder=5):
    """
    在指定的 matplotlib 坐标轴上绘制矩形 ROI。
    
    Args:
        ax (matplotlib.axes.Axes): 目标坐标轴。
        roi_dict (dict): 包含 ROI 名称和 hs.roi.RectangularROI 对象的字典。
        color_dict (dict): ROI 名称对应的颜色字典（例如 {'Surface': 'y'}）。
        scale_x (float): X轴像素大小（通常为 `ps_recon.axes_manager[0].scale`）。
        scale_y (float): Y轴像素大小（通常为 `ps_recon.axes_manager[1].scale`）。
        zorder (int): matplotlib 中图层的层级（默认 5）。
    """
    for name, roi in roi_dict.items():
        color = color_dict.get(name, 'k')  # 默认颜色为红色
        x0 = int(roi.x / scale_x)
        y0 = int(roi.y / scale_y)
        w = int(roi.width / scale_x)
        h = int(roi.height / scale_y)

        rect = patches.Rectangle((x0, y0), w, h, linewidth=1.5,
                                 edgecolor=color, facecolor='none',
                                 transform=ax.transData, zorder=zorder)
        ax.add_patch(rect)
        # 添加文字标签在左上角
        ax.text(x0+w, y0+h, name, color='k', fontsize=9,
                ha='left', va='center', zorder=zorder+1)
        
from matplotlib.transforms import Bbox

%matplotlib inline
plt.close('all')
fig = plt.figure(figsize=(3.3, 2.5))
gs = gridspec.GridSpec(1, 1, width_ratios=None, height_ratios=None, wspace=0, hspace=0, figure=fig)

# 图
subfig = fig.add_subfigure(gs[0, 0], zorder=0)
ax = subfig.add_subplot()
ax.set_position((0.0, 0, 1.0, 1.0))
ax.set_axis_off()
file = np.where(ps_masks['Zn_Mn'], np.nan, ps_recon.sum(-1))
ax.imshow(file.data, cmap='gray',aspect=1.0)
sizebar = add_sizebar(ax, 20, ps_recon, 'k')
sizebar.set_bbox_to_anchor(Bbox.from_bounds(0.0, 0.0, 0.0, 0.0), transform=ax.transAxes)
ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, labelbottom=False, labelleft=False,)

draw_roi_rectangles(
    ax, 
    rois, 
    color_dict={'Surface': 'y', 'Bulk': 'y', 'all': 'y'}, 
    scale_x=ps_recon.axes_manager[0].scale, 
    scale_y=ps_recon.axes_manager[1].scale
)

plt.tight_layout()

plt.savefig(path.joinpath(path_out, r'5-TEM_EELS_Selected_Regions_300.tif'), pad_inches=0.05, bbox_inches='tight', dpi=300, transparent=False, pil_kwargs={"compression": "tiff_lzw"})
plt.show()

#### Cluster 分析

In [None]:
def select_roi(eels_list: dict, rois: dict, selected_elements: tuple[str, ...] = (r'Mn', r'O', r'Zn'), num_clusters: int | None =None, mask = ps_masks['Zn_Mn'], save_path: path=path_out, plt_plot:bool =False, sizebarA=50) -> None:
  
    for element in selected_elements:
        if element in eels_list:
            for roi_name, roi in rois.items():
                roi_signal = roi(eels_list[element])
                mask_signal = np.isnan(roi(mask))
                roi_signal.save(path.joinpath(path_out, f'5-Clusters_{element}_{roi_name}_data.hspy'), overwrite=True)
                mask_signal.save(path.joinpath(path_out, f'5-Clusters_{element}_{roi_name}_mask.hspy'), overwrite=True)
                roi_signal.data = np.nan_to_num(roi_signal, nan=0.0, posinf=0.0, neginf=0.0)
                cluster_analysis_A(data=roi_signal, clusters_mask=mask_signal, num_clusters=num_clusters, plt_plot=plt_plot, save_data=True, save_path=save_path, roi_name=roi_name, sizebar=sizebarA)

def cluster_analysis_A(data, clusters_mask, sizebar, num_clusters: int | None = 3, plt_plot: bool = False, save_data: bool = True, save_path: path = path_out, roi_name:int | str | None = None) -> None:

    # cluster 分析
    data.decomposition(algorithm="SVD", centre="signal", navigation_mask=clusters_mask.T, print_info=False)
    num_clusters = data.estimate_number_of_clusters(cluster_source="decomposition", max_clusters=8, preprocessing="standard", navigation_mask=clusters_mask.data) if num_clusters is None else num_clusters

    data.cluster_analysis(cluster_source="decomposition", n_clusters=num_clusters, preprocessing="standard", algorithm="kmeans", n_init=8, navigation_mask=clusters_mask.data)
    
    cluster_signals = data.get_cluster_signals(signal='mean')
    cluster_signals.metadata.add_dictionary(data.metadata.as_dictionary())
    # 添加元素（避免重复）
    elements = get_elements(cluster_signals, element_lines)
    cluster_signals.add_elements(set(elements))
    cluster_labels = data.get_cluster_labels()

    cluster_data: dict = {
        'data': data,
        'signals': cluster_signals,
        'labels': cluster_labels
    }

    if save_data:
        plot_cluster_result(cluster_data, elements=elements, plt_plot=plt_plot, save_path=save_path, roi_name=roi_name, sizebar=sizebar)
        prefix = f"5-Clusters_{elements[0] if len(elements)==1 else 'all'}_{roi_name}_N{cluster_signals.axes_manager.navigation_size}"
        cluster_signals.save(save_path.joinpath(f"{prefix}_signals.hspy"), overwrite=True)
        cluster_labels.save(save_path.joinpath(f"{prefix}_labels.hspy"), overwrite=True)


def plot_cluster_result(cluster_data, elements: tuple, plt_plot:bool = False, save_path: path = path_out, roi_name: int | str | None = None, sizebar=sizebar) -> None:

    # cluster 分布
    data = cluster_data['data']
    labels = cluster_data['labels']
    signals = cluster_data['signals']

    fig = plt.figure(figsize=(7.0, 2.5))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1,1,1], height_ratios=None,
                            wspace=0.01, hspace=0, figure=fig)
    
    # 图 A
    subfig = fig.add_subfigure(gs[0, 0], zorder=0)
    ax = subfig.add_subplot()
    ax.set_position((0, 0, 1.0, 1.0))
    ax.set_box_aspect(0.8)

    for i in range(signals.axes_manager.navigation_size):
        ax.plot(signals.axes_manager['Energy loss'].axis, signals.inav[i].data, c=colors[i], ls='-',lw=1.0, label=f'#{i}', zorder=i)
    ax.set_xlabel(r'Energy (eV)', fontsize=13)
    ax.set_ylabel(r'Intensity', fontsize=13)
    configure_axes_for_element(ax=ax, elements=elements)
    title = f"{elements[0] if len(elements)==1 else 'all'}, {roi_name}, Raw"
    plt.title(title, fontsize=13, loc='left', pad=0.5)
    ax.legend(loc='best', ncols=1, frameon=False, fontsize=11, labelcolor='linecolor', columnspacing=0.4)
  
    # 图 B
    subfig = fig.add_subfigure(gs[0, 1], zorder=0)
    ax = subfig.add_subplot()
    ax.set_position((0.25, 0, 1.0, 1.0))
    if labels.axes_manager.navigation_shape == 2:
        labels = labels.T
    for i in range(labels.axes_manager.navigation_size):
        cmapping = LinearSegmentedColormap.from_list(name=r'cmapping', colors=['w', colors[i]], N=2)
        im = ax.imshow(labels.inav[i].data, cmap=cmapping, alpha=0.8, zorder=i+2, aspect=1)
    #     cax = subfig.add_subplot(zorder=10-i)
    #     cax.set_position((0.1+0.125*i, 0.05, 0.25, 0.06))
    #     subfig.colorbar(mappable=im, cax=cax, ticks=[], location='bottom', orientation='horizontal',)
    #     cax.tick_params(axis='x', direction='out')
    # for spine in ax.spines.values():
    #     spine.set_color('black')
    add_sizebar(ax, sizebar, labels, 'black')
    ax.set_axis_off()

    # 图 C
    subfig = fig.add_subfigure(gs[0, 2], zorder=0)
    ax = subfig.add_subplot()
    ax.set_position((0.45, 0, 1.0, 1.0))
    im = ax.imshow(data.sum(-1).data, cmap='grey', alpha=1.0, zorder=1, aspect=1)
    add_sizebar(ax, sizebar, data, 'black')
    ax.set_axis_off()
    fig.tight_layout()

    if save_path:
        prefix = f"5-Clusters_{elements[0] if len(elements)==1 else 'all'}_{roi_name}_N{signals.axes_manager.navigation_size}"
        plt.savefig(save_path.joinpath(f'{prefix}_300.tif'), pad_inches=0.05, bbox_inches='tight', dpi=300, transparent=False, pil_kwargs={"compression": "tiff_lzw"})

    if plt_plot:
        plt.show()
    else:
        plt.close('all')

def configure_axes_for_element(ax, elements: tuple):
    if len(elements) >= 2:
        ax.set_xlim(300.0, 3000.0)
        # ax.set_ylim(-0.03, 1.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(500))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(250))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and 'Mn' in elements:
        ax.set_xlim(600.0, 700.0)
        # ax.set_ylim(-0.03, 1.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(20))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and 'Zn' in elements:
        ax.set_xlim(980.0, 1280.0)
        # ax.set_ylim(-0.03, 1.2)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(60))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(30))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and  'O' in elements:
        ax.set_xlim(500.0, 600.0)
        # ax.set_ylim(-0.05, 1.3)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(20))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))

for number in [None, 3, 4]:
    select_roi(results_rebgk, rois, selected_elements=(r'Mn', r'Zn', r'O'), num_clusters=number, mask=ps_masks['Zn_Mn'], save_path=path_out, plt_plot=False, sizebarA=20)

#### 纯相梯度分布

In [None]:
def select_roi2(eels_list: dict, rois: dict, selected_elements: tuple[str, ...] = (r'Mn', r'O'), mask=ps_masks['Zn_Mn'], save_path: path=path_out, plt_plot:bool =False, sizebarA=50) -> None:
    for element in selected_elements:
        if element in eels_list:
            for roi_name, roi in rois.items():
                if roi_name != 'all':
                    roi_signal = roi(eels_list[element])
                    mask_signal = np.isnan(roi(mask).data)
                    roi_signal.data[mask_signal, :] = np.nan
                    roi_signal_line = roi_signal.nanmean(axis=0)
                    roi_signal_line.save(path.joinpath(path_out, f'6-line_{element}_{roi_name}.hspy'), overwrite=True)

                    plot_line_result(roi_signal, roi_signal_line, elements=(element,), plt_plot=plt_plot, save_path=save_path, roi_name=roi_name, sizebarA=sizebarA)


def plot_line_result(roi_signal, roi_signal_line, sizebarA, elements: tuple, plt_plot:bool = False, save_path: path = path_out, roi_name: int | str | None = None) -> None:

    roi_signal.data = np.nan_to_num(roi_signal.data, nan=0.0, posinf=0.0, neginf=0.0)
    roi_signal_line.data = np.nan_to_num(roi_signal_line.data, nan=0.0, posinf=0.0, neginf=0.0)

    fig = plt.figure(figsize=(7.0, 2.5))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1,1], height_ratios=None,
                            wspace=0.01, hspace=0, figure=fig)
    
    # 图 A
    subfig = fig.add_subfigure(gs[0, 0], zorder=0)
    ax = subfig.add_subplot()
    ax.set_position((0, 0, 1.0, 1.0))
    ax.set_box_aspect(0.8)

    ax.imshow(roi_signal_line.data, cmap='Spectral', alpha=1.0, zorder=1, aspect=1)
    configure_axes_for_element2(ax=ax, elements=elements)

    title = f"{elements[0] if len(elements)==1 else 'all'}, {roi_name}, line"
    plt.title(title, fontsize=13, loc='left', pad=0.5)

#     # 图 A
#     subfig = fig.add_subfigure(gs[0, 0], zorder=0)
#     ax = subfig.add_subplot()
#     ax.set_position((0, 0, 1.0, 1.0))
#     ax.set_box_aspect(0.8)

#     cmapcolors = ListedColormap(mpl.colormaps['sunset'](np.linspace(1.0, 0.5, roi_signal.data.shape[1])), name='cmapcolors')

#     for i in range(roi_signal_line.axes_manager.navigation_size):
#         ax.plot(roi_signal_line.axes_manager['Energy loss'].axis, roi_signal_line.inav[i].data, c=cmapcolors(i), ls='-',lw=1.0, label=None, zorder=i)
#     ax.set_xlabel(r'Energy (eV)', fontsize=13)
#     ax.set_ylabel(r'Intensity', fontsize=13)
#     configure_axes_for_element(ax=ax, elements=elements)
#     title = f"{elements[0] if len(elements)==1 else 'all'}, {roi_name}, line"
#     plt.title(title, fontsize=13, loc='left', pad=0.5)
#     ax.legend(loc='best', ncols=1, frameon=False, fontsize=11, labelcolor='linecolor', columnspacing=0.4)
  
    # 图 B
    subfig = fig.add_subfigure(gs[0, 1], zorder=0)
    ax = subfig.add_subplot()
    ax.set_position((-0.4, 0, 1.0, 1.0))
    im = ax.imshow(roi_signal.nansum(-1).data, cmap='grey', alpha=1.0, zorder=2, aspect=1)
    add_sizebar(ax, sizebarA, roi_signal, 'k')
    ax.set_axis_off()
    fig.tight_layout()

    if save_path:
        prefix = f"6-line_{elements[0] if len(elements)==1 else 'all'}_{roi_name}"
        plt.savefig(save_path.joinpath(f'{prefix}_300.tif'), pad_inches=0.05, bbox_inches='tight', dpi=300, transparent=False, pil_kwargs={"compression": "tiff_lzw"})

    if plt_plot:
        plt.show()
    else:
        plt.close('all')

def configure_axes_for_element2(ax, elements: tuple):
    if len(elements) >= 2:
        # ax.set_xlim(300.0, 3000.0)
        # ax.set_ylim(-0.03, 1.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(500))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(250))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and 'Mn' in elements:
        # ax.set_xlim(600.0, 700.0)
        # ax.set_ylim(-0.03, 1.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(20))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and 'Zn' in elements:
        # ax.set_xlim(980.0, 1180.0)
        # ax.set_ylim(-0.03, 1.2)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(60))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(30))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))
    elif len(elements) == 1 and  'O' in elements:
        # ax.set_xlim(500.0, 600.0)
        # ax.set_ylim(-0.05, 1.3)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(20))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))
        # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
        # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.15))

select_roi2(results_rebgk, rois, selected_elements=(r'Mn', r'O'), mask=ps_masks['Zn_Mn'], save_path=path_out, plt_plot=True, sizebarA=10)

#### 对应的 L3 L2 比例，对应的 signals Clusters

In [None]:
from lmfit.models import LinearModel, StepModel
from scipy.signal import find_peaks
from scipy.integrate import simpson

def index_of(arr, threshold):
    """返回数组中第一个大于 threshold 的索引。"""
    return np.argmax(arr > threshold)

def auto_select_two_largest_peaks(data, distance=10):
    """自动选择两个主峰"""
    peaks, _ = find_peaks(data, distance=distance)
    if len(peaks) < 2:
        raise ValueError("找不到两个峰")
    peak_values = data[peaks]
    top_peaks = peaks[np.argsort(peak_values)[-2:]]
    return np.sort(top_peaks)

def find_onset_energy(energy, data, thread=0.1):
    imax = auto_select_two_largest_peaks(data)
    imin = find_peaks(-data, distance=10)[0]
    imin = [next((v for v in imin if v > peak), None) for peak in imax]
    # print(f"峰位置: {energy[imax]}, 谷位置: {energy[imin]}")
    if None in imin:
        raise ValueError("未能找到合适的峰后谷")

    ionset = [
        index_of(data[:imax[0]], thread * data[imax[0]]),
        index_of(data[imin[0]:imax[1]], data[imin[0]] + thread * data[imax[1]]) + imin[0] # type: ignore
    ]

    return ionset, imin, imax

def baseline(energy, data, ionset, imax, imin):
    if None in imin or len(imin) < 2:
        raise ValueError("找不到两个有效的主谷 imin")

    height = [data[i] for i in imin if i is not None]

    # 安全截取前段和后段数据
    pre_range = slice(max(0, ionset[0]-10), ionset[0])
    post_range = slice(max(0, imin[1]-5), min(len(data), imin[1]+5))

    if pre_range.stop - pre_range.start < 3 or post_range.stop - post_range.start < 3:
        raise ValueError("用于拟合的区域太窄，无法安全拟合")

    xdat = np.concatenate((energy[pre_range], energy[post_range]))
    ydat = np.concatenate((data[pre_range], data[post_range]))

    line_mod = LinearModel(prefix='line_')
    step_mod1 = StepModel(form='arctan', prefix='step1_')
    step_mod2 = StepModel(form='arctan', prefix='step2_')

    params = line_mod.make_params(intercept=np.mean(data[:ionset[0]-10]), slope=dict(value=0)) # vary=False
    params.update(step_mod1.make_params(
        center=dict(value=energy[ionset[0]], vary=False),
        sigma=dict(value=0.5, expr='1*step2_sigma'),
        amplitude=dict(expr='2*step2_amplitude')
    ))
    params.update(step_mod2.make_params(
        center=dict(value=energy[ionset[1]], vary=False),
        sigma=dict(value=0.5, min=0.1, max=1.0),
        amplitude=dict(value=height[1]/3, vary=True, expr=f'({height[1]}-line_slope*step2_center-line_intercept)/3')
    ))

    model = line_mod + step_mod1 + step_mod2
    result = model.fit(ydat, params, x=xdat)
    baseline = result.eval(result.params, x=energy)
    peaks = data - baseline
    peak_intensities = [simpson(y=peaks[pid-2:pid+2], x=energy[pid-2:pid+2]) for pid in imax]
    ratio = peak_intensities[0] / peak_intensities[1]
    return baseline, peaks, peak_intensities, ratio, result

def plot_fit(energy, data, baseline, peaks, ratio, ionset, name, path_out):
    fig = plt.figure(figsize=(3.3, 2.5))
    gs = gridspec.GridSpec(1, 1, figure=fig)   
    ax = fig.add_subplot(gs[0])
    ax.set_position((0.0, 0, 1.0, 1.0))
    ax.set_box_aspect(0.8)
    ax.plot(energy, data, ls='-', label='data')
    ax.plot(energy, baseline, ls='--', label='bkg')
    ax.plot(energy, peaks, ls='--', label='fit')
    ax.set_ylabel(r'Intensity (count)', fontsize=11)
    ax.set_xlabel(r'Energy (eV)', fontsize=11)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(base=20))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
    ax.text(0.95, 0.90, f'L3/2 ratio: {ratio:.2f}', transform=ax.transAxes,
            ha='right', va='top', fontsize=10)
    energy_str = ", ".join(f"{e:.2f}" for e in energy[ionset])
    ax.text(0.95, 0.80, f'{energy_str}', transform=ax.transAxes,
            ha='right', va='top', fontsize=10)
    plt.tight_layout()
    plt.savefig(
            path_out.joinpath(f"7_{name}.tif"),
            bbox_inches="tight",
            dpi=600,
            transparent=False,
            pil_kwargs={"compression": "tiff_lzw"}
        )
    plt.show()

def process_one_spectrum(name, energy, data, threshold, plot_each, path_out):
    try:
        ionset, imin, imax = find_onset_energy(energy, data, thread=threshold)
        if len(imax) != 2:
            raise ValueError("检测到的主峰数不为2")
        baseline_fit, bkg_removed, peak_intensities, ratio, _ = baseline(energy, data, ionset, imax, imin)

        if plot_each:
            plot_fit(energy, data, baseline_fit, bkg_removed, ratio, ionset, name, path_out)
        else:
            plt.close('all')

        return name, {
            'original': data,
            'baseline': baseline_fit,
            'bkg_removed': bkg_removed,
            'intensity_L3': peak_intensities[0],
            'intensity_L2': peak_intensities[1],
            'L3_L2_ratio': ratio,
            'L3_energy': energy[ionset[0]],
            'L2_energy': energy[ionset[1]],
        }

    except Exception as e:
        print(f"[!] {name} 出错: {e}")
        return name, None

def process_dataframe(df: pd.DataFrame, energy_column='energy', threshold=0.1, plot_each=False, path_out=path_out):
    energy = df[energy_column].values
    spectra_names = df.columns.drop(energy_column)

    results = {}
    failure_log = []

    for name in spectra_names:
        spectrum = df[name].values
        try:
            result = process_one_spectrum(name, energy, spectrum, threshold, plot_each, path_out)
            if result and isinstance(result, tuple) and result[1] is not None:
                _, data = result
                results[name] = data
            else:
                failure_log.append((name, "结果为空或处理失败"))
        except Exception as e:
            failure_log.append((name, str(e)))

    # 写入失败日志
    if failure_log:
        log_file = path_out / "log_failed_spectra.txt"
        with open(log_file, "w", encoding="utf-8") as f:
            f.write("以下谱线在处理过程中出错：\n")
            for name, reason in failure_log:
                f.write(f"{name}: {reason}\n")
        print(f"⚠️ 失败日志已保存到: {log_file}")

    print(f"✅ 成功处理 {len(results)} 条谱线；❌ 失败 {len(failure_log)} 条谱线")

    if not results:
        raise ValueError("所有谱线均处理失败。")

    # 创建 xarray Dataset
    dataset = xr.Dataset(
        data_vars={
            'data': (['spectrum', 'energy'], np.array([v['original'] for v in results.values()])),
            'baseline': (['spectrum', 'energy'], np.array([v['baseline'] for v in results.values()])),
            'bkg_removed': (['spectrum', 'energy'], np.array([v['bkg_removed'] for v in results.values()])),
            'intensity_L3': (['spectrum'], [v['intensity_L3'] for v in results.values()]),
            'intensity_L2': (['spectrum'], [v['intensity_L2'] for v in results.values()]),
            'L3_L2_ratio': (['spectrum'], [v['L3_L2_ratio'] for v in results.values()]),
            'L3_energy': (['spectrum'], [v['L3_energy'] for v in results.values()]),
            'L2_energy': (['spectrum'], [v['L2_energy'] for v in results.values()])
        },
        coords={
            'spectrum': list(results.keys()),
            'energy': energy
        }
    )

    return dataset

file_path = list(path_out.glob(r'5-Clusters_Mn_*_Signals.hspy'))
signals = hs.load(file_path) # type: ignore
signals = hs.stack(signals, axis=0) # type: ignore
energy_axis = signals.axes_manager["Energy loss"].axis.reshape(1, -1) # type: ignore
data = np.concatenate([energy_axis, signals.data], axis=0)
df = pd.DataFrame(data.T)
df.columns = ['energy'] + list(str(i) for i in range(1, df.shape[1]))
# df.iloc[:20, 1:] = df.iloc[:20, 1:].where(lambda x: (x > -5)&(x < 100), 0) 
xr_result = process_dataframe(df[df['energy'] < 680.0], plot_each=False, path_out=path_out, threshold=0.1)
xr_result.to_netcdf(path_out.joinpath(f"7_TEM_EELS_Cluster.NETCDF4"), engine="h5netcdf")

#### 对应的 L3/L2 maapings

In [None]:
import tqdm.notebook as tqdm
from lmfit.models import LinearModel, StepModel
from scipy.signal import find_peaks
from scipy.integrate import simpson

def index_of(arr, threshold):
    """返回数组中第一个大于 threshold 的索引。"""
    return np.argmax(arr > threshold)

def auto_select_two_largest_peaks(data, distance=10):
    """自动选择两个主峰"""
    peaks, _ = find_peaks(data, distance=distance)
    if len(peaks) < 2:
        raise ValueError("找不到两个峰")
    peak_values = data[peaks]
    top_peaks = peaks[np.argsort(peak_values)[-2:]]
    return np.sort(top_peaks)

def find_onset_energy(energy, data, thread):
    imax = auto_select_two_largest_peaks(data)
    imin = find_peaks(-data, distance=10)[0]
    imin = [next((v for v in imin if v > peak), None) for peak in imax]
    # print(f"峰位置: {energy[imax]}, 谷位置: {energy[imin]}")
    if None in imin:
        raise ValueError("未能找到合适的峰后谷")

    ionset = [
        index_of(data[:imax[0]], thread * data[imax[0]]),
        index_of(data[imin[0]:imax[1]], data[imin[0]] + thread * data[imax[1]]) + imin[0] # type: ignore
    ]

    return ionset, imin, imax

def baseline(energy, data, ionset, imax, imin):

    height = [data for data in data[imin]]
    xdat = np.concatenate((energy[:ionset[0]-10], energy[imin[1]-5:imin[1]+5]))
    ydat = np.concatenate((data[:ionset[0]-10], data[imin[1]-5:imin[1]+5]))

    line_mod = LinearModel(prefix='line_')
    step_mod1 = StepModel(form='arctan', prefix='step1_')
    step_mod2 = StepModel(form='arctan', prefix='step2_')

    params = line_mod.make_params(intercept=np.mean(data[:ionset[0]-10]), slope=dict(value=0)) # vary=False
    params.update(step_mod1.make_params(
        center=dict(value=energy[ionset[0]], vary=False),
        sigma=dict(value=0.5, expr='1*step2_sigma'),
        amplitude=dict(expr='2*step2_amplitude')
    ))
    params.update(step_mod2.make_params(
        center=dict(value=energy[ionset[1]], vary=False),
        sigma=dict(value=0.5, min=0.1, max=1.0),
        amplitude=dict(value=height[1]/3, vary=True, expr=f'({height[1]}-line_slope*step2_center-line_intercept)/3')
    ))

    model = line_mod + step_mod1 + step_mod2
    result = model.fit(ydat, params, x=xdat)
    baseline = result.eval(result.params, x=energy)
    peaks = data - baseline
    area_L3 = energy_window_integral(energy, peaks, energy[imax[0]], width_ev=2.0)
    area_L2 = energy_window_integral(energy, peaks, energy[imax[1]], width_ev=2.0)
    peak_intensities = [area_L3, area_L2]
    # peak_intensities = [simpson(y=peaks[pid-5:pid+5], x=energy[pid-5:pid+5]) for pid in imax]
    ratio = peak_intensities[0] / peak_intensities[1]
    return baseline, peaks, peak_intensities, ratio, result

def energy_window_integral(energy, data, peak, width_ev):
    mask = (energy >= peak - width_ev / 2) & (energy <= peak + width_ev / 2)
    if not np.any(mask):
        raise ValueError(f"在能量 {peak:.2f} 附近找不到足够点进行积分")
    return simpson(y=data[mask], x=energy[mask])

def process_one_spectrum(name, energy, data, threshold):
    try:
        ionset, imin, imax = find_onset_energy(energy, data, thread=threshold)
        baseline_fit, bkg_removed, peak_intensities, ratio, _ = baseline(energy, data, ionset, imax, imin)

        return name, {
            'original': data,
            'baseline': baseline_fit,
            'bkg_removed': bkg_removed,
            'intensity_L3': peak_intensities[0],
            'intensity_L2': peak_intensities[1],
            'L3_L2_ratio': ratio,
            'L3_energy': energy[ionset[0]],
            'L2_energy': energy[ionset[1]],
        }
    except Exception as e:
        print(f"[!] {name} 出错: {e}")
        return name, None

def select_roi3(eels_list: dict, rois: dict, elements: tuple[str, ...] = (r'Mn', r'O'), threshold=0.1, mask=eelsmappings['mask'], save_path: path=path_out, plt_plot=True) -> None:
    for element in elements:
        if element == 'Mn' and element in eels_list:
            for roi_name, roi in rois.items():
                if roi_name in ['all', 'Bulk', 'Surface']:
                    roi_signal = roi(eels_list[element]) 
                    mask_signal = np.isnan(roi(mask))
                    process_dataframe(data=roi_signal, mask=mask_signal, path_out=save_path, threshold=threshold, roi_name=roi_name, element=element, plt_plot=plt_plot)

def process_dataframe(data, mask, threshold, path_out, roi_name, element, plt_plot) -> None:
    energy = data.axes_manager['Energy loss'].axis
    results = {}

    if mask is not None and mask.axes_manager.navigation_shape != data.axes_manager.navigation_shape:
        mask = mask.T.deepcopy()

    for i in tqdm.tqdm(range(data.axes_manager.navigation_shape[0])):
        for j in range(data.axes_manager.navigation_shape[1]):
            name = f"{i}_{j}"
            try:
                if mask is not None and mask.inav[i, j].data:
                    results[name] = None
                    continue

                spectrum = data.inav[i, j].data
                _, result_data = process_one_spectrum(name, energy, spectrum, threshold)
                results[name] = result_data  # 可为 None，也可为 dict
            except Exception as e:
                print(f"[X] {name} -> 处理出错: {e}")
                results[name] = None

    # 初始化矩阵
    result_keys = ['original', 'baseline', 'bkg_removed', 'intensity_L3', 'intensity_L2', 'L3_L2_ratio', 'L3_energy', 'L2_energy']
    spectra_shape = (data.axes_manager.navigation_shape[0], data.axes_manager.navigation_shape[1], len(energy))

    spectra_data = {k: np.full(spectra_shape, np.nan, dtype=np.float32) for k in ['original', 'baseline', 'bkg_removed']}
    scalar_data = {k: np.full(data.axes_manager.navigation_shape, np.nan, dtype=np.float32) for k in result_keys[3:]}

    for name, res in results.items():
        if res is None:
            continue
        i, j = map(int, name.split('_'))
        for k in spectra_data:
            spectra_data[k][i, j, :] = res[k]
        for k in scalar_data:
            scalar_data[k][i, j] = res[k]

    dataset = xr.Dataset(
        data_vars={
            **{k: (['x', 'y', 'energy'], spectra_data[k]) for k in spectra_data},
            **{k: (['x', 'y'], scalar_data[k]) for k in scalar_data},
        },
        coords={
            'x': np.arange(data.axes_manager.navigation_shape[0]),
            'y': np.arange(data.axes_manager.navigation_shape[1]),
            'energy': energy,
        }
    )

    if path_out:
        dataset.to_netcdf(path_out.joinpath(f"7_TEM_EELS_{element}_{roi_name}.NETCDF4"), engine="h5netcdf")
        dataset['L3_L2_ratio'].to_dataframe().to_csv(path_out.joinpath(f"7_TEM_EELS_{element}_{roi_name}_L3_L2_ratio.csv"), index=True)

    if plt_plot:
        plt.close('all')
        fig = plt.figure(figsize=(3.3, 2.5))
        ax = fig.add_subplot()
        im = ax.imshow(dataset['L3_L2_ratio'].values, cmap='gray', aspect='equal')
        add_sizebar(ax, 10, mask, color='k')
        ax.axis('off')
        fig.tight_layout()
        plt.show()
    else:
        plt.close('all')

select_roi3(results_rebgk, rois, elements=(r'Mn', r'O'), mask=ps_masks['Zn_Mn'], save_path=path_out, threshold=0.1)

##### 画出上述的 L2 和 L3 峰的变化

In [None]:
dataset = xr.open_dataset(path_out.joinpath(r"7_TEM_EELS_Mn_all.NETCDF4"), engine="h5netcdf")
mask_cluster = hs.load(path_out.joinpath(r"5-Clusters_Mn_all_mask.hspy")) # type: ignore

In [None]:
from PIL import Image
plt.close('all')
def generate_distribution_phase_mapping(
    mapping_L3: np.ndarray,
    mapping_L2: np.ndarray,
    mask: np.ndarray,
    k_min: float | None = None,
    k_max: float | None = None,
    k_step: float = 0.2,
    gif_path: path = path_out,
    figsize: tuple = (12, 2.5),
    dpi: int = 100,
    bins: int = 100,
) -> None:
    assert mapping_L3.shape == mapping_L2.shape == mask.shape, "All input arrays must have the same shape."

    with np.errstate(divide='ignore', invalid='ignore'):
        K_map = np.where((mapping_L2 != 0) & (~mask), mapping_L3 / mapping_L2, np.nan)

    mapping_L3_A = mapping_L3[~mask]
    mapping_L2_A = mapping_L2[~mask]
    K_flat = K_map[~mask]
    valid = np.isfinite(K_flat)
    mapping_L3_A_K = mapping_L3_A[valid]
    mapping_L2_A_K = mapping_L2_A[valid]
    K_flat = K_flat[valid]

    if k_min is None:
        k_min = float(np.nanmin(K_flat))
    if k_max is None:
        k_max = float(np.nanmax(K_flat))
    k_ranges = np.arange(k_min, k_max, k_step)

    cmap_base = plt.cm.hot_r(np.linspace(0.2, 0.8, 256))
    cmap = LinearSegmentedColormap.from_list('cmap', cmap_base)
    cmap.set_over('white')
    cmap.set_under('black')

    fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'wspace': 0.05}, constrained_layout=True)
    ax_overlay, ax_map = axes

    ax_overlay.set_position([0.1, 0.18, 0.3, 0.7])
    ax_overlay.set_box_aspect(1.0)
    ax_overlay.hist2d(mapping_L3_A_K, mapping_L2_A_K, cmap=cmap, cmin=1, bins=bins)
    sc2 = ax_overlay.scatter([], [], s=2, c='magenta', alpha=0.1)
    legend_text2 = ax_overlay.text(0.95, 0.95, '', ha='right', va='top', transform=ax_overlay.transAxes, fontsize=8)
    ax_overlay.set_title("After Masking", fontsize=12, loc='left')
    ax_overlay.set_xlabel("L3")
    ax_overlay.set_ylabel("L2")

    ax_map.set_position([0.4, 0.2, 0.3, 0.7])
    masked_map = np.where(mask, np.nan, mapping_L3)
    img_map = ax_map.imshow(masked_map, cmap='gray')
    ax_map.set_title("Distribution of L3", fontsize=12, loc='left')
    ax_map.axis('off')
    mask_overlay = ax_map.imshow(np.zeros_like(mapping_L3), cmap='hot', alpha=0.7, vmin=0, vmax=1)

    tiff_stack = []
    for frame_index in range(len(k_ranges)):
        k_start = k_ranges[frame_index]
        k_end = k_start + k_step

        range_mask_flat = (K_flat > k_start) & (K_flat <= k_end)
        Mn_filtered = mapping_L3_A_K[range_mask_flat]
        Zn_filtered = mapping_L2_A_K[range_mask_flat]
        sc2.set_offsets(np.column_stack((Mn_filtered, Zn_filtered)))
        legend_text2.set_text(f'{k_start:.2f} < K ≤ {k_end:.2f}')

        range_mask_2d = (K_map > k_start) & (K_map <= k_end)
        mask_overlay.set_data(range_mask_2d.astype(float))

        fig.canvas.draw()
        img_array = np.asarray(fig.canvas.buffer_rgba())[:, :, :3]
        pil_img = Image.fromarray(img_array)
        tiff_stack.append(pil_img.convert("RGB"))

    # 保存 TIFF stack
    tiff_output_path = gif_path.joinpath('7-distribution_L3_L2.tif')
    tiff_stack[0].save(
        tiff_output_path,
        save_all=True,
        append_images=tiff_stack[1:],
        compression="tiff_deflate"
    )
    plt.close(fig)
    print(f"[✔] TIFF stack saved to: {tiff_output_path.resolve()}")

generate_distribution_phase_mapping(
    dataset['intensity_L3'].values,
    dataset['intensity_L2'].values,
    mask_cluster.data.T,
    k_min=-0.2,
    k_max=4.0,
    k_step=0.2,
    gif_path=path_out,
    figsize=(7.0, 2.5),
    dpi=80,
    bins=60
)

### 拟合，在原始数据上，=== 结果并不是很喜欢，放弃 ====

#### 提取不同的片段数据

In [None]:
# # 定义元素特征峰能量范围（可扩展）
# element_lines = {
#     'O':   (480.0, 600.0),   # O-K
#     'Mn':  (600.0, 700.0),   # Mn-L
#     'Zn':  (980.0, 1180.0), # Zn-L
#     'S':   (2430.0, 2550.0), # S-K
# }

# fit_ranges = {
#     'O':   (480.0, 520.0),   # O-K
#     'Mn':  (600.0, 623.0),   # Mn-L
#     'Zn':  (980.0, 1010.0), # Zn-L
#     'S':   (2430.0, 2450.0), # S-K
# }

# data_ranges = {
#     'O':   (480.0, 600.0),   # O-K
#     'Mn':  (600.0, 700.0),   # Mn-L
#     'Zn':  (980.0, 1280.0), # Zn-L
#     'S':   (2430.0, 2550.0), # S-K
# }

# def get_elements(data, element_lines: dict[str, tuple[float, float]]) -> tuple[str, ...]:
#     """根据谱图能量范围，确定包含的元素。"""
#     elements = tuple(
#         element for element, (lowenergy, highenergy) in element_lines.items()
#         if data.axes_manager['Energy loss'].high_value >= highenergy -20 and data.axes_manager['Energy loss'].low_value <= lowenergy + 20
#     )
#     return elements

# import tqdm.notebook as tqdm
# def Clip_data(
#     data,
#     element_lines: dict = element_lines,
#     data_ranges: dict | None = data_ranges,
#     save_data: bool = True,
#     path_out: path = path_out,
# ) -> dict:
#     """
#     自动分段处理谱图并移除背景。
#     返回每个元素对应的去背景数据段组成的字典。
#     """
#     hs.set_log_level('ERROR')

#     elements = get_elements(data, element_lines)
#     if not elements:
#         raise ValueError("No recognizable elements found in the data's energy range.")

#     # 添加元素（避免重复）
#     data.add_elements(set(elements))
#     if len(elements) >=2:
#             elements += ('all',)

#     result: dict = {}

#     for element in tqdm.tqdm(elements):
        
#         # 取出能量段与拟合范围
#         if element == 'all':
#             result[element] = data
#         else:
#             if data_ranges is not None:
#                 start, end = data_ranges[element]
#                 # 提取该段信号
#                 sig = data.isig[start:end]
#             else:
#                 sig = data

#             # 保存
#             result[element] = sig
#             if save_data:
#                 result[element].save(path.joinpath(path_out, f'2-ps_{element}.hspy'), overwrite=True)
            
#     return result

# results_clip = Clip_data(
#     ps_recon,
#     element_lines=element_lines,
#     data_ranges=data_ranges,
#     save_data=True,
#     path_out=path_out,
# )

In [None]:
# %matplotlib ipympl
# data_ranges = {
#     'O':   (490.0, 580.0),   # O-K
#     'Mn':  (620.0, 670.0),   # Mn-L
#     }

# def select_roi3(
#     eels_list: dict,
#     rois: dict,
#     selected_elements: tuple[str, ...] = (r'Mn', r'O'),
#     mask=None,
#     save_path: path = path_out,
#     plt_plot: bool = False,
#     lowloss=None
# ) -> None:
#     # 自动过滤不合法元素
#     valid_elements = []

#     for element in selected_elements:
#         if element not in eels_list:
#             continue

#         data = eels_list[element]
#         start, end = data_ranges.get(element, (None, None))
#         if start is None or end is None:
#             print(f"Skipping '{element}': no data range defined.")
#             continue

#         energy_min = data.axes_manager['Energy loss'].low_value
#         energy_max = data.axes_manager['Energy loss'].high_value

#         if start >= energy_min and end <= energy_max:
#             valid_elements.append(element)
#             # print(f"Processing, Data Range of '{element}' is invalid.")

#     # 主循环：只处理有效元素
#     for element in valid_elements:
#         for roi_name, roi in rois.items():
#             roi_signal = roi(eels_list[element])
#             lowloss_roi = roi(lowloss) if lowloss is not None else None
#             mask_signal = np.isnan(roi(mask).data) if mask is not None else None
#             if mask_signal is not None:
#                 roi_signal.data[mask_signal, :] = np.nan

#             EELSfitting(
#                 data=roi_signal,
#                 lowloss=lowloss_roi,
#                 element=element,
#                 mask=mask_signal,
#                 data_ranges=data_ranges,
#                 plot_fig=plt_plot,
#                 save_data=True,
#                 path_out=save_path,
#                 roi_name=roi_name
#                 )


# def EELSfitting(
#     data,
#     element,
#     lowloss = None,
#     data_ranges: dict = data_ranges,
#     mask: np.ndarray | None = None,
#     plot_fig: bool = False,
#     save_data: bool = True,
#     path_out: path = path_out,
#     roi_name: int | str | None = None
# ) -> None:
#     """
#     对 EELS 数据进行背景去除和拟合。

#     参数:
#         data (Signal1D): 输入的 EELS 数据。
#         lowloss (Signal1D): 低损耗数据。
#         element: 元素名称。
#         data_ranges (dict): 数据范围。
#         mask (np.ndarray): 掩膜。
#         plot_fig (bool): 是否绘图。
#         save_data (bool): 是否保存数据。
#         path_out (Path): 输出路径。
#     """
    
#     import pooch
#     GOSH10 = pooch.retrieve(
#         url="https://zenodo.org/records/7645765/files/Segger_Guzzinati_Kohl_1.5.0.gosh",
#         known_hash="md5:7fee8891c147a4f769668403b54c529b",
#     )
#     GOSDIRAC = pooch.retrieve(
#         url="https://zenodo.org/records/12800856/files/Dirac_GOS.gosh",
#         known_hash="md5:02fb22ab55e39e51eb03c08dbf699545",
#     )

#     if element == r'Mn':
            
#             if data_ranges is not None:
#                 start, end = data_ranges[element]
#                 # 提取该段信号
#                 data = data.isig[start:end]
#             else:
#                 data = data

#             # 拟合
#             model = data.create_model(low_loss=lowloss, auto_add_edges=True, auto_background=True, gos_file_path=GOSDIRAC)
#             model.reset_signal_range()
#             model.fit_background(start_energy=data_ranges[element][0]+2, only_current=False, mask=mask)
#             model.components.PowerLaw.set_parameters_not_free()
#             model.enable_fine_structure()
#             model.enable_free_onset_energy()
#             model.set_all_edges_intensities_positive()
#             model.resolve_fine_structure()
#             model.multifit(mask=mask, optimizer='lm', loss_function='ls')

#             if plot_fig:
#                   model.plot(plot_components=True)
#             else:
#                 plt.close('all')
#             if save_data:
#                 model.save(path_out.joinpath(f'7-{element}_{roi_name}_results.hspy'), overwrite=True)
#                 model.save_parameters2file(path_out.joinpath(f'7-{element}_{roi_name}_model_parameters'))
#                 model_residual = data.data - model.as_signal().data
#                 xr.Dataset(
#                         data_vars=dict(data = (['y', 'x','energy_loss'], data.data),
#                                      fit = (['y', 'x','energy_loss'], model.as_signal().data),
#                                      fit_Mn_II = (['y', 'x','energy_loss'], model.as_signal(component_list=['Mn_L2']).data),
#                                      fit_Mn_III = (['y', 'x','energy_loss'], model.as_signal(component_list=['Mn_L3']).data),
#                                      fit_residual = (['y', 'x','energy_loss'], model_residual),
#                                      fit_Mn_L3_intensity = (['y', 'x'], model.components.Mn_L3.intensity.as_signal().data),
#                                      fit_Mn_L3_onset_energy = (['y', 'x'], model.components.Mn_L3.onset_energy.as_signal().data),
#                                      fit_Mn_L2_intensity = (['y', 'x'], model.components.Mn_L2.intensity.as_signal().data),
#                                      fit_Mn_L2_onset_energy = (['y', 'x'], model.components.Mn_L2.onset_energy.as_signal().data),
#                                      ),
#                         coords=dict(
#                                     energy_loss = data.axes_manager['Energy loss'].axis,
#                                     x= data.axes_manager['x'].axis,
#                                     y= data.axes_manager['y'].axis,
#                                     ),
#                         attrs=dict(
#                                     chisq = model.chisq.data[0],
#                                     red_chisq = model.red_chisq.data[0],
#                                     ),
#                             ).to_netcdf(path_out.joinpath(f'7-{element}_{roi_name}_results.NETCDF4'), engine="h5netcdf")

In [None]:
# select_roi3(results_clip, rois, selected_elements=(r'Mn', r'O'), mask=onset_energy['onset_energy'], save_path=path_out, plt_plot=False,lowloss=eels_list[-2])

### 其他

#### 其他来源数据的汇总去除背景2

In [None]:
# import numpy as np
# import pandas as pd
# import xarray as xr
# import matplotlib.pyplot as plt
# from matplotlib import ticker, gridspec
# from lmfit.models import LinearModel, StepModel
# from scipy.signal import find_peaks
# from scipy.integrate import simpson

# def index_of(arr, threshold):
#     """返回数组中第一个大于 threshold 的索引。"""
#     return np.argmax(arr > threshold)

# def auto_select_two_largest_peaks(data, distance=10):
#     """自动选择两个主峰"""
#     peaks, _ = find_peaks(data, distance=distance)
#     if len(peaks) < 2:
#         raise ValueError("找不到两个峰")
#     peak_values = data[peaks]
#     top_peaks = peaks[np.argsort(peak_values)[-2:]]
#     return np.sort(top_peaks)

# def find_onset_energy(energy, data, thread=0.3):
#     imax = auto_select_two_largest_peaks(data)
#     imin = find_peaks(-data, distance=10)[0]
#     imin = [next((v for v in imin if v > peak), None) for peak in imax]
#     if None in imin:
#         raise ValueError("未能找到合适的峰后谷")

#     ionset = [
#         index_of(data[:imax[0]], thread * data[imax[0]]),
#         index_of(data[imin[0]:imax[1]], data[imin[0]] + thread * data[imax[1]]) + imin[0]
#     ]
#     height = [data[i] for i in imin]
#     return ionset, imin, imax, height

# def baseline(energy, data, ionset, imax, imin, height):
#     xdat = np.concatenate((energy[:ionset[0]-10], energy[imin[1]-5:imin[1]+5]))
#     ydat = np.concatenate((data[:ionset[0]-10], data[imin[1]-5:imin[1]+5]))

#     line_mod = LinearModel(prefix='line_')
#     step_mod1 = StepModel(form='arctan', prefix='step1_')
#     step_mod2 = StepModel(form='arctan', prefix='step2_')

#     params = line_mod.make_params(intercept=np.mean(data[:ionset[0]-10]), slope=dict(value=0, vary=False))
#     params.update(step_mod1.make_params(
#         center=dict(value=energy[ionset[0]], vary=False),
#         sigma=dict(value=0.5, expr='1*step2_sigma'),
#         amplitude=dict(expr='2*step2_amplitude')
#     ))
#     params.update(step_mod2.make_params(
#         center=dict(value=energy[ionset[1]], vary=False),
#         sigma=dict(value=0.5, min=0.1, max=1.0),
#         amplitude=dict(value=height[1]/3, vary=True, expr=f'({height[1]}-line_intercept)/3')
#     ))

#     model = line_mod + step_mod1 + step_mod2
#     result = model.fit(ydat, params, x=xdat)

#     baseline = result.eval(result.params, x=energy)
#     peaks = data - baseline
#     peak_intensities = [simpson(y=peaks[pid-5:pid+5], x=energy[pid-5:pid+5]) for pid in imax]
#     ratio = peak_intensities[0] / peak_intensities[1]
#     return baseline, peaks, peak_intensities, ratio, result

# def plot_fit(energy, data, baseline, peaks, ratio, ionset, name, path_out):
#     fig = plt.figure(figsize=(3.3, 2.5))
#     gs = gridspec.GridSpec(1, 1, figure=fig)   
#     ax = fig.add_subplot(gs[0])
#     ax.set_position((0.0, 0, 1.0, 1.0))
#     ax.set_box_aspect(0.8)
#     ax.plot(energy, data, ls='-', label='data')
#     ax.plot(energy, baseline, ls='--', label='bkg')
#     ax.plot(energy, peaks, ls='--', label='fit')
#     ax.set_ylabel(r'Intensity (count)', fontsize=11)
#     ax.set_xlabel(r'Energy (eV)', fontsize=11)
#     ax.xaxis.set_major_locator(ticker.MultipleLocator(base=20))
#     ax.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
#     ax.text(0.95, 0.90, f'L3/2 ratio: {ratio:.2f}', transform=ax.transAxes,
#             ha='right', va='top', fontsize=10)
#     energy_str = ", ".join(f"{e:.2f}" for e in energy[ionset])
#     ax.text(0.95, 0.80, f'{energy_str}', transform=ax.transAxes,
#             ha='right', va='top', fontsize=10)
#     plt.tight_layout()
#     plt.savefig(
#             path_out.joinpath(f"{name}.tif"),
#             bbox_inches="tight",
#             dpi=600,
#             transparent=False,
#             pil_kwargs={"compression": "tiff_lzw"}
#         )
#     plt.show()

# def process_one_spectrum(name, energy, data, threshold, plot_each, path_out):
#     try:
#         ionset, imin, imax, height = find_onset_energy(energy, data, thread=threshold)
#         baseline_fit, bkg_removed, peak_intensities, ratio, _ = baseline(energy, data, ionset, imax, imin, height)

#         if plot_each:
#             plot_fit(energy, data, baseline_fit, bkg_removed, ratio, ionset, name, path_out)
#         else:
#             plt.close('all')

#         return name, {
#             'original': data,
#             'baseline': baseline_fit,
#             'bkg_removed': bkg_removed,
#             'intensity_L3': peak_intensities[0],
#             'intensity_L2': peak_intensities[1],
#             'L3_L2_ratio': ratio,
#             'L3_energy': energy[ionset[0]],
#             'L2_energy': energy[ionset[1]],
#         }
#     except Exception as e:
#         print(f"[!] {name} 出错: {e}")
#         return name, None

# def process_dataframe(df: pd.DataFrame, energy_column='energy', threshold=0.1, plot_each=False, path_out=path_out):
#     energy = df[energy_column].values
#     spectra_names = df.columns.drop(energy_column)

#     results = {}

#     for name in spectra_names:
#         spectrum = df[name].values
#         try:
#             result = process_one_spectrum(name, energy, spectrum, threshold, plot_each, path_out)
#             if result and isinstance(result, tuple) and len(result) == 2:
#                 _, data = result
#                 results[name] = data
#         except Exception as e:
#             print(f"Error processing {name}: {e}")

#     if not results:
#         raise ValueError("No valid spectra processed.")

#     dataset = xr.Dataset(
#         data_vars={
#             'data': (['spectrum', 'energy'], np.array([v['original'] for v in results.values()])),
#             'baseline': (['spectrum', 'energy'], np.array([v['baseline'] for v in results.values()])),
#             'bkg_removed': (['spectrum', 'energy'], np.array([v['bkg_removed'] for v in results.values()])),
#             'intensity_L3': (['spectrum'], [v['intensity_L3'] for v in results.values()]),
#             'intensity_L2': (['spectrum'], [v['intensity_L2'] for v in results.values()]),
#             'L3_L2_ratio': (['spectrum'], [v['L3_L2_ratio'] for v in results.values()]),
#             'L3_energy': (['spectrum'], [v['L3_energy'] for v in results.values()]),
#             'L2_energy': (['spectrum'], [v['L2_energy'] for v in results.values()])
#         },
#         coords={
#             'spectrum': list(results.keys()),
#             'energy': energy
#         }
#     )

#     return dataset


In [None]:
# file_path = path(r"C:\Users\chengliu\Desktop\Figure\EELS_Mn_raw.csv")
# df = pd.read_csv(file_path, sep=',', header=0)
# df.iloc[:20, 1:] = df.iloc[:20, 1:].where(lambda x: (x > -5)&(x < 100), 0) 
# xr_result = process_dataframe(df[df['energy'] < 680.0], plot_each=True, path_out=path_out, threshold=0.1)
# xr_result.to_netcdf(path_out.joinpath(f"TEM_EELS_Fit.NETCDF4"), engine="h5netcdf")

#### 单独画 两个之间的关系

In [None]:
# ps_mappings_Mn = hs.load(r"C:\Users\chengliu\Desktop\Figure\3-ps_Mn_mapping.hspy")
# ps_mappings_Zn = hs.load(r"C:\Users\chengliu\Desktop\Figure\3-ps_Zn_mapping.hspy")
# ps_mappings_mask = hs.load(r"C:\Users\chengliu\Desktop\Figure\3-ps_Zn_Mn_mask.hspy")

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# from matplotlib.colors import LinearSegmentedColormap
# from PIL import Image

# def generate_distribution_phase_mapping(
#     mapping_Mn: np.ndarray,
#     mapping_Zn: np.ndarray,
#     mask: np.ndarray,
#     k_min: float | None = None,
#     k_max: float | None = None,
#     k_step: float = 0.2,
#     gif_path: path = None,
#     figsize: tuple = (12, 2.5),
#     dpi: int = 100,
#     bins: int = 100,
# ) -> None:
#     assert mapping_Mn.shape == mapping_Zn.shape == mask.shape, "All input arrays must have the same shape."

#     with np.errstate(divide='ignore', invalid='ignore'):
#         K_all = np.where(mapping_Zn != 0, mapping_Mn / mapping_Zn, np.nan)
#         K_map = np.where((mapping_Zn != 0) & (~mask), mapping_Mn / mapping_Zn, np.nan)

#     Mn_all = mapping_Mn.flatten()
#     Zn_all = mapping_Zn.flatten()
#     K_all_flat = K_all.flatten()
#     valid_all = np.isfinite(K_all_flat)
#     Mn_all = Mn_all[valid_all]
#     Zn_all = Zn_all[valid_all]
#     K_all_flat = K_all_flat[valid_all]

#     mapping_Mn_A = mapping_Mn[~mask]
#     mapping_Zn_A = mapping_Zn[~mask]
#     K_flat = K_map[~mask]
#     valid = np.isfinite(K_flat)
#     mapping_Mn_A_K = mapping_Mn_A[valid]
#     mapping_Zn_A_K = mapping_Zn_A[valid]
#     K_flat = K_flat[valid]

#     if k_min is None:
#         k_min = float(np.nanmin(K_flat))
#     if k_max is None:
#         k_max = float(np.nanmax(K_flat))
#     k_ranges = np.arange(k_min, k_max, k_step)

#     cmap_base = plt.cm.hot_r(np.linspace(0.2, 0.8, 256))
#     cmap = LinearSegmentedColormap.from_list('cmap', cmap_base)
#     cmap.set_over('white')
#     cmap.set_under('black')

#     fig, axes = plt.subplots(1, 3, figsize=figsize, gridspec_kw={'wspace': 0.05}, constrained_layout=True)
#     ax_hist, ax_overlay, ax_map = axes

#     ax_hist.set_position([0.05, 0.18, 0.3, 0.7])
#     ax_overlay.set_position([0.3, 0.18, 0.3, 0.7])
#     ax_map.set_position([0.55, 0.35, 0.4, 0.7])

#     ax_hist.set_box_aspect(1.0)
#     ax_hist.hist2d(Mn_all, Zn_all, cmap=cmap, cmin=1, bins=bins)
#     sc1 = ax_hist.scatter([], [], s=2, c='magenta', alpha=0.1)
#     legend_text1 = ax_hist.text(0.95, 0.95, '', ha='right', va='top', transform=ax_hist.transAxes, fontsize=8)
#     ax_hist.set_title("Before Masking", fontsize=12, loc='left')
#     ax_hist.set_xlabel("Mn")
#     ax_hist.set_ylabel("Zn")

#     ax_overlay.set_box_aspect(1.0)
#     ax_overlay.hist2d(mapping_Mn_A_K, mapping_Zn_A_K, cmap=cmap, cmin=1, bins=bins)
#     sc2 = ax_overlay.scatter([], [], s=2, c='magenta', alpha=0.1)
#     legend_text2 = ax_overlay.text(0.95, 0.95, '', ha='right', va='top', transform=ax_overlay.transAxes, fontsize=8)
#     ax_overlay.set_title("After Masking", fontsize=12, loc='left')
#     ax_overlay.set_xlabel("Mn")
#     ax_overlay.set_ylabel("Zn")

#     masked_map = np.where(mask, np.nan, mapping_Mn)
#     img_map = ax_map.imshow(masked_map, cmap='gray')
#     ax_map.set_title("Distribution of Mn", fontsize=12, loc='left')
#     ax_map.axis('off')

#     mask_overlay = ax_map.imshow(np.zeros_like(mapping_Mn), cmap='hot', alpha=0.7, vmin=0, vmax=1)

#     tiff_stack = []

#     for frame_index in range(len(k_ranges)):
#         k_start = k_ranges[frame_index]
#         k_end = k_start + k_step

#         range_mask_all = (K_all_flat > k_start) & (K_all_flat <= k_end)
#         Mn_all_selected = Mn_all[range_mask_all]
#         Zn_all_selected = Zn_all[range_mask_all]
#         sc1.set_offsets(np.column_stack((Mn_all_selected, Zn_all_selected)))
#         legend_text1.set_text(f'{k_start:.2f} < K ≤ {k_end:.2f}')

#         range_mask_flat = (K_flat > k_start) & (K_flat <= k_end)
#         Mn_filtered = mapping_Mn_A_K[range_mask_flat]
#         Zn_filtered = mapping_Zn_A_K[range_mask_flat]
#         sc2.set_offsets(np.column_stack((Mn_filtered, Zn_filtered)))
#         legend_text2.set_text(f'{k_start:.2f} < K ≤ {k_end:.2f}')

#         range_mask_2d = (K_map > k_start) & (K_map <= k_end)
#         mask_overlay.set_data(range_mask_2d.astype(float))

#         fig.canvas.draw()
#         img_array = np.asarray(fig.canvas.buffer_rgba())[:, :, :3]
#         pil_img = Image.fromarray(img_array)
#         tiff_stack.append(pil_img.convert("RGB"))

#     # 保存 TIFF stack
#     tiff_output_path = gif_path.joinpath('K_distribution_with_kmask_stack.tif')
#     tiff_stack[0].save(
#         tiff_output_path,
#         save_all=True,
#         append_images=tiff_stack[1:],
#         compression="tiff_deflate"
#     )

#     plt.close(fig)
#     print(f"[✔] TIFF stack saved to: {tiff_output_path.resolve()}")

# generate_distribution_phase_mapping(
#     ps_mappings_Mn.data,
#     ps_mappings_Zn.data,
#     ps_mappings_mask.data,
#     k_min=-0.2,
#     k_max=5.0,
#     k_step=0.1,
#     gif_path=path_out,
#     figsize=(10, 2.5),
#     dpi=80,
#     bins=60
# )