# Image processing and segmentation pipeline

## Imports

In [1]:
import cv2
import shutil
from skimage.filters import roberts, prewitt, threshold_otsu
from skimage.io import imread, imsave
from skimage.morphology import disk, binary_erosion
from skimage.measure import label, regionprops
from skimage.segmentation import clear_border
from skimage.exposure import equalize_adapthist
from scipy import ndimage as ndi
from tqdm import tqdm
import os
import numpy as np
import pandas as pd
import warnings


import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10, 10)


## Setting up constants for directory management

In [2]:
# Change ENV to suitable value, otherwise there could be multiple issues
ENV = "local"  # "local" or "colab"
LOCAL_DIR = os.path.join("path", "to", "local", "dir")
GOOGLEDRIVE_DIR = os.path.join("/content", "drive", "My Drive")
WORKING_DIR = os.path.join("dataset", "COVID-CTset")
DATA_DIR = os.path.join(LOCAL_DIR, WORKING_DIR) if ENV == "local" else os.path.join(
    GOOGLEDRIVE_DIR, WORKING_DIR)
EXPORT_PATH = os.path.join(LOCAL_DIR, "dataset", "processed") if ENV == "local" else os.path.join(
    GOOGLEDRIVE_DIR, "dataset", "processed")


## CSV file generation

Images paths are read from the dataset directory to generate a CSV file containing said paths and their labels for easier processing further down the line

In [24]:
images_dir_list = []
for folder in tqdm(list(set(os.listdir(DATA_DIR)) - {"desktop.ini", "image_path_and_labels.csv", "image_path_and_labels_gdrive.csv", "unprocessed_images_gdrive.pickle"})):
    folder = os.path.join(DATA_DIR, folder)
    for patient in list(set(os.listdir(folder)) - {"desktop.ini"}):
        patient = os.path.join(folder, patient)
        for sr in list(set(os.listdir(patient)) - {"desktop.ini"}):
            sr = os.path.join(patient, sr)
            for image in list(set(os.listdir(sr)) - {"desktop.ini"}):
                if image.endswith(".tif"):
                    images_dir_list.append(os.path.join(sr, image))


100%|██████████| 12/12 [00:21<00:00,  1.77s/it]


In [25]:
images_dir_df = pd.DataFrame(images_dir_list, columns=["images_dir"])
images_dir_df.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 63849 entries, 0 to 63848
Data columns (total 1 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   images_dir  63849 non-null  object
dtypes: object(1)
memory usage: 498.9+ KB


In [26]:
def get_img_lables(list):
    """
    Helper functions to get the labels from the image path

        Parameters
        ----------
        list : list
            List of image paths
    """
    labels = []
    for _, image_path in enumerate(list):
        image_path = os.path.normpath(image_path)
        split_image_path = image_path.split(os.sep)
        if "covid" in split_image_path[-4]:
            labels.append(1)
        else:
            labels.append(0)
    return labels


In [27]:
images_dir_df["class"] = get_img_lables(images_dir_df["images_dir"])
images_dir_df.to_csv(os.path.join(
    DATA_DIR, "image_path_and_labels.csv"), index=False)
images_dir_df.head()


Unnamed: 0,images_dir,class
0,G:My Drive\MSc CS Project\Dataset\COVID-CTset\...,0
1,G:My Drive\MSc CS Project\Dataset\COVID-CTset\...,0
2,G:My Drive\MSc CS Project\Dataset\COVID-CTset\...,0
3,G:My Drive\MSc CS Project\Dataset\COVID-CTset\...,0
4,G:My Drive\MSc CS Project\Dataset\COVID-CTset\...,0


## Segmentation functions

In [3]:
def get_segmented_lung(src_image, plot=False):
    """
    Function to segment lungs from chest CT scans

        Parameters
        ----------
        src_image : numpy array
            Source image
        plot : bool
            Whether to plot the results for each step
    """

    image = src_image.copy()
    if plot == True:
        f, plots = plt.subplots(2, 4, figsize=(40, 20))
        f.tight_layout()
        plots[0, 0].axis('off')
        plots[0, 0].set_title('Original Image')
        plots[0, 0].imshow(image, cmap=plt.cm.bone)

    # Convert to grayscale
    threshold = threshold_otsu(image)
    binary = src_image < threshold
    if plot == True:
        plots[0, 1].axis('off')
        plots[0, 1].set_title('Binary Image')
        plots[0, 1].imshow(binary, cmap=plt.cm.bone)
    
    # Apply border clearing operation
    cleared = clear_border(binary)
    if plot == True:
        plots[0, 2].axis('off')
        plots[0, 2].set_title('Clear Borders')
        plots[0, 2].imshow(cleared, cmap=plt.cm.bone)
    
    # Label images for segmentation
    label_image = label(cleared)

    if plot == True:
        plots[0, 3].axis('off')
        plots[0, 3].set_title('Labelled Image')
        plots[0, 3].imshow(label_image, cmap=plt.cm.bone)

    # Sort labelled areas and get the two highest areas out of the image
    areas = [r.area for r in regionprops(label_image)]
    areas.sort()
    if len(areas) > 2:
        for region in regionprops(label_image):
            if region.area < areas[-2]:
                for coordinates in region.coords:
                    label_image[coordinates[0], coordinates[1]] = 0
    if plot == True:
        plots[1, 0].axis('off')
        plots[1, 0].set_title('Region Finding & Sorting')
        plots[1, 0].imshow(binary, cmap=plt.cm.bone)

    binary = label_image > 0

    # Erosion to remove border pixels that might contain extra information
    binary = binary_erosion(binary, footprint=disk(2))
    if plot == True:
        plots[1, 1].axis('off')
        plots[1, 1].set_title('Image with Erosion')
        plots[1, 1].imshow(binary, cmap=plt.cm.bone)

    # Fill holes using edge detection and region growing
    edges = prewitt(binary)
    binary = ndi.binary_fill_holes(edges)
    if plot == True:
        plots[1, 2].axis('off')
        plots[1, 2].set_title('Region Filling')
        plots[1, 2].imshow(binary, cmap=plt.cm.bone)

    get_high_vals = binary == 0
    image[get_high_vals] = 0
    if plot == True:
        plots[1, 3].axis('off')
        plots[1, 3].set_title('Final binary mask')
        plots[1, 3].imshow(image, cmap=plt.cm.bone)

    return binary


In [5]:
def check_image_size(image):
    """
    Function to check the size of the image based on the dimensions of the 
    mask. If the mask is too small, the image is rejected.
    
        Parameters
        ----------
        image : numpy array
            Source image
    """
    height, width = image.shape
    if height < 120 or width < 120:
        return False
    else:
        if (height / (height + width) > 0.50) or (width / (height + width) > 0.50):
            return True
        else:
            return False


In [8]:
def contrast_adjustment(image, plot=False):
    """
    Function to adjust the contrast of the image

        Parameters
        ----------
        image : numpy array
            Source image
        plot : bool
            Whether to plot the results for each step
    """
    
    if plot == True:
        f, plots = plt.subplots(1, 2, figsize=(20, 10))
        f.tight_layout()

        plots[0].axis('off')
        plots[0].set_title('Original Image')
        plots[0].imshow(image, cmap=plt.cm.bone)
    
    image = equalize_adapthist(image, clip_limit=0.01)

    if plot == True:
        plots[1].axis('off')
        plots[1].set_title('Contrast Adjusted Image')
        plots[1].imshow(image, cmap=plt.cm.bone)
        
    return image


In [31]:
def check_valid_lung(image):
    """
    Check if the lung is valid based on area of the white pixels in the lung mask.
    If the area is less than 10% of the total area of the mask, the image is
    rejected.

        Parameters
        ----------
        image : numpy array
            Source image
    """

    (unique_labels, counts) = np.unique(
        image, return_counts=True)
    unique_labels, counts

    if len(unique_labels) < 2:
        return False

    count_black, count_white = counts[0], counts[1]
    white_amount = count_white / (count_black + count_white)

    if white_amount < 0.10:
        return False
    else:
        return True

In [38]:
def image_read_crop(image_path, plot=False):
    """
    Function to read the image and crop it to the size of the mask.

        Parameters
        ----------
        image_path : str
            Path to the image
        plot : bool
            Whether to plot the results for each step
    """
    
    image = imread(image_path)
    if plot == True:
        f, plots = plt.subplots(2, 2, figsize=(20, 20))
        f.tight_layout()

        plots[0, 0].axis('off')
        plots[0, 0].set_title('Original Image')
        plots[0, 0].imshow(image, cmap=plt.cm.bone)

    seg_image = get_segmented_lung(image)
    if plot == True:
        plots[0, 1].axis('off')
        plots[0, 1].set_title('Segmented Lungs')
        plots[0, 1].imshow(seg_image, cmap=plt.cm.bone)

    lung_validity = check_valid_lung(seg_image)

    if lung_validity == False:
        return None, None

    seg_image = seg_image.astype(np.uint8)

    lungX, lungY = np.where(seg_image != 0)
    top, bottom = lungY.min(), lungY.max()
    left, right = lungX.min(), lungX.max()

    seg_crop = seg_image[left:right, top:bottom]
    if check_image_size(seg_crop) == False:
        return None, None

    masked_image = cv2.bitwise_and(image, image, mask=seg_image)
    if plot == True:
        plots[1, 0].axis('off')
        plots[1, 0].set_title('Masked Image')
        plots[1, 0].imshow(masked_image, cmap=plt.cm.bone)

    lung_crop = image[left:right, top:bottom]
    if plot == True:
        plots[1, 1].axis('off')
        plots[1, 1].set_title('Lungs Cropped')
        plots[1, 1].imshow(lung_crop, cmap=plt.cm.bone)

    return lung_crop


In [33]:
def start_preprocess(images_dir_df, export_path, plot=False):
    """
    Helper function to start preprocessing the images.

        Parameters
        ----------
        images_dir_df : pandas dataframe
            Dataframe containing the image paths and labels
        export_path : str
            Path to the export folder
        plot : bool
            Whether to plot the results for each step
    """
    
    normal_export_path = os.path.join(EXPORT_PATH, "normal")
    covid_export_path = os.path.join(EXPORT_PATH, "covid")
    if not os.path.exists(export_path):
        os.mkdir(export_path)
        os.mkdir(normal_export_path)
        os.mkdir(covid_export_path)
    else:
        shutil.rmtree(export_path)
        os.mkdir(export_path)
        os.mkdir(normal_export_path)
        os.mkdir(covid_export_path)

    for index, row in tqdm(images_dir_df.iterrows(), total=images_dir_df.shape[0]):
        try:
            image_path = row["images_dir"]
            image_label = "covid" if row["class"] == 1 else "normal"
            image = image_read_crop(image_path, plot=plot)
            image = contrast_adjustment(
                image)
            image = image / image.max()
            image = 255 * image
            image = image.astype(
                np.uint8)
            image_name = str(index) + ".png"
            image_path = os.path.join(export_path, image_label, image_name)
            imsave(image_path, image)
                
        except (ValueError, AttributeError) as e:
            pass


## Start preprocessing

In [34]:
start_preprocess(images_dir_df, EXPORT_PATH, plot=False)

100%|██████████| 63849/63849 [1:39:25<00:00, 10.70it/s]  
