In [None]:
import hyperspy.api as hs
import pyxem as pxm
import numpy as np
import time

from pathlib import Path

In [None]:
def get_masks_from_metadata(signal):
    """
    Returns the summed diffraction and navigation masks from a signals metadata
    
    Arguments:
    ----------
    signal: The hyperspy signal to extract the masks from. Diffraction masks will be looked for under "Preprocessing/Masks/Diffraction" and navigation masks will be looked for under "Preprocessing/Masks/Navigation".
    
    Returns:
    --------
    diffmask, navmask: The summed diffraction mask and the summed navmask
    """
    
    try:
        masks = signal.metadata.Preprocessing.Masks.Diffraction
    except AttributeError as e:
        print(f'No Diffraction mask available in metadata:\n{e}')
        diffmask = None
    else:
        print(f'Found {len(masks)} diffraction masks in the metadata')
        diffmask = np.zeros(s.axes_manager.signal_shape, dtype=bool)  # Create mask
        title = ''
        for mask_name, mask in masks:  # Iterate and add masks together
            title = f'{title}, {mask_name}'
            print(f'Adding mask "{mask_name}" {mask} to diffraction mask')
            diffmask += mask
        diffmask.metadata.General.title = f'Diffraction mask [{title}]'
    
    try:
        masks = signal.metadata.Preprocessing.Masks.Navigation  # Extract the navigation masks from the metadata
    except AttributeError as e:
        print(f'No Navigation masks available in metdata:\n{e}')
        navmask = None
    else:
        print(f"Found {len(masks)} navigation masks in the metadata")
        navmask = np.zeros(s.axes_manager.navigation_shape, dtype=bool)  # Create mask
        title = f''
        for mask_name, mask in masks:  # Iterate and add masks together
            title = f'{title}, {mask_name}'
            print(f"Adding mask "{mask_name}" {mask} to navigation mask")
            navmask += mask
        navmask.metadata.General.title = f'Navigation mask [{title}]'
    
    return diffmask, navmask

def estimate_threshold(loadings, component, method=None):
    if method is None:
        _ = try_all_threshold(np.nan_to_num(loadings.inav[component].data, copy=True, nan=np.nanmin(loadings.inav[component].data)))
        fig = plt.gcf()
        fig.suptitle(component)
    else:
        return method(np.nan_to_num(loadings.inav[component].data, copy=True, nan=np.nanmin(loadings.inav[component].data)))

# Dataset A

## Load and prepare data

In [None]:
filepath = Path(r'')

In [None]:
s = hs.load(str(filepath), lazy=False)
s.change_dtype('float32')
vbf = s.get_integrated_intensity(hs.roi.CircleROI(0.0, 0.0, 0.01))
maximums = s.max(axis=[0, 1])

## Get the pre-made masks from the metadata (see preprocessing notebook for details)

In [None]:
diffmask, navmask = get_masks_from_metadata(s)
hs.plot.plot_images([diffmask, vbf, diffmask*maximums], cmap='RdBu')

In [None]:
hs.plot.plot_images([navmask, vbf, vbf*navmask], cmap='RdBu', axes_decor='off')
#navmask = navmask.data.transpose() #The data needs to be transposed due to different conventions in hyperspy and numpy.

## Run SVD decomposition

In [None]:
tic = time.time()
decomp = s.decomposition(
    normalize_poissonian_noise=True,
    algorithm='SVD',
    navigation_mask=navmask.data.transpose(),
    signal_mask=diffmask.data,
    return_info=True
)
toc = time.time()
print(f'Finished decomposition. Elapsed time: {toc - tic} seconds')
print(f'Decoposition parameters: {decomp}')
print(f'Decomposition reconstruction error: {decomp.reconstruction_err_}')
print(f'Decomposition number of iterations: {decomp.n_iter_}')
s.learning_Results.save(f'{filepath}_SVD1.hspy')

## Run first NMF decomposition

In [None]:
output_dimension = 6 #The number of components to allow
tic = time.time()
decomp = s.decomposition(
    normalize_poissonian_noise=True,
    algorithm='NMF',
    output_dimension=output_dimension,
    navigation_mask=navmask.data.transpose(),
    signal_mask=diffmask.data,
    return_info=True,
    init='nndsvd',
    max_iter=10000
)
toc = time.time()
print(f'Finished decomposition. Elapsed time: {toc - tic} seconds')
print(f'Decoposition parameters: {decomp}')
print(f'Decomposition reconstruction error: {decomp.reconstruction_err_}')
print(f'Decomposition number of iterations: {decomp.n_iter_}')

#Save the decomposition results
s.learning_Results.save(f'{filepath}_NMF1_{output_dimension}.hspy')

#Save the factors and loadings individually as well
factors = signal.get_decomposition_factors()
loadings = signal.get_decomposition_loadings()
if decomp is not None:
    factors.metadata.add_dictionary({'Decomposition': decomp.__dict__()})
    loadings.metadata.add_dictionary({'Decomposition': decomp.__dict__()})
factors.save(f'{filepath}_NMF1_{output_dimension}_factors.hspy')
loadings.save(f'{filepath}_NMF1_{output_dimension}_loadings.hspy')

In [None]:
hs.plot.plot_images(loadings, per_row=6, cmap='grays', axes_decor='off')
hs.plot.plot_images(factors, per_row=6, cmap='grays', norm='symlog', axes_decor='off')

### Estimate thresholds for phase map

In [None]:
thresholds = {component: estimate_threshold(loadings, component, method) for component in components}

In [None]:
theta_100 = [loadings.inav[component]>=thresholds[component] for component in (2,3)]
T1 = [loadings.inav[component]>=1.04*thresholds[component] for component in (1, 4)] #Scale the threshold slightly

theta_100_mask = sum(theta_100)
T1_mask = sum(T1)
