# Analysis of microscopy images

author: ericmuckley@gmail.com

## Install and import libraries

First install ***ncem***, a package for electron microscopy analysis
supported by the National Center for Electron Microscopy at the
Molecular Foundry. Code is avilable here:
https://github.com/ercius/openNCEM

In [0]:
!pip install ncempy

import numpy as np
from glob import glob
from time import time
from ncempy.io import dm
import scipy.fftpack as fp
import matplotlib.pyplot as plt



## Import all the images from Github

In [0]:
# clone the entire github repository where the data file is located
%cd /content/
!rm -rf cloned-data
!git clone -l -s git://github.com/ericmuckley/datasets.git cloned-data
# navigate to the repo
%cd cloned-data/2020-01-24_CeO2ZnO2_on_graphene/

# set path of folder which contains image files
folder_path = "/content/cloned-data/2020-01-24_CeO2ZnO2_on_graphene/*.*"
files = glob(folder_path)
print('\n-----------------------------\nFound {} files:'.format(len(files)))
for f in files:
    print(f.split('/')[-1])

/content
Cloning into 'cloned-data'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 224 (delta 2), reused 14 (delta 2), pack-reused 209[K
Receiving objects: 100% (224/224), 241.24 MiB | 31.77 MiB/s, done.
Resolving deltas: 100% (81/81), done.
Checking out files: 100% (59/59), done.
/content/cloned-data/2020-01-24_CeO2ZnO2_on_graphene

-----------------------------
Found 38 files:
1-CW_10%_1s.dm3
3-pristie.dm3
3-CW-16%10s.dm3
1-SAED_as_dep-b.dm3
1-SAED_as_dep.dm3
3-CW-15%3s.dm3
3-CW15%.dm3
3-SI_Survey_Image.dm3
1-as_dep_31.5K.dm3
1-CW_10%_1s-c.dm3
1-CW_10%_1s-b.dm3
3-spot_HAADF.dm3
3-spot_EELS-2.dm3
3-CW-18%5s.dm3
3-CW-17.5%80s.dm3
2-CW18%.dm3
3-CW-17.5%10s.dm3
3-SAED_after_18%.dm3
3-CW-17.5%140s.dm3
3-HAADF.dm3
2-CW19%.dm3
1-400K_as_dep-b.dm3
2-CW19%-2.dm3
2-pristine.dm3
3-CW-17.5%260s.dm3
3-Analog.dm3
1-CW_10%_1s-SAED.dm3
2-13-110ms.dm3
3-EELS_Spectrum_Image.dm3
3-scanning_

## Define some functions

In [0]:
# change matplotlib settings to make plots look nicer
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.linewidth'] = 3
plt.rcParams['xtick.minor.width'] = 3
plt.rcParams['xtick.major.width'] = 3
plt.rcParams['ytick.minor.width'] = 3
plt.rcParams['ytick.major.width'] = 3


def plot_setup(labels=['X', 'Y'], fsize=18, setlimits=False,
               limits=[0,1,0,1], title='',
               legend=False, save=False, filename='plot.jpg'):
    """Creates a custom plot configuration to make graphs look nice.
    This can be called with matplotlib for setting axes labels,
    titles, axes ranges, and the font size of plot labels.
    This should be called between plt.plot() and plt.show() commands."""
    plt.xlabel(str(labels[0]), fontsize=fsize)
    plt.ylabel(str(labels[1]), fontsize=fsize)
    plt.title(title, fontsize=fsize)
    fig = plt.gcf()
    fig.set_size_inches(6, 6)
    if legend:
        plt.legend(fontsize=fsize-4)
    if setlimits:
        plt.xlim((limits[0], limits[1]))
        plt.ylim((limits[2], limits[3]))
    if save:
        fig.savefig(filename, dpi=120, bbox_inches='tight')
        plt.tight_layout()


def scale_array(arr, lim=(0, 1)):
    """Scale values of an array inside new limits."""
    scale = lim[1] - lim[0]
    arr_scaled = scale*(arr-np.min(arr))/(np.max(arr)-np.min(arr))+lim[0]
    return arr_scaled


def array_phase(x, y):
    """Get average phase angle from arrays of x and y values."""
    x, y = np.median(np.abs(x)), np.median(np.abs(y))
    phase = np.abs(np.arctan(np.divide(y, x)))
    return phase


def array_mag(x, y):
    """Get magnitudes from arrays of x and y values."""
    mag = np.sqrt(np.median(np.square(x)) + np.median(np.square(y)))
    return mag


def std_lims(arr, n=1):
    """Create limits for a plot scale based on n number of standard
    deviations the scale should stretch beyond the mean of an array."""
    avg = np.mean(arr)
    std = np.std(arr)
    return (avg-n*std, avg+n*std)


def get_sample_grid(data, samp_num):
    """Get a grid of pixel coordinates at which to position a sliding
    window across the entire image."""
    samp_origin_x = np.arange(0, len(data['data'][0]), samp_num)
    samp_origin_y = np.arange(0, len(data['data']), samp_num)
    samp_grid = np.array(np.meshgrid(
        samp_origin_x, samp_origin_y)).T.reshape(-1,2)
    return samp_grid


def get_sampling_pixels(data, x0, y0, samp_num, oversamp_num):
    """Get pixels to sample when oversampling a 2D array.
    samp_num = number of pixels to smaple
    oversamp_num = number of pixels to oversample
    x0, y0 = origin of sampling window
    Returns pixel slices for sampling and oversampling."""
    x1, y1 = x0 + samp_num, y0 + samp_num
    over_x0, over_x1 = np.clip((
        x0-oversamp_num, x0+samp_num+oversamp_num),
            0, len(data['data'][0]))
    over_y0, over_y1 = np.clip((
        y0-oversamp_num, y0+samp_num+oversamp_num),
        0, len(data['data']))
    os_slice = np.s_[over_y0:over_y1, over_x0:over_x1]
    s_slice = np.s_[y0:y1, x0:x1]
    return s_slice, os_slice


def read_img_data(filename):
    """Read dm3 file and get image information."""
    data = dm.dmReader(filename)
    x_span = data['pixelSize'][0]*len(data['data'][0])
    y_span = data['pixelSize'][1]*len(data['data'])
    data['label'] = filename.split('/')[-1].split('.dm3')[0]
    data['span'] = (0, x_span, 0, y_span)
    return data


def run_fft_filter(img, plot_fft=False):
    """Filter out high and low freuqnecies from an image using 2D FFT.""" 
    # calculate the 2D FFT
    fft = fp.fftshift(fp.fft2((img).astype(float)))
    # get image shape
    (w, h) = np.shape(img)
    h_lim, w_lim = int(h/6), int(w/6)
    # filter out high and low frequencies
    fft2 = np.copy(fft)
    n = int(w/50)
    fft2[int(w/2)-n:int(w/2)+n, int(h/2)-n:int(h/2)+n] = 0
    fft2[:h_lim] = 0
    fft2[-h_lim:] = 0
    fft2[:, :w_lim] = 0
    fft2[:, -w_lim:] = 0
    # get filtered image
    img2 = np.real(fp.ifft2(fp.ifftshift(fft2)))
    if plot_fft:
        # plot the 2D FFt
        plt.imshow((20*np.log10( 0.1 + fft)).astype(int),
                   origin='lower', cmap='gray')
        plot_setup(title='FFT of entire image',
                   labels=['Distance (nm)', 'Distance (nm)'])
        plt.show()
        # plot filtered FFT
        plt.imshow((20*np.log10( 0.1 + fft2)).astype(int),
            origin='lower', cmap='gray', extent=data['span'])
        plot_setup(title='FFT bandpass filtered',
                   labels=['Distance (nm)', 'Distance (nm)'])
        plt.show()
        # plot filtered image
        plt.imshow(img2, origin='lower', cmap='gray', extent=data['span'])
        plot_setup(title='Reconstructed filtered image',
                   labels=['Distance (nm)', 'Distance (nm)'])
        plt.show()
    return img2

def map_domains(data, samp_num, oversamp_num):
    """Map the domains using a snumber of sampling pixels, and a number
    of oversampling pixels, which buffer the sampling pixels.
    The sliding window is rastered over the image statistics
    are collected over each window of pixels, enabling compilation of
    the statistics into a map to be overlaid on the original image."""
    # intialize empty maps to hold domain orientation information
    remap, keys = {}, ['grad_angle',
                       'grad_mag',
                       'fft_peak_dist',
                       'fft_peak_angle']
    remap = {k: np.zeros_like(data['data']).astype(float) for k in keys}
  
    # loop sliding window over entire image
    sample_grid = get_sample_grid(data, samp_num)
    for x0, y0 in sample_grid:
        # get sampling area, clipping if it lies outside image range
        s_slice, os_slice = get_sampling_pixels(data, x0, y0,
                                                samp_num, oversamp_num)
        # check if FFT-filtered image data exists. If not, use raw image
        if 'filtered' in list(data.keys()):
            img0 = data['filtered'][os_slice]
        else:
            img0 = data['data'][os_slice]

        # perform 2D FFT
        fft_complex = fp.fftshift(fp.fft2((img0).astype(float)))
        fft_re = np.abs(np.real(fft_complex))
        fft_im = np.abs(np.imag(fft_complex))
        # filter out center bright peak
        fft_re[int(samp_num/2), int(samp_num/2)] = 0
        indices = np.argpartition(fft_re.flatten(), -1)[-1:]
        peak_idx = np.vstack(np.unravel_index(indices, fft_re.shape)).T[0] 
        peak_loc = np.subtract(int(samp_num/2), [peak_idx[1], peak_idx[0]])
        fft_peak_dist = np.linalg.norm(peak_loc)
        fft_peak_angle = np.arctan(peak_loc[1]/peak_loc[0])

        # save stats of sampling window
        grad_y, grad_x = np.abs(np.gradient(img0))
        remap['grad_angle'][s_slice] = array_phase(grad_x, grad_y)
        remap['grad_mag'][s_slice] = array_mag(grad_x, grad_y)
        remap['fft_peak_dist'][s_slice] = fft_peak_dist
        remap['fft_peak_angle'][s_slice] = fft_peak_angle

    # find values with weak signal
    weak = np.where(remap['grad_mag'] < (
        np.mean(remap['grad_mag'] - np.std(remap['grad_mag']))))
    # scale final result maps
    for r in remap:
        #remap[r][weak] = np.nan
        remap[r] = scale_array(np.ma.array(
            remap[r], mask=np.isnan(remap[r])))
    return remap

## Loop over each image and perform analysis

User-defined parameters:
* samp_num: number of pixels to use for sampling window
* oversamp_num: number of pixels to oversample outside of sampling window
* perform_fft_filtering: using FFT filtering to filter raw image
prior to analysis

In [0]:
samp_num = 20
oversamp_num = 50
perform_fft_filtering = True

# get list of image files to examine
images = [f for f in files if 'CW' in f.upper() and 'SAED' not in f.upper()]  

# loop over each image
script_start_time = time()
for filename in images:#[5:6]:

    # read file and plot raw image
    image_start_time = time()
    data = read_img_data(filename)
    print('\n---------------------------------------------------------')
    print('File: {}'.format(data['label']))
    plt.imshow(data['data'], origin='lower', cmap='gray', extent=data['span'])
    plot_setup(title=data['label'], labels=['Distance (nm)', 'Distance (nm)'])
    plt.show()

    if perform_fft_filtering:
        data['filtered'] = run_fft_filter(data['data'], plot_fft=True)

    # plot caculated maps on top of original image
    remap = map_domains(data, samp_num, oversamp_num)
    for r in list(remap.keys()):
        # plot raw image
        plt.imshow(data['data'], origin='lower',
                   cmap='gray', extent=data['span'])
        # plot the calculated result
        scale_lim = std_lims(remap[r])
        plt.imshow(remap[r], origin='lower', cmap='jet',
                   zorder=2, extent=data['span'], alpha=0.2,
                   #vmin=scale_lim[0], vmax=scale_lim[1]
                   )
        plot_setup(title=data['label']+'\n'+r,
                labels=['Distance (nm)', 'Distance (nm)'])
        plt.colorbar()
        plt.show()

    print('image runtime: {} s'.format(round(time() - image_start_time)))
    print('total runtime: {} s'.format(round(time() - script_start_time)))


Output hidden; open in https://colab.research.google.com to view.