In [65]:
import os
import cv2
import numpy as np

In [88]:
max_shift = 20 # Maximum shift in either x or y direction expected between any two given images of the series. Must be a positive integer

class ImgSeries:
    """Stores a series of SEM images.

    Attributes:
        plain (list): Images without pre-processing stored as NumPy arrays
        prep (list): Pre-processed images stored as NumPy arrays
    """

    def __init__(self, folder_path: str, interval: tuple=None):
        """Initializes an instance by loading .tiff images from a given folder.

        Args:
            folder_path (str): Path to the folder containing the images
            interval (tuple): Indicates the index of the first and last .tiff file to load (start, stop)
        
        Raises:
            AssertionError: When the indicated folder contains no .tiff files
        """
        self.plain = []
        self.prep = []
        fname_list = [f for f in os.listdir(folder_path) if f.endswith(".tiff")]
        assert len(fname_list) > 0, "no .tiff files in the indicated folder"
        if interval:
            fname_list = fname_list[interval[0]:interval[1]+1]
        for fname in fname_list:
            print(fname)
            img_path = os.path.join(folder_path, fname)
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            self.plain.append(img)
        self.len = len(self.plain)

    def __str__(self) -> str:
        assert len(self.plain) > 0, "ImgSeries is empty"
        return f"ImgSeries containing {len(self.plain)} plain images with size {self.plain[0].shape} and {len(self.prep)} pre-processed images"
    
    def crop(self, amount: tuple=(0, 0, 0, 0)) -> None:
        """Crops the images in self.plain and self.prep by a given amount. Only run before self.prep is created

        Args:
            amount (tuple): Specifies how many pixels are discarded (top, bottom, left, right)
        """
        top, bottom, left, right = amount
        assert len(self.prep) == 0, "please only run this function before creating pre-processed images"
        for i, img in enumerate(self.plain):
            height, width = img.shape[:2]
            if top + bottom >= height or left + right >= width:
                raise ValueError("Crop values are too large for the image size.")
            self.plain[i] = img[top:height-bottom, left:width-right]

    def gaussian_blur(self, kernel_size: tuple):
        """Applies a Gaussian blur to a copy of the images and stores them in self.prep.

        Args:
            kernel_size (tuple): Size of the blur's kernel. Both components must be odd integers
        """
        assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1, "both components of kernel_size must be odd integers"
        for img in self.plain:
            self.prep.append(cv2.GaussianBlur(img, kernel_size, 0))


class ImgPair:
    """Used to store information in the ShiftGraph class and perform calculations that only include a pair of images.
    
    Attributes:
        indices (tuple): Indices of the two images (img1, img2) in the image list
        img1, img2 (np.ndarray): Two pre-processed images stored as NumPy arrays
        costs (dict): Contains all the cost values calculated for this ImgPair. The indices are tuples and the values floats
        shift (tuple): Contains two integers (x, y) specifying the optimal overlay. None while not yet calculated
    """

    def __init__(self, img_series: ImgSeries, indices: tuple):
        """Initializes an instance of ImgPair from an ImgSeries.
        
        Args:
            img_series (ImgSeries): The instance of ImgSeries the ShiftGraph will be based on
            indices (tuple): A tuple with two integers (img1, img2) storing the indices of the two images
        """
        self.indices = indices
        self.img1 = img_series.prep[indices[0]]
        self.img2 = img_series.prep[indices[1]]
        self.costs = {}
        self.shift = None

    def __str__(self) -> str:
        return f"img{self.indices[0]} and img{self.indices[1]} with shift {self.shift}"
    
    def get_cost(self, sx: int, sy: int) -> float:
        """Adds the value of the cost function for a given shift x, y to the self.costs dictionary and returns it.

        First, the function checks whether the cost has already been calculated.
        If not, it shifts the images by cropping each so that only the common part remains and then calculates the square average.

        Args:
            sx, sy (int): Amount the images are shifted in x and y directions

        Returns:
            The cost value for the given shift
        """
        if (sx, sy) in self.costs:
            return self.costs[(sx, sy)]
        else:
            def shift_image(img: np.ndarray, sx: int, sy) -> np.ndarray:
                """Helper function that shifts an image by a certain amount by slicing it"""
                x_slice = slice(max(sx, 0), None if sx >= 0 else sx)
                y_slice = slice(max(sy, 0), None if sy >= 0 else sy)
                return img[y_slice, x_slice]
            
            img1_shifted = shift_image(self.img1, sx, sy)
            img2_shifted = shift_image(self.img2, sx, sy)
            img1_cut = img1_shifted[max_shift:-max_shift, max_shift:-max_shift]
            img2_cut = img2_shifted[max_shift:-max_shift, max_shift:-max_shift]
            diff = (img2_cut-img1_cut)**2
            self.costs[(sx, sy)] = np.average(diff)
            return self.costs[(sx, sy)]
    
    def get_optimal_shift(self) -> tuple[int]:
        """Determines the shift between the two images necessary for optimal overlay by finding the cost function's minimum.
        
        This is done using a multi-scale grid search progress, where the cost function is evaluated at shifts corresponding to a 3x3 grid. The point with the lowest value is the center of the next grid, scaled down by a factor of 2. This is repeated until reaching the 1-pixel scale.

        Returns:
            The determined optimal shift
        """
        best_pos = [0, 0]
        min_val = self.get_cost(0, 0)
        max_step_size = 1 << (max_shift - 1).bit_length() # round max_shift to the nearest power of 2
        for i in range(int(max_step_size**0.5)):
            # 3x3 grid, size reduced by factor of 2 at each step
            step_size = max_step_size/2**i
            for x in [best_pos[0]-step_size, best_pos[0], best_pos[0]+step_size]:
                for y in [best_pos[1]-step_size, best_pos[1], best_pos[1]+step_size]:
                    self.get_cost(x, y)
                    if self.costs[(x, y)] < min_val:
                        min_val = self.costs[(x, y)]
                        best_pos = [x, y]
        self.shift = tuple(best_pos)
        return self.shift
            

class ShiftGraph:
    """Stores information on the relative shifts between images in a series as an adjacency matrix.

    Attributes:
        img_series: The ImgSeries associated with this ShiftGraph
        matrix: The adjacency matrix. The values are either ImgPair objects or None
    """

    def __init__(self, img_series: ImgSeries):
        """Initializes an instance with an empty adjacency matrix based on an ImgSeries.

        Args:
            img_series: An instance of the ImgSeries class
        """
        self.imgs = img_series
        self.matrix = [[None for _ in range(self.imgs.len)] for _ in range(self.imgs.len)]
        assert len(self.imgs.prep) > 0, "list with pre-processed images has not been initialized"
        for i in range(len(self.imgs.prep)):
            for j in range(i+1, len(self.imgs.prep)):
                self.matrix[i][j] = ImgPair(self.imgs, (i, j))
    
    def __str__(self) -> str:
        """Prints an adjacency matrix with the shifts for each ImgPair"""
        matrix_toprint = [row[:] for row in self.matrix]
        for i, row in enumerate(matrix_toprint):
            for j, element in enumerate(row):
                if type(element) is ImgPair:
                    matrix_toprint[i][j] = element.shift
        return str(matrix_toprint)