In [2]:
import os
import cv2

In [63]:
class ImgSeries:
    """Stores a series of SEM images.

    Attributes:
        plain: A list storing images without pre-processing as NumPy arrays
        prep: A list storing pre-processed images as NumPy arrays
    """

    def __init__(self, folder_path: str, interval: tuple=None):
        """Initializes an instance by loading .tiff images from a given folder.

        Args:
            folder_path: A string indicating the folder to use
            interval: A tuple formatted (start, stop) indicating the index of the first and last .tiff file to load
        
        Raises:
            AssertionError: When the indicated folder contains no .tiff files
        """
        self.plain = []
        self.prep = []
        fname_list = [f for f in os.listdir(folder_path) if f.endswith(".tiff")]
        assert len(fname_list) > 0, "no .tiff files in the indicated folder"
        if interval:
            fname_list = fname_list[interval[0]:interval[1]+1] #TODO check whether this works as intended
        for fname in fname_list:
            print(fname)
            img_path = os.path.join(folder_path, fname)
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            self.plain.append(img)
        self.len = len(self.plain)

    def __str__(self) -> str:
        assert len(self.plain) > 0, "ImgSeries is empty"
        return f"ImgSeries containing {len(self.plain)} plain images with size {self.plain[0].shape} and {len(self.prep)} pre-processed images"
    
    def crop(self, amount: tuple=(0, 0, 0, 0)) -> None:
        """Crops the images in self.plain and self.prep by a given amount. Only run before self.prep is created

        Args:
            amount: A tuple specifying how many pixels are discarded (top, bottom, left, right)
        """
        top, bottom, left, right = amount
        assert len(self.prep) == 0, "please only run this function before creating pre-processed images"
        for i, img in enumerate(self.plain):
            height, width = img.shape[:2]
            if top + bottom >= height or left + right >= width:
                raise ValueError("Crop values are too large for the image size.")
            self.plain[i] = img[top:height-bottom, left:width-right]

    def gaussian_blur(self, kernel_size: tuple):
        """Applies a Gaussian blur to a copy of the images and stores them in self.prep.

        Args:
            kernel_size: A tuple specifying the size of the blur's kernel. Both components must be odd integers
        """
        assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1, "both components of kernel_size must be odd integers"
        for img in self.plain:
            self.prep.append(cv2.GaussianBlur(img, kernel_size, 0))


class ImgPair:
    """Used to store information in the ShiftGraph class and perform calculations that only include a pair of images.
    
    Attributes:
        indices: A tuple with two integers (img1, img2) storing the indices of the two images
        img1, img2: Two pre-processed images stored as NumPy arrays
        costs: A dictionary storing all the cost values calculated for this ImgPair. The indices are tuples and the values floats
        shift: A tuple with two integers (x, y) specifying the optimal overlay. None while not yet calculated
    """

    def __init__(self, img_series: ImgSeries, indices: tuple):
        """Initializes an instance of ImgPair from an ImgSeries.
        
        Args:
            img_series: An ImgSeries object
            indices: A tuple with two integers (img1, img2) storing the indices of the two images
        """
        self.indices = indices
        self.img1 = img_series.prep[indices[0]]
        self.img2 = img_series.prep[indices[1]]
        self.costs = {}
        self.shift = None

    def __str__(self) -> str:
        return f"img{self.indices[0]} and img{self.indices[1]} with shift {self.shift}"


class ShiftGraph:
    """Stores information on the relative shifts between images in a series as an adjacency matrix.

    Attributes:
        img_series: The ImgSeries associated with this ShiftGraph
        matrix: The adjacency matrix. The values are either ImgPair objects or None
    """

    def __init__(self, img_series: ImgSeries):
        """Initializes an instance with an empty adjacency matrix based on an ImgSeries.

        Args:
            img_series: An instance of the ImgSeries class
        """
        self.imgs = img_series
        self.matrix = [[None for _ in range(self.imgs.len)] for _ in range(self.imgs.len)]
        assert len(self.imgs.prep) > 0, "list with pre-processed images has not been initialized"
        for i in range(len(self.imgs.prep)):
            for j in range(i+1, len(self.imgs.prep)):
                self.matrix[i][j] = ImgPair(self.imgs, (i, j))
    
    def __str__(self) -> str:
        """Prints an adjacency matrix with the shifts for each ImgPair"""
        matrix_toprint = [row[:] for row in self.matrix]
        for i, row in enumerate(matrix_toprint):
            for j, element in enumerate(row):
                if type(element) is ImgPair:
                    matrix_toprint[i][j] = element.shift
        return str(matrix_toprint)

In [64]:
test_series = ImgSeries("C:\\Users\\gioni\\Documents\\Hector-Seminar\\NanoPEACH_phase2\\2024-02-13\\partikel2\\speed4", (1, 5))
test_series.gaussian_blur((13, 13))

test_graph = ShiftGraph(test_series)
print(test_graph)

p2_10.tiff
p2_11.tiff
p2_12.tiff
p2_13.tiff
p2_14.tiff
[[None, None, None, None, None], [None, None, None, None, None], [None, None, None, None, None], [None, None, None, None, None], [None, None, None, None, None]]
