## Image registration: mimumum working example
This notebook provides a sample implementation of the image registration approach described in detail in the paper "Image registration of low signal-to-noise cryo-STEM data" (doi: XXXXXX).  The code herein demonstrates the mimumum working example.  For more in-depth implementations (in particular, enabling more careful user-defined Fourier masking and outlier detection), see the additional two sample notebooks bundled with this package.  For a brief introduction to the approach implemented here, see the README.md file.  For more details, see the original paper.

If this code was a useful resource for your work, please consider citing the associated publication.

In [4]:
# Import global libraries and functions
import numpy as np
import matplotlib.pyplot as plt
from os import listdir
from os.path import splitext, basename
from time import time
from tifffile import imread, imsave
%matplotlib

# Import local libraries
import rigidregistration

Using matplotlib backend: Qt4Agg


In [5]:
# Load data.

# Here we use image stacks saved in .tif format, with 16-bit depth pixels.
# To load data saved in any of the image formats outputted by standard electron microscopy acquisition software (.dm3, .ser, etc.)
# into numpy arrays, we recommend the HyperSpy package (hyperspy.org).

f="/Users/Ben/Work/Data/20170711_Nb3Cl8/tiffs/11_unsavedStack_ZoneAxix3.tif"
stack=imread(f)
stack=np.rollaxis(stack,0,3)       # Restructed 3D array such that final axis iterates over images
stack=stack[:,:,:]/float(2**16)    # For best performance, data should be normalized between 0 and 1; alter this line for different bit-depth images
print("Analyzing {}.".format(f))

Analyzing /Users/Ben/Work/Data/20170711_Nb3Cl8/tiffs/11_unsavedStack_ZoneAxix3.tif.


In [6]:
# Inspect data to ensure it has been correctly loaded.

for i in range(24,28):
    fig,(ax1,ax2)=plt.subplots(1,2)
    ax1.matshow(stack[:,:,i],cmap='gray')
    ax2.matshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(stack[:,:,i])))),cmap='gray',vmin=np.average(np.log(np.abs(np.fft.fft2(stack[:,:,i]))))) 
    plt.show()

In [7]:
# Instantiate imstack object and acquire FFTs.

s=rigidregistration.stackregistration.imstack(stack[:,:,:27])
s.getFFTs()

### Fourier masking
Select a Fourier mask.  

The Fourier mask controls which information is used to align the image pairs.  The smaller the Fourier mask, the more the low-frequency information, i.e. slow variations in the image, is used to determine the correct alignments.  A smaller mask may be useful in cutting out noise, or in avoiding errors such as unit-cell misalignments in crystalline samples.  A mask that is too small may discard important structural information and lead to an imprecise alignment.  Generally, a smaller mask improves the accuracy of the registration, but degrades the precision.  For more information on Fourier mask selection, see the README.md file, or the associated paper.

The parameters used here for Fourier masks are as follows:

    mask:   A string, indicating the functional form i.e. shape of the mask.  
            Options are: "bandpass", "lowpass", "hann", "hamming", "blackman", "gaussian", "none".
    n:      A float, controlling the mask size.  Features smaller than ~n pixels will be smoothed out 
            and ignored during image correlation.

In all cases, the parameter n controls the mask cutoff frequency; features smaller than ~n pixels will be ignored during image correlation.
For data with higher SNR, choosing a mask with n at the information limit is frequently sufficient.
For low-SNR data, choosing a mask with a cutoff frequency near the primary Bragg peaks is often preferable, as this heavily weights low frequency information to avoid unit-cell hops, but ideally contains just enough lattice information to 'lock-in' to the lattice.

Supported apodization functions for makeFourierMask() method are "bandpass", "lowpass", "hann", "hamming", "blackman", "gaussian", "none".
For lattices lacking high rotational symmetry, an anisotropic mask is generally preferable to avoid overweighting one lattice direction.  The makeFourierMask_eg() method creates an elliptical gaussian mask, with parameters n1, n2, and theta, corresponding to cutoff frequencies along the two primary axes and the mask tilt in degrees.

In [7]:
# Mask shape options

masktypes=["bandpass","lowpass","hann","hamming","blackman","gaussian","none"]
n=4

i,j = 5,9   # Choose image pair
for masktype in masktypes:
    s.makeFourierMask(mask=masktype,n=n)
    s.show_Fourier_mask(i=i,j=j)

In [8]:
# Vary cutoff frequencies

masktype="hann"

i,j = 5,9   # Choose image pair
for n in np.arange(2,12,2):
    s.makeFourierMask(mask=masktype,n=n)
    s.show_Fourier_mask(i=i,j=j)

In [9]:
# Make elliptical Gaussian mask

n1=4
n2=2
theta=20

s.makeFourierMask_eg(n1=n1,n2=n2,theta=np.radians(theta))
s.show_Fourier_mask(i=i,j=j)

In [10]:
# Vary elliptical Gaussian angle

n1=4
n2=2
thetas=[90,75,60,45,30,15,0]

i,j = 5,9   # Choose image pair
for theta in thetas:
    s.makeFourierMask_eg(n1=n1,n2=n2,theta=np.radians(theta))
    s.show_Fourier_mask(i=i,j=j)

### Calculate all cross correlations
Calculate the relative shifts between all pairs of images.

findMaxima can be "pixel" or "gf", corresponding to identifying the shift between two images as the maximum pixel in their cross correlation, or fitting gaussians to identify the shift, which both results in subpixel resolution and also handles some sampling errors.

num_peaks sets how many maxima to fit gaussians to; typically 3-5 are sufficient to handle sampling problems.

sigma_guess sets the initial guess for the standard deviation of the guassian fits, in pixels.  May be estimated quickly from the peak widths in the cross correlations or the width of atomic columns in the raw data.

correlationType sets the type of correlation used to determine image shifts. "cc", "mc", or "pc" set the correlation to the cross correlation, mutual correlation, or phase correlation, respectively.  **Note: only 'cc' is currently supported.**

In [11]:
# Make elliptical Gaussian mask

n1=12
n2=9.106
theta=1.42

s.makeFourierMask_eg(n1=n1,n2=n2,theta=theta)
s.show_Fourier_mask(i=i,j=j)

  ax2.matshow(np.log(np.abs(np.fft.fftshift(imstack.fftstack[:,:,i]*np.where(imstack.mask_fourierspace,imstack.mask_fourierspace,0.0001)))), cmap='gray',


In [12]:
findMaxima = 'gf'
s.setGaussianFitParams(num_peaks=5,sigma_guess=3,window_radius=4)

n1=12
n2=9.106
theta=1.4213702504683843
s.makeFourierMask_eg(n1=n1,n2=n2,theta=theta)

# Find all image shifts
t0=time()
s.findImageShifts(correlationType='cc',findMaxima=findMaxima)
t=time()-t0
print("Performed {} correlations in {} minutes {} seconds".format(s.nz*(s.nz-1)/2,int(t/60),t%60))

Correlating images 0 and 1
Correlating images 0 and 2
Correlating images 0 and 3
Correlating images 0 and 4
Correlating images 0 and 5
Correlating images 0 and 6
Correlating images 0 and 7
Correlating images 0 and 8
Correlating images 0 and 9
Correlating images 0 and 10
Correlating images 0 and 11
Correlating images 0 and 12
Correlating images 0 and 13
Correlating images 0 and 14
Correlating images 0 and 15
Correlating images 0 and 16
Correlating images 0 and 17
Correlating images 0 and 18
Correlating images 0 and 19
Correlating images 0 and 20
Correlating images 0 and 21
Correlating images 0 and 22
Correlating images 0 and 23
Correlating images 0 and 24
Correlating images 0 and 25
Correlating images 0 and 26
Correlating images 1 and 2
Correlating images 1 and 3
Correlating images 1 and 4
Correlating images 1 and 5
Correlating images 1 and 6
Correlating images 1 and 7
Correlating images 1 and 8
Correlating images 1 and 9
Correlating images 1 and 10
Correlating images 1 and 11
Correlati

Correlating images 16 and 18
Correlating images 16 and 19
Correlating images 16 and 20
Correlating images 16 and 21
Correlating images 16 and 22
Correlating images 16 and 23
Correlating images 16 and 24
Correlating images 16 and 25
Correlating images 16 and 26
Correlating images 17 and 18
Correlating images 17 and 19
Correlating images 17 and 20
Correlating images 17 and 21
Correlating images 17 and 22
Correlating images 17 and 23
Correlating images 17 and 24
Correlating images 17 and 25
Correlating images 17 and 26
Correlating images 18 and 19
Correlating images 18 and 20
Correlating images 18 and 21
Correlating images 18 and 22
Correlating images 18 and 23
Correlating images 18 and 24
Correlating images 18 and 25
Correlating images 18 and 26
Correlating images 19 and 20
Correlating images 19 and 21
Correlating images 19 and 22
Correlating images 19 and 23
Correlating images 19 and 24
Correlating images 19 and 25
Correlating images 19 and 26
Correlating images 20 and 21
Correlating im

### Identify outliers in shift matrix
Show the shift matrix with show)Rij() method.
If applicable, set range of minimum and maximum images to include with set_nz(nz_min,nz_max) method.
Find outliers with get_outliers(method, *args) method.  method parameter can be "NN", or "transitivity".  "NN" identifies outliers using only nearest neighbor elements of shift matrix, and requires an additional threshold parameter.  "transitivity" identifies outliers by identifying matrix elements which are inconsistent with physical stage positions, which must obey additive transitivity, and requires an additional threshold parameter, and an optional num_paths parameter.

In [13]:
# Show Xij and Yij matrices
s.show_Rij()

In [18]:
# Create mask, defining unuseable data points

s.set_nz(0,27)               # Set min/maz useable image indices
s.get_outliers("transitivity",10,10)       # Set outlier threshhold
s.set_bad_images([])         # Set bad images

s.show_Rij(mask=True)

In [19]:
# Optionally, add or remove points from the outlier mask that were not correctly identified

correct_pairs=[[11,12],[11,13]]
incorrect_pairs=[[1,9]]

for pair in correct_pairs:
    i,j=pair[0],pair[1]
    s.Rij_mask[i,j]=1
    s.Rij_mask[j,i]=1
for pair in incorrect_pairs:
    i,j=pair[0],pair[1]
    s.Rij_mask[i,j]=0
    s.Rij_mask[j,i]=0
    
s.show_Rij()

In [20]:
s.Rij_mask[:5,15]=0
s.Rij_mask[:3,18]=0
s.Rij_mask[:8,24:26]=0
s.Rij_mask[15,:5]=0
s.Rij_mask[18,:3]=0
s.Rij_mask[24:26,:8]=0

s.show_Rij()

### Remove outliers
The make_corrected_Rij() method determines the correct values for outliers in the shift matrices using transitivity, and the show_Rij_c() method shows the corrected shift matrices.

This step is optional, as it is performed automatically when get_average_image() is run if not called manually.
However, manually calling make_corrected_Rij() is recommended so that the final shift matrices can be visually inspected to ensure physical consistency.

In [21]:
s.make_corrected_Rij()
s.show_Rij_c()

### Calculate average image

In [22]:
# Create registered image stack and average

s.get_averaged_image()

In [23]:
# Display final image

s.show()

In [24]:
# Display report of registration procedure

s.show_report()

In [25]:
# Save report of registration procedure

s.save_report("/Users/Ben/Desktop/test_report.pdf")

Note that the full field of view of the original data has been preserved in the final image for computational simplicity, however, data at the edges is not physically meaningful.  All shifts are stored as 1D arrays in the attributes shifts_x and shifts_y, thus to obtain the limits of physical meaningful data, use, e.g. s.shifts_x.max(), s.shifts_x.min(), etc.