In [5]:
%matplotlib tk

import pyxem as px
from diffsims import generators
import tensorflow as tf
import diffpy.structure
from matplotlib import pyplot as plt
import numpy as np
import os
import hyperspy as hs
from tqdm import tqdm

### Load phase maps and ground truth

In [6]:
ground_truth = hs.io.load(r'.\Final_outputs\ground_truth\ground_truth_combined.hdf5')
phase_map_nmf = hs.io.load(r'.\Final_outputs\phasemaps\DatasetA\pred-NMF.hdf5')
phase_map_vector = hs.io.load(r'.\Final_outputs\phasemaps\DatasetA\pred-vector.hdf5')
phase_map_template = hs.io.load(r'.\Final_outputs\phasemaps\DatasetA\pred-template-dog.hdf5')
phase_map_ANN = hs.io.load(r'.\Final_outputs\phasemaps\DatasetA\pred-ANN.hdf5')

### Estimate the error

In [7]:
error_nmf = np.count_nonzero(np.abs((phase_map_nmf-ground_truth).data))/(512*512)
print(f'NMF mislabels {error_nmf:.2%} of the pixels.')

error_vector = np.count_nonzero(np.abs((phase_map_vector-ground_truth).data))/(512*512)
print(f'Vector analysis mislabels {error_vector:.2%} of the pixels.')

error_template = np.count_nonzero(np.abs((phase_map_template-ground_truth).data))/(512*512)
print(f'Template matching mislabels {error_template:.2%} of the pixels.')

error_ANN = np.count_nonzero(np.abs((phase_map_ANN-ground_truth).data))/(512*512)
print(f'ANN mislabels {error_ANN:.2%} of the pixels.')

NMF mislabels 1.34% of the pixels.
Vector analysis mislabels 1.54% of the pixels.
Template matching mislabels 1.76% of the pixels.
ANN mislabels 0.99% of the pixels.


### Plot the difference maps

In [8]:
diff_nmf = ground_truth.data == phase_map_nmf.data
diff_vector = ground_truth.data == phase_map_vector.data
diff_template = ground_truth.data == phase_map_template.data
diff_ANN = ground_truth.data == phase_map_ANN.data

In [9]:
img = diff_ANN
file_name = '/diff-ANN.png'
directory = r'.\Final_outputs\phasemaps\DatasetA'

fig,ax = plt.subplots()
ax.axis('off')
ax.imshow(img, cmap='Greys')

plt.show()
plt.savefig(directory + file_name, transparent = True, bbox_inches = 'tight', pad_inches = 0, dpi=300)

### Plot phase maps

In [None]:
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap

color_names = ['linen', 'darkorange', 'dodgerblue', 'forestgreen']
colors = [to_rgba(c) for c in color_names]

cmap = LinearSegmentedColormap.from_list('gt_cmap', colors, N=len(color_names))

In [None]:
img = phase_map_ANN.data
file_name = '/pred_nmf.png'
directory = r'.\Final_outputs\phasemaps\DatasetA'
cmap_inplot = False
cmap_gt = True
scalebar = False

if cmap_gt:
    img_flatten = img.flatten()
    image = np.zeros((512*512,4))
    for i in range(len(color_names)):
        mask = img_flatten == i
        image[mask] = colors[i]
    image = image.reshape((512,512,4))
else:
    image = img

fig,ax = plt.subplots()
ax.axis('off')
if cmap_gt:
    ax.imshow(image, cmap=cmap)
else:
    ax.imshow(image, cmap='magma_r', norm=LogNorm())
if scalebar:
    ax.add_artist(scalebar_)

if cmap_inplot:
    cmappable = ScalarMappable(norm=Normalize(0,1), cmap=cmap)
    cbar = plt.colorbar(cmappable, orientation='horizontal', fraction=0.043, pad=0.04)

    tick_locs = ((np.arange(5) + 0.5))/5
    cbar.set_ticks(tick_locs)

    cbar.ax.set_xticklabels(['Al', r"$\theta'_{\langle100\rangle}$", r"$\theta'_{[001]}$", 'T1', 'Not indexed'])

    plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), fontsize=14)

plt.show()
plt.savefig(directory + file_name, transparent = True, bbox_inches = 'tight', pad_inches = 0, dpi=300)
