In [2]:
import os
import cv2
import numpy as np

In [29]:
max_shift = 20 # Maximum shift in either x or y direction expected between any two given images of the series. Must be a positive integer

class ImgSeries:
    """Stores a series of SEM images.

    Attributes:
        plain (list): Images without pre-processing stored as NumPy arrays
        prep (list): Pre-processed images stored as NumPy arrays
    """

    def __init__(self, folder_path: str, interval: tuple=None) -> None:
        """Initializes an instance by loading .tiff images from a given folder.

        Args:
            folder_path (str): Path to the folder containing the images
            interval (tuple): Indicates the index of the first and last .tiff file to load (start, stop)
        
        Raises:
            FileNotFoundError: When the indicated folder contains no .tiff files
        """
        self.plain = []
        self.prep = []
        fname_list = [f for f in os.listdir(folder_path) if f.endswith(".tiff")]
        if len(fname_list) == 0:
            raise FileNotFoundError("There are no .tiff files in the indicated folder")
        if interval:
            fname_list = fname_list[interval[0]:interval[1]+1]
        for fname in fname_list:
            print(fname)
            img_path = os.path.join(folder_path, fname)
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            self.plain.append(img)
        self.len = len(self.plain)

    def __str__(self) -> str:
        if len(self.plain) == 0:
            raise FileNotFoundError("ImgSeries is empty")
        return f"ImgSeries containing {len(self.plain)} plain images with size {self.plain[0].shape} and {len(self.prep)} pre-processed images"
    
    def crop(self, amount: tuple=(0, 0, 0, 0)) -> None:
        """Crops the images in self.plain and self.prep by a given amount. Only run before self.prep is created

        Args:
            amount (tuple): Specifies how many pixels are discarded (top, bottom, left, right)

        Raises:
            Exception: If the function is run after pre-processed images have been created
            ValueError: If the crop values are too large
        """
        top, bottom, left, right = amount
        if len(self.prep) > 0:
            raise Exception("Only run this function before creating pre-processed images")
        for i, img in enumerate(self.plain):
            height, width = img.shape[:2]
            if top + bottom >= height or left + right >= width:
                raise ValueError("Crop values are too large for the image size.")
            self.plain[i] = img[top:height-bottom, left:width-right]

    def gaussian_blur(self, kernel_size: tuple):
        """Applies a Gaussian blur to a copy of the images and stores them in self.prep.

        Args:
            kernel_size (tuple): Size of the blur's kernel. Both components must be odd integers
        
        Raises:
            ValueError: When one or both of the kernel size components isn't an odd integer.
        """
        if kernel_size[0] % 2 != 1 or kernel_size[1] % 2 != 1:
            raise ValueError("Both components of kernel_size must be odd integers.")
        for img in self.plain:
            self.prep.append(cv2.GaussianBlur(img, kernel_size, 0))


class ImgPair:
    """Used to store information in the ShiftGraph class and perform calculations that only include a pair of images.
    
    Attributes:
        indices (tuple): Indices of the two images (img1, img2) in the image list
        img1, img2 (np.ndarray): Two pre-processed images stored as NumPy arrays
        costs (dict): Contains all the cost values calculated for this ImgPair. The indices are tuples and the values floats
        shift (tuple): Contains two integers (x, y) specifying the optimal overlay. None while not yet calculated
    """

    def __init__(self, img_series: ImgSeries, indices: tuple) -> None:
        """Initializes an instance of ImgPair from an ImgSeries.
        
        Args:
            img_series (ImgSeries): The instance of ImgSeries the ShiftGraph will be based on
            indices (tuple): A tuple with two integers (img1, img2) storing the indices of the two images
        """
        self.indices = indices
        self.img1 = img_series.prep[indices[0]]
        self.img2 = img_series.prep[indices[1]]
        self.costs = {}
        self.shift = None

    def __str__(self) -> str:
        return f"img{self.indices[0]} and img{self.indices[1]} with shift {self.shift}"
    
    def get_cost(self, sx: int, sy: int) -> float:
        """Adds the value of the cost function for a given shift x, y to the self.costs dictionary and returns it.

        First, the function checks whether the cost has already been calculated.
        If not, it shifts the images by cropping each so that only the common part remains and then calculates the square average.

        Args:
            sx, sy (int): Amount the images are shifted in x and y directions

        Returns:
            The cost value for the given shift
        """
        if (sx, sy) in self.costs:
            return self.costs[(sx, sy)]
        else:
            def shift_image(img: np.ndarray, sx: int, sy: int) -> tuple:
                """
                Shifts img2 relative to img1 by (sx, sy).
                Positive sx shifts right, positive sy shifts down.
                Returns the shifted img2 and corresponding slices for img1.
                """
                # Handle horizontal shift
                if sx > 0:
                    img2_shifted = img[sy:, sx:]
                    img1_cut = img[:-sy if sy !=0 else None, :-sx]
                elif sx < 0:
                    img2_shifted = img[sy:, :sx]
                    img1_cut = img[:-sy if sy !=0 else None, -sx:]
                else:
                    img2_shifted = img[sy:, :]
                    img1_cut = img[:-sy if sy !=0 else None, :]
                
                # Handle vertical shift
                if sy > 0:
                    img2_shifted = img2_shifted
                    img1_cut = img1_cut
                elif sy < 0:
                    img2_shifted = img2_shifted
                    img1_cut = img1_cut
                else:
                    img2_shifted = img2_shifted
                    img1_cut = img1_cut
                
                return img2_shifted, img1_cut

            # Shift img2 relative to img1
            img2_shifted, img1_cut = shift_image(self.img2, sx, sy)
            
            # Debugging: Print shapes after shifting
            print(f"Shift: ({sx}, {sy})")
            print(f"img1_cut shape: {img1_cut.shape}")
            print(f"img2_shifted shape: {img2_shifted.shape}")
            
            # Ensure both images are not empty
            if img1_cut.size == 0 or img2_shifted.size == 0:
                # Assign a high cost if there's no overlap
                self.costs[(sx, sy)] = float('inf')
                return self.costs[(sx, sy)]
            
            # Ensure both images have the same shape
            min_height = min(img1_cut.shape[0], img2_shifted.shape[0])
            min_width = min(img1_cut.shape[1], img2_shifted.shape[1])
            img1_cut = img1_cut[:min_height, :min_width]
            img2_shifted = img2_shifted[:min_height, :min_width]
            
            # Compute the difference and cost
            diff = (img2_shifted - img1_cut) ** 2
            cost = np.average(diff)
            self.costs[(sx, sy)] = cost
            return cost
    
    def get_optimal_shift(self) -> tuple[int]:
        """Determines the shift between the two images necessary for optimal overlay by finding the cost function's minimum.
        
        This is done using a multi-scale grid search progress, where the cost function is evaluated at shifts corresponding to a 3x3 grid. The point with the lowest value is the center of the next grid, scaled down by a factor of 2. This is repeated until reaching the 1-pixel scale.

        Returns:
            The determined optimal shift
        """
        best_pos = [0, 0]
        min_val = self.get_cost(0, 0)
        max_step_size = 1 << (max_shift - 1).bit_length() # round max_shift to the nearest power of 2
        for i in range(1, int(max_step_size**0.5)+1):
            # 3x3 grid, size reduced by factor of 2 at each step
            step_size = int(max_step_size/2**i)
            for x in [best_pos[0]-step_size, best_pos[0], best_pos[0]+step_size]:
                for y in [best_pos[1]-step_size, best_pos[1], best_pos[1]+step_size]:
                    self.get_cost(x, y)
                    if self.costs[(x, y)] < min_val:
                        min_val = self.costs[(x, y)]
                        best_pos = [int(x), int(y)]
        self.shift = tuple(best_pos)
        return self.shift
            

class ShiftGraph:
    """Stores information on the relative shifts between images in a series as an adjacency matrix.

    Attributes:
        img_series: The ImgSeries associated with this ShiftGraph
        matrix: The adjacency matrix. The values are either tuples or None
    """

    def __init__(self, img_series: ImgSeries) -> None:
        """Initializes an instance with an empty adjacency matrix based on an ImgSeries.

        Args:
            img_series: An instance of the ImgSeries class
        
        Raises:
            FileNotFoundError: When the list with pre_processed images is empty
        """
        self.imgs = img_series
        self.len = len(self.imgs.prep)
        self.matrix = [[None for _ in range(self.imgs.len)] for _ in range(self.imgs.len)]
        for i in range(self.len):
            self.matrix[i][i] = (0, 0)
    
    def __str__(self) -> str:
        """Prints an adjacency matrix with the shifts for each ImgPair"""
        return str(self.matrix)
    
    def add_shift(self, from_img: int, to_img: int) -> None:
        """Calculates the shift between two images using the ImgPair class. Then, it adds the shift as well as its opposite to the matrix.
        
        Args:
            from_img, to_img (int): indices of the two images in the ImgSeries
        
        Raises:
            IndexError: If one or both of the specified indices are outside the range of ImgSeries
        """
        if from_img > len(self.imgs.prep) or to_img > len(self.imgs.prep):
            raise IndexError("Please enter valid indices.")
        img_pair = ImgPair(self.imgs, (from_img, to_img))
        shift = img_pair.get_optimal_shift()
        self.matrix[from_img][to_img] = shift
        self.matrix[to_img][from_img] = (-shift[0], -shift[1])
    
    def remove_shift(self, from_n: int, to_n: int) -> None:
        """Removes the connections between two given nodes from the graph.
        
        Args:
            from_img, to_img (int): indices of the two images in the ImgSeries
        """
        if from_n in self.graph and to_n in self.graph[from_n]:
            del self.graph[from_n][to_n]
        if to_n in self.graph and from_n in self.graph[to_n]:
            del self.graph[to_n][from_n]
    
    def get_full_graph(self) -> None:
        "Calculates all possible shifts in the graph and adds them to the matrix"
        nodes = list(range(self.len)) # helper list with [0, 1, ...] of length n for the nested loop
        for i in range(self.len):
            for j in nodes[i+1:]:
                self.add_shift(i, j)
    
    def get_actual_shift(self, from_n: int, to_n: int) -> None:
        """Calculates the actual shift between two images using all possible 2-step connections between the two nodes
        
        Args:
            from_img, to_img (int): indices of the two images in the ImgSeries
        """
        shifts = [self.matrix[from_n][to_n]]
        nodes = list(range(self.len))
        nodes.remove(from_n)
        nodes.remove(to_n)
        for node in nodes:
            if self.matrix[from_n][node] != None and self.matrix[node][to_n] != None:
                x = self.matrix[from_n][node][0] - self.matrix[to_n][node][0]
                y = self.matrix[from_n][node][1] - self.matrix[to_n][node][1]
                shifts.append((x, y))
        avg_x = sum(x for x, y in shifts) / len(shifts)
        avg_y = sum(y for x, y in shifts) / len(shifts)
        return (avg_x, avg_y)
    
    def get_overlay_image(self) -> None:
        """Calculates the overlay image based on a shift graph.
        
        Returns:
            The shifted overlay image as a np.ndarray
        """
        cum_shift_list = [self.get_actual_shift(0, i) for i in range(1, self.len)]
        min_x, max_x = int(min(c[0] for c in cum_shift_list)), int(max(c[0] for c in cum_shift_list))
        min_y, max_y = int(min(c[1] for c in cum_shift_list)), int(max(c[1] for c in cum_shift_list))
        final_imgs = []
        for j, (x_shift, y_shift) in enumerate(cum_shift_list):
            # Calculate cut dimensions directly without intermediate variables
            x_shift, y_shift = int(x_shift), int(y_shift)
            img = self.imgs.plain[j][max_y-y_shift : self.imgs.plain[j].shape[0]-(y_shift-min_y), max_x-x_shift : self.imgs.plain[j].shape[1]-(x_shift-min_x)]
            final_imgs.append(img)
        # Compute and return the average of the final images
        return np.mean(final_imgs, axis=0)

In [30]:
test_series = ImgSeries("C:\\Users\\gioni\\Documents\\Hector-Seminar\\NanoPEACH_phase2\\2024-02-13\\partikel1\\speed2", (0, 9))
test_series.gaussian_blur((13, 13))

test_graph = ShiftGraph(test_series)

test_graph.get_full_graph()

print(test_graph.matrix)

result = test_graph.get_overlay_image()

cv2.imwrite("result_p1_s2_10.png", result)

p1_01.tiff
p1_02.tiff
p1_03.tiff
p1_04.tiff
p1_05.tiff
p1_06.tiff
p1_07.tiff
p1_08.tiff
p1_09.tiff
p1_10.tiff
Shift: (0, 0)
img1_cut shape: (768, 1024)
img2_shifted shape: (768, 1024)
Shift: (-16, -16)
img1_cut shape: (16, 1008)
img2_shifted shape: (16, 1008)
Shift: (-16, 0)
img1_cut shape: (768, 1008)
img2_shifted shape: (768, 1008)
Shift: (-16, 16)
img1_cut shape: (752, 1008)
img2_shifted shape: (752, 1008)
Shift: (0, -16)
img1_cut shape: (16, 1024)
img2_shifted shape: (16, 1024)
Shift: (0, 16)
img1_cut shape: (752, 1024)
img2_shifted shape: (752, 1024)
Shift: (16, -16)
img1_cut shape: (16, 1008)
img2_shifted shape: (16, 1008)
Shift: (16, 0)
img1_cut shape: (768, 1008)
img2_shifted shape: (768, 1008)
Shift: (16, 16)
img1_cut shape: (752, 1008)
img2_shifted shape: (752, 1008)
Shift: (-8, -8)
img1_cut shape: (8, 1016)
img2_shifted shape: (8, 1016)
Shift: (-8, 0)
img1_cut shape: (768, 1016)
img2_shifted shape: (768, 1016)
Shift: (-8, 8)
img1_cut shape: (760, 1016)
img2_shifted shape: (7

True