In [1]:
#
# Code based on https://www.kaggle.com/code/paulbacher/custom-preprocessor-rsna-breast-cancer
#

import os
import time
import numpy as np
import pandas as pd
import cv2
import pydicom
import dicomsdl
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
from joblib import Parallel, delayed
import gc

In [2]:
csv_path = '/mnt/striped/kaggle/rsna_bcd/data/train.csv'
train_path = '/mnt/striped/kaggle/rsna_bcd/data/train_images'
data = pd.read_csv(csv_path)

In [3]:
def get_paths(n: int=len(data), shuffle: bool=False):
    if shuffle == True:
        df = data.sample(frac=1, random_state=0)
    else:
        df = data
    paths = []
    ids_cache = []
    for i in range(n):
        patient = str(df.iloc[i]['patient_id'])
        scan = str(df.iloc[i]['image_id'])
        paths.append(train_path + '/' + patient + '/' + scan + '.dcm')
        ids_cache.append({'patient_id': patient, 'scan_id': scan})
    return paths, ids_cache


# Example
paths, _ = get_paths(n=5, shuffle=True)
paths

['/mnt/striped/kaggle/rsna_bcd/data/train_images/58224/132390955.dcm',
 '/mnt/striped/kaggle/rsna_bcd/data/train_images/21809/1307476428.dcm',
 '/mnt/striped/kaggle/rsna_bcd/data/train_images/58351/1689606258.dcm',
 '/mnt/striped/kaggle/rsna_bcd/data/train_images/60826/1465194139.dcm',
 '/mnt/striped/kaggle/rsna_bcd/data/train_images/1250/1329687627.dcm']

In [5]:
def calculate_aspect_ratios(paths: list, preprocessor=None):
    ratios = []
    for i in trange(len(paths)):
        if preprocessor:
            img = preprocessor.preprocess_single_image(paths[i])
        else:
            scan = pydicom.dcmread(paths[i])
            img = scan.pixel_array
        height, width = img.shape
        ratio = height / width
        ratios.append(ratio)
    return ratios


# Example
ratios = calculate_aspect_ratios(paths)
print("Ratios:", ratios)
print("Min:", np.min(ratios))
print("Max:", np.max(ratios))
print("Avg:", np.mean(ratios))

  0%|          | 0/5 [00:00<?, ?it/s]

Ratios: [1.3, 1.1985370950888192, 1.2307692307692308, 1.279030910609858, 1.0895218718209563]
Min: 1.0895218718209563
Max: 1.3
Avg: 1.219571821657773


In [9]:
def image_resize(image, width = None, height = None, inter = cv2.INTER_NEAREST):
    # initialize the dimensions of the image to be resized and
    # grab the image size
    dim = None
    (h, w) = image.shape[:2]

    # if both the width and height are None, then return the
    # original image
    if width is None and height is None:
        return image

    # check to see if the width is None
    if width is None:
        # calculate the ratio of the height and construct the
        # dimensions
        r = height / float(h)
        dim = (int(w * r), height)

    # otherwise, the height is None
    else:
        # calculate the ratio of the width and construct the
        # dimensions
        r = width / float(w)
        dim = (width, int(h * r))

    # resize the image
    resized = cv2.resize(image, dim, interpolation = inter)

    # return the resized image
    return resized

class MammographyPreprocessor():
    
    # Constructor
    def __init__(self, size: tuple=None, breast_side: str='L',
                 csv_path=None, train_path=None, save_root=None):
        self.size = size
        os.makedirs(os.getcwd(), exist_ok=True)
        self.breast_side = breast_side
        assert breast_side in ['L', 'R'], "breast_side should be 'L' or 'R'"
        
        # implement the paths of the original RSNA dataset (V2)
        if csv_path:
            self.csv_path = csv_path
            
        if train_path:
            self.train_path = train_path
            
        self.df = pd.read_csv(self.csv_path)
        
        self.save_root = save_root
    
    # Get the paths from the preprocessor (V2)
    def get_paths(self, n: int=None, shuffle: bool=False, return_cache: bool=False):
        if n == None:
            n = len(self.df)
        if shuffle == True:
            df = self.df.sample(frac=1, random_state=0).copy()
        else:
            df = self.df.copy()
        paths = []
        ids_cache = []
        for i in range(n):
            patient = str(df.iloc[i]['patient_id'])
            scan = str(df.iloc[i]['image_id'])
            paths.append(self.train_path + '/' + patient + '/' + scan + '.dcm')
            ids_cache.append({'patient_id': patient, 'scan_id': scan})
        if return_cache:
            return paths, ids_cache
        else:
            return paths
    
    # Read from a path and convert to image array
    def read_image(self, path: str):
        scan = pydicom.dcmread(path)
        img = scan.pixel_array
        return img
    
    # Apply the preprocessing methods on one image
    def preprocess_single_image(self, path: str, save: bool=False,
                                save_dir: str=None, png: bool=True):
        scan = dicomsdl.open(path)
        img = scan.pixelData()
        img = self._windowing(img, scan)
        img = self._fix_photometric_interpretation(img, scan)
        img = self._normalize_to_255(img)
        img = self._flip_breast_side(img)
        img = self._crop(img)
        if self.size:
            img = self._resize(img)
        if save:
            self._save_image(img, path, png, save_dir)
            return # do not return the images to avoid memory leak
        return img
    
    # Preprocess all the images from the paths
    def preprocess_all(self, paths: list, save: bool=True,
                       save_dir: str='train_images', png: bool=True,
                       parallel: bool=False, n_jobs: int=4):
        clock = time.time()
        if parallel:
            Parallel(n_jobs=n_jobs) \
            (delayed(self.preprocess_single_image) \
            (path, save, save_dir, png) for path in tqdm(paths, total=len(paths)))
            print("Parallel preprocessing done!")
        else:
            for i in trange(len(paths)):
                self.preprocess_single_image(paths[i], save, save_dir, png)
            print("Sequential preprocessing done!")
        print("Time =", np.around(time.time() - clock, 3), 'sec')
    
    # Display the images from the dicom paths with optional preprocessing
    def display(self, paths: list, rows: int, cols: int,
                preprocess: bool=False, cmap='bone', cbar: bool=False,
                save_fig: bool=False, save_name: str='myplot.png'):
        assert len(paths) >= (rows * cols), \
        f"Not enough paths for the display. " \
        f"Please give at least {rows * cols} paths."
        plt.figure(figsize=(18, 26 * rows / cols))
        for i in trange(rows * cols):
            path = paths[i]
            if preprocess:
                img = self.preprocess_single_image(path, save=False)
            else:
                img = self.read_image(path)
            plt.subplot(rows, cols, i+1)
            plt.imshow(img, cmap=cmap)
            if cbar:
                plt.colorbar()
            plt.grid(False)
            plt.title(path.split('/')[-1][:-4])
        plt.suptitle("Preprocessed images" if preprocess \
                     else "Raw images", fontsize=25)
        if save_fig:
            plt.savefig(save_name, facecolor='white')
        plt.show()
    
    # Adjust the contrast of an image
    def _windowing(self, img, scan):
        center = scan.WindowCenter
        width = scan.WindowWidth
        bits_stored = scan.BitsStored
        function = scan.VOILUTFunction
        if isinstance(center, list):
            center = center[0]
        if isinstance(width, list):
            width = width[0] 
        y_range = float(2**bits_stored - 1)
        if function == 'SIGMOID':
            img = y_range / (1 + np.exp(-4 * (img - center) / width))
        else: # LINEAR
            center -= 0.5
            width -= 1
            below = img <= (center - width / 2)
            above = img > (center + width / 2)
            between = np.logical_and(~below, ~above)
            img[below] = 0
            img[above] = y_range
            img[between] = ((img[between] - center) / width + 0.5) * y_range
        return img
    
    # Interpret pixels in a consistant way
    def _fix_photometric_interpretation(self, img, scan):
        if scan.PhotometricInterpretation == 'MONOCHROME1':
            return img.max() - img
        elif scan.PhotometricInterpretation == 'MONOCHROME2':
            return img - img.min()
        else:
            raise ValueError("Invalid Photometric Interpretation: {}"
                               .format(scan.PhotometricInterpretation))
    
    # Cast into 8-bits for saving
    def _normalize_to_255(self, img):
        if img.max() != 0:
            img = img / img.max()
        img *= 255
        return img.astype(np.uint8)
    
    # Flip the breast horizontally on the chosen side 
    def _flip_breast_side(self, img):
        img_breast_side = self._determine_breast_side(img)
        if img_breast_side == self.breast_side:
            return img
        else:
            return np.fliplr(img)    
    
    # Determine the current breast side
    def _determine_breast_side(self, img):
        col_sums_split = np.array_split(np.sum(img, axis=0), 2)
        left_col_sum = np.sum(col_sums_split[0])
        right_col_sum = np.sum(col_sums_split[1])
        if left_col_sum > right_col_sum:
            return 'L'
        else:
            return 'R'
    
    # Crop the useless background of the image
    def _crop(self, img):
        bin_img = self._binarize(img, threshold=5)
        contour = self._extract_contour(bin_img)
        img = self._erase_background(img, contour)
        x1, x2 = np.min(contour[:, :, 0]), np.max(contour[:, :, 0])
        y1, y2 = np.min(contour[:, :, 1]), np.max(contour[:, :, 1])
        x1, x2 = int(0.99 * x1), int(1.01 * x2)
        y1, y2 = int(0.99 * y1), int(1.01 * y2)
        return img[y1:y2, x1:x2]
    
    # Binarize the image at the threshold
    def _binarize(self, img, threshold):
        return (img > threshold).astype(np.uint8)
    
    # Get contour points of the breast
    def _extract_contour(self, bin_img):
        contours, _ = cv2.findContours(
            bin_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        contour = max(contours, key=cv2.contourArea)
        return contour
    
    # Set to background pixels of the image to zero
    def _erase_background(self, img, contour):
        mask = np.zeros(img.shape, np.uint8)
        cv2.drawContours(mask, [contour], -1, 255, cv2.FILLED)
        output = cv2.bitwise_and(img, mask)
        return output
    
    # Resize the image to the preprocessor size
    def _resize(self, img):
        
        #print(img.shape, self.size)
        target_size = (self.size[1], self.size[0])
        
        if (img.shape[1] / img.shape[0]) > (target_size[1] / target_size[0]):
            img_resized = image_resize(img, width=target_size[1])
            # Pad vertically
            pad = target_size[0] - img_resized.shape[0]
            img_resized = cv2.copyMakeBorder(img_resized, pad // 2, pad - pad // 2, 0, 0, cv2.BORDER_CONSTANT, value=0)
        else:
            img_resized = image_resize(img, height=target_size[0])
            # Pad horizontally
            pad = target_size[1] - img_resized.shape[1]
            img_resized = cv2.copyMakeBorder(img_resized, 0, 0, pad // 2, pad - pad // 2, cv2.BORDER_CONSTANT, value=0)

        return img_resized
    
    # Get the save path of a given dicom file
    def _get_save_path(self, path, png, save_dir):
        patient = path.split('/')[-2]
        filename = path.split('/')[-1]
        if png:
            filename = filename.replace('dcm', 'png')
        else:
            filename = filename.replace('dcm', 'jpeg')
        if save_dir:
            save_path = os.path.join(self.save_root, save_dir, patient, filename)
        else:
            save_path = os.path.join(self.save_root, patient, filename)
        return save_path
    
    # Save the preprocessed image
    def _save_image(self, img, path, png, save_dir):
        save_path = self._get_save_path(path, png, save_dir)
        patient_folder = os.path.split(save_path)[0]
        
        os.makedirs(patient_folder, exist_ok=True)
        cv2.imwrite(save_path, img)

In [92]:
mp = MammographyPreprocessor(size=(384, 768))

mp.preprocess_all(paths, save=True, save_dir='../preprocessed_mp_aspectratio_768_384', parallel=True, n_jobs=32)

  0%|          | 0/54706 [00:00<?, ?it/s]

Parallel preprocessing done!
Time = 2104.782 sec
