In [1]:
import joblib
import pandas as pd
import os
import SimpleITK as sitk
import cc3d
import numpy as np
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import matplotlib.patches as patches

# Utility function

In [5]:
def get_min_max(pos, shape, box_size = 128):
    """
        check the edge case and return the bounding box.
    """
    helf_box_size = int(box_size/2)
    val_list = []
    for axis in range(3):
        if pos[axis] < helf_box_size:
            min_val = 0
            max_val = min_val + box_size if shape[axis]>min_val+box_size else shape[axis]
        elif pos[axis] + helf_box_size > shape[axis]:
            max_val = shape[0]
            min_val = max_val - box_size if max_val -box_size >0 else 0
        else:
            min_val = pos[axis] - helf_box_size
            max_val = pos[axis] + helf_box_size
        val_list.append((min_val,max_val))
    return val_list

def create_folder(path):
    """
        if the path folder does not exist, then create a folder according to its name.
    """
    if not os.path.exists(path):
        os.mkdir(path)

In [15]:
new_preprocess_img_folder = './inputs/cc3d'
image_out_folder = f'{new_preprocess_img_folder}/img'


In [16]:
create_folder(new_preprocess_img_folder)
create_folder(image_out_folder)

In [17]:
def show_mid_slice(out_folder, img_numpy, mask, is_mask=False, idx='1', title='img'):
   """
   Accepts an 3D numpy array and shows median slices in all three planes
   """
   assert img_numpy.ndim == 3
   n_i, n_j, n_k = img_numpy.shape

   # sagittal (left image)
   center_i1 = int((n_i - 1) / 2)
   # coronal (center image)
   center_j1 = int((n_j - 1) / 2)
   # axial slice (right image)
   center_k1 = int((n_k - 1) / 2)

   # print(f"img_numpy: {img_numpy.shape}, mask: {mask.shape}")
   # print(f"({center_i1}, {center_j1}, {center_k1})")

   show_slices([img_numpy[center_i1, :, :],
               img_numpy[:, center_j1, :],
               img_numpy[:, :, center_k1]],
               [mask[center_i1, :, :],
               mask[:, center_j1, :],
               mask[:, :, center_k1]],
               is_mask)
   plt.suptitle(title)
   plt.savefig(os.path.join(out_folder, idx)+'.png') 
   plt.close()

def show_slices(slices, mask_slices, is_mask):
   """
   Function to display a row of image slices
   Input is a list of numpy 2D image slices
   """
   fig, axes = plt.subplots(1, len(slices))
   for i, slice in enumerate(slices):
      mask_slice = mask_slices[i]
      if is_mask == True:
         mask = mask_slice.T
         mask[np.where(mask == 1)] = 0
         mask[np.where(mask == 3)] = 0

         masked_array = np.ma.masked_where(mask == 0, mask)

         cmap = plt.get_cmap('Reds').copy()
         cmap.set_bad(color='black')
         
         axes[i].imshow(slice.T, cmap="gray", vmin=-1024, vmax=512)
         axes[i].imshow(masked_array, cmap=cmap, alpha=0.7, vmin=0)
      else:
         axes[i].imshow(slice.T, cmap="gray", vmin=-1024, vmax=512)


In [18]:
TUMOR_BASE = 0
KIDNEY_BASE = 1
def show_slice_v2(img_numpy, mask,image_out_folder,crop, is_mask=False,slice_pos=np.array([64,64,64]), idx='1', title='img'):
   """
   Accepts an 3D numpy array and shows median slices in all three planes
   """
   fig, axes = plt.subplots(3, 3,figsize=(8, 8))
   for row_idx in range(2):
      if img_numpy[row_idx] is not None:
         if row_idx==TUMOR_BASE:
            slice_pos0 = slice_pos[TUMOR_BASE][0]
            slice_pos1 = slice_pos[TUMOR_BASE][1]
            slice_pos2 = slice_pos[TUMOR_BASE][2]
         elif row_idx==KIDNEY_BASE:
            if crop[TUMOR_BASE]!=None:
               tmp_slice0 = slice_pos[TUMOR_BASE][0]+crop[TUMOR_BASE][0][0] - crop[KIDNEY_BASE][0][0]
               tmp_slice1 = slice_pos[TUMOR_BASE][1]+crop[TUMOR_BASE][1][0] - crop[KIDNEY_BASE][1][0]
               tmp_slice2 = slice_pos[TUMOR_BASE][2]+crop[TUMOR_BASE][2][0] - crop[KIDNEY_BASE][2][0]
            else:
               tmp_slice0 = -1
               tmp_slice1 = -1
               tmp_slice2 = -1
            
            range128 = list(range(128))
            if all([tmp_slice0 in range128, tmp_slice1 in range128, tmp_slice2 in range128]):
               slice_pos0 = tmp_slice0
               slice_pos1 = tmp_slice1
               slice_pos2 = tmp_slice2
            else:
               slice_pos0 = slice_pos[KIDNEY_BASE][0]
               slice_pos1 = slice_pos[KIDNEY_BASE][1]
               slice_pos2 = slice_pos[KIDNEY_BASE][2]
            
         if row_idx==KIDNEY_BASE:
               is_mask = False
         _show_slices_v2(axes,[img_numpy[row_idx][slice_pos0, :, :],
                     img_numpy[row_idx][:, slice_pos1, :],
                     img_numpy[row_idx][:, :,slice_pos2]],
                     [mask[row_idx][slice_pos0, :, :],
                        mask[row_idx][:, slice_pos1, :],
                        mask[row_idx][:, :, slice_pos2]],
                        is_mask, row_idx)
         

      else:
         for i in range(3):
            axes[row_idx][i].imshow(np.zeros((128,128)),cmap='gray')
   _show_row3(axes,img_numpy[2],mask[2],crop,slice_pos)
   plt.suptitle(title)
   plt.savefig(os.path.join(image_out_folder, idx)+'.png') 
   plt.close()
def _show_row3(axes,raw_image,raw_mask,crop,slice_pos):
   if crop[TUMOR_BASE] is not None:
      slices = [raw_image[slice_pos[TUMOR_BASE][0]+crop[TUMOR_BASE][0][0],:,:],raw_image[:,slice_pos[TUMOR_BASE][1]+crop[TUMOR_BASE][1][0],:],raw_image[:,:,slice_pos[TUMOR_BASE][2]+crop[TUMOR_BASE][2][0]]]
   elif crop[KIDNEY_BASE] is not None:
      slices = [raw_image[slice_pos[KIDNEY_BASE][0]+crop[KIDNEY_BASE][0][0],:,:],raw_image[:,slice_pos[KIDNEY_BASE][1]+crop[KIDNEY_BASE][1][0],:],raw_image[:,:,slice_pos[KIDNEY_BASE][2]+crop[KIDNEY_BASE][2][0]]]
   else:
      slices = [raw_image[int(raw_image.shape[0]/2),:,:],raw_image[:,int(raw_image.shape[1]/2),:],raw_image[:,:,int(raw_image.shape[2]/2)]]
   for i, slice in enumerate(slices):
      axes[2][i].imshow(slice,cmap='gray', vmin=-1024, vmax=512)
      if crop[TUMOR_BASE] is None or crop[KIDNEY_BASE] is None:
         kidney_linestyle = '-'
      else:
         if crop[TUMOR_BASE][i][0] > crop[KIDNEY_BASE][i][1]:
               kidney_linestyle='-.'
         elif crop[TUMOR_BASE][i][1] < crop[KIDNEY_BASE][i][0]:
            kidney_linestyle='--'
         else:
            kidney_linestyle='-'
      if i== 0:
         # crop[TUMOR_BASE][2][0] is the start x, so crop[TUMOR_BASE][2][1] is the end
         width = crop[TUMOR_BASE][2][1] - crop[TUMOR_BASE][2][0] # x range
         height = crop[TUMOR_BASE][1][1] - crop[TUMOR_BASE][1][0] # y range
         # Rectangle((x,y), width, heigth)
         tumor_base_box = patches.Rectangle((crop[TUMOR_BASE][2][0], crop[TUMOR_BASE][1][0]), width, height, linewidth=1, edgecolor='r', facecolor='none') if crop[TUMOR_BASE] is not None else None
         kidney_base_box = patches.Rectangle((crop[KIDNEY_BASE][2][0], crop[KIDNEY_BASE][1][0]), width, height, linewidth=1, edgecolor='b', facecolor='none',linestyle=kidney_linestyle) if crop[KIDNEY_BASE] is not None else None
      elif i==1:
         width = crop[TUMOR_BASE][2][1] - crop[TUMOR_BASE][2][0] # x range
         height = crop[TUMOR_BASE][0][1] - crop[TUMOR_BASE][0][0] # y range
         # Rectangle((x,y), width, heigth)
         tumor_base_box = patches.Rectangle((crop[TUMOR_BASE][2][0], crop[TUMOR_BASE][0][0]), width, height, linewidth=1, edgecolor='r', facecolor='none') if crop[TUMOR_BASE] is not None else None
         kidney_base_box = patches.Rectangle((crop[KIDNEY_BASE][2][0], crop[KIDNEY_BASE][0][0]), width, height, linewidth=1, edgecolor='b', facecolor='none',linestyle=kidney_linestyle) if crop[KIDNEY_BASE] is not None else None
      else:
         width = crop[TUMOR_BASE][1][1] - crop[TUMOR_BASE][1][0] # x range
         height = crop[TUMOR_BASE][0][1] - crop[TUMOR_BASE][0][0] # y range
         # Rectangle((x,y), width, heigth)
         tumor_base_box = patches.Rectangle((crop[TUMOR_BASE][1][0], crop[TUMOR_BASE][0][0]), width, height, linewidth=1, edgecolor='r', facecolor='none') if crop[TUMOR_BASE] is not None else None
         kidney_base_box = patches.Rectangle((crop[KIDNEY_BASE][1][0], crop[KIDNEY_BASE][0][0]), width, height, linewidth=1, edgecolor='b', facecolor='none',linestyle=kidney_linestyle)  if crop[KIDNEY_BASE] is not None else None
      if tumor_base_box is not None:
         axes[2][i].add_patch(tumor_base_box)
      if kidney_base_box is not None:
         axes[2][i].add_patch(kidney_base_box)
def _show_slices_v2(axes,slices, mask_slices, is_mask,row_idx):
   min_ = min(np.min(slices[0]),np.min(slices[1]),np.min(slices[2]))
   max_ = max(np.max(slices[0]),np.max(slices[1]),np.max(slices[2]))
   for i, slice in enumerate(slices):
      mask_slice = mask_slices[i]
      if is_mask == True:
         mask = mask_slice
         mask[np.where(mask == 1)] = 0
         mask[np.where(mask == 3)] = 0

         masked_array = np.ma.masked_where(mask == 0, mask)

         slice = (slice - min_) / (max_ - min_)
         slice = np.stack((slice,)*3,-1)
         masked_array = np.stack((masked_array,) * 3,-1)
         red_image = np.zeros_like(slice)
         red_image[:,:,0] = 1

         axes[row_idx][i].imshow(np.where(masked_array, red_image, slice), vmin=0, vmax=1)
      else:
         axes[row_idx][i].imshow(slice, cmap="gray", vmin=-1024, vmax=512)

In [19]:
def resample_image(file_path, out_spacing=[1.0, 1.0, 1.0], is_label=False):
    """
        Resample the image based on spacing
    """
    # Read file
    itk_image = sitk.ReadImage(file_path)
    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    # Resample the image to physical size
    out_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, out_spacing)]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    if is_label: # if the image is mask, choose the interpolation of sitkNearestNeighbor
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
        resample.SetDefaultPixelValue(0)
    else: # if the image is not mask, use sitkBSpline
        resample.SetInterpolator(sitk.sitkBSpline)

    img = resample.Execute(itk_image)
    img = sitk.GetArrayFromImage(img)

    return img

In [23]:
def preprocess_func(img_path, mask_path, uid,
                    output_img_folder, output_3d_folder,
                    is_vghtc=True):
    
    # check whether the image and path already exist
    if os.path.exists(img_path) == False or os.path.exists(mask_path) == False:
        print(f"{uid} does not exist.")
        return
    assert os.path.exists(img_path) and os.path.exists(mask_path)

    # Resample the image and mask according to spacing
    img = resample_image(img_path, out_spacing=[1.0, 1.0, 1.0], is_label=False)
    seg_img = resample_image(mask_path, out_spacing=[1.0, 1.0, 1.0], is_label=True)

    # select the tumor regions
    data_2 = seg_img == 2
    data_2 = cc3d.dust(
        data_2, threshold=100, 
        connectivity=26, in_place=False
    )

    # N stands for how many connected component are found.
    labels_out, N = cc3d.connected_components(data_2, connectivity=26, return_N=True)
    
    if N == 0:
        print(f"{uid} does not have tumor.")
    if N >= 1:
        """
        cc3d statistics will include the information of bounding box, for example:

        'bounding_boxes': [(slice(0, 560, None), slice(0, 500, None), slice(0, 500, None)),
                           (slice(270, 370, None), slice(235, 313, None), slice(304, 384, None))],

        """
        stats = cc3d.statistics(labels_out)

        # select the maximum voxel of tumors
        max_vol_idx = np.argmax(stats['voxel_counts'][1:]) + 1

        # get its bounding box
        x_point = stats['bounding_boxes'][max_vol_idx][0]
        y_point = stats['bounding_boxes'][max_vol_idx][1]
        z_point = stats['bounding_boxes'][max_vol_idx][2]

        # get the centroid
        max_vol_centroids = stats['centroids'][max_vol_idx]
        max_vol_centroids = np.rint(max_vol_centroids).astype(int)

        # crop according to the bounding box
        crop_value = [[max(int(x_point.start)-5, 0), min(int(x_point.stop)+5, img.shape[0])],
                      [max(int(y_point.start)-5, 0), min(int(y_point.stop)+5, img.shape[1])],
                      [max(int(z_point.start)-5, 0), min(int(z_point.stop)+5, img.shape[2])]]

    if N>= 1:
        # crop the image and mask
        process_img = img[crop_value[0][0] : crop_value[0][1],
                      crop_value[1][0] : crop_value[1][1],
                      crop_value[2][0] : crop_value[2][1]]
    
        process_seg_img = seg_img[crop_value[0][0] : crop_value[0][1],
                              crop_value[1][0] : crop_value[1][1],
                              crop_value[2][0] : crop_value[2][1]]

        # get the statistics of new region
        tumor_seg_img = process_seg_img == 2
        tumor_labels_out, tumor_num = cc3d.connected_components(tumor_seg_img, connectivity=26, return_N=True)
        tumor_stats = cc3d.statistics(tumor_labels_out)
        slice_pos_tumor_base = np.rint(tumor_stats['centroids'][1]).astype(int) if tumor_num > 0 else np.array([64,64,64])


        # export the new region of image as nii.gz file
        f_name = os.path.join(output_img_folder,
                              'img_' + uid) + '.nii.gz'
        out = sitk.GetImageFromArray(process_img)
        sitk.WriteImage(out, f_name)

        # export the new region of mask as nii.gz file
        f_name_ = os.path.join(output_img_folder,
                               'seg_' + uid) + '.nii.gz'
        out = sitk.GetImageFromArray(process_seg_img)
        sitk.WriteImage(out, f_name_)

        # check their size must be the same
        assert process_img.shape == process_seg_img.shape 

        # export as image
        show_slice_v2(img_numpy=[process_img, process_img, img],
                   mask = [process_seg_img, process_seg_img, seg_img],
                   crop = [crop_value, crop_value],
                   image_out_folder = new_preprocess_img_folder, # os.path.join(, 'out')
                   is_mask = True,
                   idx = uid,
                   slice_pos= [slice_pos_tumor_base, slice_pos_tumor_base],
                   title = f"UID : {uid} \n"+ f"bounding box: {crop_value}\n" + f"voxel:{stats['voxel_counts'][max_vol_idx]}")

        return uid, stats['voxel_counts'][max_vol_idx]

In [None]:
res = preprocess_func('inputs\img_case_0.nii.gz',
                     'inputs\seg_case_0.nii.gz',
                     'case_0',
                      new_preprocess_img_folder,
                      image_out_folder,
                      True
                      ) for i in raw_data.index

In [25]:
data_uid = [item[0] for item in res if item]
data_voxel = [item[1] for item in res if item]

data_dict = {
    'uid': data_uid,
    'voxel': data_voxel
}

# Record the voxel of each data
pd.DataFrame(data_dict).to_csv(os.path.join(new_preprocess_img_folder, 'data_voxel.csv'))