In [None]:
%matplotlib qt
import matplotlib.pyplot as plt
import hyperspy.api as hs
import pyxem as pxm
import numpy as np

from pathlib import Path
from matplotlib.colors import SymLogNorm
from skimage.measure import label

from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap

from skimage.filters.thresholding import try_all_threshold

from skimage.filters.thresholding import threshold_triangle, threshold_li, threshold_isodata

color_names = ['linen', 'darkorange', 'dodgerblue', 'forestgreen', 'red']
colors = [to_rgba(c) for c in color_names]

cmap = LinearSegmentedColormap.from_list('gt_cmap', colors, N=len(color_names))

## Convenient classes

In [None]:
class Decomposition(object):

    def __init__(self, datapath, label='', load_output_file=True, metadata_dict=None):
        self.datapath = Path(datapath)
        self.label = label
        self.loadings = hs.load(self.datapath.with_name(f'{self.datapath.stem}_loadings{self.datapath.suffix}'))
        self.factors = hs.load(self.datapath.with_name(f'{self.datapath.stem}_factors{self.datapath.suffix}'))

        if load_output_file:
            output_files = list(self.datapath.parent.glob(r'*.out'))
            if len(output_files) == 1:
                if metadata_dict is None:
                    metadata_dict = {}
                metadata_dict.update({'Decomposition_logfile': output_files[0].read_text()})
            else:
                if len(output_files) == 0:
                    print(f'No output log file detected')
                else:
                    print(f'{len(output_files)} output log files found, please add the correct output log file to the metadata of the loadings and factors manually instead')

        if metadata_dict is not None:
            if isinstance(metadata_dict, dict):
                self.loadings.metadata.add_dictionary({'Decomposition': metadata_dict})
                self.factors.metadata.add_dictionary({'Decomposition': metadata_dict})
            else:
                raise TypeError(f'Could not add {metadata_dict!r} as a metadata dictionary to {self!r}. Only dictionaries are allowed.')

    @property
    def output_dimension(self):
        return len(self.loadings)

    def __repr__(self):
        return f'{self.__class__.__name__}({self.datapath!r}, label={self.label!r})'

    def __str__(self):
        return f'{self.__class__.__name__} {self.label!s} with {self.output_dimension} components from path "{self.datapath.absolute()}"'

    def __iter__(self):
        for loading, factor in zip(self.loadings, self.factors):
            yield (loading, factor)

    def estimate_threshold(self, component, method=None):
        if method is None:
            _ = try_all_threshold(np.nan_to_num(self.loadings.inav[component].data, copy=True, nan=np.nanmin(self.loadings.inav[component].data)))
            fig = plt.gcf()
            fig.suptitle(component)
        else:
            return method(np.nan_to_num(self.loadings.inav[component].data, copy=True, nan=np.nanmin(self.loadings.inav[component].data)))

    def as_dictionary(self):
        return {'path': str(self.datapath.absolute()),
                'label': self.label,
                'loadings': self.loadings.deepcopy(),
                'factors': self.factors.deepcopy()
               }


    def plot(self, *args, **kwargs):
        hs.plot.plot_signals([self.loadings, self.factors], *args, **kwargs)

    def export_as_png(self, axis_size = 6, dpi=150, *args, **kwargs):
        for i, (loading, factor) in enumerate(self):
            fig, axes = plt.subplots(nrows=1, ncols=2, subplot_kw={'xticks': [], 'yticks': []}, figsize=(axis_size*2, axis_size), dpi=dpi)
            axes[0].imshow(loading.data,cmap='viridis_r')
            axes[1].imshow(factor.data, norm=SymLogNorm(0.03),cmap='viridis_r')
            axes[0].annotate(f'Loading {i}', (0.02, 0.98), xycoords='axes fraction', color='w', ha='left', va='top', bbox=dict(facecolor='k', alpha=0.5))
            axes[1].annotate(f'Factor {i}', (0.02, 0.98), xycoords='axes fraction', color='w', ha='left', va='top', bbox=dict(facecolor='k', alpha=0.5))
            plt.tight_layout()
            fig.savefig(self.datapath.with_name(f'{self.datapath.stem}_{i}.png'), dpi=dpi)
            plt.close('all')

    def __hash__(self):
        return hash(tuple(self.loadings, self.factors))

class DecomposedPhase(object):

    def __init__(self, phase, decomposition, components, thresholds=None, value=1):
        self.phase = phase
        self.value = value
        self.decomposition = decomposition
        self.component_thresholds = dict()
        if thresholds is None:
            thresholds = [None]*len(components)
        [self.update_component(component, threshold) for (component, threshold) in zip(components, thresholds)]

    def __repr__(self):
        return f'{self.__class__.__name__}({self.phase!r}, {self.decomposition!r}, {list(self.component_thresholds.keys())!r}, {list(self.component_thresholds.values())!r}, value={self.value!r})'

    def __str__(self):
        return f'{self.__class__.__name__} for {self.phase!s} ({self.value}) based on {self.decomposition!s}.\nComponent thresholds:{self.component_thresholds!s}'

    @property
    def data(self):
        phase_map = np.zeros(self.decomposition.loadings.axes_manager.signal_shape, dtype=bool)
        if len(self.component_thresholds) > 0:
            for component in self.component_thresholds:
                phase_map += self.decomposition.loadings.inav[component].data>=self.component_thresholds[component]
        phase_map = phase_map * self.value
        return phase_map

    def as_signal(self):
        s = hs.signals.Signal2D(self.data)
        s.metadata.add_dictionary({
            'Phase': {
                'name': self.phase,
                'value': self.value,
                'decomposition': self.decomposition.as_dictionary(),
                'component_thresholds': {str(key): self.component_thresholds[key] for key in self.component_thresholds}},
            'General': {
                'title': f"{self.__class__.__name__} for {self.phase}"}
        })
        return s

    def update_component(self, component, threshold=None):
        if 0 <= component < len(self.decomposition.loadings):
            self.component_thresholds.update({component: threshold})
        else:
            raise IndexError(f'Cannot update component {component} in {self!r}, the component is out of range for decomposition {self.decomposition}')

    def remove_component(self, component):
        self.component_thresholds = {key: self.component_thresholds[key] for key in self.component_thresholds if key != component}

    def estimate_thresholds(self, method=None, update=False):
        thresholds = {component: self.decomposition.estimate_threshold(component, method) for component in self.component_thresholds}
        if update:
            [self.update_component(component, thresholds[component]) for component in thresholds]
        return thresholds

    def plot(self, axis_size = 6, dpi=150, savefig=False, *args, **kwargs):
        ncols = len(self.component_thresholds)+1
        fig, axes = plt.subplots(nrows=1, ncols=ncols, sharex=True, sharey=True, subplot_kw={'xticks': [], 'yticks': []}, figsize=(axis_size*ncols, axis_size), dpi=dpi)
        for i, component in enumerate(self.component_thresholds):
            axes[i].imshow(self.decomposition.loadings.inav[component].data,**kwargs)
            axes[i].annotate(f'Loading {component} of phase {self.phase}', (0.02, 0.98), xycoords='axes fraction', color='w', ha='left', va='top', bbox=dict(facecolor='k', alpha=0.5))
        axes[-1].imshow(self.data, cmap='Greys_r')
        plt.tight_layout()
        if savefig:
            fig.savefig(self.decomposition.datapath.with_name(f'{self.decomposition.datapath.stem}_{i}.png'), dpi=dpi)
            plt.close('all')

    def plot_component_histograms(self, components=None, axis_size = 6, dpi=150, savefig=False, *args, **kwargs):
        if components is None:
            components = list(self.component_thresholds.keys())
        ncols = len(components)
        print(ncols)
        fig, axes = plt.subplots(nrows=1, ncols=ncols, sharex=True, sharey=True)
        try:
            l = len(axes)
        except TypeError:
            l = 1
            axes = list([axes])
        for i, component in enumerate(components):
            component_threshold = self.component_thresholds.get(component, None)
            if component_threshold is None:
                raise ValueError(f'Cannot plot histogram of component {component} for {self!r}. No such component in the threshold dicrionary')
            axes[i].hist(self.decomposition.loadings.inav[component].data.flatten(), *args, **kwargs)
            axes[i].axvline(component_threshold, color='r')
            axes[i].set_title(f'Decomposition histogram of component {component}')
        plt.tight_layout()

    def make_mask(self, add_to=None):
        navigation_mask = hs.signals.Signal2D(self.data>=self.value, metadata={'General': {'title': f'{self.phase}'}})
        if add_to is not None:
            add_to.metadata.add_dictionary({'Preprocessing': {'Masks': {'Navigation': {self.phase: navigation_mask}}}})
        return navigation_mask

class PhaseMap(object):

    def __init__(self, phases):
        """
        Create a phase map from a list of PRIORITIZED phases. The first phase will have the highest priority, and the last will have the lowest.
        """
        type_test = [isinstance(phase, DecomposedPhase) for phase in phases]
        if not all(type_test):
            raise TypeError(f'Only DecompositionPhase objects are allowed to be specified in a PhaseMap: {type_test}')
        navigation_shapes = [phase.decomposition.loadings.axes_manager.signal_shape for phase in phases]
        if not all([navigation_shape == navigation_shapes[0] for navigation_shape in navigation_shapes]):
            raise ValueError(f'Navigation shapes in supplied phases does not match: {navigation_shapes}')

        self.phases = phases
        self.nx, self.ny = navigation_shapes[0]

    def __repr__(self):
        return f'{self.__class__.__name__}({self.phases!r})'

    def __str__(self):
        return f'{self.__class__.__name__} ({self.nx}x{self.ny}) of {self.phases!s}'

    def __iter__(self):
        for phase in self.phases:
            yield phase

    @property
    def data(self):
        phase_map = np.zeros((self.nx, self.ny), dtype=int)
        for phase in self.phases[::-1]:
            phase_map = np.where(phase.data>0, phase.data, phase_map)
        return phase_map


    def plot_phase_map(self, axis_size = 6, dpi=150, *args, **kwargs):
        fig, ax = plt.subplots(nrows=1, ncols=1, subplot_kw={'xticks': [], 'yticks': []}, figsize=(axis_size, axis_size), dpi=dpi, frameon=False)
        ax.imshow(self.data, *args, **kwargs)
        plt.tight_layout()

    def as_signal(self):
        s = hs.signals.Signal2D(self.data)
        s.metadata.add_dictionary({'Phases': {phase.phase: phase.as_signal() for phase in self}})
        s.metadata.General.title = r'Phase map'
        return s

    def save(self, filename, *args, **kwargs):
        self.as_signal().save(filename, *args, **kwargs)

    def make_RGBA(self, normalize_colors = True, normalize_alpha=True):
        rgba_loadings = np.zeros(self.phases[0].decomposition.loadings.axes_manager.signal_shape + (4,))
        rgba_factors = np.zeros(self.phases[0].decomposition.factors.axes_manager.signal_shape + (4,))
        for i, phase in enumerate(self.phases):

            factor = np.zeros(phase.decomposition.factors.axes_manager.signal_shape)
            loading = np.zeros(phase.decomposition.loadings.axes_manager.signal_shape)
            for c in phase.component_thresholds:
                factor += phase.decomposition.factors.inav[c].data
                loading += np.nan_to_num(phase.decomposition.loadings.inav[c].data, copy=True, nan=0)

            if normalize_colors:
                factor = factor/np.nanmax(factor)
                loading = loading/np.nanmax(loading)

            rgba_loadings[:, :, i] = loading
            rgba_factors[:, :, i] = factor

            rgba_loadings[:, :, -1] += loading
            rgba_factors[:, :, -1] += factor

        if normalize_alpha:
            rgba_loadings[:, :, -1] /= np.nanmax(rgba_loadings[:, :, -1])
            rgba_factors[:, :, -1] /= np.nanmax(rgba_factors[:, :, -1])

        figure = plt.figure(figsize=(6, 6), frameon=True)
        ax = figure.add_axes([0, 0, 1, 1], xticks=[], yticks=[], frameon=False)
        ax.imshow(rgba_loadings)
        subax = figure.add_axes([0.75, 0.75, 0.25, 0.25], xticks=[], yticks=[], frameon=True)
        subax.imshow(rgba_factors)
        #figure.savefig(f'{datapath.stem}_decomposition_RGBA.png')

### Assist functions

In [None]:
def load_decompositions(path):
    p = Path(path)
    decompositions = []
    for factor in p.glob(r'*_factors.hspy'):
        decompositions.append(Decomposition(factor.with_name(factor.name.replace('_factors', ''))))

    return sorted(decompositions, key=lambda decomp: decomp.output_dimension)

def read_logfile(path):
    p = Path(path)
    return p.read_text()

## SVD

In [None]:
datapath = Path(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\2021_10_06_2xxx_24h_250C\Preprocessed_data\SPED_600x600x12_10x10_4p63x4p63_1deg_100Hz_CL12cm_NBD_alpha5_spot1p3_preprocessed.hspy')
signal = hs.load(datapath, lazy=True)

In [None]:
signal.metadata

In [6]:
signal.plot_explained_variance_ratio()
print(f'{int(signal.estimate_elbow_position())}')

5


## NMF

### 1st iteration

In [None]:
decomposition_1 = Decomposition(r'nndsvd/12980451/SPED_600x600x12_10x10_4p63x4p63_1deg_100Hz_CL12cm_NBD_alpha5_spot1p3_preprocessed_NMF_5.hspy')

Estimate thresholds

In [None]:
[base_decomposition.estimate_threshold(component) for component in range(base_decomposition.output_dimension)]

Make phasemaps

In [None]:
T1 = DecomposedPhase('T1', base_decomposition, [1, 4], value=3)
thetap_100 = DecomposedPhase('theta_100', base_decomposition, [2, 3], value=1)

T1_thresholds = T1.estimate_thresholds(method=threshold_triangle, update=True)
thethap_100_thresholds = thetap_100.estimate_thresholds(method=threshold_triangle, update=True)

print(T1.component_thresholds)
print(thetap_100.component_thresholds)

p = PhaseMap([thetap_100, T1])
p.plot_phase_map(cmap=cmap, vmin=0, vmax=4)

Tune threshold values

In [None]:
[T1.update_component(component, 1.04*T1.component_thresholds[component]) for component in T1.component_thresholds]
pp = PhaseMap([thetap_100, T1])
pp.plot_phase_map(cmap=cmap, vmin=0, vmax=4)

Make exclusion mask

In [None]:
signalpath = Path(r'Y:\Input\SPED\PhaseMappingPaper\New\SPED_600x600x12_10x10_4p63x4p63_1deg_100Hz_CL12cm_NBD_alpha5_spot1p3_preprocessed.hspy')
signal = hs.load(signalpath, lazy=True, mode='a')

T1.make_mask(add_to=signal)
thetap_100.make_mask(add_to=signal)

print(signal.metadata)
signal.save(signalpath, write_dataset=False, close_file=True)
#signal.close_file()

### 2nd iteration

In [None]:
decompositions = load_decompositions(r'nndsvd/navmasked/13051293/')

In [None]:
decompositions[4].plot()

In [None]:
decomposition_2 = decompositions[4]
print(decomposition_2)

In [None]:
thetap_001 = DecomposedPhase('theta_001', decomposition_2, [4], value=2)

In [None]:
thetap_001.estimate_thresholds()

In [None]:
thetap_001.estimate_thresholds(method=threshold_isodata, update=True)

In [None]:
thetap_001.plot()

## Make Phase map

In [None]:
phasemap = PhaseMap([thetap_100, T1, thetap_001])
phasemap.plot_phase_map(cmap=cmap, vmin=0, vmax=4)
phasemap.make_RGBA(normalize_colors=True, normalize_alpha=True)
phasemap.save(r'NMF_phasemap.hspy')

## Make plots

In [None]:
for phase_name, phase in s.metadata.Phases:
    factors = phase.metadata.Phase.decomposition.factors
    loadings = phase.metadata.Phase.decomposition.loadings
    datapath = Path(phase.metadata.Phase.decomposition.path)

    axis_size = 6 #inches
    dpi=150
    cmap = plt.colormaps.get('Greys')
    cmap.set_bad('lightblue')

    fig = plt.figure(figsize=(3, 3), frameon=False, dpi=dpi)
    ax = fig.add_axes((0, 0, 1, 1), xticks=[], yticks=[], frameon=False)
    ax.imshow(phase.data, cmap="gray_r")
    fig.savefig(f"{phase.metadata.Phase.name}_phaseimage.png", dpi=dpi)

    for component_key, threshold in phase.metadata.Phase.component_thresholds:
        component = int(component_key.replace("Number_", ""))
        print(f"{phase_name}: {component}")
        fig, axes = plt.subplots(nrows=1, ncols=2, subplot_kw={'xticks': [], 'yticks': []}, figsize=(axis_size*2, axis_size), dpi=dpi)
        axes[0].imshow(loadings.inav[component].data,cmap=cmap)
        axes[1].imshow(factors.inav[component].data, norm=SymLogNorm(0.03),cmap='viridis_r')
        axes[0].annotate(f'Loading {component}', (0.02, 0.98), xycoords='axes fraction', color='w', ha='left', va='top', bbox=dict(facecolor='k', alpha=0.5))
        axes[1].annotate(f'Factor {component}', (0.02, 0.98), xycoords='axes fraction', color='w', ha='left', va='top', bbox=dict(facecolor='k', alpha=0.5))
        plt.tight_layout()
        fig.savefig(f'{phase.metadata.Phase.name}_{component}.png', dpi=dpi)
        plt.close('all')

        fig = plt.figure(figsize=(3, 3), frameon=False, dpi=dpi)
        ax = fig.add_axes((0, 0, 1, 1), xticks=[], yticks=[], frameon=False)
        ax.imshow(loadings.inav[component].data, cmap=cmap)
        fig.savefig(f'{phase.metadata.Phase.name}_loading_{component:03d}.png', dpi=dpi)


        factor_image = factors.inav[component].deepcopy()
        factor_image = factor_image.data
        #if False:
        #    factor_image[mask] = np.nan
        fig = plt.figure(figsize=(3, 3), frameon=False, dpi=dpi)
        ax = fig.add_axes((0, 0, 1, 1), xticks=[], yticks=[], frameon=False)
        ax.imshow(factor_image, cmap=cmap)
        fig.savefig(f'{phase.metadata.Phase.name}_factor_{component:03d}.png', dpi=dpi)

## Compare

In [66]:
gt = hs.load(r'C:\Users\emilc\OneDrive - NTNU\NORTEM\Data\2021_10_06_2xxx_24h_250C\Ground_truth_all.hdf5')
gt.plot(cmap=cmap)

pm = hs.load(r'NMF_phasemap.hspy')
pm.plot(cmap=cmap)

difference = gt - pm
difference = np.abs(difference.data) > 0

percentage_error = np.count_nonzero(difference) / np.multiply(*gt.axes_manager.signal_shape)
print(f'Percentage error of NMF phase map: {percentage_error:.0%}\nSuccessrate: {1-percentage_error:.0%}')

percentage_error = np.count_nonzero(difference[gt.data>0]) / np.count_nonzero(gt.data)
print(f'Percentage error of NMF phase map (disregarding ground_truth Al pixels): {percentage_error:.0%} (Successrate: {1-percentage_error:.0%})')

fig = plt.figure(figsize=(ax_size*2, ax_size*2), frameon=False)
ax = fig.add_axes([0, 0, 1, 1], xticks=[], yticks=[], frameon=False)
ax.imshow(difference, cmap='Greys_r')
fig.savefig('GT_error.png')
error_map_signal = hs.signals.Signal2D(difference)
error_map_signal.save('GT_error.hspy')

