In [1]:
import os

import glob
import time
import numpy as np
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

from scipy.ndimage import gaussian_filter, binary_fill_holes
from skimage import morphology
from skimage.measure import label, regionprops

from skimage import filters

import matplotlib.pyplot as plt

In [2]:
def load_images(files, resize=(256,256)):
    images = []
    for file in files:
        im = Image.open(file)
        im = ImageOps.exif_transpose(im)
        if resize:
            im = im.resize(resize)
        im = im.convert('L')
        images.append(np.array(im))
    return np.array(images)

In [3]:
BASE_OUTPUT_PATH = 'outputs/segmentation_otsu'

In [4]:
mode = ['images', 'masks']
BASE_PATH_IMGS = os.path.join('data', mode[0])
BASE_PATH_MASKS = os.path.join('data', mode[1])
files = glob.glob(os.path.join(BASE_PATH_IMGS, '*'))
files_masks = glob.glob(os.path.join(BASE_PATH_MASKS, '*'))

In [5]:
images = load_images(files)
masks = load_images(files_masks)

In [13]:
# Verify the binary images and check if the borders marked as 0 or 1, if not, correct it
def verify_binary_image(binary_image):
    if np.median(binary_image[0, :]) == 1:
        binary_image = 1 - binary_image
    return binary_image

In [7]:
def thresholding(image):
    thresh = filters.threshold_otsu(image)
    binary = image > thresh
    binary = morphology.remove_small_objects(binary, min_size=64)
    binary = morphology.remove_small_holes(binary, area_threshold=64)
    return binary

def keep_largest_connected_component(segmented_mask):
    """
    Keep only the largest connected component in a segmented binary mask.

    Parameters:
        segmented_mask (ndarray): Binary mask from segmentation (0s and 1s).
    
    Returns:
        ndarray: Binary mask with only the largest connected component.
    """
    labeled_mask = label(segmented_mask)
    
    largest_component = max(regionprops(labeled_mask), key=lambda r: r.area)
    
    largest_mask = (labeled_mask == largest_component.label).astype(np.uint8)
    return largest_mask


In [19]:
def run_exp(resize, debug=False):
    images = load_images(files, resize=resize)
    masks = load_images(files_masks, resize=resize)

    t1 = time.time()
    for i in range(len(images)):
        name = files[i].split('/')[-1][:-4]
        image = images[i]
        image = gaussian_filter(image, sigma=3)
        binary = thresholding(image)

        binary = 1 - binary

        mask = np.zeros_like(image)
        mask[binary > 0.5] = 1
        mask = verify_binary_image(mask)
        mask = morphology.remove_small_objects(mask.astype('bool'), min_size=64, connectivity=1)
        mask = binary_fill_holes(mask).astype('uint8')
        mask = keep_largest_connected_component(mask)

        mask = Image.fromarray(mask * 255)
        out_full_path = os.path.join(BASE_OUTPUT_PATH, f'{resize[0]}x{resize[1]}')
        if not os.path.exists(out_full_path):
            os.makedirs(out_full_path)
        mask.save(os.path.join(out_full_path, f'{name}.png'))

        if debug:
            plt.figure(figsize=(16, 9))
            plt.subplot(2, 3, 1)
            plt.imshow(image)
            
            plt.subplot(2, 3, 2)
            plt.imshow(masks[i])
            
            plt.subplot(2,3,3)
            plt.imshow(mask)

            plt.show()
    
    t2 = time.time()
    print(f"Took {t2 - t1} seconds to process {len(images)} images.")


In [20]:
run_exp((256, 256))

Took 0.1279003620147705 seconds to process 19 images.


In [21]:
run_exp((64, 64))

Took 0.034329891204833984 seconds to process 19 images.


In [22]:
run_exp((1024, 1024))

Took 1.8242249488830566 seconds to process 19 images.
