# 0. Import and define functions

In [1]:
##import libaries
import tomopy
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
import skimage
from skimage import io
import tifffile as tif
from skimage import transform as tf
import pandas

In [2]:
def Gauss_blur(I,sigma):
    output = np.copy(I)
    fSize = sigma*3 +  ~(sigma*3)%2
    for i in range(I.shape[0]):
        output[i,:,:] = cv2.GaussianBlur(data[i,:,:],(fSize, fSize),sigma)
    return output
    
def normalization_both(I0,satuation):
    ##normalize the whole stack and remove the extreme values
    I = np.copy(I0)
    Is = np.flip(np.sort(I0,None))
    high = Is[int(np.ceil(len(Is)*satuation))]
    low = Is[int(np.ceil(len(Is)*(1-satuation)))]
    for i in range(I0.shape[0]):
        I[i,:,:] = (I0[i,:,:]-low)/(high-low)
    I = (I0-low)/(high-low)
    I[I<0] = 0
    I[I>1] = 1
    return I

#A patched align_joint function in tomopy to enhance the cross corelation with filtering
#This patch is based on tomopy 1.14.4

def align_joint_modify(
        prj, ang, blr, fdir='.', iters=10, pad=(0, 0),
        blur=True, center=None, algorithm='sirt',
        upsample_factor=10, rin=0.5, rout=0.8,
        save=False, debug=True):
    import concurrent.futures as cf
    from tomopy.recon.algorithm import recon
    from tomopy.sim.project import project
    from tomopy.prep.alignment import scale, blur_edges
    from skimage.registration import phase_cross_correlation
    from cv2 import GaussianBlur
    # Needs scaling for skimage float operations.
    prj, scl = scale(prj)

    # Shift arrays
    sx = np.zeros((prj.shape[0]))
    sy = np.zeros((prj.shape[0]))

    conv = np.zeros((iters))

    # Pad images.
    npad = ((0, 0), (pad[1], pad[1]), (pad[0], pad[0]))
    prj = np.pad(prj, npad, mode='constant', constant_values=0)

    # Initialization of reconstruction.
    rec = 1e-12 * np.ones((prj.shape[1], prj.shape[2], prj.shape[2]))

    extra_kwargs = {}
    if algorithm != 'gridrec':
        extra_kwargs['num_iter'] = 1

    # Register each image frame-by-frame.
    for n in range(iters):

        if np.mod(n, 1) == 0:
            _rec = rec

        # Reconstruct image.
        rec = recon(prj, ang, center=center, algorithm=algorithm,
                    init_recon=_rec, **extra_kwargs)

        # Re-project data and obtain simulated data.
        sim = project(rec, ang, center=center, pad=False)

        # Blur edges.
        if blur:
            _prj = blur_edges(prj, rin, rout)
            _sim = blur_edges(sim, rin, rout)
        else:
            _prj = prj
            _sim = sim

        # Initialize error matrix per iteration.
        err = np.zeros((prj.shape[0]))

        # For each projection
        for m in range(prj.shape[0]):

            # Register current projection in sub-pixel precision
            shift, error, diffphase = phase_cross_correlation(
                    GaussianBlur(_prj[m],(blr*5,blr*5),blr), GaussianBlur(_sim[m],(blr*5,blr*5),blr), upsample_factor = upsample_factor, normalization = None)
            err[m] = np.sqrt(shift[0]*shift[0] + shift[1]*shift[1])
            sx[m] += shift[0]
            sy[m] += shift[1]

            # Register current image with the simulated one
            tform = tf.SimilarityTransform(translation=(shift[1], shift[0]))
            prj[m] = tf.warp(prj[m], tform, order=5)

        if debug:
            print('iter=' + str(n) + ', err=' + str(np.linalg.norm(err)))
            conv[n] = np.linalg.norm(err)

    # Re-normalize data
    prj *= scl
    return prj, sx, sy, conv, _sim, rec


# 1. Load tilt series and angle

In [3]:
##import coarse aligned tilting series
##rotation is already corrected
##please change two file names
tilt_stack = io.imread('input/_20240612_124708_log_norm.tif')#[:,64:-64, 64:-64]
print(tilt_stack.shape)
print(tilt_stack.dtype)
print(tilt_stack.max())
print(tilt_stack.min())
tilt_angle = pandas.read_csv('input/tiltAnglesSeries__20240612_124708.csv',header = None).to_numpy().squeeze()[:,1]

(1002, 512, 512)
float32
2.529988
-1.6859161


# 2. Find the tilt angle of all tilt series

In [4]:
##this block finds the tilt angles
##change the tilt_range to match the number of tilts per tilt series of _20240612_124708_log_norm
tilt_range = 41
angle_matrix = []
lut = []
loc = 0
while ((loc + tilt_range - 1) < tilt_angle.shape[0]):
    angles = tilt_angle[loc:loc + tilt_range]
    angle_matrix.append(angles)
    lut.append(np.array((loc, loc+tilt_range)))
    loc += tilt_range - 1
angle_matrix = np.array(angle_matrix)
lut = np.array(lut)
print(angle_matrix)
print(lut)

[[-45.  -42.5 -40.  ...  50.   52.5  55. ]
 [ 55.   52.5  50.  ... -40.  -42.5 -45. ]
 [-45.  -42.5 -40.  ...  50.   52.5  55. ]
 ...
 [-45.  -42.5 -40.  ...  50.   52.5  55. ]
 [ 55.   52.5  50.  ... -40.  -42.5 -45. ]
 [-45.  -42.5 -40.  ...  50.   52.5  55. ]]
[[   0   41]
 [  40   81]
 [  80  121]
 [ 120  161]
 [ 160  201]
 [ 200  241]
 [ 240  281]
 [ 280  321]
 [ 320  361]
 [ 360  401]
 [ 400  441]
 [ 440  481]
 [ 480  521]
 [ 520  561]
 [ 560  601]
 [ 600  641]
 [ 640  681]
 [ 680  721]
 [ 720  761]
 [ 760  801]
 [ 800  841]
 [ 840  881]
 [ 880  921]
 [ 920  961]
 [ 960 1001]]


# 3. Iterative alignment

In [None]:
##this block does the alignment 
##change the output folder name and create the folder
sigma = 1
itr = 0
# for i in range(lut.shape[0]):
for i in range(len(lut)):
    data = np.copy(tilt_stack[lut[i,0]:lut[i,1],:,:])
    angle = angle_matrix[i,:]/180*np.pi
    print('Now aligning tilt series: ' + str(itr+1) + '/' + str(len(angle_matrix)))
    data = normalization_both(data,0.001)
    proj_aligned = align_joint_modify(data, angle, sigma, center = 256, algorithm = 'sirt',iters=20)
    tif.imwrite('output/align/_20240612_124708/' + "{:03d}".format(itr) + '.tif', proj_aligned[0])
    itr += 1