# Module: braggdiskdetection

This module contains functions finding the positions of the Bragg disks in a 4DSTEM scan.  Generally this will involve two steps: getting a vacuum probe, then finding the Bragg disks using the vacuum probe as a template. 

## Submodule: diskdetection

The notebook demos functions related to finding the Bragg disks.  Using a vacuum probe as a template - i.e. a convolution kernel - a cross correlation (or phase or hybrid correlation) is taken between each DP and the template, and the positions and intensities of all local corraltion maxima are used to identify the Bragg disks.  Erroneous peaks are filtered out with several types of threshold.  Detected Bragg disks are generally stored in PointLists (when run on only selected DPs) or PointListArrays (when run on a full DataCube).

This notebook demos:
* Disk detection on single or selected diffraction patterns
* Disk detection on all diffraction patterns
* Additional filtering of detected Bragg disks

### Import packages, load data

In [None]:
import py4DSTEM

import numpy as np
import matplotlib.pyplot as plt

from py4DSTEM.process.braggdiskdetection import find_Bragg_disks_single_DP
from py4DSTEM.process.braggdiskdetection import find_Bragg_disks_selected
from py4DSTEM.process.braggdiskdetection import find_Bragg_disks
from py4DSTEM.process.braggdiskdetection import threshold_Braggpeaks


In [None]:
from time import time
import ipyparallel as ipp

from py4DSTEM.process.braggdiskdetection import find_Bragg_disks_single_DP_FK
from py4DSTEM.process.braggdiskdetection import PointListArray
from py4DSTEM.process.braggdiskdetection import print_progress_bar

In [None]:
# Load data
#fp = "/home/ben/Data/20180905_FePO4_unlithiated/raw/Stack1_57x47+30nmss_spot 8_0p05s_CL=600_alpha=0p48_300kV_bin4.dm4"
#fp = "/Users/Ben/Work/NCEM/Projects/py4DSTEM/sample_data/20180905_FePO4_unlithiated/Stack2_60x60+30nmss_spot 8_0p05s_CL=600_alpha=0p48_300kV_bin4.dm3"

fp = "/global/u2/m/mhenders/ncem/Stack2_60x60+30nmss_spot 8_0p05s_CL=600_alpha=0p48_300kV_bin4.h5"
dc = py4DSTEM.file.readwrite.read(fp)
dc.set_scan_shape(47,57)
dc.data4D = np.roll(dc.data4D,-2,1) # Correct for acquisition wrap-around error

# Load the template
#fp_probetemplate = "/home/ben/Data/20180905_FePO4_unlithiated/processing/vacuum_probe_kernel.h5"
#fp = "/Users/Ben/Work/NCEM/Projects/py4DSTEM/sample_data/20180905_FePO4_unlithiated/processing/vacuum_probe_kernel.h5"
fp_probetemplate = "/global/u2/m/mhenders/ncem/vacuum_probe_kernel.h5"
browser = py4DSTEM.file.readwrite.FileBrowser(fp_probetemplate, rawdatacube=dc)
browser.show_dataobjects()
probe_kernel = browser.get_dataobject(0)

#### Single DP

In [None]:
# Select a DP

Rx=20
Ry=25
power=0.3

DP = dc.data4D[Rx,Ry,:,:]

fig,(ax1,ax2)=plt.subplots(1,2,figsize=(12,12))
ax1.matshow(np.average(dc.data4D,axis=(2,3)))
ax1.scatter(Ry,Rx,color='r')
ax2.matshow(DP**power)
ax1.axis('off')
ax2.axis('off')
plt.show()

In [None]:
# Get peaks

corrPower = 0.9
sigma = 2
edgeBoundary = 20
maxNumPeaks = 70
minPeakSpacing = 30
minRelativeIntensity = 0.005

peaks = find_Bragg_disks_single_DP(DP, probe_kernel.data2D,
                                   corrPower=corrPower,
                                   sigma=sigma,
                                   edgeBoundary=edgeBoundary,
                                   minRelativeIntensity=minRelativeIntensity,
                                   minPeakSpacing=minPeakSpacing,
                                   maxNumPeaks=maxNumPeaks)

In [None]:
# Show

power=0.3
size_scale_factor = 500       # Set to zero to make all points the same size

fig,(ax1,ax2)=plt.subplots(1,2,figsize=(12,12))
ax1.matshow(np.average(dc.data4D,axis=(2,3)))
ax1.scatter(Ry,Rx,color='r')
ax2.matshow(DP**power)
ax2.scatter(peaks.data['qy'],peaks.data['qx'],color='r',s=size_scale_factor*peaks.data['intensity']/np.max(peaks.data['intensity']))
ax1.axis('off')
ax2.axis('off')
plt.show()

#### Several DPs

In [None]:
# Select a few DPs

Rxs=(20,31,18)
Rys=(25,31,10)
power=0.3

fig,((ax11,ax12),(ax21,ax22))=plt.subplots(2,2,figsize=(12,12))
ax11.matshow(np.average(dc.data4D,axis=(2,3)))
ax11.scatter(Rys,Rxs,color=('r','yellow','deepskyblue'))
ax12.matshow(dc.data4D[Rxs[0],Rys[0],:,:]**power)
ax21.matshow(dc.data4D[Rxs[1],Rys[1],:,:]**power)
ax22.matshow(dc.data4D[Rxs[2],Rys[2],:,:]**power)

ax11.axis('off')
ax12.axis('off')
ax21.axis('off')
ax22.axis('off')
plt.show()

In [None]:
# Get peaks

corrPower = 0.8
sigma = 2
edgeBoundary = 20
maxNumPeaks = 70
minPeakSpacing = 50
minRelativeIntensity = 0.001

peaks = find_Bragg_disks_selected(dc, probe_kernel.data2D, Rxs, Rys,
                                  corrPower=corrPower,
                                  sigma=sigma,
                                  edgeBoundary=edgeBoundary,
                                  minRelativeIntensity=minRelativeIntensity,
                                  minPeakSpacing=minPeakSpacing,
                                  maxNumPeaks=maxNumPeaks)

In [None]:
# Show

power=0.3
size_scale_factor = 500       # Set to zero to make all points the same size

fig,((ax11,ax12),(ax21,ax22))=plt.subplots(2,2,figsize=(12,12))
ax11.matshow(np.average(dc.data4D,axis=(2,3)))
ax11.scatter(Rys,Rxs,color=('r','g','b'))
ax12.matshow(dc.data4D[Rxs[0],Rys[0],:,:]**power)
ax21.matshow(dc.data4D[Rxs[1],Rys[1],:,:]**power)
ax22.matshow(dc.data4D[Rxs[2],Rys[2],:,:]**power)

if size_scale_factor == 0:
    ax12.scatter(peaks[0].data['qy'],peaks[0].data['qx'],color='r')
    ax21.scatter(peaks[1].data['qy'],peaks[1].data['qx'],color='g')
    ax22.scatter(peaks[2].data['qy'],peaks[2].data['qx'],color='b')
else:
    ax12.scatter(peaks[0].data['qy'],peaks[0].data['qx'],color='r',s=size_scale_factor*peaks[0].data['intensity']/np.max(peaks[0].data['intensity']))
    ax21.scatter(peaks[1].data['qy'],peaks[1].data['qx'],color='g',s=size_scale_factor*peaks[1].data['intensity']/np.max(peaks[1].data['intensity']))
    ax22.scatter(peaks[2].data['qy'],peaks[2].data['qx'],color='b',s=size_scale_factor*peaks[2].data['intensity']/np.max(peaks[2].data['intensity']))


ax11.axis('off')
ax12.axis('off')
ax21.axis('off')
ax22.axis('off')
plt.show()

In [None]:
@ipp.require(
    'numpy', 
    'scipy.ndimage.filters',
    'py4DSTEM.file.datastructure', 
    'py4DSTEM.process.utils'
    )
def _find_Bragg_disks_single_DP_FK(DP, probe_kernel_FT,
                                  corrPower = 1,
                                  sigma = 2,
                                  edgeBoundary = 20,
                                  minRelativeIntensity = 0.005,
                                  minPeakSpacing = 60,
                                  maxNumPeaks = 70,
                                  return_cc = False,
                                  peaks = None):
    """
    Finds the Bragg disks in DP by cross, hybrid, or phase correlation with probe_kernel_FT.

    After taking the cross/hybrid/phase correlation, a gaussian smoothing is applied
    with standard deviation sigma, and all local maxima are found. Detected peaks within
    edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks with
    intensities less than minRelativeIntensity of the brightest peak in the correaltion are
    discarded. Then peaks which are within a distance of minPeakSpacing of their nearest neighbor
    peak are found, and in each such pair the peak with the lesser correlation intensities is
    removed. Finally, if the number of peaks remaining exceeds maxNumPeaks, only the maxNumPeaks
    peaks with the highest correlation intensity are retained.

    IMPORTANT NOTE: the argument probe_kernel_FT is related to the probe kernels generated by
    functions like get_probe_kernel() by:

            probe_kernel_FT = np.conj(np.fft.fft2(probe_kernel))

    if this function is simply passed a probe kernel, the results will not be meaningful! To run
    on a single DP while passing the real space probe kernel as an argument, use
    find_Bragg_disks_single_DP().

    Accepts:
        DP                   (ndarray) a diffraction pattern
        probe_kernel_FT      (ndarray) the vacuum probe template, in Fourier space. Related to the
                             real space probe kernel by probe_kernel_FT = F(probe_kernel)*, where F
                             indicates a Fourier Transform and * indicates complex conjugation.
        corrPower            (float between 0 and 1, inclusive) the cross correlation power. A
                             value of 1 corresponds to a cross correaltion, and 0 corresponds to a
                             phase correlation, with intermediate values giving various hybrids.
        sigma                (float) the standard deviation for the gaussian smoothing applied to
                             the cross correlation
        edgeBoundary         (int) minimum acceptable distance from the DP edge, in pixels
        minRelativeIntensity (float) the minimum acceptable correlation peak intensity, relative to
                             the intensity of the brightest peak
        minPeakSpacing       (float) the minimum acceptable spacing between detected peaks
        maxNumPeaks          (int) the maximum number of peaks to return
        return_cc            (bool) if True, return the cross correlation
        peaks                (PointList) For internal use.
                             If peaks is None, the PointList of peak positions is created here.
                             If peaks is not None, it is the PointList that detected peaks are added
                             to, and must have the appropriate coords ('qx','qy','intensity').

    Returns:
        peaks                (PointList) the Bragg peak positions and correlation intensities
    """
    # Get cross correlation
    cc = py4DSTEM.process.utils.get_cross_correlation_fk(DP, probe_kernel_FT, corrPower)
    cc = numpy.maximum(cc,0)
    cc = scipy.ndimage.filters.gaussian_filter(cc, sigma)

    # Get maxima
    maxima_x,maxima_y = py4DSTEM.process.utils.get_maxima_2D(
        cc, sigma=sigma, edgeBoundary=edgeBoundary,
        minRelativeIntensity=minRelativeIntensity,
        minSpacing=minPeakSpacing, maxNumPeaks=maxNumPeaks)

    # Make peaks PointList
    if peaks is None:
        coords = [('qx',float),('qy',float),('intensity',float)]
        peaks = py4DSTEM.file.datastructure.PointList(coordinates=coords)
    else:
        assert(isinstance(peaks,py4DSTEM.file.datastructure.PointList))
    peaks.add_tuple_of_nparrays((maxima_x,maxima_y,cc[maxima_x,maxima_y]))

    if return_cc:
        return peaks, cc
    else:
        return peaks


In [None]:
def _find_Bragg_disks(datacube, probe,
                     corrPower = 1,
                     sigma = 2,
                     edgeBoundary = 20,
                     minRelativeIntensity = 0.005,
                     minPeakSpacing = 60,
                     maxNumPeaks = 70,
                     verbose = False,
                     view = None):
    """
    Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or phase
    correlation with probe.

    Accepts:
        DP                   (ndarray) a diffraction pattern
        probe                (ndarray) the vacuum probe template, in real space.
        corrPower            (float between 0 and 1, inclusive) the cross correlation power. A
                             value of 1 corresponds to a cross correaltion, and 0 corresponds to a
                             phase correlation, with intermediate values giving various hybrids.
        sigma                (float) the standard deviation for the gaussian smoothing applied to
                             the cross correlation
        edgeBoundary         (int) minimum acceptable distance from the DP edge, in pixels
        minRelativeIntensity (float) the minimum acceptable correlation peak intensity, relative to
                             the intensity of the brightest peak
        minPeakSpacing       (float) the minimum acceptable spacing between detected peaks
        maxNumPeaks          (int) the maximum number of peaks to return
        verbose              (bool) if True, prints completion updates

    Returns:
        peaks                (PointListArray) the Bragg peak positions and correlation intensities
    """
    # Make the peaks PointListArray
    coords = [('qx',float),('qy',float),('intensity',float)]
    peaks = PointListArray(coordinates=coords, shape=(datacube.R_Nx, datacube.R_Ny))

    # Get the probe kernel FT
    probe_kernel_FT = np.conj(np.fft.fft2(probe))

    if view is None:
        # Loop over all diffraction patterns
        t0 = time()
        for Rx in range(datacube.R_Nx):
            for Ry in range(datacube.R_Ny):
                if verbose:
                    print_progress_bar(Rx*datacube.R_Ny+Ry+1, datacube.R_Nx*datacube.R_Ny,
                                       prefix='Analyzing:', suffix='Complete', length=50)
                DP = datacube.data4D[Rx,Ry,:,:]
                _find_Bragg_disks_single_DP_FK(DP, probe_kernel_FT,
                                              corrPower = corrPower,
                                              sigma = sigma,
                                              edgeBoundary = edgeBoundary,
                                              minRelativeIntensity = minRelativeIntensity,
                                              minPeakSpacing = minPeakSpacing,
                                              maxNumPeaks = maxNumPeaks,
                                              peaks = peaks.get_pointlist(Rx,Ry))
        t = time()-t0
        print("Analyzed {} diffraction patterns in {}h {}m {}s".format(datacube.R_N, int(t/3600),
                                                                       int(t/60), int(t%60)))
    else:
        results = []

        # submit all computations
        t0 = time()
        for Rx in range(datacube.R_Nx):
            for Ry in range(datacube.R_Ny):
                if verbose:
                    print_progress_bar(Rx*datacube.R_Ny+Ry+1, datacube.R_Nx*datacube.R_Ny,
                                       prefix='Analyzing:', suffix='Complete', length=50)

                DP = datacube.data4D[Rx,Ry,:,:]
                results.append(
                    view.apply(
                        _find_Bragg_disks_single_DP_FK, 
                        DP, probe_kernel_FT, corrPower, sigma, edgeBoundary, 
                        minRelativeIntensity, minPeakSpacing, maxNumPeaks
                ))

        print("Number of computations: {}".format(len(results)))
        
        # collect results
        i = 0
        for Rx in range(datacube.R_Nx):
            for Ry in range(datacube.R_Ny):
                peaks.get_pointlist(Rx, Ry).data = results[i].get().data.copy()
                i += 1
        t = time()-t0
        print("Analyzed {} diffraction patterns in {}h {}m {}s".format(datacube.R_N, int(t/3600),
                                                                       int(t/60), int(t%60)))

    return peaks


In [None]:
c = ipp.Client(cluster_id="cori_20402773")
lbv = c.load_balanced_view()

#### All DPs

In [None]:
%time
# Get peaks

corrPower = 0.8
sigma = 2
edgeBoundary = 20
maxNumPeaks = 70
minPeakSpacing = 50
minRelativeIntensity = 0.001
verbose = False

peaks = _find_Bragg_disks(dc, probe_kernel.data2D,
                         corrPower=corrPower,
                         sigma=sigma,
                         edgeBoundary=edgeBoundary,
                         minRelativeIntensity=minRelativeIntensity,
                         minPeakSpacing=minPeakSpacing,
                         maxNumPeaks=maxNumPeaks,
                         verbose=verbose,
                         view=lbv)

In [None]:
# Show

Rxs=(20,31,18)
Rys=(25,31,10)
power=0.3
size_scale_factor = 500       # Set to zero to make all points the same size

fig,((ax11,ax12),(ax21,ax22))=plt.subplots(2,2,figsize=(12,12))
ax11.matshow(np.average(dc.data4D,axis=(2,3)))
ax11.scatter(Rys,Rxs,color=('r','g','b'))
ax12.matshow(dc.data4D[Rxs[0],Rys[0],:,:]**power)
ax21.matshow(dc.data4D[Rxs[1],Rys[1],:,:]**power)
ax22.matshow(dc.data4D[Rxs[2],Rys[2],:,:]**power)

peaks0 = peaks.get_pointlist(Rxs[0],Rys[0])
peaks1 = peaks.get_pointlist(Rxs[1],Rys[1])
peaks2 = peaks.get_pointlist(Rxs[2],Rys[2])
if size_scale_factor == 0:
    ax12.scatter(peaks0.data['qy'],peaks0.data['qx'],color='r')
    ax21.scatter(peaks1.data['qy'],peaks1.data['qx'],color='g')
    ax22.scatter(peaks2.data['qy'],peaks2.data['qx'],color='b')
else:
    ax12.scatter(peaks0.data['qy'],peaks0.data['qx'],color='r',s=size_scale_factor*peaks0.data['intensity']/np.max(peaks0.data['intensity']))
    ax21.scatter(peaks1.data['qy'],peaks1.data['qx'],color='g',s=size_scale_factor*peaks1.data['intensity']/np.max(peaks1.data['intensity']))
    ax22.scatter(peaks2.data['qy'],peaks2.data['qx'],color='b',s=size_scale_factor*peaks2.data['intensity']/np.max(peaks2.data['intensity']))

ax11.axis('off')
ax12.axis('off')
ax21.axis('off')
ax22.axis('off')
plt.show()

#### Apply post-detection thresholding

In [None]:
# Remove points based on new peak spacing or minimum relative intensity thresholds

maxNumPeaks = 20
minPeakSpacing = 50
minRelativeIntensity = 0.01

peaks_thresh = peaks.copy(name='Braggpeaks')  # Create a copy of the PointListArray to further threshold
peaks_thresh = threshold_Braggpeaks(peaks_thresh,
                                    minRelativeIntensity=minRelativeIntensity,
                                    minPeakSpacing=minPeakSpacing,
                                    maxNumPeaks=maxNumPeaks)

In [None]:
# Show

Rxs=(20,31,18)
Rys=(25,31,10)
power=0.3
size_scale_factor = 500       # Set to zero to make all points the same size

fig,((ax11,ax12),(ax21,ax22))=plt.subplots(2,2,figsize=(12,12))
ax11.matshow(np.average(dc.data4D,axis=(2,3)))
ax11.scatter(Rys,Rxs,color=('r','g','b'))
ax12.matshow(dc.data4D[Rxs[0],Rys[0],:,:]**power)
ax21.matshow(dc.data4D[Rxs[1],Rys[1],:,:]**power)
ax22.matshow(dc.data4D[Rxs[2],Rys[2],:,:]**power)

peaks0 = peaks_thresh.get_pointlist(Rxs[0],Rys[0])
peaks1 = peaks_thresh.get_pointlist(Rxs[1],Rys[1])
peaks2 = peaks_thresh.get_pointlist(Rxs[2],Rys[2])
if size_scale_factor == 0:
    ax12.scatter(peaks0.data['qy'],peaks0.data['qx'],color='r')
    ax21.scatter(peaks1.data['qy'],peaks1.data['qx'],color='g')
    ax22.scatter(peaks2.data['qy'],peaks2.data['qx'],color='b')
else:
    ax12.scatter(peaks0.data['qy'],peaks0.data['qx'],color='r',s=size_scale_factor*peaks0.data['intensity']/np.max(peaks0.data['intensity']))
    ax21.scatter(peaks1.data['qy'],peaks1.data['qx'],color='g',s=size_scale_factor*peaks1.data['intensity']/np.max(peaks1.data['intensity']))
    ax22.scatter(peaks2.data['qy'],peaks2.data['qx'],color='b',s=size_scale_factor*peaks2.data['intensity']/np.max(peaks2.data['intensity']))

ax11.axis('off')
ax12.axis('off')
ax21.axis('off')
ax22.axis('off')
plt.show()