# Medical Image Segmentation using AI Models

This notebook demonstrates how to perform whole-body CT segmentation using two different AI models: TotalSegmentator and MedIm.

---

## Part 1: Data Preparation

This section handles the download and extraction of the dataset required for the segmentation tasks.

This cell installs the necessary libraries, downloads a sample CT scan dataset from a Google Drive link using `gdown`, and then unzips the downloaded archive. The dataset contains CT images and corresponding segmentations.

In [None]:
!pip install gdown
!gdown --id 1no_qpQSWioNIu5CChE0bnNldkiY63dgg
!unzip CT_subset_big.zip

!pip install --quiet nibabel numpy scipy scikit-image matplotlib totalsegmentator

---

## Part 2: TotalSegmentator

This section utilizes the TotalSegmentator library, a widely used tool for automatic segmentation of various organs and structures in CT scans.

This cell performs a comprehensive total body segmentation on the downloaded CT scan using the `totalsegmentator` library. It first sets up the input and output paths, then cleans up any previous output directory. The `totalsegmentator` function is called with options to save individual organ segmentations and use the more accurate segmentation mode. Finally, it includes a function to remove any generated segmentation files that are found to be completely empty.

In [None]:
import shutil
from totalsegmentator.python_api import totalsegmentator
import os
import nibabel as nib
import numpy as np

# Define the path to the input CT scan file
input_file = '/content/s0010/ct.nii.gz'

# Set the name for the output directory where segmentations will be saved
output_dir_total = "TotalSegmentator_Output"

# Clean up the output directory if it already exists
if os.path.exists(output_dir_total):
    shutil.rmtree(output_dir_total)
    print(f"Previous output directory removed: {output_dir_total}")
else:
    print(f"Output directory not found, proceeding with creation: {output_dir_total}")

# Create the output directory to store the segmentation results
os.makedirs(output_dir_total, exist_ok=True)

# Execute TotalSegmentator to perform segmentation on the input scan
print(f"Initiating TotalSegmentator for comprehensive segmentation...")
totalsegmentator(
    input=input_file,
    output=output_dir_total,
    ml=False,               # Option to save individual organ segmentations
    task="total",           # Specify the task to segment all structures
    fast=False,             # Use the more accurate segmentation mode
    preview=False,          # Disable quick preview segmentation
    output_type="niftigz"   # Ensure output files are in compressed NIfTI format
)
print(f"Segmentation process completed! Results are in: {output_dir_total}")

# Optional step: Remove any generated segmentation files that are completely empty
def filter_empty_files(output_dir):
    print("\nChecking for and removing empty segmentation files...")
    for root, dirs, files in os.walk(output_dir):
        for file in files:
            if file.endswith('.nii.gz'):
                file_path = os.path.join(root, file)
                try:
                    img = nib.load(file_path)
                    data = img.get_fdata()
                    if np.sum(data) == 0:
                        print(f"Identified and removing empty file: {file_path}")
                        os.remove(file_path)
                except Exception as e:
                    print(f"An error occurred while processing file {file_path}: {e}")

filter_empty_files(output_dir_total)

This cell compresses the output directory generated by TotalSegmentator into a zip file named `TotalSegmentator_Output.zip`. It then provides a download link for this zip file, allowing you to easily access the segmentation results on your local machine.

In [None]:
from google.colab import files
import shutil

# Zip the output directory
shutil.make_archive(output_dir_total, 'zip', output_dir_total)

# Download the zipped directory
files.download(f'{output_dir_total}.zip')

* * *

## Part 3: Wholebody CT Segmentation Model

This section uses the wholebody CT segmentation pre-trained model from the MONAI Model Zoo. It covers the steps of downloading and loading the model, preprocessing the input image, performing the segmentation inference, and saving the resulting organ masks.

In [None]:
# Import necessary libraries and modules for the "wholebody ct segemntation" AI Model
!pip install -q monai[all]
!pip install -q nibabel

import os
import torch
import nibabel as nib
import numpy as np
from google.colab import files
import shutil
from monai.bundle import download, load
from monai.transforms import (
    LoadImage,
    EnsureChannelFirst,
    ScaleIntensity,
    Orientation,
    CropForeground,
    DivisiblePad,
    EnsureType
)

# Download the MONAI whole body CT segmentation model
bundle_name = "wholeBody_ct_segmentation"

# Download the model bundle (from MONAI Model Zoo)
bundle_dir = download(name=bundle_name, source="github", progress=True)
print(f"Model download complete: {bundle_dir}")

# Load the downloaded model
model = load(name=bundle_name, bundle_dir=bundle_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"Using device: {device}")

# Specify the path to the input CT scan file
input_path = "/content/s0010/ct.nii.gz"

# Verify the input file exists
if not os.path.exists(input_path):
    print(f"Input file not found: {input_path}")
    print("Ensure your CT scan is located at /content/s0010/ct.nii.gz")
    print("\nFile upload options:")
    print("1. Direct upload:")
    print("   !mkdir -p /content/s0010")
    print("   from google.colab import files")
    print("   uploaded = files.upload()")
    print("   !mv uploaded_file.nii.gz /content/s0010/ct.nii.gz")
    print("\n2. Mount Google Drive:")
    print("   from google.colab import drive")
    print("   drive.mount('/content/drive')")
    print("   !cp /content/drive/MyDrive/path/to/ct.nii.gz /content/s0010/ct.nii.gz")
else:
    print(f"Input file found: {input_path}")

# Define and create the output directory
output_dir = "/content/wholebody_ct_segmentation"
os.makedirs(output_dir, exist_ok=True)

# Load and preprocess the input image
loader = LoadImage(image_only=True)
image = loader(input_path)

# Define image transformations
transforms = [
    EnsureChannelFirst(),
    ScaleIntensity(),
    Orientation(axcodes="RAS"),
    CropForeground(),
    DivisiblePad(k=96),
    EnsureType()
]

# Apply transformations to the image
for t in transforms:
    image = t(image)

image = image.unsqueeze(0).float().to(device)
print(f"Processed image shape: {image.shape}")

# Perform segmentation using sliding window inference
from monai.inferers import sliding_window_inference

print("Performing segmentation with sliding window...")
print("This process may take some time due to memory constraints.")

# Clear GPU cache before inference
torch.cuda.empty_cache()

# Run inference and store output on CPU
with torch.no_grad():
    output = sliding_window_inference(
        inputs=image,
        roi_size=(96, 96, 96),      # Patch size for processing
        sw_batch_size=1,             # Batch size for sliding window
        predictor=model,
        overlap=0.5,                 # Overlap between patches
        mode="gaussian",             # Blending mode
        device=torch.device("cpu")   # Output device
    )
    output = torch.argmax(output, dim=1).cpu().numpy().squeeze()

# Clear GPU cache after inference
torch.cuda.empty_cache()

print(f"Segmentation inference complete. Output shape: {output.shape}")

# Define organ labels mapping
organ_labels = {
    1: "spleen",
    2: "kidney_right",
    3: "kidney_left",
    4: "gallbladder",
    5: "liver",
    6: "stomach",
    7: "aorta",
    8: "inferior_vena_cava",
    9: "portal_vein_and_splenic_vein",
    10: "pancreas",
    11: "adrenal_gland_right",
    12: "adrenal_gland_left",
    13: "lung_upper_lobe_left",
    14: "lung_lower_lobe_left",
    15: "lung_upper_lobe_right",
    16: "lung_middle_lobe_right",
    17: "lung_lower_lobe_right",
    18: "vertebrae_L5",
    19: "vertebrae_L4",
    20: "vertebrae_L3",
    21: "vertebrae_L2",
    22: "vertebrae_L1",
    23: "vertebrae_T12",
    24: "vertebrae_T11",
    25: "vertebrae_T10",
    26: "vertebrae_T9",
    27: "vertebrae_T8",
    28: "vertebrae_T7",
    29: "vertebrae_T6",
    30: "vertebrae_T5",
    31: "vertebrae_T4",
    32: "vertebrae_T3",
    33: "vertebrae_T2",
    34: "vertebrae_T1",
    35: "vertebrae_C7",
    36: "vertebrae_C6",
    37: "vertebrae_C5",
    38: "vertebrae_C4",
    39: "vertebrae_C3",
    40: "vertebrae_C2",
    41: "vertebrae_C1",
    42: "esophagus",
    43: "trachea",
    44: "heart_myocardium",
    45: "heart_atrium_left",
    46: "heart_ventricle_left",
    47: "heart_atrium_right",
    48: "heart_ventricle_right",
    49: "pulmonary_artery",
    50: "brain",
    51: "iliac_artery_left",
    52: "iliac_artery_right",
    53: "iliac_vena_left",
    54: "iliac_vena_right",
    55: "small_bowel",
    56: "duodenum",
    57: "colon",
    58: "rib_left_1",
    59: "rib_left_2",
    60: "rib_left_3",
    61: "rib_left_4",
    62: "rib_left_5",
    63: "rib_left_6",
    64: "rib_left_7",
    65: "rib_left_8",
    66: "rib_left_9",
    67: "rib_left_10",
    68: "rib_left_11",
    69: "rib_left_12",
    70: "rib_right_1",
    71: "rib_right_2",
    72: "rib_right_3",
    73: "rib_right_4",
    74: "rib_right_5",
    75: "rib_right_6",
    76: "rib_right_7",
    77: "rib_right_8",
    78: "rib_right_9",
    79: "rib_right_10",
    80: "rib_right_11",
    81: "rib_right_12",
    82: "humerus_left",
    83: "humerus_right",
84: "scapula_left",
    85: "scapula_right",
    86: "clavicula_left",
    87: "clavicula_right",
    88: "femur_left",
    89: "femur_right",
    90: "hip_left",
    91: "hip_right",
    92: "sacrum",
    93: "face",
    94: "gluteus_maximus_left",
    95: "gluteus_maximus_right",
    96: "gluteus_medius_left",
    97: "gluteus_medius_right",
    98: "gluteus_minimus_left",
    99: "gluteus_minimus_right",
    100: "autochthon_left",
    101: "autochthon_right",
    102: "iliopsoas_left",
    103: "iliopsoas_right",
    104: "urinary_bladder"
}

# Save individual organ masks
nii = nib.load(input_path)
affine, header = nii.affine, nii.header

saved_count = 0
for label_id, organ_name in organ_labels.items():
    mask = (output == label_id).astype(np.uint8)

    if np.sum(mask) == 0:
        continue  # Skip empty organs

    seg_img = nib.Nifti1Image(mask, affine, header)
    out_file = os.path.join(output_dir, f"{organ_name}.nii.gz")
    nib.save(seg_img, out_file)
    saved_count += 1
    print(f"Saved: {organ_name}.nii.gz")

print(f"\nSegmentation process finished. {saved_count} organ masks were saved.")

# Download the results as a ZIP file
!zip -r segmentation_results.zip {output_dir}
files.download('segmentation_results.zip')

# Option to download to Google Drive (uncomment to use)
# import shutil
# drive_output = "/content/drive/MyDrive/segmentation_output"
# shutil.copytree(output_dir, drive_output)
# print(f"Results copied to Google Drive: {drive_output}")

---

## Part 4: MedIm Segmentation

This section explores using the MedIm library for whole-body CT segmentation, showcasing a different model and workflow.

This cell installs the `medim` library and its dependencies, then attempts to perform segmentation using a pre-trained STU-Net model from `medim`. It defines a mapping for anatomical structures, preprocesses the input CT scan by clipping, normalizing, resampling, and resizing. It then runs inference using the loaded model and postprocesses the output to save individual organ masks as NIfTI files with descriptive names, as well as a combined multi-label segmentation file. Error handling is included for model loading and processing steps, with a fallback to a basic STU-Net model if the primary one fails.

In [None]:
# Install MedIM and required dependencies
!pip install medim
!pip install nibabel numpy torch torchvision
!pip install monai  # For medical image preprocessing utilities

# TotalSegmentator organ class mapping (104 classes + background)
# Based on TotalSegmentator v1 (104 anatomical structures)
TOTALSEGMENTATOR_CLASS_MAPPING = {
    0: "background",
    1: "spleen",
    2: "kidney_right",
    3: "kidney_left",
    4: "gallbladder",
    5: "liver",
    6: "stomach",
    7: "aorta",
    8: "inferior_vena_cava",
    9: "portal_vein_and_splenic_vein",
    10: "pancreas",
    11: "adrenal_gland_right",
    12: "adrenal_gland_left",
    13: "lung_upper_lobe_left",
    14: "lung_lower_lobe_left",
    15: "lung_upper_lobe_right",
    16: "lung_middle_lobe_right",
    17: "lung_lower_lobe_right",
    18: "vertebrae_L5",
    19: "vertebrae_L4",
    20: "vertebrae_L3",
    21: "vertebrae_L2",
    22: "vertebrae_L1",
    23: "vertebrae_T12",
    24: "vertebrae_T11",
    25: "vertebrae_T10",
    26: "vertebrae_T9",
    27: "vertebrae_T8",
    28: "vertebrae_T7",
    29: "vertebrae_T6",
    30: "vertebrae_T5",
    31: "vertebrae_T4",
    32: "vertebrae_T3",
    33: "vertebrae_T2",
    34: "vertebrae_T1",
    35: "vertebrae_C7",
    36: "vertebrae_C6",
    37: "vertebrae_C5",
    38: "vertebrae_C4",
    39: "vertebrae_C3",
    40: "vertebrae_C2",
    41: "vertebrae_C1",
    42: "esophagus",
    43: "trachea",
    44: "heart_myocardium",
    45: "heart_atrium_left",
    46: "heart_ventricle_left",
    47: "heart_ventricle_right",
    48: "heart_atrium_right",
    49: "pulmonary_artery",
    50: "brain",
    51: "iliac_artery_left",
    52: "iliac_artery_right",
    53: "iliac_vena_left",
    54: "iliac_vena_right",
    55: "small_bowel",
    56: "duodenum",
    57: "colon",
    58: "rib_left_1",
    59: "rib_left_2",
    60: "rib_left_3",
    61: "rib_left_4",
    62: "rib_left_5",
    63: "rib_left_6",
    64: "rib_left_7",
    65: "rib_left_8",
    66: "rib_left_9",
    67: "rib_left_10",
    68: "rib_left_11",
    69: "rib_left_12",
    70: "rib_right_1",
    71: "rib_right_2",
    72: "rib_right_3",
    73: "rib_right_4",
    74: "rib_right_5",
    75: "rib_right_6",
    76: "rib_right_7",
    77: "rib_right_8",
    78: "rib_right_9",
    79: "rib_right_10",
    80: "rib_right_11",
    81: "rib_right_12",
    82: "humerus_left",
    83: "humerus_right",
    84: "scapula_left",
    85: "scapula_right",
    86: "clavicula_left",
    87: "clavicula_right",
    88: "femur_left",
    89: "femur_right",
    90: "hip_left",
    91: "hip_right",
    92: "sacrum",
    93: "face",
    94: "gluteus_maximus_left",
    95: "gluteus_maximus_right",
    96: "gluteus_medius_left",
    97: "gluteus_medius_right",
    98: "gluteus_minimus_left",
    99: "gluteus_minimus_right",
    100: "autochthon_left",
    101: "autochthon_right",
    102: "iliopsoas_left",
    103: "iliopsoas_right",
    104: "urinary_bladder"
}

import os
import shutil
import torch
import numpy as np
import nibabel as nib
from scipy import ndimage
import medim

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Input CT scan
input_file = '/content/s0010/ct.nii.gz'

# Define output directory
output_dir_medim = "MedIm"

# Remove existing output directory if it exists
if os.path.exists(output_dir_medim):
    shutil.rmtree(output_dir_medim)
    print(f"Removed directory: {output_dir_medim}")

# Create output directory
os.makedirs(output_dir_medim, exist_ok=True)

def preprocess_ct_scan(input_file, target_spacing=(1.5, 1.5, 1.5), target_size=(128, 128, 128)):
    """
    Preprocess CT scan for MedIM STU-Net model
    """
    print("Loading and preprocessing CT scan...")

    # Load NIfTI file
    img = nib.load(input_file)
    data = img.get_fdata().astype(np.float32)

    # Get original spacing and affine
    original_spacing = img.header.get_zooms()[:3]
    affine = img.affine

    print(f"Original shape: {data.shape}")
    print(f"Original spacing: {original_spacing}")

    # Clip HU values (typical CT range)
    data = np.clip(data, -1024, 1024)

    # Normalize to [0, 1] range
    data = (data + 1024) / 2048.0

    # Resample to target spacing if needed
    if original_spacing != target_spacing:
        zoom_factors = [orig/target for orig, target in zip(original_spacing, target_spacing)]
        data = ndimage.zoom(data, zoom_factors, order=1, mode='nearest')
        print(f"Resampled shape: {data.shape}")

    # Resize to target size for model input
    current_shape = data.shape
    resize_factors = [target/current for target, current in zip(target_size, current_shape)]
    data_resized = ndimage.zoom(data, resize_factors, order=1, mode='nearest')

    print(f"Final preprocessed shape: {data_resized.shape}")

    # Convert to tensor and add batch and channel dimensions
    data_tensor = torch.from_numpy(data_resized).float()
    data_tensor = data_tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W, D)

    return data_tensor, data, current_shape, affine, img.header

def postprocess_segmentation(pred_tensor, original_shape, original_affine, original_header, output_dir):
    """
    Postprocess segmentation results and save individual organ masks with proper names
    """
    print("Postprocessing segmentation results...")

    # Remove batch dimension and convert to numpy
    pred_np = pred_tensor.squeeze(0).cpu().numpy()  # (num_classes, H, W, D)

    # Get number of classes
    num_classes = pred_np.shape[0]
    print(f"Number of segmented classes: {num_classes}")

    # Resize back to original CT scan dimensions
    resize_factors = [orig/current for orig, current in zip(original_shape, pred_np.shape[1:])]

    # Create a multi-label segmentation volume
    multi_label_seg = np.zeros(original_shape, dtype=np.uint16)

    # Keep track of saved organs
    saved_organs = []

    # Process each class
    for class_idx in range(num_classes):
        if class_idx == 0:  # Skip background class
            continue

        # Get class probability map
        class_prob = pred_np[class_idx]

        # Resize to original dimensions
        class_resized = ndimage.zoom(class_prob, resize_factors, order=1, mode='nearest')

        # Threshold to get binary mask (adjust threshold as needed)
        class_mask = (class_resized > 0.5).astype(np.uint16)

        # Add to multi-label volume
        multi_label_seg[class_mask > 0] = class_idx

        # Save individual organ mask with proper name
        if np.sum(class_mask) > 0:  # Only save non-empty masks
            # Get organ name from mapping
            organ_name = TOTALSEGMENTATOR_CLASS_MAPPING.get(class_idx, f"unknown_class_{class_idx}")

            organ_img = nib.Nifti1Image(class_mask, original_affine, original_header)
            organ_filename = os.path.join(output_dir, f"{organ_name}.nii.gz")
            nib.save(organ_img, organ_filename)

            voxel_count = np.sum(class_mask)
            saved_organs.append((organ_name, voxel_count))
            print(f"Saved {organ_name}: {voxel_count} voxels")

    # Save multi-label segmentation
    multi_label_img = nib.Nifti1Image(multi_label_seg, original_affine, original_header)
    multi_label_filename = os.path.join(output_dir, "segmentations.nii.gz")
    nib.save(multi_label_img, multi_label_filename)
    print(f"Saved multi-label segmentation: {multi_label_filename}")

    # Print summary of saved organs
    print(f"\nSegmentation Summary:")
    print(f"Total organs segmented: {len(saved_organs)}")
    print("Organs found:")
    for organ_name, voxel_count in sorted(saved_organs, key=lambda x: x[1], reverse=True):
        print(f"  - {organ_name}: {voxel_count:,} voxels")

    return multi_label_seg

def run_medim_segmentation(input_file, output_dir):
    """
    Main function to run MedIM segmentation
    """
    print("Creating MedIM model...")

    # Create STU-Net model pre-trained on TotalSegmentator dataset
    # This should give similar results to TotalSegmentator
    model = medim.create_model("STU-Net-B", dataset="TotalSegmentator")
    model = model.to(device)
    model.eval()

    print("Model loaded successfully!")

    # Preprocess input
    input_tensor, original_data, original_shape, affine, header = preprocess_ct_scan(input_file)
    input_tensor = input_tensor.to(device)

    print("Running inference...")
    with torch.no_grad():
        # Run inference
        output = model(input_tensor)

        # Apply softmax to get probabilities
        if len(output.shape) == 5:  # (batch, classes, H, W, D)
            output = torch.softmax(output, dim=1)

    print("Inference completed!")

    # Postprocess and save results
    segmentation_result = postprocess_segmentation(
        output, original_shape, affine, header, output_dir
    )

    return segmentation_result

# Run the segmentation
print(f"Running MedIM segmentation on: {input_file}")
try:
    segmentation_result = run_medim_segmentation(input_file, output_dir_medim)
    print(f"Segmentation completed! Results saved in: {output_dir_medim}")

    # List generated files with proper names
    print("\nGenerated organ files:")
    organ_files = []
    for file in os.listdir(output_dir_medim):
        if file.endswith('.nii.gz') and file != 'segmentations.nii.gz':
            file_path = os.path.join(output_dir_medim, file)
            img = nib.load(file_path)
            data = img.get_fdata()
            non_zero_voxels = np.count_nonzero(data)
            organ_files.append((file.replace('.nii.gz', ''), non_zero_voxels))

    # Sort by voxel count (largest organs first)
    organ_files.sort(key=lambda x: x[1], reverse=True)
    for organ_name, voxel_count in organ_files:
        print(f"  {organ_name}: {voxel_count:,} voxels")

    print(f"\nMain segmentation file: segmentations.nii.gz")
    print(f"Total files created: {len(organ_files) + 1}")

except Exception as e:
    print(f"Error during segmentation: {e}")
    print("This might be due to:")
    print("1. Model not available for the specified dataset")
    print("2. Input image format issues")
    print("3. Memory constraints")

    # Fallback: try with a different model configuration
    print("\nTrying with basic STU-Net model...")
    try:
        model = medim.create_model("STU-Net-S")
        model = model.to(device)
        model.eval()
        print("Basic model loaded - you may need to adapt the preprocessing/postprocessing")
    except Exception as e2:
        print(f"Fallback also failed: {e2}")

# Note: Empty files are NOT removed - all organ files are preserved
# even if they have zero segmented voxels (different from TotalSegmentator behavior)
print("Processing complete!")