In [1]:
import os
import pydicom
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob
import re
import numpy as np
from skimage import exposure
from scipy.ndimage import gaussian_filter


In this notebook we attempt to convert the pydicom images into normalised png images we can then directly use for training. We draw inspiration from [this notebook](https://www.kaggle.com/code/itsuki9180/rsna2024-lsdc-making-dataset).

In [2]:
BASE_URL = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'
EPS = 1e-6
GRAYSCALE = 255
IMAGE_SIZE = (224, 224)

SAGITTAL_START_X = 0.4
SAGITTAL_END_X = 0.8
SAGITTAL_START_Y = 0.1
SAGITTAL_END_Y = 0.9
AXIAL_START_X = 0.3
AXIAL_END_X = 0.7
AXIAL_START_Y = 1.0 - 0.85
AXIAL_END_Y = 1.0 - 0.35

We need to check here to see if we can improve the quality of the images in this conversion (we are using interpolation with cv2.INTER_CUBIC).

In [3]:
train_df = pd.read_csv(BASE_URL + 'train.csv')

In [4]:
len(train_df['study_id'].unique())

1975

In [5]:
train_descriptions = pd.read_csv(BASE_URL + 'train_series_descriptions.csv')

In [6]:
study_ids = train_descriptions['study_id'].unique()
len(study_ids)

1975

In [7]:
description_list = train_descriptions['series_description'].unique()
print(description_list)

['Sagittal T2/STIR' 'Sagittal T1' 'Axial T2']


In [8]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [9]:
def analyze_image(image):
    """Analyze image properties to determine the best normalization method."""
    
    histogram = cv2.calcHist([image], [0], None, [256], [0, 256])
    histogram = histogram.flatten() / histogram.sum()
    
    # Calculate image statistics
    mean = np.mean(image)
    std = np.std(image)
    dynamic_range = image.max() - image.min()
    
    # Calculate histogram entropy
    entropy = -np.sum(histogram * np.log2(histogram + 1e-7))
    
    return {
        'mean': mean,
        'std': std,
        'dynamic_range': dynamic_range,
        'entropy': entropy
    }

def adaptive_normalize_image(image, lower_percentile=1, upper_percentile=99):
    """
    Dynamically choose and apply the best normalization method based on image properties.
    """
    # Ensure image is in float format and in range [0, 1]
    min_val = image.min()
    max_val = image.max()
    eps = 1e-6  # Small epsilon value to avoid division by near-zero
    if max_val - min_val > eps:
        image_float = ((image - min_val) / (max_val - min_val)).astype(float)
    else:
        # If all pixels have very similar values, return a uniform image
        return np.zeros_like(image, dtype=np.uint8)
    
    # Analyze image
    stats = analyze_image((image_float * 255).astype(np.uint8))
    
    # Determine normalization method based on image properties
    if stats['dynamic_range'] < 75:  # Low contrast image
        return exposure.equalize_adapthist(image_float, clip_limit=0.03)
    elif stats['entropy'] < 6:  # Image with limited intensity levels
        return exposure.equalize_hist(image_float)
    elif stats['mean'] < 50 and stats['std'] < 30:  # Dark image with low variance
        # Apply gamma correction
        gamma = 1 - stats['mean'] / 255  # Adaptive gamma
        return exposure.adjust_gamma(image_float, gamma)
    elif stats['mean'] > 150:  # Very bright image
        # Apply logarithmic correction
        return exposure.adjust_log(image_float, 1)
    else:
        # For well-balanced images, apply mild contrast stretching
        p_low, p_high = np.percentile(image_float, (lower_percentile, upper_percentile))
        return exposure.rescale_intensity(image_float, in_range=(p_low, p_high))

# Final normalization function
def normalize_image(image):
    normalized = adaptive_normalize_image(image)
    return (normalized * 255).astype(np.uint8)

In [10]:
def crop_sagittal_images(images, lower_percentile=20, upper_percentile=95, sigma=1, target_width_ratio=0.6, padding_ratio=0.1, edge_threshold=10):
    """
    Detect the x-range and y-range by looking at the brightness of the normalized images,
    center the spine, crop to a target width ratio, add padding, and remove dark edges.
    Returns the cropped images and the start coordinates.
    """
    
    num_images = len(images)
    image = images[num_images // 2]
    assert np.sum(image) > 0
    
    # Determine the horizontal and vertical ranges
    vertical_median = np.median(image, axis=0)
    
    # Detect and remove dark edges
    left_edge = 0
    right_edge = image.shape[1] - 1
    top_edge = 0
    bottom_edge = image.shape[0] - 1
    
    while left_edge < image.shape[1] and np.mean(image[:, left_edge]) < edge_threshold:
        left_edge += 1
    
    while right_edge > 0 and np.mean(image[:, right_edge]) < edge_threshold:
        right_edge -= 1
    
    while top_edge < image.shape[0] and np.mean(image[top_edge, :]) < edge_threshold:
        top_edge += 1
    
    while bottom_edge > 0 and np.mean(image[bottom_edge, :]) < edge_threshold:
        bottom_edge -= 1
    
    # Focus on the central region of the image
    center_x = (left_edge + right_edge) // 2
    central_width = (right_edge - left_edge) // 2
    central_start = max(left_edge, center_x - central_width // 2)
    central_end = min(right_edge, center_x + central_width // 2)
    
    # Use dynamic thresholding based on the central region
    central_vertical_median = vertical_median[central_start:central_end]
    lower_threshold = np.percentile(central_vertical_median, lower_percentile)
    upper_threshold = np.percentile(central_vertical_median, upper_percentile)
    
    smoothed_median = gaussian_filter(central_vertical_median, sigma=sigma)
    
    # Find regions above lower threshold and below upper threshold in the central region
    relevant_indices = np.where((smoothed_median > lower_threshold) & (smoothed_median < upper_threshold))[0]
    
    if len(relevant_indices) == 0:
        # If no relevant region is found, use the full central width
        x_start = central_start
        x_end = central_end
    else:
        x_start = central_start + relevant_indices[0]
        x_end = central_start + relevant_indices[-1]
    
    # Calculate the center of the detected region
    center = (x_start + x_end) // 2
    
    # Calculate the target width (minimum half of the original width)
    target_width = max(int(target_width_ratio * (right_edge - left_edge)), (right_edge - left_edge) // 2)
    
    # Calculate padding
    padding = int(padding_ratio * (right_edge - left_edge))
    
    # Detect and exclude bright noisy regions on the left
    left_region = vertical_median[left_edge:center]
    left_smoothed = gaussian_filter(left_region, sigma=sigma)
    left_bright_threshold = np.percentile(left_smoothed, 98)  # Adjust this percentile as needed
    left_bright_indices = np.where(left_smoothed > left_bright_threshold)[0]
    
    if len(left_bright_indices) > 0:
        # Find the rightmost bright region on the left
        last_bright_left = left_bright_indices[-1] + left_edge
        # Adjust x_start to exclude the bright region
        x_start = max(x_start, last_bright_left)
    
    # Detect and exclude bright noisy regions on the right
    right_region = vertical_median[center:right_edge]
    right_smoothed = gaussian_filter(right_region, sigma=sigma)
    right_bright_threshold = np.percentile(right_smoothed, 98)  # Adjust this percentile as needed
    right_bright_indices = np.where(right_smoothed > right_bright_threshold)[0]
    
    if len(right_bright_indices) > 0:
        # Find the leftmost bright region on the right
        first_bright_right = right_bright_indices[0] + center
        # Adjust x_end to exclude the bright region
        x_end = min(x_end, first_bright_right)
    
    # Ensure the spine is centered in the final crop
    crop_width = min(target_width, x_end - x_start)
    center = (x_start + x_end) // 2
    x_start = max(left_edge, center - crop_width // 2 - padding)
    x_end = min(right_edge, x_start + crop_width + 2 * padding)
    
    # Ensure the crop is at least 1/3 of the original width
    min_crop_width = image.shape[1] // 3
    crop_width = max(crop_width, min_crop_width)
    
    x_start = max(left_edge, center - crop_width // 2 - padding)
    x_end = min(right_edge, x_start + crop_width + 2 * padding)
    
    # If the crop is still too narrow, expand it while keeping the center
    if x_end - x_start < min_crop_width:
        extra_width = min_crop_width - (x_end - x_start)
        x_start = max(left_edge, x_start - extra_width // 2)
        x_end = min(right_edge, x_end + extra_width // 2)
        
        # If we hit the image edges, shift the crop to ensure minimum width
        if x_start == left_edge:
            x_end = min(right_edge, x_start + min_crop_width)
        elif x_end == right_edge:
            x_start = max(left_edge, x_end - min_crop_width)
    # Ensure the crop is at least half the original width
    if x_end - x_start < (right_edge - left_edge) // 2:
        extra_width = ((right_edge - left_edge) // 2) - (x_end - x_start)
        x_start = max(left_edge, x_start - extra_width // 2)
        x_end = min(right_edge, x_end + extra_width // 2)
    
    # Adjust vertical cropping
    y_start = max(top_edge, int(SAGITTAL_START_Y * image.shape[0]))
    y_end = min(bottom_edge, int(SAGITTAL_END_Y * image.shape[0]))
    
    return [image[y_start:y_end, x_start:x_end] for image in images]

In [11]:
def crop_axial_images(images, flipped = [False, False]):
    ''' 
    Cropping Axial images based on robust statistical analysis on the label coordinates.
    Crop images and get the coordinates of the new top left corner 
    Could use optimisation
    Currently does not handle images of different size in the same scan well
    '''
    
    shape = images[0].shape
    
    x_start = round(shape[1] * AXIAL_START_X)
    x_end = round(shape[1] * AXIAL_END_X)
    y_start = round(shape[0] * AXIAL_START_Y)
    y_end = round(shape[0] * AXIAL_END_Y)
    
    return [image[y_start: y_end, x_start: x_end] for image in images]

In [12]:
def extract_z_value(dicom):
    return dicom.ImagePositionPatient[2]

def extract_slice_thickness(dicom):
    return dicom.SliceThickness

In [13]:
def write_axial_images(study_id, destination_path):
    '''
    Write axial images to a destination path
    '''
    series_ids = train_descriptions[(train_descriptions['study_id'] == study_id) & 
                                        (train_descriptions['series_description'] == 'Axial T2')]['series_id']
    
    if len(series_ids) == 0:
        print(f"No images found for Axial study_id: {study_id}")
        return
    
    consolidated_scan = []
    z_values = []
    slice_thicknesses = []
    
    for series_id in series_ids:
        image_paths = glob(f'{BASE_URL}/train_images/{study_id}/{series_id}/*.dcm')
        image_paths.sort(key=natural_keys)
        
        for path in image_paths:
            try:
                dcm = pydicom.dcmread(path)
                slice_thicknesses.append(extract_slice_thickness(dcm))
                z_value = extract_z_value(dcm)
                
                image = normalize_image(dcm.pixel_array)
                if dcm.ImageOrientationPatient[0] < 0:
                    image = np.fliplr(image)
                if dcm.ImageOrientationPatient[4] < 0:
                    image = np.flipud(image)
                
                consolidated_scan.append(image)
                z_values.append(z_value)
            except Exception as e:
                print(f"Error processing image {path}: {str(e)}")
    
    if not consolidated_scan:
        print(f"No images found for Axial study_id: {study_id}")
        return
    
    sorted_indices = np.argsort(z_values)
    consolidated_scan = [consolidated_scan[i] for i in sorted_indices]
    z_values = [z_values[i] for i in sorted_indices]
    
    min_slice_thickness = min(slice_thicknesses)
    filtered_scan = []
    filtered_z_values = []
    
    for i, (image, z_value) in enumerate(zip(consolidated_scan, z_values)):
        if not filtered_z_values or abs(z_value - filtered_z_values[-1]) >= min_slice_thickness / 2:
            filtered_scan.append(image)
            filtered_z_values.append(z_value)

    # Select 10 evenly spaced images from filtered_scan
    num_images = len(filtered_scan)
    if num_images > 10:
        step = len(filtered_scan) / 10.0
        start = len(filtered_scan) / 2.0 - 4.0 * step
        end = len(filtered_scan) + 0.0001
        selected_indices = [max(0, int(i - 0.50001)) for i in np.arange(start, end, step)]
        assert len(selected_indices) == 10
    else:
        print('Not enough images to select 10 evenly spaced images, Axial study_id:', study_id)
        selected_indices = range(num_images)
    
    images = [filtered_scan[i] for i in selected_indices]
    cropped_images = crop_axial_images(images)
    images = [cv2.resize(cropped_image, IMAGE_SIZE[::-1], interpolation = cv2.INTER_CUBIC) for cropped_image in cropped_images]

    normalized_images = [normalize_image(image) for image in images]
    for idx, image in enumerate(normalized_images):
        cv2.imwrite(f'{destination_path}/{idx:02d}.png', image)
    


In [14]:
def write_sagittal_images(study_id, destination_path, series_description):
    '''
    Write sagittal images to a destination path
    '''
    series_ids = train_descriptions[(train_descriptions['study_id'] == study_id) & 
                                    (train_descriptions['series_description'] == series_description)]
    
    if len(series_ids) == 0:
        print(f"No images found for Sagittal study_id: {study_id} series_description: {series_description}")
        return
    
    series_id = series_ids['series_id'].values[0]
    image_paths = glob(f'{BASE_URL}/train_images/{study_id}/{series_id}/*.dcm')
    image_paths.sort(key=natural_keys)
    
    try:
        first_image = pydicom.dcmread(image_paths[0])
        num_images = len(image_paths)

        flipped = [first_image.ImageOrientationPatient[1] < 0, first_image.ImageOrientationPatient[5] > 0]

        images = [normalize_image(pydicom.dcmread(path).pixel_array) for idx, path in enumerate(image_paths)]

        if flipped[0]:
            images = [np.fliplr(image) for image in images]
        if flipped[1]:
            images = [np.flipud(image) for image in images]
    except Exception as e:
        print(f"Error loading Sagittal T2 study_id: {study_id}. Error: {str(e)}")

    if num_images > 10:
        step = len(images) / 10.0
        start = len(images) / 2.0 - 4.0 * step
        end = len(images) + 0.0001
        selected_indices = [max(0, int(i - 0.50001)) for i in np.arange(start, end, step)]
        assert len(selected_indices) == 10
    else:
        print(f'Not enough images to select 10 evenly spaced images, Sagittal study_id: {study_id} series_id: {series_id}')
        selected_indices = range(len(images))
    
    images = [images[i] for i in selected_indices]

    cropped_images = crop_sagittal_images(images)
    normalized_images = [normalize_image(image) for image in cropped_images]
    resized_images = [cv2.resize(image, IMAGE_SIZE[::-1], interpolation = cv2.INTER_CUBIC) for image in normalized_images]
    for idx, image in enumerate(resized_images):
        cv2.imwrite(f'{destination_path}/{idx:02d}.png', image)
    

Dividing the image by scan (see comment in code below), I am guessing to avoid using the uniformative outer images of the saggital scans where you see nothing (although why a big step? Surely you wouldn't want to not use some of the good images... to further explain or improve)

## Loop that converts the images

In [15]:
for study_id in tqdm(study_ids):
    for description in description_list:
        description_ = description.replace(' ', '_').replace('/', '-')

        destination = f'Converted_smaller_images/{study_id}/{description_}'
        os.makedirs(destination, exist_ok=True)
        
        if description == 'Axial T2':
            write_axial_images(study_id, destination)
        elif description == 'Sagittal T2/STIR':
            write_sagittal_images(study_id, destination, description)
        elif description == 'Sagittal T1':
            write_sagittal_images(study_id, destination, description)
        
        

  2%|▏         | 33/1975 [01:30<1:39:34,  3.08s/it]

Not enough images to select 10 evenly spaced images, Sagittal study_id: 82066307 series_id: 2003554076


 17%|█▋        | 344/1975 [15:18<1:36:44,  3.56s/it]

Not enough images to select 10 evenly spaced images, Sagittal study_id: 767443105 series_id: 3268622861


 47%|████▋     | 928/1975 [41:23<33:58,  1.95s/it]

Not enough images to select 10 evenly spaced images, Sagittal study_id: 2053213309 series_id: 2972736368


 57%|█████▋    | 1126/1975 [50:23<32:29,  2.30s/it]

No images found for Sagittal study_id: 2492114990 series_description: Sagittal T1


 64%|██████▍   | 1272/1975 [56:38<33:07,  2.83s/it]

No images found for Sagittal study_id: 2780132468 series_description: Sagittal T1


 70%|██████▉   | 1378/1975 [1:01:40<20:31,  2.06s/it]

No images found for Sagittal study_id: 3008676218 series_description: Sagittal T2/STIR


 71%|███████▏  | 1408/1975 [1:02:59<22:47,  2.41s/it]

Not enough images to select 10 evenly spaced images, Sagittal study_id: 3074144108 series_id: 845353263


 77%|███████▋  | 1517/1975 [1:07:54<18:50,  2.47s/it]

Not enough images to select 10 evenly spaced images, Axial study_id: 3303545110


100%|██████████| 1975/1975 [1:28:18<00:00,  2.68s/it]
