# Data exploration & preprocessing

This notebook enhances an [example](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) from the monai tutorials repo with interactive visualization for CT scans.

**Hint:** You might be required to run `jupyter nbextension enable --py widgetsnbextension` and restart the kernel in order to enable the JavaScript widgets in this notebook.

## Imports & configuration

In [39]:
from pathlib import Path
import glob

import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, fixed, interactive
from IPython.display import display
import numpy as np

import torch
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.transforms import (
    LoadImage,
    LoadImaged,
    AddChanneld,
    Spacingd,
    Orientationd,
    RandAffined,
    Rand3DElasticd,
)

print_config()

MONAI version: 0.4.0
Numpy version: 1.19.4
Pytorch version: 1.7.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 0563a4467fa602feca92d91c7f47261868d171a1

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.1
scikit-image version: 0.18.0
Pillow version: 8.0.1
Tensorboard version: 2.4.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.8.2
ITK version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.54.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [2]:
data_dir = Path().absolute().parents[1] / "data/lung-tumor-segmentation/Task06_Lung"
if not data_dir.exists():
    data_dir.mkdir()

## Load data

In [3]:
train_images = sorted(glob.glob(str(data_dir/"imagesTr"/"*.nii.gz")))
train_labels = sorted(glob.glob(str(data_dir/"labelsTr"/"*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
]
train_data_dicts = data_dicts

In [4]:
train_data_dicts[0]

{'image': '/Users/felix/code/ml/ml-experiments/data/lung-tumor-segmentation/Task06_Lung/imagesTr/lung_001.nii.gz',
 'label': '/Users/felix/code/ml/ml-experiments/data/lung-tumor-segmentation/Task06_Lung/labelsTr/lung_001.nii.gz'}

In [5]:
loader = LoadImage(dtype=np.float32)

In [6]:
img, meta = loader(train_data_dicts[0]["image"])

In [7]:
print(f"input: {train_data_dicts[0]['image']}")
print(f"image shape: {img.shape}")
print(f"image affine:\n{meta['affine']}")
print(f"image pixdim:\n{meta['pixdim']}")

input: /Users/felix/code/ml/ml-experiments/data/lung-tumor-segmentation/Task06_Lung/imagesTr/lung_001.nii.gz
image shape: (512, 512, 304)
image affine:
[[  -0.69335938    0.            0.          182.15332031]
 [   0.            0.69335938    0.          -40.15332031]
 [   0.            0.            1.         -305.        ]
 [   0.            0.            0.            1.        ]]
image pixdim:
[-1.         0.6933594  0.6933594  1.         0.         0.
  0.         0.       ]


In [8]:
loader = LoadImaged(keys=("image", "label"))

In [9]:
data_dict = loader(train_data_dicts[0])
print(f"input:, {train_data_dicts[0]}")
print(f"image shape: {data_dict['image'].shape}")
print(f"label shape: {data_dict['label'].shape}")
print(f"image pixdim:\n{data_dict['image_meta_dict']['pixdim']}")

input:, {'image': '/Users/felix/code/ml/ml-experiments/data/lung-tumor-segmentation/Task06_Lung/imagesTr/lung_001.nii.gz', 'label': '/Users/felix/code/ml/ml-experiments/data/lung-tumor-segmentation/Task06_Lung/labelsTr/lung_001.nii.gz'}
image shape: (512, 512, 304)
label shape: (512, 512, 304)
image pixdim:
[-1.         0.6933594  0.6933594  1.         0.         0.
  0.         0.       ]


## Visualize CT images

In [10]:
def plot_seg_image(image, label, dim, dim_val, separate_plots=False):
    if isinstance(dim, str):
        dim = ["Width", "Height", "Depth"].index(dim)
    if separate_plots:
        plt.figure("visualize", (16, 8))
        plt.subplot(1, 2, 1)
        plt.title("image")
        plt.imshow(image.take(indices=dim_val, axis=dim), cmap="gray")
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label.take(indices=dim_val, axis=dim))
    else:
        plt.figure("visualize", (8, 8))
        plt.title("image & label")
        plt.imshow(image.take(indices=dim_val, axis=dim), cmap="gray")
        plt.imshow(label.take(indices=dim_val, axis=dim), alpha=0.3)
    plt.show()

In [11]:
image, label = data_dict["image"], data_dict["label"]
shape = image.shape

dim = widgets.RadioButtons(options=['Width', 'Height', 'Depth'], value='Depth', description='Dimension: ', disabled=False)
dim_val = widgets.IntSlider(value=0, max=(np.min(shape)-1), description="Level: ")
separate_plots = widgets.Checkbox(value=False, description="Separate plots?", disabled=False)

def update_dim_val_range(*args):
    if dim.value == 'Width':
        dim_val.max = shape[0] - 1
    elif dim.value == 'Height':
        dim_val.max = shape[1] - 1
    elif dim.value == 'Depth':
        dim_val.max = shape[2] - 1
dim.observe(update_dim_val_range, 'value')

ui = widgets.HBox([dim, dim_val, separate_plots])

out = widgets.interactive_output(plot_seg_image, {'dim': dim, 
                                                  'dim_val': dim_val, 
                                                  'separate_plots': separate_plots, 
                                                  'image': fixed(image), 
                                                  'label': fixed(label)})
display(ui, out)

HBox(children=(RadioButtons(description='Dimension: ', index=2, options=('Width', 'Height', 'Depth'), value='D…

Output()

## Standard transformations

### Add channel dimension

In [16]:
add_channel = AddChanneld(keys=["image", "label"])
datac_dict = add_channel(data_dict)
print(f"image shape: {datac_dict['image'].shape}")

image shape: (1, 512, 512, 304)


### Resample to consistent voxel size

In [19]:
print(f"image affine before spacing:\n{data_dict['image_meta_dict']['affine']}")

image affine before spacing:
[[  -0.69335938    0.            0.          182.15332031]
 [   0.            0.69335938    0.          -40.15332031]
 [   0.            0.            1.         -305.        ]
 [   0.            0.            0.            1.        ]]


In [22]:
spacing = Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 5.0), mode=("bilinear", "nearest"))

In [23]:
datac_dict = spacing(datac_dict)
print(f"image shape: {datac_dict['image'].shape}")
print(f"image affine after spacing:\n{datac_dict['image_meta_dict']['affine']}")

image shape: (1, 237, 237, 62)
image affine after spacing:
[[  -1.5           0.            0.          182.15332031]
 [   0.            1.5           0.          -40.15332031]
 [   0.            0.            5.         -305.        ]
 [   0.            0.            0.            1.        ]]


In [28]:
image, label = datac_dict["image"].squeeze(), datac_dict["label"].squeeze()
shape = image.shape

dim = widgets.RadioButtons(options=['Width', 'Height', 'Depth'], value='Depth', description='Dimension: ', disabled=False)
dim_val = widgets.IntSlider(value=0, max=(np.min(shape)-1), description="Level: ", continuous_update=False)
separate_plots = widgets.Checkbox(value=False, description="Separate plots?", disabled=False)

def update_dim_val_range(*args):
    if dim.value == 'Width':
        dim_val.max = shape[0] - 1
    elif dim.value == 'Height':
        dim_val.max = shape[1] - 1
    elif dim.value == 'Depth':
        dim_val.max = shape[2] - 1
dim.observe(update_dim_val_range, 'value')

ui = widgets.HBox([dim, dim_val, separate_plots])

out = widgets.interactive_output(plot_seg_image, {'dim': dim, 
                                                  'dim_val': dim_val, 
                                                  'separate_plots': separate_plots, 
                                                  'image': fixed(image), 
                                                  'label': fixed(label)})
display(ui, out)

HBox(children=(RadioButtons(description='Dimension: ', index=2, options=('Width', 'Height', 'Depth'), value='D…

Output()

### Reorient to designated axes codes

In [30]:
orientation = Orientationd(keys=["image", "label"], axcodes="PLI")

In [31]:
data_dict = orientation(datac_dict)
print(f"image shape: {data_dict['image'].shape}")
print(f"image affine after orientation:\n{data_dict['image_meta_dict']['affine']}")

image shape: (1, 237, 237, 62)
image affine after orientation:
[[  0.          -1.5          0.         182.15332031]
 [ -1.5          0.           0.         313.84667969]
 [  0.           0.          -5.           0.        ]
 [  0.           0.           0.           1.        ]]


In [32]:
image, label = data_dict["image"].squeeze(), data_dict["label"].squeeze()
shape = image.shape

dim = widgets.RadioButtons(options=['Width', 'Height', 'Depth'], value='Depth', description='Dimension: ', disabled=False)
dim_val = widgets.IntSlider(value=0, max=(np.min(shape)-1), description="Level: ", continuous_update=False)
separate_plots = widgets.Checkbox(value=False, description="Separate plots?", disabled=False)

def update_dim_val_range(*args):
    if dim.value == 'Width':
        dim_val.max = shape[0] - 1
    elif dim.value == 'Height':
        dim_val.max = shape[1] - 1
    elif dim.value == 'Depth':
        dim_val.max = shape[2] - 1
dim.observe(update_dim_val_range, 'value')

ui = widgets.HBox([dim, dim_val, separate_plots])

out = widgets.interactive_output(plot_seg_image, {'dim': dim, 
                                                  'dim_val': dim_val, 
                                                  'separate_plots': separate_plots, 
                                                  'image': fixed(image), 
                                                  'label': fixed(label)})
display(ui, out)

HBox(children=(RadioButtons(description='Dimension: ', index=2, options=('Width', 'Height', 'Depth'), value='D…

Output()

## Random transformations

### Apply random affine transformation

In [34]:
rand_affine = RandAffined(
    keys=["image", "label"],
    mode=("bilinear", "nearest"),
    prob=1.0,
    spatial_size=(224, 224, 55),
    translate_range=(40, 40, 2),
    rotate_range=(np.pi / 36, np.pi / 36, np.pi / 4),
    scale_range=(0.15, 0.15, 0.15),
    padding_mode="border",
)

In [37]:
affined_data_dict = rand_affine(data_dict)
print(f"image shape: {affined_data_dict['image'].shape}")

image shape: torch.Size([1, 224, 224, 55])


In [38]:
image, label = affined_data_dict["image"].squeeze().numpy(), affined_data_dict["label"].squeeze().numpy()
shape = image.shape

dim = widgets.RadioButtons(options=['Width', 'Height', 'Depth'], value='Depth', description='Dimension: ', disabled=False)
dim_val = widgets.IntSlider(value=0, max=(np.min(shape)-1), description="Level: ", continuous_update=False)
separate_plots = widgets.Checkbox(value=False, description="Separate plots?", disabled=False)

def update_dim_val_range(*args):
    if dim.value == 'Width':
        dim_val.max = shape[0] - 1
    elif dim.value == 'Height':
        dim_val.max = shape[1] - 1
    elif dim.value == 'Depth':
        dim_val.max = shape[2] - 1
dim.observe(update_dim_val_range, 'value')

ui = widgets.HBox([dim, dim_val, separate_plots])

out = widgets.interactive_output(plot_seg_image, {'dim': dim, 
                                                  'dim_val': dim_val, 
                                                  'separate_plots': separate_plots, 
                                                  'image': fixed(image), 
                                                  'label': fixed(label)})
display(ui, out)

HBox(children=(RadioButtons(description='Dimension: ', index=2, options=('Width', 'Height', 'Depth'), value='D…

Output()

### Apply random elastic deformation

In [40]:
rand_elastic = Rand3DElasticd(
    keys=["image", "label"],
    mode=("bilinear", "nearest"),
    prob=1.0,
    sigma_range=(5, 8),
    magnitude_range=(100, 200),
    spatial_size=(224, 224, 20),
    translate_range=(50, 50, 2),
    rotate_range=(np.pi / 36, np.pi / 36, np.pi),
    scale_range=(0.15, 0.15, 0.15),
    padding_mode="border",
)

In [41]:
deformed_data_dict = rand_elastic(data_dict)
print(f"image shape: {deformed_data_dict['image'].shape}")

image shape: (1, 224, 224, 20)


In [42]:
image, label = deformed_data_dict["image"].squeeze(), deformed_data_dict["label"].squeeze()
shape = image.shape

dim = widgets.RadioButtons(options=['Width', 'Height', 'Depth'], value='Depth', description='Dimension: ', disabled=False)
dim_val = widgets.IntSlider(value=0, max=(np.min(shape)-1), description="Level: ", continuous_update=False)
separate_plots = widgets.Checkbox(value=False, description="Separate plots?", disabled=False)

def update_dim_val_range(*args):
    if dim.value == 'Width':
        dim_val.max = shape[0] - 1
    elif dim.value == 'Height':
        dim_val.max = shape[1] - 1
    elif dim.value == 'Depth':
        dim_val.max = shape[2] - 1
dim.observe(update_dim_val_range, 'value')

ui = widgets.HBox([dim, dim_val, separate_plots])

out = widgets.interactive_output(plot_seg_image, {'dim': dim, 
                                                  'dim_val': dim_val, 
                                                  'separate_plots': separate_plots, 
                                                  'image': fixed(image), 
                                                  'label': fixed(label)})
display(ui, out)

HBox(children=(RadioButtons(description='Dimension: ', index=2, options=('Width', 'Height', 'Depth'), value='D…

Output()