# Analysis of a TEM image of a grain boundary
* <a href="#download">Downloading data</a>
* <a href="#open">Exploring the data</a>
* <a href="#analysis">Analysis of the dataset</a>

### <a id="download"></a>1. Downloading the data if necessary and checking the hash
This we do using bash as it's a bit easier than with python and could also be done in a shell outside the jupyter notebook.

In [None]:
%%bash

# create the data folder in the current folder if it doesn't exist
if [ ! -d data ]; then
   mkdir data
fi

if [ ! -f data/dataset.emd ]; then
    echo "Data not yet present, downloading"
    wget https://owncloud.gwdg.de/index.php/s/utJfj0388mp8W1S/download -O data/dataset.emd
else
    echo "Data already present"
fi

It's also important to check the hash of a downloaded file to verify it is the file we want.

In [None]:
%%bash

hash=$(sha256sum data/dataset.emd | awk -F" " '{print $1}')
if [ $hash == 777a5f480c5b10288bb9c83c12be440408e7dd25620adc56b5fad27bbfc65d05 ]; then
    echo "Hash is ok"
else
    echo "Wrong hash! This may be the wrong file, or it may have been tampered with."
fi

### <a id="open"></a> 2. Opening and exploring the data
We will work with an `.emd` file, a semi-proprietary HDF5 based file format created by Thermo Fisher as output format for their Velox software.
The details of this file format are not important; we can use hyperspy as a standard interface with many different microscopy data formats in the same way.

In [None]:
import hyperspy.api as hs
import matplotlib.pyplot as plt

The file actually contains a list of different "signals".
Each signal has data shape of 39x512x512, meaning 39 frames of 512x512 images.
BF, HAADF, DF4 and DF2 are individual detectors in the microscope, representing different signals collected during the experiment.
The DCFI datasets are actually datasets derived from the other datasets by image processing in Velox, specifically rigid registration to align the images and correct for drift.

In [None]:
data = hs.load("data/dataset.emd")
data

In [None]:
%matplotlib notebook

In [None]:
data[1].plot()

We see that there is drift during the image acquisition.
We could attempt to align the images ourselves but the DCFI option in Velox does a decent job so we continue with this.
The last frame in the DCFI dataset is usually what we want - it represents the average of all the aligned frames.

In [None]:
image = data[3].inav[-1]
image.plot()

In [None]:
plt.close("all")

We can explore the acquisition conditions, microscope parameters and other metadata found in the original file through the `metadata` and `original_metadata` attribute.
For scale information (related to the axes) we look at the `axes_manager` attribute.

In [None]:
# metadata in the format of hyperspy. Only a subset of all the metadata parsed from the file, but will have the same structure for all different types of files.
image.metadata

In [None]:
# raw metadata as parsed from the file 
image.original_metadata

In [None]:
image.axes_manager

We can index into these structures just like python dictionaries to query information we care about.
For example, convergence angle is a parameter we often care about as this determines the shape of the electron probe and hence the image formation physics.
We might need this for simulations, but in this case we just show it's possible to get the information.

Of course the data needs to be in the file to be able to query it.
If the file format does not store critical acquisition parameters, it's important to keep track of it manually through, for example, electronic lab notebooks.

In [None]:
# data is loaded as strings, we must cast to float
convergence_angle = float(image.original_metadata["Optics"]["BeamConvergence"])*1000  # data is in radians, we usually work with mrad
c2_aperture_size = float(image.original_metadata["Optics"]["Apertures"]["Aperture-1"]["Diameter"])*1e6  # data is in meters, we like to work with micrometers
print("Convergence angle (mrad):", convergence_angle)
print("Aperture radius (micron):", c2_aperture_size)

### <a id="analysis"></a> 3. Perform some atomic column based analysis
Just as a basic example (and to work in some machine learning) let's try to identify the atomic columns in the image and segment them based on nearest neighbor information

In [None]:
import numpy as np
import atomap.api as am
from atomap.atom_finding_refining import _remove_too_close_atoms
import hyperspy.api as hs
from skimage.filters import gaussian
from skimage.exposure import rescale_intensity

First we do some image processing and find the coordinates in the image with peak finding.
Then we refine using center of mass (implemented in atomap)

In [None]:
# we only take a subset of the image near the grain boundary
half_x = image.data.shape[1]//2
half_y = image.data.shape[0]//2
half_wx = 100
half_wy = 200
#sub_image = image.data[half_y-half_wy:half_y+half_wy, half_x-half_wx:half_y+half_wx]
sub_image = image.data[:, half_x-half_wx:half_y+half_wx]
# because we mainly care about the positioning of the hex rings, we smooth and find the centers of the hex rings instead of the individual columns
sub_image_hex = gaussian(sub_image, 0.8)
sub_image_hex = rescale_intensity(-sub_image_hex, out_range=(0, 1))

im_peaks = sub_image_hex
atom_positions = am.get_atom_positions(hs.signals.Signal2D(im_peaks), separation=1)
atom_positions = _remove_too_close_atoms(atom_positions, 5)
sublattice = am.Sublattice(atom_positions, image=im_peaks)
sublattice.find_nearest_neighbors()
sublattice.refine_atom_positions_using_center_of_mass()

In [None]:
print(atom_positions.shape)
fig, ax = plt.subplots(figsize=(5, 10))
ax.imshow(sub_image, cmap="Greys_r", vmax = np.percentile(sub_image, 98), vmin = np.percentile(sub_image, 2))
ax.scatter(sublattice.atom_positions[:,0], sublattice.atom_positions[:,1], s=5, c="red")
ax.axis("off")
fig.tight_layout()

In [None]:
plt.close()

To know how large we should make our search distance we can look at the histogram of the distance of the n'th nearest neighbor

In [None]:
from scipy.spatial import cKDTree
# check the distribution of nearest neighbor spacings
columns = sublattice.atom_positions
tree_c = cKDTree(columns)
# which nearest neighbor
NN = 1
nn_distance = tree_c.query(columns, k=NN+1)[0][:,NN]
print(f"Median {NN}th nearest neighbor distance (pixels): ", np.median(nn_distance))
fig, ax = plt.subplots()
_ = ax.hist(nn_distance, bins=10)

In [None]:
plt.close("all")

Features we might be interested in:
* 0: the average distance of first layer of nearest neighbors
* 1: the standard deviation in the this distance
* 2: the average angle towards consecutive nearest neighbors
* 3: the standard deviation of these angles
* 4: the "rotation" of the voronoi cell by looking at the smallest angle with the horizontal
* 5: the number of nearest neighbors within a certain distance
* 6: the closest nearest neighbor
* 7: the furthest nearest neighbor

In [None]:
def get_nn_features(coordinates, max_distance=12, k=10):
    # A distance of 12 seems reasonable to separate first and second layer atoms
    nearest_neighbor_indices = tree_c.query(coordinates, k=k, distance_upper_bound=max_distance)[1][:, 1:]
    features = np.zeros((coordinates.shape[0], 8), dtype=np.float64)
    for point_index, nns in enumerate(nearest_neighbor_indices):
        # get rid of all the indices that don't correspond to a point
        nns = nns[nns != coordinates.shape[0]]
        # get the coordinates of those points
        nn_coords = coordinates[nns]
        # original coordinates
        point_coords = coordinates[point_index]
        # the difference vectors
        difference_vectors = nn_coords - point_coords
        # length information
        lengths = np.linalg.norm(difference_vectors, axis=1)
        average_length = np.mean(lengths)
        std_length = np.std(lengths)
        # angle information
        angles = np.arctan2(difference_vectors[:,1], difference_vectors[:,0]) % (2*np.pi)
        sorted_angles = np.array(np.sort(angles), angles.min()+2*np.pi)   # add the first angle again + 2*pi because we care about increments
        angle_increments = np.diff(sorted_angles)
        average_increment = np.mean(angle_increments)
        std_increment = np.std(angle_increments)
        min_angle = np.min(angles)
        # number of vertices
        vertices = nns.shape[0]
        # fill in the values into the array
        features[point_index, 0] = average_length
        features[point_index, 1] = std_length
        features[point_index, 2] = average_increment
        features[point_index, 3] = std_increment
        features[point_index, 4] = min_angle
        features[point_index, 5] = vertices
        features[point_index, 6] = lengths.min()
        features[point_index, 7] = lengths.max()
    return features
        
features = get_nn_features(columns, max_distance=11)

We can plot a histogram of the number of nearest neighbors found in the search distance. Most will have 6 (hexagons). The points at the edge will have <5 usually so we discard them.

In [None]:
fig, ax = plt.subplots()
ax.hist(features[:,5], 50)
ax.set_yscale("log")

In [None]:
# we only care about points not on the edge where we have at least 5 nearest neighbors
off_edge = features[:,5] >= 5
filtered_features = features[off_edge]

Below we can plot various features on the image to see their distribution

In [None]:
# plot the raw features to see the effect
FI = 5

fig, ax = plt.subplots(figsize=(5, 10))

ax.imshow(sub_image, cmap="Greys_r", vmax = np.percentile(sub_image, 98), vmin = np.percentile(sub_image, 2))

filtered_coordinates = sublattice.atom_positions[off_edge]
points = ax.scatter(filtered_coordinates[:,0], filtered_coordinates[:,1], s=20, c=filtered_features[:,FI], cmap="viridis")
fig.colorbar(points, ax=ax)

ax.set_xlim(0,sub_image.shape[1])
ax.set_ylim(0,sub_image.shape[0])
ax.axis("off")
fig.tight_layout()

We now want to cluster based on these features. Since they have different ranges we rescale them to a similar range.

In [None]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler
scaler = MinMaxScaler()
scaled_features = scaler.fit_transform(filtered_features)

With PCA we check how significant each of our features are to the principal components.

In [None]:
from sklearn.decomposition import PCA
dimension_reduction = PCA(8)
xy = dimension_reduction.fit_transform(scaled_features)

In [None]:
fig, ax = plt.subplots()
gg = ax.imshow(dimension_reduction.components_, cmap="coolwarm")
fig.colorbar(gg)
ax.set_xlabel("Feature")
ax.set_ylabel("Principal component")

We can then attempt a clustering using various clustering algorithms implemented in scikit learn. Meanshift seemed to give a decent answer.

In [None]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.cluster import DBSCAN, KMeans, SpectralClustering, OPTICS, MeanShift

#cluster_model_1 = KMeans(5)  # optimal kmeans
cluster_model_1 = MeanShift()
#datafit = PolynomialFeatures(2).fit_transform(scaled_features)
# it seems that the most significant features for good clustering are the standard deviation in length, the angle of the cell, and the number of nearest neighbors
datafit = scaled_features[:, [1, 4, 5]]
cluster_model_1.fit(datafit)

In [None]:
fig, ax = plt.subplots()
ax.set_aspect("equal")

dimension_reduction = PCA(2)
xy = dimension_reduction.fit_transform(datafit)
ax.scatter(xy[:,0], xy[:,1], c=cluster_model_1.labels_, cmap="Set1")

In [None]:
fig, ax = plt.subplots(figsize=(5, 10))
ax.imshow(sub_image, cmap="Greys_r", vmax = np.percentile(sub_image, 98), vmin = np.percentile(sub_image, 2))
not_classified = sublattice.atom_positions[np.invert(off_edge)]
ax.scatter(filtered_coordinates[:,0], filtered_coordinates[:,1], s=20, c=cluster_model_1.labels_, cmap="Set1")
ax.scatter(not_classified[:,0], not_classified[:,1], marker="s", s=10, c="black")
ax.set_xlim(0,sub_image.shape[1])
ax.set_ylim(0,sub_image.shape[0])

In [None]:
plt.close("all")

We can now color in the picture according to the nearest cluster index to segment it on a pixel by pixel basis

In [None]:
tree = cKDTree(filtered_coordinates)
Y, X = np.mgrid[0:sub_image.shape[0], 0:sub_image.shape[1]]
coords = np.stack([X.ravel(), Y.ravel()]).T
nearest_point = tree.query(coords)[1]
classification = cluster_model_1.labels_[nearest_point]
labeled_image = classification.reshape(sub_image.shape)

In [None]:
fig, ax = plt.subplots(figsize=(5, 10))
ax.imshow(sub_image, cmap="Greys_r", vmax = np.percentile(sub_image, 98), vmin = np.percentile(sub_image, 2))
ax.imshow(labeled_image, cmap = "Set1", alpha=0.5)
ax.axis("off")
fig.tight_layout()

In [None]:
plt.close("all")

Further cleaning of the image and more playing with the clustering may be able to improve the segmentation