Title: FocusStackerDedrifter
AUTHOR: Caroline Berthebaud Cheung
DATE: 2022/05/23 (YY/MM/DD)

[] =  FocusStackerDedrifter(input_dir, output_dir_fs, output_dir_ds, output_dir, output_format, fs=True, image_range_
      fs=True, ds=True, image_range_ds=True, shift_data=True, pad=True, overwrite=True, parallel=True, crop=True)

    This script is used as a pipeline to run FocusStacker, Drifty_Shifty, and crop_aligned_images sequentially in
    order to merge slices of an image taken at different focal points into one focused image, to dedrift and align the
    images, and then to crop and retain only the parts of the images that are visible in every single frame. This
    script reads images from path='input_dir', focus stacks the images, and saves them into an output directory
    (output_dir_fs). It then dedrifts these images and saves them into another output directory (output_dir_ds). Since
    the dedrifting process will create a black border around the images, it may be removed using the
    crop_aligned_images function, which takes the images in the 'output_dir_ds' directory as input, removes the black
    border as well as parts of the images that are not in every single frame, and then saves the cropped images in the
    final output directory (output_dir). All of the edited images will be saved as the given output format
    (output_format ie. png or tif(f)).

    To be clear, the input and output directories for the three different processes (FocusStacker, Drifty_Shifty,
    crop_aligned_images) are as follows:
    1. FocusStacker input directory = input_dir
    2. FocusStacker output directory = output_dir_fs
    3. Drifty_Shifty input directory = output_dir_fs
    4. Drifty_Shifty output directory = output_dir_ds
    5. crop_aligned_images input directory = output_dir_ds
    6. crop_aligned_images output directory = output_dir

    This pipeline takes 5 required inputs (input_dir, output_dir_fs, output_dir_ds, output_dir, output_format) and
    9 optional parameters (image_range_fs, fs, ds, image_range_ds, shift_data, pad, overwrite, parallel, crop) which
    are all True by default.
    
    If you don't want to convert all of your images in your folder, then input a tuple of
    integers, bracketed and separated by a comma, as the image_range to denote the start and the end of the images to
    be converted.
    
    If you want to just perform FocusStacker, then input ds=False, and if you wish to perform only
    Drifty_Shifty, then input fs=False.

    By default, "shift_data" is set as True which means that the function calc_shift will be called and a
    shift array will be calculated. However, if there is already a calculated shift array (saved as
    "shift_arrays.npz" file), then please move this file into "output_dir_fs" and input the file name as the
    "shift_data" argument.

    If you want to just calculate the shift array and not dedrift the images, then set pad=False.

    By default, the variable input "overwrite" is set to True, but can be changed to False. This will save subsequent
    files with the same base name with the addition of sequential numbers at the end.

    User can also define whether to run the functions in parallel (via joblib) or not, denoted here as
    "parallel", which by default is True but can be set to False.

    Finally, by default, the dedrifted/aligned images will be cropped to remove the black padding created by the
    dedrifting process and to retain only the parts of the images that are visible in every single frame (via the
    crop_aligned_images function). But if crop is set to False, this function will not be called.

# Import all the dependencies/libraries

In [1]:
import numpy as np
import cv2
from joblib import Parallel, delayed
from timeit import default_timer as timer
import itertools
import os
import sys
import glob
import scipy.fft as fft
from iteration_utilities import deepflatten

# All the functions for FocusStacker

In [None]:
def get_laplacian_pyramid(img, N):
    """
    returns N-level Laplacian Pyramid of input image as a list
    @input: image
    @output: - Laplacian Pyramid: list of N images containing laplacian pyramids from level 0 to level N
             - Gaussian Pyramid: list of N images containing gaussian pyramids from level 0 to level N
    """
    # current level image
    curr_img = img

    lap_pyramids = []
    gaussian_pyramids = [curr_img, ]

    # for N level
    for i in range(N):
        down = cv2.pyrDown(curr_img)
        gaussian_pyramids.append(down)
        up = cv2.pyrUp(down, dstsize=(curr_img.shape[1], curr_img.shape[0]))
        lap = curr_img - up.astype('int16')  # NOTE: BE SURE to use int16 instead of cv2.subtract,
        #       which cv2 will clip value to 0-255, here we want
        #       arbitratry integeter value.
        lap_pyramids.append(lap)
        curr_img = down
        # top level laplacian be a gaussian downsampled
        if i == N - 1:
            lap_pyramids.append(curr_img)

    return lap_pyramids

def get_probabilities(gray_image):
    levels, counts = np.unique(gray_image.astype(np.uint8), return_counts=True)
    probabilities = np.zeros((256,), dtype=np.float64)
    probabilities[levels] = counts.astype(np.float64) / counts.sum()
    return probabilities


def _area_entropy(area, probabilities):
    levels = area.flatten()
    return -1. * (levels * np.log(probabilities[levels])).sum()


def entropy(image, kernel_size):
    probabilities = get_probabilities(image)
    pad_amount = int((kernel_size - 1) / 2)
    padded_image = cv2.copyMakeBorder(image, pad_amount, pad_amount, pad_amount, pad_amount, cv2.BORDER_REFLECT101)
    entropies = np.zeros(image.shape[:2], dtype=np.float64)
    offset = np.arange(-pad_amount, pad_amount + 1)
    for row in range(entropies.shape[0]):
        for column in range(entropies.shape[1]):
            area = padded_image[row + pad_amount + offset[:, np.newaxis], column + pad_amount + offset]
            entropies[row, column] = _area_entropy(area, probabilities)

    return entropies


def _area_deviation(area):
    average = np.average(area).astype(np.float64)
    return np.square(area - average).sum() / area.size


# calculates the D: Deviation for every pixel locations
# Source: https://github.com/sjawhar/focus-stacking/blob/master/focus_stack/pyramid.py - Line 108-122
def deviation(image, kernel_size):
    pad_amount = int((kernel_size - 1) / 2)
    padded_image = cv2.copyMakeBorder(image, pad_amount, pad_amount, pad_amount, pad_amount, cv2.BORDER_REFLECT101)
    deviations = np.zeros(image.shape[:2], dtype=np.float64)
    offset = np.arange(-pad_amount, pad_amount + 1)
    for row in range(deviations.shape[0]):
        for column in range(deviations.shape[1]):
            area = padded_image[row + pad_amount + offset[:, np.newaxis], column + pad_amount + offset]
            deviations[row, column] = _area_deviation(area)

    return deviations


def generating_kernel(a):
    kernel = np.array([0.25 - a / 2.0, 0.25, a, 0.25, 0.25 - a / 2.0])
    return np.outer(kernel, kernel)


def convolve(image, kernel=generating_kernel(0.4)):
    return cv2.filter2D(src=image.astype(np.float64), ddepth=-1, kernel=np.flip(kernel))


# calculated RE: regional energy for every pixel locations
# Source: https://github.com/sjawhar/focus-stacking/blob/master/focus_stack/pyramid.py - Line 167-169
def region_energy(laplacian):
    return convolve(np.square(laplacian))


#focus-stacking (laplacian pyramid fusion method)
def lap_focus_stacking(images, N=5, kernel_size=5):
    """
    achieves the functionality of focus stacking using Laplacian Pyramid Fusion described 
        in Wang and Chang's 2011 paper (regional fusion)
    @input: images - array of images
            N      - Depth of Laplacian Pyramid, default is 5
            kernel_size - integer represents the side length of Gaussian kernel, default is 5
    @output: single image that stacked the depth of fields of all images
    """

    # 1- Generate array of Laplacian pyramids
    list_lap_pyramids = np.array([get_laplacian_pyramid(img, N)[:-1] for img in images], dtype=object)

    LP_f = []


    # 2 - Regional fusion using these Laplacian pyramids
    # fuse level = N laplacian pyramid, D=deviation, E=entropy
    D_N = np.array([deviation(lap, kernel_size) for lap in list_lap_pyramids[:, -1]])
    E_N = np.array([entropy(lap, kernel_size) for lap in list_lap_pyramids[:, -1]])

    # 2.1 - init level N fusion canvas
    LP_N = np.zeros(list_lap_pyramids[0, -1].shape)
    for m in range(LP_N.shape[0]):
        for n in range(LP_N.shape[1]):
            D_max_idx = np.argmax(D_N[:, m, n])
            E_max_idx = np.argmax(E_N[:, m, n])
            D_min_idx = np.argmin(D_N[:, m, n])
            E_min_idx = np.argmin(E_N[:, m, n])
            # if the image maximizes BOTH the deviation and entropy, use the pixel from that image
            if D_max_idx == E_max_idx:
                LP_N[m, n] = list_lap_pyramids[D_max_idx, -1][m, n]
            # if the image minimizes BOTH the deviation and entropy, use the pixel from that image
            elif D_min_idx == E_min_idx: 
                LP_N[m, n] = list_lap_pyramids[D_min_idx, -1][m, n]
            # else average across all images
            else:
                for k in range(list_lap_pyramids.shape[0]):
                    LP_N[m, n] += list_lap_pyramids[k, -1][m, n]
                LP_N[m, n] /= list_lap_pyramids.shape[0]

    LP_f.append(LP_N)

    # 2.2 - Fusion other levels of Laplacian pyramid (N-1 to 0)
    for l in reversed(range(0, N-1)):
        # level l final laplacian canvas
        LP_l = np.zeros(list_lap_pyramids[0, l].shape)

        # region energy map for level l
        RE_l = np.array([region_energy(lap) for lap in list_lap_pyramids[:, l]], dtype=object)

        for m in range(LP_l.shape[0]):
            for n in range(LP_l.shape[1]):
                RE_max_idx = np.argmax(RE_l[:, m, n])
                LP_l[m, n] = list_lap_pyramids[RE_max_idx, l][m, n]

        LP_f.append(LP_l)

    LP_f = np.array(LP_f, dtype=object)
    LP_f = np.flip(LP_f)


    # 3 - time to reconstruct final laplacian pyramid(LP_f) back to original image!
    # get the top-level of the gaussian pyramid
    for img in images:
        base = get_laplacian_pyramid(img, N)[-1]
    fused_img = cv2.pyrUp(base, dstsize=(LP_f[-1].shape[1], LP_f[-1].shape[0])).astype(np.float64)

    for i in reversed(range(N)):
        # combine with laplacian pyramid at the level
        fused_img += LP_f[i]
        if i != 0:
            fused_img = cv2.pyrUp(fused_img, dstsize=(LP_f[i-1].shape[1], LP_f[i-1].shape[0]))

    return fused_img


# Recreates the stacked image and saves it into output_dir
def merged_focus(group, output_dir_fs, output_format, overwrite=True):

    # 1 - load images (in GRAY)
    image = [cv2.imread(g) for g in group]
    if image[0].shape[2] == 3:
        images = np.array([cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in image])
    else:
        images = np.array([img for img in image])

    # defines the base name in order to name the files correctly
    z = group[0].rfind('z')
    names = [g[:z] for g in group][0]

    # check the filenames are valid
    if any([image is None for image in images]):
        raise RuntimeError("Cannot load one or more input files.")

    # 2 - focus stacking by first creating fused image as RGB image
    RGB_images = np.array([img for img in images])
    canvas = np.array([lap_focus_stacking(RGB_images[:, :, :])])
    canvas = np.moveaxis(canvas, 0, -1)

    # 3 - write to file (grayscale)
    if overwrite == False:
        a = glob.glob(os.path.join(output_dir_fs, f"{names}_merged.{output_format}"))
        b = len(a)
        if output_format.lower() == 'png':
            if b == 0:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged.{output_format}'), canvas)
            elif b == 1:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged1.{output_format}'), canvas)
            else:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged{b}.{output_format}'), canvas)
        else:
            canvas2 = cv2.normalize(src=canvas, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            if b == 0:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged.{output_format}'), canvas2)
            elif b == 1:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged1.{output_format}'), canvas2)
            else:
                cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged{b}.{output_format}'), canvas2)
    else:
        if output_format.lower() == 'png':
            cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged.{output_format}'), canvas)
        else:
            canvas2 = cv2.normalize(src=canvas, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            cv2.imwrite(os.path.join(output_dir_fs, f'{names}_merged.{output_format}'), canvas2)

    return None

def FocusStacker(input_dir, output_dir_fs, output_format, parallel=True, image_range_fs=True, overwrite=True):

    start = timer()
    print('start merging')

    # current working dir
    cwd = os.curdir

    # change working dir to image directory
    os.chdir(input_dir)

    # These are all the accepted types of extensions
    expected_ext = ['png', 'tif', 'tiff']

    # sorting and grouping images according to the base name and the number of slices (z)
    image_files = sorted(os.listdir(input_dir))
    file_names = [img for img in image_files if img.split(".")[-1].lower() in expected_ext and img[0].isalnum()]

    # If you don't want to convert all of your images in your folder, then input a tuple of integers as image_range.
    # Input 2 integers (bracketed and separated by a comma) to denote the start and the end of the images to be
    # converted. If you do not input exactly 2 integers, an error will be raised. If image_range is True,
    # all the files in the folder will be passed into the function.
    if image_range_fs != True:
        if len(image_range_fs)!=2:
            sys.exit("image_range doesn't exist. Please input a tuple of two values.")
        else:
            file_names = file_names[image_range_fs[0]:image_range_fs[1]]

    z = file_names[0].rfind('z')
    groups = [list(g) for _, g in itertools.groupby(sorted(file_names), lambda x: x[0:z])]

    # determines how many slices (z - images with different focus) in each group
    for group in groups:
        num_of_zplane_images = len(group)

    # input sanity checks
    num_files = len(file_names)
    assert num_files > 1, "Provide at least 2 images."

    # determines the number of files already in output directory
    output_files = os.listdir(output_dir_fs)
    if len(output_files) != (0 or len(file_names)/num_of_zplane_images):
        output_files2 = [output_file.split('_merged')[0] for output_file in output_files]
        file_names2 = [f for f in file_names if f[0:z] not in output_files2]
        groups = [list(g) for _, g in itertools.groupby(sorted(file_names2), lambda x: x[0:z])]


    # If you want to run the main focus stacking function, merged_focus, serially which would be slower
    if parallel == False:
        [merged_focus(group, output_dir_fs, output_format, overwrite) for group in groups]
    else:
        # run the main focus stacking function, merged_focus, in parallel with joblib
        Parallel(n_jobs=-1)(delayed(merged_focus)(group, output_dir_fs, output_format, overwrite) for group in groups)

    # change working dir back to original working directory
    os.chdir(cwd)
    end = timer()
    print(f'elasped time {end-start}, Focus Stacking successful')
    return None

# All the functions for Drifty_shifty (dedrifter)

In [None]:
def get_ref(images):

    # Get and process reference frame (first one in the sequence)
    frameref = cv2.imread(images[0])
    if frameref.shape[2] == 3:
        frameref = cv2.cvtColor(frameref, cv2.COLOR_BGR2GRAY)
    fft_ref = fft.fft2(frameref)

    vidHeight = frameref.shape[1]
    vidWidth = frameref.shape[0]  # The blank variable here gets rid of extra padding - we didn't do this in python!!
    centery = (vidHeight / 2) + 1
    centerx = (vidWidth / 2) + 1

    return fft_ref, vidHeight, vidWidth, centery, centerx


# this function performs fourier transformation for each image and returns the maximum x and y indices
def calc_shift(images, fft_ref):

    img = cv2.imread(images)
    if img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Calculate fourier transformation of each image
    fft_frame = fft.fft2(img)

    # vector multiplication of the fourier-transformed image reference with the
    # complex conjugate array of each fourier-transformed image
    prod = fft_ref * np.conjugate(fft_frame)

    # get inverse fourier transformation of the product
    # need to get the "real" numbers and not the imaginary numbers
    cc = (fft.ifft2(prod)).real

    # 'fftshift' moves corners to center, 'max()' gives largest element in the whole array, and
    # 'where' returns indices of that point
    maxYX = np.where(fft.fftshift(cc) == np.max(cc))

    return maxYX


# this function then takes the maximum x and y indices to calculate the x-y shift array
def calc_shift2(images, maxYX, vidHeight, vidWidth, centery, centerx):

    nFrames = len(images)

    shifty = np.zeros((nFrames,))
    shiftx = np.zeros((nFrames,))

    maxYX2 = list(deepflatten(maxYX))
    maxY = maxYX2[::2]
    maxX = maxYX2[1::2]

    i=0
    for i in range(nFrames):

        shifty[i] = maxY[i] - centery
        shiftx[i] = maxX[i] - centerx

        # Previous version didn't subtract center point here
        if i > 0:  # Checks to see if there is an ambiguity problem with FFT because of the periodic boundary in FFT
            if np.abs(shifty[i] - shifty[i - 1]) > vidHeight / 2:
                shifty[i] = shifty[i] - np.sign(shifty[i] - shifty[i - 1]) * vidHeight

            if np.abs(shiftx[i] - shiftx[i - 1]) > vidWidth / 2:
                shiftx[i] = shiftx[i] - np.sign(shiftx[i] - shiftx[i - 1]) * vidWidth

        i=i+1

    return shifty, shiftx


# Step 2: Pads and defrifts images

# This function is the core function and actually pads, centers, and dedrifts the images according to the calculated
# shift data. It will take the shift array saved as an .npz file in the 'input_dir' and dedrift the images accordingly.
def pad_images(images, shift_arrays, output_dir, output_format, overwrite=True):

    # this will load the shift_data from the calc_shift function
    with np.load(shift_arrays) as data:
        shiftx = data['x']
        shifty = data['y']

    # number of images
    nFrames = len(images)

    # Get Height & Width of reference image (first one in the sequence)
    frameref = cv2.imread(images[0])
    if frameref.shape[2] == 3:
        frameref = cv2.cvtColor(frameref, cv2.COLOR_BGR2GRAY)
    frameref = frameref.astype(dtype='uint8')
    vidHeight, vidWidth = frameref.shape[0:2]

    # Pad the images. Use first image as "center"
    newsizey = round(2 * np.max(np.abs(shifty)) + vidHeight)
    newsizex = round(2 * np.max(np.abs(shiftx)) + vidWidth)

    # Assume max positive shift = max negative shift; centers reference frame
        # This was the original code but for some reason works for some but not all sets of images
        # midindexy = (newsizey - vidHeight) / 2 + 1
        # midindexx = (newsizex - vidWidth) / 2 + 1
    midindexy = (newsizey - vidHeight) / 2
    midindexx = (newsizex - vidWidth) / 2

    # Determine how many images are in the output directory in case run was stopped while in progress
    files_in_outputdir = glob.glob(os.path.join(output_dir, f"*_dedrifted*.{output_format}"))
    num_files_in_outputdir = len(files_in_outputdir)

    # If the 'output_dir' does not contain any images or contains the entire set of dedrifted images, it will start
    # processing the images from the beginning.
    # If the 'output_dir' contains some of the dedrifted images but not all from the 'input_dir', it will continue
    # processing the images where it left off.
    if num_files_in_outputdir == (0 or nFrames):
        range_of_images = range(nFrames)
    else:
        range_of_images = range(num_files_in_outputdir, nFrames)

    # The following code takes the image and shifts it according to the shift array in a frame padded with a black
    # border if overwrite is False. Newly dedrifted images will be saved with an extra number at the end if the same
    # file already exists in the 'output_dir'
    for i in range_of_images:
        frame_shift = np.zeros((newsizey, newsizex), dtype='uint8')

        img = cv2.imread(images[i])
        if img.shape[2] == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # For every image, want to shift the frame according to shifty/shiftx and center and pad it.
        # starty/x and endy/x are the coordinates inside the frame_shift frame in which to place the image (img)
        starty = (midindexy + shifty[i]).astype(int)
        endy   = (midindexy + shifty[i] + (vidHeight)).astype(int)
        startx = (midindexx + shiftx[i]).astype(int)
        endx   = (midindexx + shiftx[i] + (vidWidth)).astype(int)
        frame_shift[starty:endy, startx:endx] = img

        # This file is the corrected image and is subsequently saved
        dedrifted = np.round_(frame_shift)

        # This code saves the dedrfited images into the output directory. If overwrite is False, then newly saved
        # images will not overwrite files with the same name already in the output directory.
        if overwrite == False:
            a = glob.glob(os.path.join(output_dir, f"{images[i].rsplit('.')[0]}_dedrifted*.{output_format}"))
            b = len(a)
            if b == 0:
                cv2.imwrite(os.path.join(output_dir, f"{images[i].rsplit('.')[0]}_dedrifted.{output_format}"), dedrifted)
            elif b == 1:
                cv2.imwrite(os.path.join(output_dir, f"{images[i].rsplit('.')[0]}_dedrifted1.{output_format}"),
                            dedrifted)
            else:
                cv2.imwrite(os.path.join(output_dir, f"{images[i].rsplit('.')[0]}_dedrifted{b}.{output_format}"),
                            dedrifted)
        else:
            cv2.imwrite(os.path.join(output_dir, f"{images[i].rsplit('.')[0]}_dedrifted.{output_format}"), dedrifted)

    # This code will print out how many images are done to determine progress
        if i % 50 == 0:
            print(f'{i} frames out of {nFrames} are done')

        i=i+1

    return None



# Step 3: Call all of the functions using the main function
def Drifty_Shifty(input_dir, output_dir, output_format,
                  image_range_ds=True, shift_data="True", pad=True, overwrite=True, parallel=True):

    start = timer()
    print(f'start shifting')

    # define current working dir as cwd
    cwd = os.curdir

    # Change the working directory to the input directory
    os.chdir(input_dir)

    # Get all the individual files in the input directory
    files = sorted(os.listdir(input_dir))

    # These are all the accepted types of extensions
    expected_ext = ['png', 'tif', 'tiff']

    # Determine input format type
    fileNames = [img for img in files if (img.split(".")[-1].lower() in expected_ext and img[0].isalnum())]

    # in case someone accidentally inputs another type of format (ie. jpg), the output format will be tif
    if output_format not in expected_ext:
        output_format = 'tif'


    # If you don't want to convert all of your images in your folder, then input a tuple of integers as image_range.
    # Input 2 integers (bracketed and separated by a comma) to denote the start and the end of the images to be
    # converted. If you do not input exactly 2 integers, an error will be raised. If image_range is True,
    # all the files in the folder will be passed into the function.
    if image_range_ds != True:
        if len(image_range_ds)!=2:
            sys.exit("image_range doesn't exist. Please input a tuple of two values.")
        else:
            fileNames = fileNames[image_range_ds[0]:image_range_ds[1]]

    # this code calls the functions to calculate the shift array and save it as an .npz file in the 'input_dir'.
    # If shift_data is not True and is instead a path to an .npz file that contains the shift data, then the code
    # will bypass this part and go directly to pad images.
    if shift_data == "True":
        fft_ref, vidHeight, vidWidth, centery, centerx = get_ref(fileNames)
        if parallel == False:
            # Serial processing
            maxYX = [calc_shift(fileNames[i], fft_ref) for i in range(len(fileNames))]
            shift_arrays = calc_shift2(fileNames, maxYX, vidHeight, vidWidth, centery, centerx)
        else:
            # Joblib multiprocessing
            maxYX = Parallel(n_jobs=-1, prefer="threads")(delayed(calc_shift)(fileNames[i], fft_ref)
                                                          for i in range(len(fileNames)))
            shift_arrays = calc_shift2(fileNames, maxYX, vidHeight, vidWidth, centery, centerx)


        # if overwrite is False, then any previously saved shift arrays will not be overwritten and new arrays will
        # instead have an extra number at the end.
        if overwrite == False:
            a = glob.glob('shift_arrays*.npz')
            b = len(a)
            if b == 0:
                np.savez('shift_arrays.npz', x=shift_arrays[1], y=shift_arrays[0])
                shift_arrays = 'shift_arrays.npz'
            elif b == 1:
                np.savez(f'shift_arrays1.npz', x=shift_arrays[1], y=shift_arrays[0])
                shift_arrays = 'shift_arrays1.npz'
            else:
                np.savez(f'shift_arrays{b}.npz', x=shift_arrays[1], y=shift_arrays[0])
                shift_arrays = f'shift_arrays{b}.npz'
        else:
            np.savez('shift_arrays.npz', x=shift_arrays[1], y=shift_arrays[0])
            shift_arrays = 'shift_arrays.npz'
    # but if already have a shift array for this particular set of images saved somewhere, then put that file into the
    # input_dir and input path as argument for shift_arrays
    else:
        print('importing shift arrays')
        shift_arrays = shift_data
    end = timer()
    print(f'elapsed time: {end - start}')

    # use this function to actually shift the images in the frame with a black padding
    if pad == True:
        start = timer()
        print(f'start padding')

        pad_images(fileNames, shift_arrays, output_dir, output_format, overwrite)

        end = timer()
        print(f'elapsed time: {end - start}')


    # change the working directory back to the original one
    os.chdir(cwd)

    print('dedrifting successful')
    return None

# All the functions for Crop_aligned_images

In [None]:
def remove_black_border(images):
    y_nonzero, x_nonzero = list(zip(*[np.nonzero(i) for i in images]))

    min_y = np.max([np.min(y) for y in y_nonzero])
    max_y = np.min([np.max(y) for y in y_nonzero])
    min_x = np.max([np.min(x) for x in x_nonzero])
    max_x = np.min([np.max(x) for x in x_nonzero])

    return min_y, min_x, max_y, max_x



def crop_aligned_images(input_dir, output_dir, output_format):
    start = timer()
    print('start cropping')

    # define current working dir as cwd
    cwd = os.curdir

    # Change the working directory to the input directory
    os.chdir(input_dir)

    # Get all the individual files in the input directory
    files = sorted(os.listdir(input_dir))

    # These are all the accepted types of extensions
    expected_ext = ['png', 'tif', 'tiff']

    # Determine input format type
    fileNames = [img for img in files if (img.split(".")[-1].lower() in expected_ext and img[0].isalnum())]
    images = [cv2.imread(f) for f in fileNames]
    images = [cv2.cvtColor(i, cv2.COLOR_BGR2GRAY) for i in images if i.shape[2]==3]
    names = [f.rsplit('.')[0] for f in fileNames]

    min_y, min_x, max_y, max_x = remove_black_border(images)

    for i in range(len(images)):
        image = images[i][min_y:max_y, min_x:max_x]
        cv2.imwrite(os.path.join(output_dir, f"{names[i]}_crop.{output_format}"), image)

    end = timer()
    print(f'elapsed time {end-start}')
    print('cropping successful')

    return None

# Main function for FocusStackerDedrifter

In [None]:
def FocusStackerDedrifter(input_dir, output_dir_fs, output_dir_ds, output_dir, output_format, fs=True,
                           image_range_fs=True, ds=True, image_range_ds=True, shift_data=True, pad=True, overwrite=True,
                           parallel=True, crop=True):


    if fs == True:
        FocusStacker(input_dir, output_dir_fs, output_format, parallel=True, image_range_fs=True,
                     overwrite=True)
    if ds == True:
        Drifty_Shifty(output_dir_fs, output_dir_ds, output_format, image_range_ds=True, shift_data=True,
                      pad=True, overwrite=True, parallel=True)
    if crop == True:
        crop_aligned_images(output_dir_ds, output_dir, output_format)

# All the inputs: please refer to the introduction section and read carefully.

In [None]:
input_dir = "/Volumes/Caro2/finalstack_orig"
output_dir_fs = "/Volumes/Caro2/finalstack_orig"
output_dir_ds = "/Volumes/Caro2/finalstack_dedrift"
output_dir = "/Volumes/Caro2/finalstack_crop"
output_format = 'png'
fs=True
image_range_fs=True
ds=True
image_range_ds=True
shift_data=True
pad=True
overwrite=True
parallel=True
crop=True

# Call the FocusStackerDedrifter function

In [None]:
FocusStackerDedrifter(input_dir, output_dir_fs, output_dir_ds, output_dir, output_format, fs=True,
                           image_range_fs=True, ds=True, image_range_ds=True, shift_data=True, pad=True, 
                           overwrite=True, parallel=True, crop=True)