In [1]:
import os
import numpy as np
import tifffile

In [8]:
def blend_overlap(mask1, mask2, overlap, direction='horizontal'):
    """
    Blend two overlapping masks to handle the overlap.

    Args:
        mask1 (ndarray): First segmentation mask.
        mask2 (ndarray): Second segmentation mask.
        overlap (int): Overlap between the masks.
        direction (str): Direction of the overlap ('horizontal' or 'vertical').

    Returns:
        Blended segmentation mask.
    """
    if direction == 'horizontal':
        blended_mask = mask1.copy()
        for i in range(overlap):
            alpha = i / overlap
            blended_mask[:, -overlap + i] = (1 - alpha) * mask1[:, -overlap + i] + alpha * mask2[:, i]
        return blended_mask
    elif direction == 'vertical':
        blended_mask = mask1.copy()
        for i in range(overlap):
            alpha = i / overlap
            blended_mask[-overlap + i, :] = (1 - alpha) * mask1[-overlap + i, :] + alpha * mask2[i, :]
        return blended_mask

def calculate_image_shape(input_folder, slice_size, overlap):
    """
    Calculate the shape of the original image based on the slices in the input folder.

    Args:
        input_folder (str): Path to the folder containing segmentation mask TIFF files.
        slice_size (tuple): Size of each slice (width, height).
        overlap (int): Overlap between adjacent slices.

    Returns:
        tuple: Shape of the original image (height, width).
    """
    max_right = 0
    max_lower = 0

    for filename in os.listdir(input_folder):
        if filename.endswith('.tif') or filename.endswith('.tiff'):
            parts = os.path.splitext(filename)[0].split('_')
            left, upper, right, lower = map(int, parts[-4:])
            if right > max_right:
                max_right = right
            if lower > max_lower:
                max_lower = lower

    return (max_lower, max_right)

def stitch_masks(input_folder, output_file, slice_size=(512, 512), overlap=64):
    """
    Stitch segmentation masks back together to get one for the original image.

    Args:
        input_folder (str): Path to the folder containing segmentation mask TIFF files.
        output_file (str): Path to the output file where the stitched segmentation mask will be saved.
        slice_size (tuple): Size of each slice (width, height).
        overlap (int): Overlap between adjacent slices.

    Returns:
        None
    """
    # Calculate the shape of the original image
    image_shape = calculate_image_shape(input_folder, slice_size, overlap)

    # Initialize an array to hold the stitched segmentation mask
    stitched_mask = np.zeros(image_shape, dtype=np.uint8)

    # Create a dictionary to store slices by their position
    slices = {}
    for filename in os.listdir(input_folder):
        if filename.endswith('.tif') or filename.endswith('.tiff'):
            input_file = os.path.join(input_folder, filename)
            parts = os.path.splitext(filename)[0].split('_')
            left, upper, right, lower = map(int, parts[-4:])
            slices[(left, upper)] = tifffile.imread(input_file)

    # Place slices into the stitched mask
    for (left, upper), mask in slices.items():
        right = left + slice_size[0]
        lower = upper + slice_size[1]

        if left > 0:
            prev_left = left - slice_size[0] + overlap
            if (prev_left, upper) in slices:
                prev_mask = slices[(prev_left, upper)]
                blended = blend_overlap(prev_mask[:, -overlap:], mask[:, :overlap], overlap, direction='horizontal')
                stitched_mask[upper:lower, left:left + overlap] = blended

        if upper > 0:
            prev_upper = upper - slice_size[1] + overlap
            if (left, prev_upper) in slices:
                prev_mask = slices[(left, prev_upper)]
                blended = blend_overlap(prev_mask[-overlap:, :], mask[:overlap, :], overlap, direction='vertical')
                stitched_mask[upper:upper + overlap, left:right] = blended

        stitched_mask[upper:lower, left:right] = mask

    # Save the stitched segmentation mask
    tifffile.imwrite(output_file, stitched_mask)

# Example usage
input_folder = r'E:\SEGMENTATIONMASKCOMP\GT_membrane'
output_file = r'E:\stitched_mask.tiff'
stitch_masks(input_folder, output_file)