In [33]:
%matplotlib qt

import hyperspy.api as hs
import pyxem as pxm
import numpy as np
import time
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap
color_names = ['linen', 'darkorange', 'dodgerblue', 'forestgreen', 'red']
colors = [to_rgba(c) for c in color_names]

cmap = LinearSegmentedColormap.from_list('gt_cmap', colors, N=len(color_names))

gray_cmap = plt.colormaps.get('Greys')
gray_cmap.set_bad('lightblue')

In [2]:
def get_masks_from_metadata(signal):
    """
    Returns the summed diffraction and navigation masks from a signals metadata
    
    Arguments:
    ----------
    signal: The hyperspy signal to extract the masks from. Diffraction masks will be looked for under "Preprocessing/Masks/Diffraction" and navigation masks will be looked for under "Preprocessing/Masks/Navigation".
    
    Returns:
    --------
    diffmask, navmask: The summed diffraction mask and the summed navmask
    """
    
    try:
        masks = [mask for mask in signal.metadata.Preprocessing.Masks.Diffraction]
    except AttributeError as e:
        print(f'No Diffraction mask available in metadata:\n{e}')
        diffmask = None
    else:
        print(f'Found {len(masks)} diffraction masks in the metadata')
        diffmask = masks[0][1].deepcopy()
        try:
            diffmask.compute()
        except AttributeError:
            pass
        diffmask.data = np.zeros(s.axes_manager.signal_shape, dtype=bool)  # Create mask
        title = []
        for mask_name, mask in masks:  # Iterate and add masks together
            try:
                mask.compute()
            except AttributeError:
                pass
            title.append(mask_name.strip())
            print(f'Adding mask "{mask_name}" {mask} to diffraction mask')
            diffmask += mask
        diffmask.metadata.General.title = f'Diffraction mask [{", ".join(title)}]'
    
    try:
        masks = [mask for mask in signal.metadata.Preprocessing.Masks.Navigation]  # Extract the navigation masks from the metadata
    except AttributeError as e:
        print(f'No Navigation masks available in metdata:\n{e}')
        navmask = None
    else:
        print(f'Found {len(masks)} navigation masks in the metadata')
        navmask = masks[0][1].deepcopy()
        try:
            navmask.compute()
        except AttributeError:
            pass
        navmask.data = np.zeros(s.axes_manager.navigation_shape, dtype=bool)  # Create mask
        title = f''
        for mask_name, mask in masks:  # Iterate and add masks together
            try:
                mask.compute()
            except AttributeError:
                pass
            title = f'{title}, {mask_name.strip()}'
            print(f'Adding mask "{mask_name}" {mask} to navigation mask')
            navmask += mask
        navmask.metadata.General.title = f'Navigation mask [{title}]'
    
    return diffmask, navmask

def estimate_threshold(loadings, component, method=None):
    if method is None:
        _ = try_all_threshold(np.nan_to_num(loadings.inav[component].data, copy=True, nan=np.nanmin(loadings.inav[component].data)))
        fig = plt.gcf()
        fig.suptitle(component)
    else:
        return method(np.nan_to_num(loadings.inav[component].data, copy=True, nan=np.nanmin(loadings.inav[component].data)))

# Dataset A

## Load and prepare data

In [3]:
filepath = Path(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\PhaseMappingPaper\Data\Dataset A\datasetA_preprocessed.hspy')

In [50]:
s = hs.load(str(filepath), lazy=False)
s.change_dtype('float32')
vbf = s.get_integrated_intensity(hs.roi.CircleROI(0.0, 0.0, 0.1))
maximums = s.max(axis=[0, 1])
try:
    vbf.compute()
    maximums.compute()
except AttributeError:
    pass
hs.plot.plot_images([vbf, maximums], norm='symlog', axes_decor='off', colorbar=None, cmap='gray_r')

[<AxesSubplot: title={'center': 'Integrated intensity'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': 'Dataset A'}, xlabel='kx axis ($A^{-1}$)', ylabel='ky axis ($A^{-1}$)'>]

## Get the pre-made masks from the metadata (see preprocessing notebook for details)

In [51]:
diffmask, navmask = get_masks_from_metadata(s)
if diffmask is not None:
    hs.plot.plot_images([maximums, diffmask*1.0], overlay=True, alphas=[1, 0.5], colors=['w', 'r'], axes_decor='off')
if navmask is not None:
    hs.plot.plot_images([vbf, navmask*1.0], overlay=True, alphas=[1, 0.5], colros=['w', 'r'], axes_decor='off')

Found 3 diffraction masks in the metadata
Adding mask "cutoff" <Signal2D, title: <7.001e-01 $A^{-1}$ mask, dimensions: (|128, 128)> to diffraction mask
Adding mask "direct_beam" <Signal2D, title: >1.595e-01 $A^{-1}$ mask, dimensions: (|128, 128)> to diffraction mask
Adding mask "reflections" <Signal2D, title: Reflection mask, dimensions: (|128, 128)> to diffraction mask
No Navigation masks available in metdata:
Navigation


## Run SVD decomposition

In [53]:
tic = time.time()
if navmask is not None:
    transposed_navmask = navmask.data.transpose()
else:
    transposed_navmask = None
    
decomp = s.decomposition(
    normalize_poissonian_noise=True,
    algorithm='SVD',
    navigation_mask=transposed_navmask,
    signal_mask=diffmask.data,
    return_info=True
)
toc = time.time()
print(f'Finished decomposition. Elapsed time: {toc - tic} seconds')
s.learning_results.save(filepath.with_name(f'{filepath.stem}_SVD1.hspy'))

  explained_variance_ratio = explained_variance / explained_variance.sum()


Decomposition info:
  normalize_poissonian_noise=True
  algorithm=SVD
  output_dimension=None
  centre=None
Finished decomposition. Elapsed time: 3151.565973520279 seconds
Decoposition parameters: None


AttributeError: 'NoneType' object has no attribute 'reconstruction_err_'

In [63]:
threshold = 6 #where the estimated threshold in the explained variance is - tune!
markersize = 4 #Marker size for the plot
dpi=300 #DPI
figwidth = 468/3 #figure size in points
pt2in = 0.01389 #conversion from points to inches
figsize = (figwidth*pt2in, figwidth*pt2in) #Figuresize in inches
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
s.plot_explained_variance_ratio(n=32,
                                threshold=threshold,
                                xaxis_type='number',
                                xaxis_labeling='ordinal',
                                ignal_fmt={'marker': 'o', 'color': 'k', 'markerfacecolor': 'k', 'lw': 0, 'ms':markersize},
                                noise_fmt={'marker':'o', 'color': 'k', 'markerfacecolor': 'w', 'lw': 0, 'ms':markersize},
                                fig=fig,
                                ax=ax
                               )
ax.set_title('')
ax.set_xlim(0)
plt.tight_layout()
fig.savefig(filepath.with_name(f'{filepath.stem}_SVD1.png'), dpi=dpi)

## Run first NMF decomposition

In [None]:
output_dimension = 6 #The number of components to allow
tic = time.time()
if navmask is not None:
    transposed_navmask = navmask.data.transpose()
else:
    transposed_navmask = None
decomp = s.decomposition(
    normalize_poissonian_noise=True,
    algorithm='NMF',
    output_dimension=output_dimension,
    navigation_mask=transposed_navmask,
    signal_mask=diffmask.data,
    return_info=True,
    init='nndsvd',
    max_iter=10000
)
toc = time.time()
print(f'Finished decomposition. Elapsed time: {toc - tic} seconds')
print(f'Decoposition parameters: {decomp}')
print(f'Decomposition reconstruction error: {decomp.reconstruction_err_}')
print(f'Decomposition number of iterations: {decomp.n_iter_}')

#Save the decomposition results
s.learning_results.save(filepath.with_name(f'{filepath.stem}_NMF1_{output_dimension}.hspy'))

#Save the factors and loadings individually as well
factors = signal.get_decomposition_factors()
loadings = signal.get_decomposition_loadings()
if decomp is not None:
    factors.metadata.add_dictionary({'Decomposition': decomp.__dict__()})
    loadings.metadata.add_dictionary({'Decomposition': decomp.__dict__()})
factors.save(filepath.with_name(f'{filepath.stem}_NMF1_{output_dimension}_factors.hspy'))
loadings.save(filepath.with_name(f'{filepath.stem}_NMF1_{output_dimension}_loadings.hspy'))

In [None]:
hs.plot.plot_images(loadings, per_row=output_dimension, cmap='grays', axes_decor='off')
hs.plot.plot_images(factors, per_row=output_dimension, cmap='grays', norm='symlog', axes_decor='off')

### Estimate thresholds for phase map

In [51]:
l1 = hs.load(Path(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\PhaseMappingPaper\Data\NMF troubleshooting\17428150\datasetA_preprocessed_NMF_5_loadings.hspy'))
l2 = hs.load(Path(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\PhaseMappingPaper\Data\NMF troubleshooting\17442425\datasetA_preprocessed_NMF_5_loadings.hspy'))

In [53]:
hs.plot.plot_images([l1, l2], per_row=5, axes_decor='off')

[<AxesSubplot: title={'center': ' (0,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (1,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (2,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (3,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (4,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (0,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (1,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (2,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (3,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>,
 <AxesSubplot: title={'center': ' (4,)'}, xlabel='x axis (nm)', ylabel='y axis (nm)'>]

In [47]:
loadings.metadata

In [5]:
loadings.plot()

In [8]:
from skimage.filters import try_all_threshold, threshold_li, threshold_triangle

In [10]:
method = None
components = [1, 2, 3, 4]
methods = [threshold_triangle, threshold_triangle, threshold_li, threshold_li]
thresholds = {component: estimate_threshold(loadings, component, method) for (component, method) in zip(components, methods)}

In [11]:
thresholds

{1: 0.00012760356, 2: 0.00014046804, 3: 0.00057723, 4: 0.0006080403}

In [13]:
theta_100 = [loadings.inav[component]>=thresholds[component] for component in (3,4)]
T1 = [loadings.inav[component]>=thresholds[component] for component in (1, 2)] #Scale the threshold slightly

#theta_100_mask = sum(theta_100)
#T1_mask = sum(T1)


In [28]:
theta_100 + T1

[<Signal2D, title: Decomposition loadings of Dataset A, dimensions: (|512, 512)>,
 <Signal2D, title: Decomposition loadings of Dataset A, dimensions: (|512, 512)>,
 <Signal2D, title: Decomposition loadings of Dataset A, dimensions: (|512, 512)>,
 <Signal2D, title: Decomposition loadings of Dataset A, dimensions: (|512, 512)>]

In [34]:
color_names

['linen', 'darkorange', 'dodgerblue', 'forestgreen', 'red']

In [35]:
colors

[(0.9803921568627451, 0.9411764705882353, 0.9019607843137255, 1.0),
 (1.0, 0.5490196078431373, 0.0, 1.0),
 (0.11764705882352941, 0.5647058823529412, 1.0, 1.0),
 (0.13333333333333333, 0.5450980392156862, 0.13333333333333333, 1.0),
 (1.0, 0.0, 0.0, 1.0)]

In [46]:
hs.plot.plot_images([m*1.0 for m in T1 + theta_100], overlay=True, colors=[colors[3], colors[3], colors[1], colors[1]], label=['T$_1$', 'T$_1$', '$\theta_{100}$', '$\theta_{100}$'])

[<AxesSubplot: >]

In [48]:
phasemap = hs.signals.Signal2D(np.zeros(loadings.axes_manager.signal_shape))

phasemap.data[T1[0]>0] = 1
phasemap.data[T1[1]>0] = 1

phasemap.data[theta_100[0]>0] = 2
phasemap.data[theta_100[1]>0] = 2

phasemap.plot()

In [22]:
phasemap.plot()

In [16]:
T1[1].plot()

In [17]:
loadings.axes_manager.signal_shape

(512, 512)

In [26]:
old_map = hs.load(Path(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\NMF\NMF_phasemap.hspy'))

ERROR:hyperspy.io:If this file format is supported, please report this error to the HyperSpy developers.


OSError: [Errno 22] Unable to open file (file read failed: time = Tue Mar 14 08:30:25 2023
, filename = 'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\NMF\NMF_phasemap.hspy', file descriptor = 7, errno = 22, error message = 'Invalid argument', buf = 000000D7BCBEA0D0, total read size = 8, bytes this sub-read = 8, bytes actually read = 18446744073709551615, offset = 0)