# Correcting distortion in SAED data #

## A NOTE BEFORE STARTING ##

Since the ``emicroml`` git repository tracks this notebook under its original
basename ``correcting_distortion_in_saed_data.ipynb``, we recommend that you
copy the original notebook and rename it to any other basename that is not one
of the original basenames that appear in the ``<root>/examples`` directory
before executing any of the notebook cells below, where ``<root>`` is the root
of the ``emicroml`` repository. For example, you could rename it
``correcting_distortion_in_saed_data.ipynb``. This way you can explore the
notebook by executing and modifying cells without changing the original
notebook, which is being tracked by git.

## Import necessary modules ##

In [None]:
# For pattern matching.
import re

# For listing files and subdirectories in a given directory, and for renaming
# directories.
import os



# For general array handling.
import numpy as np

# For creating and plotting figures.
import hyperspy.api as hs
import matplotlib.pyplot as plt

# For minimizing objective functions.
import scipy.optimize



# For loading ML models for distortion estimation in CBED.
import emicroml.modelling.cbed.distortion.estimation

In [None]:
%matplotlib ipympl
%matplotlib ipympl

## Introduction ##

In this notebook, we show how one can use the machine learning (ML) model that
is trained as a result of executing the "action" described in the page [Training
a machine learning
model](https://mrfitzpa.github.io/emicroml/examples/modelling/cbed/distortion/estimation/train_ml_model_set.html)
to correct distortion in selected area electron diffraction (SAED)
data. Strictly speaking, this ML model is trained to estimate distortion in
convergent beam electron diffraction (CBED) patterns. However, by exploiting the
fact that distortions predominantly come from post-specimen lenses,
e.g. projection lenses, we can estimate and correction distortion in SAED data
as follows:

1. Collect the target experimental SAED data;
2. Modify only pre-specimen lenses to produce CBED data;
3. Use ML model to estimate distortion field in CBED data;
4. Correct distortion in SAED data using distortion field from step 3.

We demonstrate steps 3 and 4 using pre-collected experimental SAED and CBED
patterns of a calibration sample of single-crystal Au oriented in the \[100\]
direction. This experimental data was collected on a modified Hitachi SU9000
scanning electron microscope operated at 20 keV.

You can find the documentation for the ``emicroml`` library
[here](https://mrfitzpa.github.io/emicroml/_autosummary/emicroml.html).  It is
recommended that you consult the documentation of this library as you explore
the notebook. Moreover, users should execute the cells in the order that they
appear, i.e. from top to bottom, as some cells reference variables that are set
in other cells above them.

## Loading and visualizing the SAED and CBED patterns ##

Let's load and visualize the target SAED pattern:

In [None]:
path_to_data_dir = "../data"
filename = (path_to_data_dir 
            + "/for_demo_of_distortion_correction_in_saed_data"
            + "/distorted_saed_pattern.npy")

kwargs = {"file": filename}
distorted_saed_pattern_image = np.load(**kwargs)

kwargs = {"data": distorted_saed_pattern_image}
distorted_saed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_saed_pattern_signal.plot(**kwargs)

This SAED pattern is subject to optical distortion which we want to correct. To
do this, keeping the sample inside, we modify only the pre-specimen lenses to
produce a CBED pattern which should be subject approximately to the same
distortion.

Let's load and visualize the target CBED pattern:

In [None]:
filename = (path_to_data_dir 
            + "/for_demo_of_distortion_correction_in_saed_data"
            + "/distorted_cbed_pattern.npy")

kwargs = {"file": filename}
distorted_cbed_pattern_image = np.load(**kwargs)

kwargs = {"data": distorted_cbed_pattern_image}
distorted_cbed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_cbed_pattern_signal.plot(**kwargs)

## Estimating the distortion in the CBED pattern ##

Now let's load our ML model so that we can estimate the distortion in the CBED
pattern:

In [None]:
path_to_ml_model_state_dicts = path_to_data_dir + "/ml_models/ml_model_0"
pattern = "ml_model_at_lr_step_[0-9]*\.pth"
largest_lr_step_idx = max([name.split("_")[-1].split(".")[0]
                           for name in os.listdir(path_to_ml_model_state_dicts)
                           if re.fullmatch(pattern, name)])

ml_model_state_dict_filename = \
    (path_to_ml_model_state_dicts
     + "/ml_model_at_lr_step_{}.pth".format(largest_lr_step_idx))



module_alias = emicroml.modelling.cbed.distortion.estimation
kwargs = {"ml_model_state_dict_filename": ml_model_state_dict_filename,
          "device_name": None}  # Default to CUDA device if available.
ml_model = module_alias.load_ml_model_from_file(**kwargs)

_ = ml_model.eval()

With the ML model loaded, let's estimate the distortion in the CBED pattern:

In [None]:
sampling_grid_dims_in_pixels = distorted_cbed_pattern_image.shape
distorted_cbed_pattern_images = distorted_cbed_pattern_image[None, :, :]

kwargs = {"cbed_pattern_images": distorted_cbed_pattern_images,
          "sampling_grid_dims_in_pixels": sampling_grid_dims_in_pixels}
distortion_models = ml_model.predict_distortion_models(**kwargs)

distortion_model = distortion_models[0]

Note that any input distorted CBED pattern must have image dimensions, in units
of pixels, equal to
``2*(ml_model.core_attrs["num_pixels_across_each_cbed_pattern"],)``. This is
because a given ML model is trained for images of fixed dimensions, in units of
pixels.

Let's visualize the predicted distortion field:

In [None]:
slice_step = 16



quiver_kwargs = {"angles": "uv",
                 "pivot": "middle",
                 "scale_units": "width"}



attr_name = "sampling_grid"
sampling_grid = getattr(distortion_model, attr_name)
sampling_grid = (sampling_grid[0].numpy(), sampling_grid[1].numpy())

X = sampling_grid[0][::slice_step, ::slice_step]
Y = sampling_grid[1][::slice_step, ::slice_step]



fig, ax = plt.subplots()

attr_name = "flow_field_of_coord_transform"
flow_field = getattr(distortion_model, attr_name)
flow_field = (flow_field[0].numpy(), flow_field[1].numpy())

U = flow_field[0][::slice_step, ::slice_step]
V = flow_field[1][::slice_step, ::slice_step]

kwargs = quiver_kwargs
ax.quiver(X, Y, U, V, **kwargs)
ax.set_title("Flow Field Of Coordinate Transformation")
ax.set_xlabel("fractional horizontal coordinate")
ax.set_ylabel("fractional vertical coordinate")

plt.gca().set_aspect('equal')
plt.tight_layout()
plt.show()

## Correcting the distortion in the SAED pattern ##

Let's use the predicted distortion model to correct the distortion in the SAED
pattern:

In [None]:
kwargs = \
    {"distorted_images": distorted_saed_pattern_image[None, None, :, :]}
undistorted_then_resampled_images = \
    distortion_model.undistort_then_resample_images(**kwargs)

undistorted_saed_pattern_image = \
    undistorted_then_resampled_images[0, 0].numpy(force=True)



kwargs = {"data": undistorted_saed_pattern_image}
undistorted_saed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
undistorted_saed_pattern_signal.plot(**kwargs)

## Assessing the accuracy of the distortion correction ##

We know that the sample is single-crystal Au oriented in the \[100\] direction,
used for calibration. As such, in the absence of distortions, the zero-order
Laue zone (ZOLZ) reflections should lie on a square lattice. Therefore, to
assess the accuracy of the distortion correction, we can fit square lattices to
both the distorted SAED pattern and the undistorted SAED pattern, and compare
the errors of the fits.

The first step is to locate the ZOLZ reflections that are sufficiently visible
in the SAED patterns. We can do this by applying masks, peak-finding algorithms,
and manual curation. Let's do this for the distorted SAED pattern:

In [None]:
N_y, N_x = sampling_grid_dims_in_pixels



L = 25
R = N_x-335
B = N_y-460
T = 120

rectangular_mask_image = np.zeros((N_y, N_x), dtype=bool)
rectangular_mask_image[T:N_y-B, L:N_x-R] = True

rectangular_mask_signal = hs.signals.Signal2D(data=rectangular_mask_image)



masked_distorted_saed_pattern_signal = (distorted_saed_pattern_signal
                                        * rectangular_mask_signal)

kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_saed_pattern_signal.plot(**kwargs)

kwargs = {"method": "difference_of_gaussian", 
          "overlap": 0, 
          "threshold": 0.0025, 
          "min_sigma": 1,
          "max_sigma": 2,
          "interactive": False, 
          "show_progressbar": False}
find_peaks_result = masked_distorted_saed_pattern_signal.find_peaks(**kwargs)
candidate_peak_locations = find_peaks_result.data[0][:, ::-1].tolist()

candidate_peak_locations += ([335.5, 417.5],
                             [269.5, 112.5],
                             [342.0, 160.0],
                             [354.0, 218.0], 
                             [368.0, 275.5], 
                             [384.0, 341.5])




curation_instructions = {"[287, 162]": None,
                         "[236, 242]": [235.5, 241], 
                         "[262, 357]": [261.5, 357.5], 
                         "[248, 298]": [248.5, 297.5]}

zolz_reflections = tuple()
for candidate_peak_location in candidate_peak_locations:
    candidate_peak_location_as_str = str(candidate_peak_location)
    if candidate_peak_location_as_str in curation_instructions:
        key = candidate_peak_location_as_str
        if curation_instructions[key] is None:
            continue
        else:
            zolz_reflection = curation_instructions[key]
    else:
        zolz_reflection = candidate_peak_location
    
    kwargs = {"color": "black", 
              "sizes": 3, 
              "offsets": zolz_reflection}
    marker = hs.plot.markers.Points(**kwargs)
    distorted_saed_pattern_signal.add_marker(marker, permanent=False)

    zolz_reflections += (zolz_reflection,)

zolz_reflections_of_distorted_saed_pattern = zolz_reflections

Now, let's do this for the undistorted SAED pattern:

In [None]:
L = 50
R = N_x-330
B = N_y-440
T = 129

rectangular_mask_image = np.zeros((N_y, N_x), dtype=bool)
rectangular_mask_image[T:N_y-B, L:N_x-R] = True

rectangular_mask_signal = hs.signals.Signal2D(data=rectangular_mask_image)



masked_undistorted_saed_pattern_signal = (undistorted_saed_pattern_signal
                                          * rectangular_mask_signal)

kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
undistorted_saed_pattern_signal.plot(**kwargs)

kwargs = {"method": "difference_of_gaussian", 
          "overlap": 0, 
          "threshold": 0.0025, 
          "min_sigma": 1,
          "max_sigma": 2,
          "interactive": False, 
          "show_progressbar": False}
find_peaks_result = masked_undistorted_saed_pattern_signal.find_peaks(**kwargs)
candidate_peak_locations = find_peaks_result.data[0][:, ::-1].tolist()

candidate_peak_locations += ([334.5, 402.5],
                             [261.5, 119.5],
                             [332.0, 161.0],
                             [347.5, 217.5],
                             [362.5, 274.0],
                             [377.5, 332.0])



curation_instructions = {"[282, 164]": None,
                         "[235, 242]": [235, 241], 
                         "[263, 356]": [262.5, 355.5], 
                         "[248, 298]": [248.5, 298], 
                         "[305, 287]": [305.5, 286.5], 
                         "[137, 321]": [136.5, 320.5]}

zolz_reflections = tuple()
for candidate_peak_location in candidate_peak_locations:
    candidate_peak_location_as_str = str(candidate_peak_location)
    if candidate_peak_location_as_str in curation_instructions:
        key = candidate_peak_location_as_str
        if curation_instructions[key] is None:
            continue
        else:
            zolz_reflection = curation_instructions[key]
    else:
        zolz_reflection = candidate_peak_location
    
    kwargs = {"color": "black", 
              "sizes": 3, 
              "offsets": zolz_reflection}
    marker = hs.plot.markers.Points(**kwargs)
    undistorted_saed_pattern_signal.add_marker(marker, permanent=False)

    zolz_reflections += (zolz_reflection,)

zolz_reflections_of_undistorted_saed_pattern = zolz_reflections

Now we need to perform the fits. The objective function that we will minimize is
the mean of the Euclidean distances squared between the ZOLZ reflections and
their corresponding points on the square lattice fit:

In [None]:
def objective(x, zolz_reflections):
    u_O_x, u_O_y, b, theta = x

    # u_0_x: fractional horizontal coordinate of origin of square lattice fit.
    # u_0_y: fractional vertical coordinate of origin of square lattice fit.
    # b: length of primitive lattice vector.
    # theta: rotation applied to lattice.

    N = N_y

    result = 0.0

    for (k_x, k_y) in zolz_reflections:
        to_round = ((k_x-u_O_x)*np.cos(theta) + (k_y-u_O_y)*np.sin(theta)) / b
        rounded = np.round(to_round)
        result += (to_round-rounded)**2

        to_round = (-(k_x-u_O_x)*np.sin(theta) + (k_y-u_O_y)*np.cos(theta)) / b
        rounded = np.round(to_round)
        result += (to_round-rounded)**2

    result *= ((b/N)**2) / len(zolz_reflections)

    return result

We define the fitting error to be the square root of the objective function.

Let's fit the ZOLZ reflections of the distorted SAED pattern:

In [None]:
zolz_reflections = np.array(zolz_reflections_of_distorted_saed_pattern)
differences = zolz_reflections[1:]-zolz_reflections[0]



u_O_x_guess = 235.5
u_O_y_guess = 241
b_guess = np.linalg.norm(differences, axis=(1,)).min().item()
theta_guess = np.arctan2(299-u_O_y_guess, 248-u_O_x_guess)

initial_guesses = (u_O_x_guess,
                   u_O_y_guess,
                   b_guess,
                   theta_guess)



u_O_x_bounds = (25, 335)
u_O_y_bounds = (120, 460)
b_bounds = (b_guess-5, b_guess+5)
theta_bounds = (0.9*theta_guess, 1.1*theta_guess)

bounds = (u_O_x_bounds,
          u_O_y_bounds,
          b_bounds,
          theta_bounds)



kwargs = {"fun": objective,
          "args": (zolz_reflections,),
          "x0": initial_guesses,
          "bounds": bounds}
minimization_result = scipy.optimize.minimize(**kwargs)



u_O_x, u_O_y, b, theta = minimization_result.x

u_O = np.array((u_O_x, u_O_y))
b_1 = b*np.array((np.cos(theta), np.sin(theta)))
b_2 = b*np.array((-np.sin(theta), np.cos(theta)))

M = 3

lattice_positions = tuple()
for m_1 in range(-M, M+1):
    for m_2 in range(-M, M+1):
        lattice_position = (u_O + m_1*b_1 + m_2*b_2).tolist()
        lattice_positions += (lattice_position,)

        

kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_saed_pattern_signal.plot(**kwargs)

for lattice_position in lattice_positions:
    kwargs = {"color": "black", 
              "sizes": 3, 
              "offsets": lattice_position}
    marker = hs.plot.markers.Points(**kwargs)
    distorted_saed_pattern_signal.add_marker(marker, permanent=False)



error = np.sqrt(minimization_result.fun)
msg = "The error of the fit is: {}, in units of the image width.".format(error)
print(msg)

The black dots in the figure directly above form the best square lattice fit.

Now let's do the same fitting procedure for the undistorted SAED pattern:

In [None]:
zolz_reflections = np.array(zolz_reflections_of_undistorted_saed_pattern)
differences = zolz_reflections[1:]-zolz_reflections[0]



u_O_x_guess = 235
u_O_y_guess = 241
b_guess = np.linalg.norm(differences, axis=(1,)).min().item()
theta_guess = np.arctan2(298-u_O_y_guess, 249-u_O_x_guess)

initial_guesses = (u_O_x_guess,
                   u_O_y_guess,
                   b_guess,
                   theta_guess)



u_O_x_bounds = (50, 330)
u_O_y_bounds = (129, 440)
b_bounds = (b_guess-5, b_guess+5)
theta_bounds = (0.9*theta_guess, 1.1*theta_guess)

bounds = (u_O_x_bounds,
          u_O_y_bounds,
          b_bounds,
          theta_bounds)



kwargs = {"fun": objective,
          "args": (zolz_reflections,),
          "x0": initial_guesses,
          "bounds": bounds}
minimization_result = scipy.optimize.minimize(**kwargs)



u_O_x, u_O_y, b, theta = minimization_result.x

u_O = np.array((u_O_x, u_O_y))
b_1 = b*np.array((np.cos(theta), np.sin(theta)))
b_2 = b*np.array((-np.sin(theta), np.cos(theta)))

M = 3

lattice_positions = tuple()
for m_1 in range(-M, M+1):
    for m_2 in range(-M, M+1):
        lattice_position = (u_O + m_1*b_1 + m_2*b_2).tolist()
        lattice_positions += (lattice_position,)

        

kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
undistorted_saed_pattern_signal.plot(**kwargs)

for lattice_position in lattice_positions:
    kwargs = {"color": "black", 
              "sizes": 3, 
              "offsets": lattice_position}
    marker = hs.plot.markers.Points(**kwargs)
    undistorted_saed_pattern_signal.add_marker(marker, permanent=False)



error = np.sqrt(minimization_result.fun)
msg = "The error of the fit is: {}, in units of the image width.".format(error)
print(msg)

As we can see both visually and from the lattice fit errors, our ML approach
corrects an appreciable amount of the distortion in the SAED pattern.