In [7]:
# %matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os, sys, re, math, time, shutil
import nibabel as nib
from matplotlib import colors
from niwidgets import NiftiWidget
from scipy import ndimage
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from tqdm.notebook import tqdm
from IPython.display import clear_output
import SimpleITK as sitk

In [8]:
"""
Data load info. loaders:
- input: abs. path to dir. with data
- output: a dictionary with
    - key: img #
    - val: absolute dir. to each img
"""

# load data
def load_data_to_dictionaries(path):
    """
    output dictionaries that record the input file paths
    """
    data = {}
    for dirname, _, filenames in os.walk(path):
        for filename in filenames:
            idx = int(re.findall(r'\d+', filename)[0])
            data[int(idx)] = os.path.join(dirname, filename)
    return data

# load data (LITS17 specific)
def load_filepaths_to_dictionaries(path):
    """
    output dictionaries that record the input file paths
    """
    volumes = {}
    segments = {}
    for dirname, _, filenames in os.walk(path):
        for filename in filenames:
            if filename[:3] == 'vol':
                num = filename.split('-')[1]
                num = num.split('.')[0]
                volumes[int(num)] = os.path.join(dirname, filename)
            elif filename[:3] == 'seg':
                num = filename.split('-')[1]
                num = num.split('.')[0]
                segments[int(num)] = os.path.join(dirname, filename)
            elif filename[:4] == 'test': # test-samples have volume data only
                num = filename.split('-')[2]
                num = num.split('.')[0]
                volumes[int(num)] = os.path.join(dirname, filename)
    if bool(segments):
        assert(len(volumes.keys()) == len(segments.keys()))
        for k in volumes.keys():
            assert(k in segments)
    return volumes, segments

In [21]:
vol_dict = load_data_to_dictionaries('/scratch/ec2684/cv/data/lits17/raw/train/vol/')
seg_dict = load_data_to_dictionaries('/scratch/ec2684/cv/data/lits17/raw/train/seg/')

In [22]:
PROC_TRAIN_VOL = '/scratch/ec2684/cv/data/lits17/processed/train/nii/vol/'
PROC_TRAIN_SEG = '/scratch/ec2684/cv/data/lits17/processed/train/nii/seg/'

In [23]:
# ---------------------Preprocessing parameters-----------------------------------

size = 48  # Use 48 consecutive slices as input to the network

expand_slice = 20  # Only use the liver and the upper and lower 20 slices of the liver as training samples

slice_thickness = 1  # Normalize the spacing of all data on the x,y,z-axis to 1mm

upper, lower = 200, -200  # CT data gray-scale truncation window

In [25]:
for file in tqdm(sorted(vol_dict.keys())):
    # Load CT volume and segmentation map into memory
    vol = sitk.ReadImage(vol_dict[file], sitk.sitkInt16)
    vol_array = sitk.GetArrayFromImage(vol)

    seg = sitk.ReadImage(seg_dict[file], sitk.sitkUInt8)
    seg_array = sitk.GetArrayFromImage(seg)

    # merge GT (ground truth) liver and tumor labels into one label
    seg_array[seg_array > 0] = 1

    # clip grayscale beyond thresholds
    vol_array[vol_array > upper] = upper
    vol_array[vol_array < lower] = lower
    
    resize_factor_xy = vol.GetSpacing()[0] / slice_thickness
        
    # Downsample the CT data on the cross section and resample, adjust the spacing of the z-axis of all data to 1mm
    vol_array = ndimage.zoom(vol_array, (vol.GetSpacing()[-1] / slice_thickness, 1, 1), order=3)
    seg_array = ndimage.zoom(seg_array, (vol.GetSpacing()[-1] / slice_thickness, 1, 1), order=0)
    
    # Find the slices at the beginning and end of the liver area, and expand the slices outwards
    z = np.any(seg_array, axis=(1, 2))
    start_slice, end_slice = np.where(z)[0][[0, -1]]

    # Expand slices in both directions
    start_slice = max(0, start_slice - expand_slice)
    end_slice = min(seg_array.shape[0] - 1, end_slice + expand_slice)
    
    
    # If the number of slices left at this time is less than size, just give up the data. There is very little data, so don’t worry.
    if end_slice - start_slice + 1 < size:
        print('!!!!!!!!!!!!!!!!')
        print(file, 'have too little slice', ct_array.shape[0])
        print('!!!!!!!!!!!!!!!!')
        continue

    vol_array = vol_array[start_slice:end_slice + 1, :, :]
    seg_array = seg_array[start_slice:end_slice + 1, :, :]
    
    new_size = math.ceil(seg_array.shape[1] / 16) * 16
    upper_pad = math.ceil((new_size - seg_array.shape[1]) / 2)
    lower_pad = math.floor((new_size - seg_array.shape[1]) / 2)
    
    vol_array = np.pad(vol_array, ((0,0), (upper_pad,lower_pad), (upper_pad,lower_pad)), mode = 'constant', constant_values=(-200,-200))
    seg_array = np.pad(seg_array, ((0,0), (upper_pad,lower_pad), (upper_pad,lower_pad)), mode = 'constant', constant_values=(0,0))
    
    # Finally save the data as nii
    new_ct = sitk.GetImageFromArray(vol_array)

    new_ct.SetDirection(vol.GetDirection())
    new_ct.SetOrigin(vol.GetOrigin())
    new_ct.SetSpacing(((vol.GetSpacing()[0]), (vol.GetSpacing()[1]), slice_thickness))

    new_seg = sitk.GetImageFromArray(seg_array)
    new_seg.SetDirection(seg.GetDirection())
    new_seg.SetOrigin(seg.GetOrigin())
    new_seg.SetSpacing(((seg.GetSpacing()[0] * resize_factor_xy), (seg.GetSpacing()[1] * resize_factor_xy), slice_thickness))
    
    sitk.WriteImage(new_ct, os.path.join(PROC_TRAIN_VOL, vol_dict[file].split("/")[-1]))
    sitk.WriteImage(new_seg, os.path.join(PROC_TRAIN_SEG, seg_dict[file].split("/")[-1].replace('.nii', '.nii.gz')))


In [None]:
## Visulation codes

def plot_scan_and_masks(scaled_index, scale, vol, mask, resized_vol, resized_mask, fig_width = 30, cmap = 'gray', prepro_or_output = True):
    """
    pred_mask is the predicted mask in shape (width, height, depth)
    """
    if scaled_index >= resized_vol.shape[-1] or scaled_index < 0:
        raise ValueError("Index out of range")
    
    fig_width = fig_width
    
    original_index = math.floor(scale*scaled_index)
    # Colormap for CT img.
    cmap = cmap
    
    print(f'z: {original_index} / z_scaled: {scaled_index}')
    bounds=[0,1,2,3]
    # Custom colormap for segmentation map.
    seg_cmap = colors.ListedColormap(['black', 'white', 'red'])
    norm = colors.BoundaryNorm(bounds, seg_cmap.N)
    # Custom colormap for overlay.
    overlay_cmap = colors.ListedColormap(['white', 'green', 'red'])
    norm = colors.BoundaryNorm(bounds, overlay_cmap.N)
    overlay_cmap.set_under(color="white", alpha=0.0)
    
    fig, axes = plt.subplots(nrows = 1, ncols = 6, figsize = (fig_width, fig_width*6))#, constrained_layout=False)
    axes[0].title.set_text('Original Volume')
    axes[0].imshow(vol[original_index, ...], cmap = cmap, interpolation='none')
    axes[1].title.set_text('Original Seg.')
    axes[1].imshow(mask[original_index, ...], cmap = seg_cmap, interpolation='none')
    axes[2].title.set_text('Original Overlay')
    axes[2].imshow(vol[original_index, ...], cmap = cmap, interpolation='none')
    axes[2].imshow(mask[original_index, ...], cmap = overlay_cmap, interpolation='none', alpha=0.5)
    
    output_mode = 'Model Output ' if prepro_or_output else 'Scaled '
    axes[3].title.set_text(output_mode + 'Volume')
    axes[3].imshow(resized_vol[scaled_index, ...], cmap = cmap, interpolation='none')
    axes[4].title.set_text(output_mode + 'Seg.')
    axes[4].imshow(resized_mask[scaled_index, ...], cmap = seg_cmap, interpolation='none')
    axes[5].title.set_text(output_mode + 'Overlay')
    axes[5].imshow(resized_vol[scaled_index, ...], cmap = cmap, interpolation='none')
    axes[5].imshow(resized_mask[scaled_index, ...], cmap = overlay_cmap, interpolation='none', alpha=0.5)
    
    plt.show()
    plt.close()
    

def plot_mask_comparison_over_vol(vol, seg, vol_scaled, seg_scaled, scale, index_start = 0, index_end = None, step = 1):
    """ 
    plot volume, mask and predicted mask slices 
    over index_start, index_end with even spaced n_samples
    """
    if index_end is None:
        index_end = vol_scaled.shape[-1]
    for ind in range(index_start, index_end, step):
        plot_scan_and_masks(ind, scale, vol, seg, vol_scaled, seg_scaled)

In [None]:
def plot_scan_and_masks(index, vol, mask = None, pred_mask = None, fig_width = 15, colormap = 'gray'):
    """
    pred_mask is the predicted mask in shape (width, height, depth)
    """
    if index >= vol.shape[-1] or index < 0:
        raise ValueError("Index out of range")
        
    fig_width = fig_width
    colormap = colormap
    
    
    cmap = colors.ListedColormap(['white', 'green', 'red'])
    bounds=[0,1,2,3]
    norm = colors.BoundaryNorm(bounds, cmap.N)
    cmap.set_under(color="white", alpha=0.0)
    
    # Custom colormap for segmentation map.
    # Black: background
    # White: Liver
    # Red: Tumor
    gray_cmap = colors.ListedColormap(['black', 'white', 'red'])
    bounds=[0,1,2,3]
    norm = colors.BoundaryNorm(bounds, gray_cmap.N)
    # cmap.set_under(color="white", alpha=0.0)

    # tell imshow about color map so that only set colors are used
    # img = plt.imshow(zvals, interpolation='nearest', origin='lower', cmap=cmap, norm=norm)
    
    
    if mask is not None and pred_mask is not None:
        fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (fig_width, fig_width*3))
        axes[0].imshow(vol[:,: , index], cmap = 'gray')
        axes[1].imshow(mask[... , index], cmap = 'gray')
        axes[2].imshow(pred_mask[..., index], cmap = 'gray')
        
    elif mask is not None:
        fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (fig_width, fig_width*3))
        axes[0].imshow(vol[..., index], cmap = colormap, interpolation='none')
        axes[1].imshow(mask[..., index], cmap = gray_cmap, norm=norm, interpolation='none')
        axes[2].imshow(vol[..., index], cmap = colormap, interpolation='none')
        axes[2].imshow(mask[..., index], cmap = cmap, norm=norm, interpolation='none', alpha=0.5)
    elif pred_mask is not None:
        fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (fig_width, fig_width*2))
        axes[1].imshow(pred_mask[..., index], cmap = 'gray')
        axes[0].imshow(vol[..., index], cmap = 'gray')
    else:
        fig, axes = plt.subplots(nrows = 1, ncols = 1)
        axes.imshow(vol[..., index], cmap = 'gray')
    plt.show()
    plt.close()
    
def plot_mask_comparison_over_vol(vol, mask, pred_mask, index_start = 0, index_end = None, step = 3):
    """ 
    plot volume, mask and predicted mask slices 
    over index_start, index_end with even spaced n_samples
    """
    if index_end is None:
        index_end = vol.shape[-1]
    # for ind in range(index_start, index_end, step):
    for ind in range(index_start, index_end - 1, step):
        print(ind)
        print(np.linalg.norm(vol[...,ind]-vol[...,ind+1]))
        plot_scan_and_masks(ind, vol, mask, pred_mask)

In [None]:
def plot_3d(image,seg=None,threshold=200):
    verts, faces, normals, vals = measure.marching_cubes(image, threshold)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # image (bones)
    mesh = Poly3DCollection(verts[faces], alpha=0.70)
    face_color = [0.45, 0.45, 0.75]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)
    
    # segmentations
    if seg is not None:
        # 1 liver (green)
        seg_verts, seg_faces, seg_normals, seg_vals = measure.marching_cubes(1.*(seg == 1),0.)
        seg_mesh = Poly3DCollection(seg_verts[seg_faces], alpha=0.70)
        seg_face_color = [0.0, 1.0, 0.0]
        seg_mesh.set_facecolor(seg_face_color)
        ax.add_collection3d(seg_mesh)
        
        if np.any(seg_array_unit == 2) is True:
            # 2 lesion/tumor (red)
            seg_verts, seg_faces, seg_normals, seg_vals = measure.marching_cubes(1.*(seg == 2), 0.)
            seg_mesh = Poly3DCollection(seg_verts[seg_faces], alpha=0.70)
            seg_face_color = [1.0, 0.0, 0.0]
            seg_mesh.set_facecolor(seg_face_color)
            ax.add_collection3d(seg_mesh)

    ax.set_xlim(0, image.shape[0])
    ax.set_ylim(0, image.shape[1])
    ax.set_zlim(0, image.shape[2])

    plt.show()