In [1]:
import os
import cv2
import numpy as np
import random

In [6]:
max_shift = 20 # Maximum shift in either x or y direction expected between any two given images of the series. Must be a positive integer

class ImgSeries:
    """Stores a series of SEM images.

    Attributes:
        plain (list): Images without pre-processing stored as NumPy arrays
        prep (list): Pre-processed images stored as NumPy arrays
    """

    def __init__(self, folder_path: str, interval: tuple=None) -> None:
        """Initializes an instance by loading .tiff images from a given folder.

        Args:
            folder_path (str): Path to the folder containing the images
            interval (tuple): Indicates the index of the first and last .tiff file to load (start, stop)
        
        Raises:
            FileNotFoundError: When the indicated folder contains no .tiff files
        """
        self.plain = []
        self.prep = []
        fname_list = [f for f in os.listdir(folder_path) if f.endswith(".tiff")]
        if len(fname_list) == 0:
            raise FileNotFoundError("There are no .tiff files in the indicated folder")
        if interval:
            fname_list = fname_list[interval[0]:interval[1]+1]
        for fname in fname_list:
            print(fname)
            img_path = os.path.join(folder_path, fname)
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            self.plain.append(img)
        self.len = len(self.plain)

    def __str__(self) -> str:
        if len(self.plain) == 0:
            raise FileNotFoundError("ImgSeries is empty")
        return f"ImgSeries containing {len(self.plain)} plain images with size {self.plain[0].shape} and {len(self.prep)} pre-processed images"
    
    def crop(self, amount: tuple=(0, 0, 0, 0)) -> None:
        """Crops the images in self.plain and self.prep by a given amount. Only run before self.prep is created

        Args:
            amount (tuple): Specifies how many pixels are discarded (top, bottom, left, right)

        Raises:
            Exception: If the function is run after pre-processed images have been created
            ValueError: If the crop values are too large
        """
        top, bottom, left, right = amount
        if len(self.prep) > 0:
            raise Exception("Only run this function before creating pre-processed images")
        for i, img in enumerate(self.plain):
            height, width = img.shape[:2]
            if top + bottom >= height or left + right >= width:
                raise ValueError("Crop values are too large for the image size.")
            self.plain[i] = img[top:height-bottom, left:width-right]

    def gaussian_blur(self, kernel_size: tuple):
        """Applies a Gaussian blur to a copy of the images and stores them in self.prep.

        Args:
            kernel_size (tuple): Size of the blur's kernel. Both components must be odd integers
        
        Raises:
            ValueError: When one or both of the kernel size components isn't an odd integer.
        """
        if kernel_size[0] % 2 != 1 or kernel_size[1] % 2 != 1:
            raise ValueError("Both components of kernel_size must be odd integers.")
        for img in self.plain:
            self.prep.append(cv2.GaussianBlur(img, kernel_size, 0))


class ImgPair:
    """Used to store information in the ShiftGraph class and perform calculations that only include a pair of images.
    
    Attributes:
        indices (tuple): Indices of the two images (img1, img2) in the image list
        img1, img2 (np.ndarray): Two pre-processed images stored as NumPy arrays
        costs (dict): Contains all the cost values calculated for this ImgPair. The indices are tuples and the values floats
        shift (tuple): Contains two integers (x, y) specifying the optimal overlay. None while not yet calculated
    """

    def __init__(self, img_series: ImgSeries, indices: tuple) -> None:
        """Initializes an instance of ImgPair from an ImgSeries.
        
        Args:
            img_series (ImgSeries): The instance of ImgSeries the ShiftGraph will be based on
            indices (tuple): A tuple with two integers (img1, img2) storing the indices of the two images
        """
        self.indices = indices
        self.img1 = img_series.prep[indices[0]]
        self.img2 = img_series.prep[indices[1]]
        self.costs = {}
        self.shift = None

    def __str__(self) -> str:
        return f"img{self.indices[0]} and img{self.indices[1]} with shift {self.shift}"
    
    def get_cost(self, sx: int, sy: int) -> float:
        """
        Computes the cost (MSE) for shifting img2 by (sx, sy) relative to img1.

        Args:
            sx (int): Shift in the x-direction.
            sy (int): Shift in the y-direction.

        Returns:
            float: The computed cost. Returns infinity if no overlap exists.
        """
        if (sx, sy) in self.costs:
            return self.costs[(sx, sy)]

        # Determine overlapping regions based on shift
        img1_overlap, img2_overlap = self._get_overlapping_regions(sx, sy)

        # Check if overlapping regions are valid
        if img1_overlap.size == 0 or img2_overlap.size == 0:
            self.costs[(sx, sy)] = float('inf')
            return self.costs[(sx, sy)]

        # Compute Mean Squared Error (MSE) as cost
        diff = (img1_overlap - img2_overlap) ** 2
        mse = np.mean(diff)
        self.costs[(sx, sy)] = mse
        return mse

    def _get_overlapping_regions(self, sx: int, sy: int) -> tuple:
        """
        Determines the overlapping regions of img1 and shifted img2 based on the shift.

        Args:
            sx (int): Shift in the x-direction.
            sy (int): Shift in the y-direction.

        Returns:
            tuple: Overlapping regions (img1_overlap, img2_overlap).
        """
        # Get dimensions
        height, width = self.img1.shape

        # Initialize coordinates
        if sx >= 0:
            x1_start = 0
            x1_end = width - sx
            x2_start = sx
            x2_end = width
        else:
            x1_start = -sx
            x1_end = width
            x2_start = 0
            x2_end = width + sx

        if sy >= 0:
            y1_start = 0
            y1_end = height - sy
            y2_start = sy
            y2_end = height
        else:
            y1_start = -sy
            y1_end = height
            y2_start = 0
            y2_end = height + sy

        # Slice the overlapping regions
        img1_overlap = self.img1[y1_start:y1_end, x1_start:x1_end]
        img2_overlap = self.img2[y2_start:y2_end, x2_start:x2_end]

        return img1_overlap, img2_overlap
    
    def get_optimal_shift(self) -> tuple[int]:
        """Determines the shift between the two images necessary for optimal overlay by finding the cost function's minimum.
        
        This is done using a multi-scale grid search progress, where the cost function is evaluated at shifts corresponding to a 3x3 grid. The point with the lowest value is the center of the next grid, scaled down by a factor of 2. This is repeated until reaching the 1-pixel scale.

        Returns:
            The determined optimal shift
        """
        best_pos = [0, 0]
        min_val = self.get_cost(0, 0)
        max_step_size = 1 << (max_shift - 1).bit_length() # round max_shift to the nearest power of 2
        for i in range(1, int(max_step_size**0.5)+1):
            # 3x3 grid, size reduced by factor of 2 at each step
            step_size = int(max_step_size/2**i)
            for x in [best_pos[0]-step_size, best_pos[0], best_pos[0]+step_size]:
                for y in [best_pos[1]-step_size, best_pos[1], best_pos[1]+step_size]:
                    self.get_cost(x, y)
                    if self.costs[(x, y)] < min_val:
                        min_val = self.costs[(x, y)]
                        best_pos = [int(x), int(y)]
        self.shift = tuple(best_pos)
        return self.shift
            

class ShiftGraph:
    """Stores information on the relative shifts between images in a series as an adjacency matrix.

    Attributes:
        img_series: The ImgSeries associated with this ShiftGraph
        matrix: The adjacency matrix. The values are either tuples or None
    """

    def __init__(self, img_series: ImgSeries) -> None:
        """Initializes an instance with an empty adjacency matrix based on an ImgSeries.

        Args:
            img_series: An instance of the ImgSeries class
        
        Raises:
            FileNotFoundError: When the list with pre_processed images is empty
        """
        self.imgs = img_series
        self.len = len(self.imgs.prep)
        self.matrix = [[None for _ in range(self.imgs.len)] for _ in range(self.imgs.len)]
        for i in range(self.len):
            self.matrix[i][i] = (0, 0)
    
    def __str__(self) -> str:
        """Prints an adjacency matrix with the shifts for each ImgPair"""
        return str(self.matrix)
    
    def add_shift(self, from_img: int, to_img: int) -> None:
        """Calculates the shift between two images using the ImgPair class. Then, it adds the shift as well as its opposite to the matrix.
        
        Args:
            from_img, to_img (int): indices of the two images in the ImgSeries
        
        Raises:
            IndexError: If one or both of the specified indices are outside the range of ImgSeries
        """
        if from_img > len(self.imgs.prep) or to_img > len(self.imgs.prep):
            raise IndexError("Please enter valid indices.")
        img_pair = ImgPair(self.imgs, (from_img, to_img))
        shift = img_pair.get_optimal_shift()
        self.matrix[from_img][to_img] = shift
        self.matrix[to_img][from_img] = (-shift[0], -shift[1])
    
    def remove_shift(self, from_n: int, to_n: int) -> None:
        """Removes the connections between two given nodes from the graph.
        
        Args:
            from_img, to_img (int): indices of the two images in the ImgSeries
        """
        if from_n in self.graph and to_n in self.graph[from_n]:
            del self.graph[from_n][to_n]
        if to_n in self.graph and from_n in self.graph[to_n]:
            del self.graph[to_n][from_n]
    
    def get_full_graph(self) -> None:
        """Calculates all possible shifts in the graph and adds them to the matrix"""
        nodes = list(range(self.len)) # helper list with [0, 1, ...] of length n for the nested loop
        for i in range(self.len):
            for j in nodes[i+1:]:
                self.add_shift(i, j)
    
    def get_partial_graph(self, n: int = 5) -> None:
        """Calculates some of the shifts in the graph (to save computation time with large datasets)
        
        Args:
            n (int): Number of "reference nodes" in the graph to which all shifts are calculated. Defaults to 5
        """
        if self.len < n + 2:
            raise Exception("The value of n is too small for the length of this dataset.")
        elif n < 3:
            raise Exception("The value of n must be at least 3.")
        nodes = list(range(self.len)) # helper list with [0, 1, ...] of length n for the nested loop
        universal_refs = random.sample(nodes, n) # reference points for all other images
        for i in range(self.len):
            if i != self.len - 1:
                self.add_shift(i, i+1)
            for j in universal_refs:
                if j != i:
                    self.add_shift(i, j)
    
    def get_actual_shift(self, from_n: int, to_n: int) -> tuple:
        """
        Calculates the actual shift between two images using all possible 2-step connections between the two nodes.
        
        Args:
            from_n (int): Index of the source image in the ImgSeries.
            to_n (int): Index of the target image in the ImgSeries.
        
        Returns:
            tuple: The cumulative shift (avg_x, avg_y) between the two images.
        """
        if from_n == to_n:
            return (0, 0)

        shifts = []
        
        # Direct shift
        direct_shift = self.matrix[from_n][to_n]
        if direct_shift is not None:
            shifts.append(direct_shift)
        
        # Indirect shifts via intermediate images
        for intermediate in range(self.len):
            if intermediate == from_n or intermediate == to_n:
                continue
            shift_from_intermediate = self.matrix[from_n][intermediate]
            shift_to_intermediate = self.matrix[to_n][intermediate]
            if shift_from_intermediate is not None and shift_to_intermediate is not None:
                # Cumulative shift: shift_from_intermediate - shift_to_intermediate
                cumulative_shift = (
                    shift_from_intermediate[0] - shift_to_intermediate[0],
                    shift_from_intermediate[1] - shift_to_intermediate[1]
                )
                shifts.append(cumulative_shift)
        
        if not shifts:
            # If no shifts are found, assume no shift
            return (0, 0)
    
        # Calculate the average shift
        avg_x = sum(shift[0] for shift in shifts) / len(shifts)
        avg_y = sum(shift[1] for shift in shifts) / len(shifts)
        
        # Optionally, round the shifts to the nearest integer
        avg_shift = (int(round(avg_x)), int(round(avg_y)))
        
        return avg_shift
    
    def get_overlay_image(self) -> np.ndarray:
        """
        Calculates the overlay image based on the shift graph.
        
        Returns:
            np.ndarray: The shifted and averaged overlay image.
        """
        # Calculate cumulative shifts relative to the first image (index 0)
        cum_shift_list = [self.get_actual_shift(0, i) for i in range(self.len)]
        
        # Determine the overall shift bounds
        min_x = min(shift[0] for shift in cum_shift_list)
        max_x = max(shift[0] for shift in cum_shift_list)
        min_y = min(shift[1] for shift in cum_shift_list)
        max_y = max(shift[1] for shift in cum_shift_list)
        
        # Calculate the size of the final overlay image
        reference_img = self.imgs.plain[0]
        ref_height, ref_width = reference_img.shape
        final_height = ref_height + (max_y - min_y)
        final_width = ref_width + (max_x - min_x)
        
        # Initialize arrays to accumulate pixel values and counts
        overlay_accumulator = np.zeros((final_height, final_width), dtype=np.float32)
        overlay_count = np.zeros((final_height, final_width), dtype=np.float32)
        
        for idx, (sx, sy) in enumerate(cum_shift_list):
            img = self.imgs.plain[idx]
            img_height, img_width = img.shape
            
            # Calculate placement coordinates on the overlay canvas
            y_start = max_y - sy
            y_end = y_start + img_height
            x_start = max_x - sx
            x_end = x_start + img_width
            
            # Handle boundary conditions
            if y_start < 0:
                img = img[-y_start:, :]
                y_start = 0
            if x_start < 0:
                img = img[:, -x_start:]
                x_start = 0
            if y_end > final_height:
                img = img[:final_height - y_start, :]
                y_end = final_height
            if x_end > final_width:
                img = img[:, :final_width - x_start]
                x_end = final_width
            
            # Add the image to the accumulator
            overlay_accumulator[y_start:y_end, x_start:x_end] += img.astype(np.float32)
            # Increment the count
            overlay_count[y_start:y_end, x_start:x_end] += 1.0
        
        # Avoid division by zero
        overlay_count[overlay_count == 0] = 1.0
        # Compute the average
        overlay_average = overlay_accumulator / overlay_count
        # Convert to uint8
        overlay_image = np.clip(overlay_average, 0, 255).astype(np.uint8)
        
        return overlay_image