# Code to run basic non-rigid registration on images

In order to run this code, you must first install simpleitk. To do this on a Diamond linux box, you must use a user install only as access is not granted for a system-wide install.

To install simpleitk, run: conda install -c simpleitk simpleitk --use-local

First, import all necessary packages.

In [None]:
import hyperspy.api as hs
import SimpleITK as sitk

Define function for non-rigid registration.

In [None]:
def nonrigid(im_stack, demons_it = 20, filter_size = 5.0, max_it = 3):
    "Function to non-rigidly register a series of images.
    
    Input
    -----
    im_stack: 3D numpy array of data
    demons_it: int
    The number of iterations for the demons algorithm to perform.
    filter_size: float
    max_it: int
    The number of iterations to apply the full non-rigid alignment algorithm.
    
    demons = sitk.DemonsRegistrationFilter()
    demons.SetNumberOfIterations( demons_it )
    # Standard deviation for Gaussian smoothing of displacement field
    demons.SetStandardDeviations( filter_size )
    
    for j in range(max_it):
        #Get stack average
        av_im = sitk.GetImageFromArray(np.float32(sum(im_stack)/len(im_stack))) #Faster than numpy.mean for small arrays?
        
        out_stack = []
        
        for i in range(len(im_stack)):
            
            moving = sitk.GetImageFromArray(np.float32(im_stack[i]))
            
            displacementField = demons.Execute( av_im, moving )
            
            dispfield = sitk.GetArrayFromImage(displacementField)
            
            outTx = sitk.DisplacementFieldTransform( displacementField )
            
            resampler = sitk.ResampleImageFilter()
            resampler.SetReferenceImage(av_im);
            resampler.SetInterpolator(sitk.sitkLinear)
            resampler.SetDefaultPixelValue(100)
            resampler.SetTransform(outTx)
            
            out_stack.append(sitk.GetArrayFromImage(resampler.Execute(moving)))
            
        im_stack = out_stack
        max_disp = np.max(dispfield)
        print(max_disp)
            
        if max_disp < 0.3:
            print("NRR stopped after "+str(j)+" iterations.")
            break
    
    return(out_stack)

Import multi-frame file using hyperspy.

In [None]:
filename = #Insert file location here
images = hs.load(filename,stack=True)

In [None]:
images.plot()

In order to use the non-rigid registration code, the image series must first be rigidly aligned. Hyperspy has its own rigid registration function, used below, but this doesn't always work (particularly for highly periodic data). Rigid alignment can be done prior to using this notebook in any other software.

In [None]:
images.align2D()

Once rigidly aligned, the non-rigid resgitration can be run (see the above docstring for more information on parameters).

In [None]:
images_nrr = nonrigid(images.data)

The aligned series can then be summed and processed using atomap etc.

In [None]:
image = images_nrr = hs.signals.Signal2D(images_nrr)
image_sum = images.sum()
image_sum.plot()