<a href="https://colab.research.google.com/github/killercookiee/DeepM/blob/main/DeepM_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/AMIN-HASSAIRI/medical-image-segmentation.git

Cloning into 'medical-image-segmentation'...
remote: Enumerating objects: 1109, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 1109 (delta 1), reused 18 (delta 1), pack-reused 1089 (from 1)[K
Receiving objects: 100% (1109/1109), 2.24 GiB | 31.39 MiB/s, done.
Resolving deltas: 100% (15/15), done.
Updating files: 100% (1061/1061), done.


Understanding the File format

Neuroimaging Informatics Technology Initiative
NIfTI is designed to handle 3D (volume) and 4D (time-series) data, which is common in brain imaging, especially for functional MRI (fMRI) studies
NIfTI files can also be stored in a compressed .nii.gz format to save disk space

4 D --> Height + Width + Depth + Time

We are working with pairs of files. Example
1. Say MRI scan of heart
2. Corresponding segmentation mask file -> i.e. a label that marks different part of the image, like heart chambers, valves, etc

Possibly

patient001_4d.nii.gz -> 4D file (Sequence of 3d images over time), not used for segmentation
patient001_frame01.nii.gz -> A single 2D slice of the 3D MRI data. Image file used for traning the model
patient001_frame01_gt.nii.gz -> It contains labels marking the regions of interest in the image

# Imports

In [3]:
%pip install opencv-python --quiet
%pip install scikit-image --quiet

In [1]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from skimage.transform import resize
import cv2

import shutil

# Original Training Data

In [2]:
def select_training_data(dataset_dir):
    data_anomolies = ['038', '085', '057', '089', '100']

    training_data_list = {}
    for k in ['training']:

        subset_dir = os.path.join(dataset_dir, k)
        training_data_list[k] = []

        for patient in sorted(os.listdir(subset_dir)):

            patient_dir = os.path.join(subset_dir, patient)

            # Skip files that are not directories
            if not os.path.isdir(patient_dir):
                continue

            for file in sorted(os.listdir(patient_dir)):
                if file[-8] == 't' and file[-21:-18] not in data_anomolies:

                    image_name = '{0}/{1}_frame{2}.nii.gz'.format(patient_dir, patient, file[-12:-10])
                    segt_name = '{0}/{1}_frame{2}_gt.nii.gz'.format(patient_dir, patient, file[-12:-10])

                    if os.path.exists(image_name) and os.path.exists(segt_name):
                        training_data_list[k] += [[image_name, segt_name, patient]]

    return training_data_list



In [3]:
original_training_data_list = select_training_data("./medical-image-segmentation/dataset")

In [None]:
original_training_data_list

In [None]:
def select_testing_data(dataset_dir):
    testing_data_list = {}

    for k in ['testing']:
        subset_dir = os.path.join(dataset_dir, k)
        testing_data_list[k] = []

        for patient in sorted(os.listdir(subset_dir)):

            patient_dir = os.path.join(subset_dir, patient)

            # Skip files that are not directories
            if not os.path.isdir(patient_dir):
                continue

            for file in sorted(os.listdir(patient_dir)):
                if file[-8] != 'd' and file[-8] != 'I':

                    image_name = '{0}/{1}_frame{2}.nii.gz'.format(patient_dir, patient, file[-12:-10])
                    segt_name = '{0}/{1}_frame{2}_gt.nii.gz'.format(patient_dir, patient, file[-12:-10])

                    if os.path.exists(image_name) and os.path.exists(segt_name):
                        testing_data_list[k] += [[image_name, segt_name, patient]]

    return testing_data_list


In [None]:
original_testing_data_list = select_testing_data("./medical-image-segmentation/dataset")

In [None]:
original_testing_data_list

# Image Analysis

In [None]:
# Analyzing the data

def analyze_image_properties(training_data_list):
    aspect_ratios = []
    widths = []
    heights = []
    resolutions = []
    pixel_sizes = []
    slice_counts = []

    # Loop through each file in training_data_list (ignoring the segmentation files)
    for data in training_data_list['training']:
        image_file = data[0]  # Use the image file, not the segmentation file

        # Load the NIfTI file
        img = nib.load(image_file)
        img_data = img.get_fdata()
        img_shape = img_data.shape

        # Analyze shape and pixel sizes
        if len(img_shape) >= 2:
            width = img_shape[0]
            height = img_shape[1]
            depth = img_shape[2] if len(img_shape) > 2 else 1  # Number of slices or depth

            # Aspect ratio (width to height)
            aspect_ratio = width / height
            aspect_ratios.append(aspect_ratio)

            # Resolution (total number of pixels in 2D)
            resolution = width * height
            resolutions.append(resolution)

            # Widths and heights
            widths.append(width)
            heights.append(height)

            # Slice count (for 3D images)
            slice_counts.append(depth)

        # Voxel size (physical dimensions of each voxel)
        voxel_size = img.header.get_zooms()  # Tuple (x_size, y_size, z_size)
        pixel_sizes.append(voxel_size)

    return aspect_ratios, widths, heights, resolutions, pixel_sizes, slice_counts


# Analyze the training data
aspect_ratios, widths, heights, resolutions, pixel_sizes, slice_counts = analyze_image_properties(original_training_data_list)

# Plot the results
def plot_analysis(aspect_ratios, widths, heights, resolutions, pixel_sizes, slice_counts):
    fig, ax = plt.subplots(2, 3, figsize=(12, 10))

    # Aspect Ratio
    ax[0, 0].hist(aspect_ratios, bins=20, color='blue', alpha=0.7)
    ax[0, 0].set_title("Aspect Ratio Distribution")
    ax[0, 0].set_xlabel("Aspect Ratio (Width/Height)")
    ax[0, 0].set_ylabel("Frequency")

    # Width
    ax[0, 1].hist(widths, bins=20, color='grey', alpha=0.7)
    ax[0, 1].set_title("Width Distribution")
    ax[0, 1].set_xlabel("Width")
    ax[0, 1].set_ylabel("Frequency")

    # Height
    ax[0, 2].hist(heights, bins=20, color='grey', alpha=0.7)
    ax[0, 2].set_title("Height Distribution")
    ax[0, 2].set_xlabel("Height")
    ax[0, 2].set_ylabel("Frequency")

    # Resolution
    ax[1, 0].hist(resolutions, bins=20, color='green', alpha=0.7)
    ax[1, 0].set_title("Resolution Distribution")
    ax[1, 0].set_xlabel("Resolution (Width x Height)")
    ax[1, 0].set_ylabel("Frequency")

    # Pixel Size
    pixel_sizes_arr = np.array(pixel_sizes)  # Convert list of tuples to numpy array
    ax[1, 1].hist(pixel_sizes_arr[:, 0], bins=20, color='orange', alpha=0.7, label='X-dim')
    ax[1, 1].hist(pixel_sizes_arr[:, 1], bins=20, color='purple', alpha=0.5, label='Y-dim')
    ax[1, 1].set_title("Pixel Size Distribution")
    ax[1, 1].set_xlabel("Pixel Size (in mm)")
    ax[1, 1].set_ylabel("Frequency")
    ax[1, 1].legend()

    # Slice Count (number of slices in the depth dimension)
    ax[1, 2].hist(slice_counts, bins=20, color='red', alpha=0.7)
    ax[1, 2].set_title("Slice Count Distribution")
    ax[1, 2].set_xlabel("Number of Slices")
    ax[1, 2].set_ylabel("Frequency")

    plt.tight_layout()
    plt.show()

# Plot the analysis
plot_analysis(aspect_ratios, widths, heights, resolutions, pixel_sizes, slice_counts)

# Data Standardization

In [None]:
def load_nifti_image(nifti_file):
    """Load NIfTI image and return image data."""
    img = nib.load(nifti_file)
    return img.get_fdata()

def normalize_image(img_data):
    """Normalize the image pixel values to [0, 1] range."""
    img_min = np.min(img_data)
    img_max = np.max(img_data)
    return (img_data - img_min) / (img_max - img_min)

def standardize_image(img_data):
    """Standardize the image to have mean=0 and std=1."""
    mean = np.mean(img_data)
    std = np.std(img_data)
    return (img_data - mean) / std

def pad_to_aspect_ratio(img_data, target_aspect_ratio, padding_value=0):
    """Pad image to match the target aspect ratio."""
    h, w = img_data.shape[:2]
    current_aspect_ratio = w / h
    if current_aspect_ratio == target_aspect_ratio:
        return img_data  # No padding needed

    if current_aspect_ratio < target_aspect_ratio:
        # Pad width
        new_width = int(h * target_aspect_ratio)
        pad_left = (new_width - w) // 2
        img_data = np.pad(img_data, ((0, 0), (pad_left, new_width - w - pad_left)),
                          constant_values=padding_value)
    else:
        # Pad height
        new_height = int(w / target_aspect_ratio)
        pad_top = (new_height - h) // 2
        img_data = np.pad(img_data, ((pad_top, new_height - h - pad_top), (0, 0)),
                          constant_values=padding_value)

    return img_data

def resize_image(img_data, target_size=(256, 256)):
    """Resize the image to the target size."""
    from skimage.transform import resize
    return resize(img_data, target_size, anti_aliasing=True, preserve_range=True)




def combine_slices_into_nifti(img_file, seg_file, save_folder, frame_label):
    """Combine image and segmentation slices into a single 3D NIfTI volume."""

    # Load image and segmentation data
    img_data = nib.load(img_file).get_fdata()
    seg_data = nib.load(seg_file).get_fdata()

    # Ensure they have the same dimensions
    if img_data.shape != seg_data.shape:
        raise ValueError(f"Image and segmentation dimensions do not match for {img_file}.")

    # Create NIfTI image object
    affine = nib.load(img_file).affine  # Re-use affine from original NIfTI file
    combined_nifti = nib.Nifti1Image(img_data, affine)
    combined_seg_nifti = nib.Nifti1Image(seg_data, affine)

    # Save combined 3D NIfTI image
    patient_id = os.path.basename(img_file).split('_')[0]  # Extract patient ID from filename
    output_image_file = os.path.join(save_folder, f'{patient_id}_{frame_label}.nii.gz')
    output_seg_file = os.path.join(save_folder, f'{patient_id}_{frame_label}_gt.nii.gz')

    nib.save(combined_nifti, output_image_file)
    nib.save(combined_seg_nifti, output_seg_file)

def get_second_frame(patient_data):
    """Identify the second frame that is not frame01."""
    for frame in patient_data:
        if "frame01" not in frame[0]:
            return frame
    return None  # In case no second frame is found

def process_patient_data(patient_frames, save_dir):
    """Process and transform the patient data for both frames."""
    img_file_1, seg_file_1, patient_id = patient_frames[0]  # Frame01
    second_frame = get_second_frame(patient_frames)  # Identify the second frame

    if second_frame:
        img_file_2, seg_file_2, _ = second_frame  # The second frame

        # Create patient folder
        patient_folder = os.path.join(save_dir, patient_id)
        os.makedirs(patient_folder, exist_ok=True)

        # Transform both frames and their corresponding segmentation files
        combine_slices_into_nifti(img_file_1, seg_file_1, patient_folder, "frame01")  # Frame01
        combine_slices_into_nifti(img_file_2, seg_file_2, patient_folder, "frame02")  # Second Frame

def create_new_training_dataset(training_data_list, save_dir):
    """Create a new dataset by transforming and combining patient slices."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Group data by patient
    patient_dict = {}
    for img_file, seg_file, patient_id in training_data_list['training']:
        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append((img_file, seg_file, patient_id))

    # Process each patient
    for patient_data in patient_dict.values():
        process_patient_data(patient_data, save_dir)

def create_new_testing_dataset(testing_data_list, save_dir):
    """Create a new dataset by transforming and combining patient slices."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Group data by patient
    patient_dict = {}
    for img_file, seg_file, patient_id in testing_data_list['testing']:
        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append((img_file, seg_file, patient_id))

    # Process each patient
    for patient_data in patient_dict.values():
        process_patient_data(patient_data, save_dir)


# Example usage
training_save_dir = './medical-image-segmentation/new_dataset/training'
create_new_training_dataset(original_training_data_list, training_save_dir)

testing_save_dir = './medical-image-segmentation/new_dataset/testing'
create_new_testing_dataset(original_testing_data_list, testing_save_dir)

# New Standardized Dataset

In [None]:
def select_new_training_data(dataset_dir):
    data_anomolies = ['038', '085', '057', '089', '100']

    training_data_list = {}
    for k in ['training']:

        subset_dir = os.path.join(dataset_dir, k)
        training_data_list[k] = []

        for patient in sorted(os.listdir(subset_dir)):

            patient_dir = os.path.join(subset_dir, patient)

            # Skip files that are not directories
            if not os.path.isdir(patient_dir):
                continue

            for file in sorted(os.listdir(patient_dir)):
                if file[-8] == 't' and file[-21:-18] not in data_anomolies:

                    image_name = '{0}/{1}_frame{2}.nii.gz'.format(patient_dir, patient, file[-12:-10])
                    segt_name = '{0}/{1}_frame{2}_gt.nii.gz'.format(patient_dir, patient, file[-12:-10])

                    if os.path.exists(image_name) and os.path.exists(segt_name):
                        training_data_list[k] += [[image_name, segt_name, patient]]

    return training_data_list



def select_new_testing_data(dataset_dir):
    testing_data_list = {}

    for k in ['testing']:
        subset_dir = os.path.join(dataset_dir, k)
        testing_data_list[k] = []

        for patient in sorted(os.listdir(subset_dir)):

            patient_dir = os.path.join(subset_dir, patient)

            # Skip files that are not directories
            if not os.path.isdir(patient_dir):
                continue

            for file in sorted(os.listdir(patient_dir)):
                if file[-8] != 'd' and file[-8] != 'I':

                    image_name = '{0}/{1}_frame{2}.nii.gz'.format(patient_dir, patient, file[-12:-10])
                    segt_name = '{0}/{1}_frame{2}_gt.nii.gz'.format(patient_dir, patient, file[-12:-10])

                    if os.path.exists(image_name) and os.path.exists(segt_name):
                        testing_data_list[k] += [[image_name, segt_name, patient]]

    return testing_data_list



training_data_list = select_new_training_data("/Users/killercookie/Main file/DeepM/medical-image-segmentation/new_dataset")
testing_data_list = select_new_testing_data("/Users/killercookie/Main file/DeepM/medical-image-segmentation/new_dataset")

In [None]:
#training_data_list

# Data Visualization

In [None]:
import plotly.graph_objects as go  # Graph Objects module comprises visuals like Heatmaps, Figures, etc
from skimage.transform import resize
import nibabel as nib
import matplotlib.pyplot as plt

def load_nifti_image(nifti_file):
    # Load the image using nibabel
    img = nib.load(nifti_file)
    img_data = img.get_fdata()  # Getting the actual image data as a numpy array
    return img_data

# Function to resize image data to square dimensions
def resize_image(img_data, new_size=(128, 128)):  #3D image inputted, and each of its 2D slice is resized
    resized_img = np.zeros((new_size[0], new_size[1], img_data.shape[2]))  ## Create empty 3D array with the target size for each 2D slice
    for z in range(img_data.shape[2]):  # Loop over each slice of the 3D image
        resized_img[:, :, z] = resize(img_data[:, :, z], new_size, anti_aliasing=True)  # Resizing each 2D slice
    return resized_img

# To visualise MRI image and mask side by side
def visualize_2d_image_and_mask(image_data, mask_data, title="MRI and Mask Visualization", new_size=(128, 128)):
    resized_image_data = resize_image(image_data, new_size=new_size)
    resized_mask_data = resize_image(mask_data, new_size=new_size)

    fig = go.Figure()

    # Add MRI Image slices (left)
    for z in range(resized_image_data.shape[2]):
        fig.add_trace(go.Heatmap(
            z=resized_image_data[:, :, z],  # passs the 2D slice
            visible=(z == 0),  # Show first slice by default
            colorscale='Gray', # MRI slices are displayed in grayscale
            zmin=np.min(resized_image_data), zmax=np.max(resized_image_data),  # Dynamic scaling
            name="MRI Image",
            showscale=False,
            xaxis='x1',  # Left side axis
            yaxis='y1'
        ))

    # Add Segmentation mask slices (right)
    for z in range(resized_mask_data.shape[2]):
        fig.add_trace(go.Heatmap(
            z=resized_mask_data[:, :, z],  # Align MRI and Mask slices to same index
            visible=(z == 0),  # Show first slice by default
            colorscale='Reds', # Display masked slices in red color
            zmin=np.min(resized_mask_data), zmax=np.max(resized_mask_data),  # Dynamic scaling
            name="Segmentation Mask",
            showscale=False,
            xaxis='x2',  # Right side axis
            yaxis='y2'
        ))

    # Create step sliders for synchronized MRI and Mask slices
    steps = []
    num_slices = resized_image_data.shape[2]  # Assuming that both MRI and Mask have same number of slices

    for z in range(num_slices):
        slice_step = dict(
            method="update",
            args=[{"visible": [False] * num_slices * 2},  # Initially set all slices to invisible
                  {"title": f"MRI and Mask Slice {z + 1}"}],  # Update title to show slice number
        )
        # Set visibility to True for the corresponding MRI and Mask slices
        slice_step["args"][0]["visible"][z] = True  # Make the MRI image slice visible
        slice_step["args"][0]["visible"][num_slices + z] = True  # Make the mask slice visible
        steps.append(slice_step)

    # Define a single slider for both MRI and Mask
    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Slice: "},  # Display the current slice number
        pad={"t": 50},  # Padding from the top of the plot
        steps=steps  # Synchronized steps for MRI and Mask
    )]

    # Define layout for side-by-side comparison
    fig.update_layout(
        sliders=sliders,
        title=title,
        xaxis=dict(domain=[0, 0.45], title="MRI Image"),
        yaxis=dict(scaleanchor="x", scaleratio=1),  # Ensure aspect ratio is square
        xaxis2=dict(domain=[0.55, 1], title="Segmentation Mask"),
        yaxis2=dict(scaleanchor="x2", scaleratio=1),
        height=500, width=800
    )

    fig.show()

# Function to visualize random training samples dynamically
def visualize_random_samples_side_by_side(data_list, data_type="Training", num_samples=3, new_size=(128, 128)):
    for i in range(num_samples):
        random_idx = random.randint(0, len(data_list) - 1)
        print(f"Visualizing {data_type} Dataset - Sample {random_idx}")
        img_data = load_nifti_image(data_list[random_idx][0])  # Load MRI image data
        mask_data = load_nifti_image(data_list[random_idx][1])  # Load segmentation mask data


        # Making sure dimensions of the image and mask are aligned
        assert img_data.shape == mask_data.shape, "Image and mask dimensions do not match!"

        # Visualize both the 2D MRI image and segmentation mask dynamically, side by side
        visualize_2d_image_and_mask(img_data, mask_data, title=f"{data_type} Sample {random_idx}", new_size=new_size)

# Visualize 3 random training samples from the dataset, side by side
visualize_random_samples_side_by_side(training_data_list['training'], data_type="Training", num_samples=3, new_size=(128, 128))
