In [6]:

from ipywidgets import interact, fixed
import ipywidgets as widgets


import numpy as np
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
from medpy.io import load
import os

# For interactive plotting (not used yet)
import plotly.plotly as py
import plotly.graph_objs as go

To replicate this in your machine:
- pip install the dependencies above
- Download the dataset from HanQi's Google Drive
- Put them in data/Training and data/Testing

In [51]:
def show_all_images(folder_index, x, y, z, cancer_type = "LGG", type1 = True, type2 = True, type3 = True, type4 = True):
    """
    Function: 
        - Shows images in the training set in the four modalities
    Arguments:
        - folder_index (int): the index of the folder within `data/Training/LGG` or `data/Training/HGG`
        - cancer_type (str): either "LGG" or "HGG"
        - x (int, betweeen 0 and 250): point of sagittal (left/right) slicing
        - y (int, betweeen 0 and 250): point of coronal (front/back) slicing
        - z (int, betweeen 0 and 155): point of horizontal (top/bottom) slicing
        - type0 (boolean): if True, shows type0 cancer
        - type1 (boolean): if True, shows type1 cancer
        - type2 (boolean): if True, shows type2 cancer
        - type3 (boolean): if True, shows type3 cancer
        - type4 (boolean): if True, shows type4 cancer
        
    """
    assert cancer_type in ["HGG", "LGG"], "`cancer_type` has to be either 'HGG' or 'LGG'"
    
    # Getting the target patient_folder
    if cancer_type == "LGG":
        root_path  = "data/Training/LGG"
    else:
        root_path = "data/Training/HGG"
        
    max_folder_index = len(os.listdir(root_path)) - 1
        
    try:
        patient_folder = os.listdir(root_path)[folder_index]
    except:
        raise IndexError(f"Your index is out of bound. The maximum index for {cancer_type} is {max_folder_index}")
    
    
    # Getting the subfolders within patient_folder
    subfolder_list = os.listdir(os.path.join(root_path, patient_folder))
    IM_TYPES = ["T1", "T1c", "T2", "Flair", "OT"]
    subfolder_dict = {}
    for subfolder in subfolder_list:
        for im_type in IM_TYPES:
            if im_type + "." in subfolder:
                subfolder_dict[im_type] = subfolder

    # Getting all the images
    data_dict = {}
    for im_type in subfolder_dict.keys():
        subfolder = subfolder_dict[im_type]
        semi_complete_path = os.path.join(root_path, patient_folder, subfolder)
        for file in os.listdir(semi_complete_path):
            if ".mha" in file:
                mha_file = file
        complete_path = os.path.join(semi_complete_path, mha_file)
        image_data, image_header = load(complete_path)
        data_dict[im_type] = {}
        data_dict[im_type]["image_data"] = image_data
        data_dict[im_type]["image_header"] = image_header

    
    # Collecting the cancer_types to show in the final iamge
    cancer_types = []
    if type1:
        cancer_types.append(1)        
    if type2:
        cancer_types.append(2)
    if type3:
        cancer_types.append(3)
    if type4:
        cancer_types.append(4)
    
    # Data visualization
    print(f"Showing images for {cancer_type} patient index number {folder_index} out of {max_folder_index}, patient code {patient_folder}")
    fig, axes = plt.subplots(nrows = 5, ncols = 3, figsize = (16, 16))
    for index, image_type in enumerate(["T1", "T1c", "T2", "Flair", "OT"]):
        if image_type != "OT":
            im = data_dict[image_type]["image_data"]
            mask = data_dict["OT"]["image_data"]
        
            axes[index, 0].imshow(ndi.rotate(input = np.where(np.isin(mask[x,:,:],cancer_types), 0, im[x,:,:]), angle = 90), 
                                  cmap = "viridis")
            axes[index, 0].set_title("Sagittal plane (left/right) for "  + image_type)
            axes[index, 0].axis("off")

            axes[index, 1].imshow(ndi.rotate(input = np.where(np.isin(mask[:,y,:],cancer_types), 0, im[:,y,:]), angle = 90), 
                                  cmap = "viridis")
            axes[index, 1].set_title("Coronal plane (front/back) for "  + image_type)
            axes[index, 1].axis("off")

            axes[index, 2].imshow(ndi.rotate(input = np.where(np.isin(mask[:,:,z],cancer_types), 0, im[:,:,z]), angle = 90), 
                                  cmap = "viridis")
            axes[index, 2].set_title("Horizontal plane (top/bottom) for " + image_type)
            axes[index, 2].axis("off")
        else:
            im = data_dict[image_type]["image_data"]
            mask = data_dict["OT"]["image_data"]
            axes[index, 0].imshow(ndi.rotate(input = np.where(np.isin(mask[x,:,:],cancer_types), im[x,:,:], 0), angle = 90), 
                                  cmap = "viridis")
            axes[index, 0].set_title("Sagittal plane (left/right) for label (OT)")
            axes[index, 0].axis("off")

            axes[index, 1].imshow(ndi.rotate(input = np.where(np.isin(mask[:,y,:],cancer_types), im[:,y,:], 0), angle = 90), 
                                  cmap = "viridis")
            axes[index, 1].set_title("Coronal plane (front/back) for label (OT)")
            axes[index, 1].axis("off")

            axes[index, 2].imshow(ndi.rotate(input = np.where(np.isin(mask[:,:,z],cancer_types), im[:,:,z], 0), angle = 90), 
                                  cmap = "viridis")
            axes[index, 2].set_title("Horizontal plane (top/bottom) for label (OT)")
            axes[index, 2].axis("off")
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    plt.show()
    


In [52]:
interact(show_all_images, 
         data_dict = fixed(data_dict),
         cancer_type = ["LGG", "HGG"],
         type1 = 
            widgets.Checkbox(
                value=False,
                description='1',
                disabled=False
            ),
         type2 = 
            widgets.Checkbox(
                value=False,
                description='2',
                disabled=False
            ),
         type3 = 
            widgets.Checkbox(
                value=False,
                description='3',
                disabled=False
            ),
         type4 = 
            widgets.Checkbox(
                value=False,
                description='4',
                disabled=False
            ),
         folder_index = widgets.IntSlider(min=0, max=100, step=1, value=0),
         x = widgets.IntSlider(min=0, max=250, step=1, value=0),
         y = widgets.IntSlider(min=0, max=250, step=1, value=0),
         z = widgets.IntSlider(min=0, max=155, step=1, value=100),
        continuous_update=False)

interactive(children=(IntSlider(value=0, description='folder_index'), IntSlider(value=0, description='x', max=…

<function __main__.show_all_images(folder_index, x, y, z, cancer_type='LGG', type1=True, type2=True, type3=True, type4=True)>