Tutorial for Medical Image Processing using Python (SimpleITK and Scikit-Image)
===============

------------------------------------

### Load Modules

In [None]:
%matplotlib inline
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# default params for drawing gray scaled image
plt.rcParams['image.cmap'] = 'gray'

In [None]:
imag_directory = 'dcm'
series_IDs = sitk.ImageSeriesReader.GetGDCMSeriesIDs(imag_directory)
if not series_IDs:
    print("ERROR: given directory \"" + imag_directory + "\" does not contain a DICOM series.")

print(series_IDs)

---------------------------------
### Get File Name List form series id

In [None]:
series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(imag_directory, series_IDs[0])

In [None]:
# print(series_file_names)

---------------------------------
### Read Files

In [None]:
# print(series_file_names)
series_reader = sitk.ImageSeriesReader() # creat class instance 
series_reader.SetFileNames(series_file_names) # set file name
series_reader.MetaDataDictionaryArrayUpdateOn()
series_reader.LoadPrivateTagsOn()

---------------------------------
### Execute

In [None]:
images = series_reader.Execute()

---------------------------------
### convert numpy array 

In [None]:
mainBuffer = sitk.GetArrayFromImage(images)
type(mainBuffer)

---------------------------------
### useful (default) image parameter

In [None]:
dim = images.GetSize()
org = images.GetOrigin()
dh = images.GetSpacing()

print(dim, org, dh)

---------------------------------
### functions for image rendering

In [None]:
def get_image(images, vpos, aspect, vmin, vmax):
    
    def current_slice(idx):
        fig, ax = plt.subplots(figsize=(6, 6))    
        if vpos == 'axial':
            ax.imshow(images[idx, :, :], aspect=aspect, vmin=vmin, vmax=vmax)
        elif vpos == 'coronal':
            ax.imshow(images[:, idx, :], aspect=aspect, origin='lower', vmin=vmin, vmax=vmax)
        elif vpos == 'sagittal':
            ax.imshow(images[:, :, idx], aspect=aspect, origin='lower', vmin=vmin, vmax=vmax)

        ax.set_axis_off()
#         print(idx)
        plt.show()
        
    return current_slice

def sliceimageview(images, vpos, dh, level=-40, window=400):
    from ipywidgets import IntSlider, interact
    
    if vpos == 'axial':
        current_slice = get_image(images, vpos, aspect=dh[1]/dh[0], vmin=level-window/2, vmax=level+window/2)
        step_slider = IntSlider(min=0, max=images.shape[0]-1, value=images.shape[0]/2)
    elif vpos == 'coronal':
        current_slice = get_image(images, vpos, aspect=dh[2]/dh[0], vmin=level-window/2, vmax=level+window/2)
        step_slider = IntSlider(min=0, max=images.shape[1]-1, value=images.shape[1]/2)
    elif vpos == 'sagittal':
        current_slice = get_image(images, vpos, aspect=dh[2]/dh[1], vmin=level-window/2, vmax=level+window/2)
        step_slider = IntSlider(min=0, max=images.shape[2]-1, value=images.shape[1]/2)
        
    interact(current_slice, idx=step_slider)
    
    

In [None]:
sliceimageview(mainBuffer, vpos='coronal', dh=dh)

---------------------------------
### DICOM tag dictionary

#### DICOM tag는 파일 하나를 읽어서 활용합니다. (DICOM Series 에서는 활용할 수 없음)

In [None]:
filepath = 'dcm/000000.dcm'
image = sitk.ReadImage(filepath)
# for key in image.GetMetaDataKeys():
#     print('{:7s} - {:s}'.format(key, image.GetMetaData(key)))

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))    
buffer = sitk.GetArrayFromImage(image)

# error !! why ??
ax.imshow(buffer, vmin=-240, vmax=160)
ax.set_axis_off()

### Quiz 1. How To Draw Correctly?

In [None]:
# answer
%load https://raw.github.com/jeonkiwan/PythonTutorial/master/Tutorial02/solution.py

----------------------------------

## Image Denoising



### [Non Local Mean Denoising for Texture Preserving](https://scikit-image.org/docs/dev/auto_examples/filters/plot_nonlocal_means.html)

In [None]:
# Load Module
from skimage.restoration import denoise_nl_means, denoise_wavelet, estimate_sigma

In [None]:
# set data
imag = buffer[0, :, :].astype(np.float32)

# or
# imag = mainBuffer[308, :, :]

In [None]:
sigma_est = np.mean(estimate_sigma(imag, multichannel=True))

print("estimated noise standard deviation                 = {}".format(sigma_est))

patch_kw = dict(patch_size=5,      # 5x5 patches
                patch_distance=7,  # 13x13 search area
                multichannel=False)

# slow algorithm
denoised_imag = denoise_nl_means(imag, h=1.0*sigma_est, fast_mode=False, **patch_kw)

# noise estimation 
sigma_est = np.mean(estimate_sigma(denoised_imag, multichannel=True))
print("estimated noise standard deviation after denoising = {}".format(sigma_est))


### Draw Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 8))    

ax = axes.flatten()

ax[0].imshow(imag, vmin=-240, vmax=160)
ax[0].set_axis_off()
ax[0].set_title('original image')

ax[1].imshow(denoised_imag, vmin=-240, vmax=160)
ax[1].set_axis_off()
ax[1].set_title('denoised image')

ax[2].imshow(denoised_imag-imag, vmin=-10, vmax=10)
ax[2].set_axis_off()
ax[2].set_title('difference image')

fig.tight_layout()

### [Wavelet Denoising](https://scikit-image.org/docs/dev/api/skimage.restoration.html#skimage.restoration.denoise_wavelet)

In [None]:
sigma_est = np.mean(estimate_sigma(imag, multichannel=True))

print("estimated noise standard deviation                 = {}".format(sigma_est))

mn = np.amin(imag)
mx = np.amax(imag)

# image scaling ...
t_imag = (imag-mn)/(mx-mn)

# wavelet algorithm
t_imag = denoise_wavelet(t_imag)

# re-scaling ...
denoised_imag = (mx-mn)*t_imag + mn

# noise estimation 
sigma_est = np.mean(estimate_sigma(denoised_imag, multichannel=True))
print("estimated noise standard deviation after denoising = {}".format(sigma_est))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 8))    

ax = axes.flatten()

ax[0].imshow(imag, vmin=-240, vmax=160)
ax[0].set_axis_off()
ax[0].set_title('original image')

ax[1].imshow(denoised_imag, vmin=-240, vmax=160)
ax[1].set_axis_off()
ax[1].set_title('denoised image')

ax[2].imshow(denoised_imag-imag, vmin=-10, vmax=10)
ax[2].set_axis_off()
ax[2].set_title('difference image')

fig.tight_layout()

------------------------------------------

## [Image Segmentation](https://scikit-image.org/docs/dev/api/skimage.segmentation.html)

### Whole Lung Segmentation

In [None]:
imag = mainBuffer[158, :, :].astype(np.float32)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(imag, vmin=-240, vmax=160)
ax.set_axis_off()

### Image Scaling (Lung Enhance)

In [None]:
mx = -300
mn = -1100
t_imag = (imag-mn)/(mx-mn)
t_imag[t_imag > 1.0] = 1.0
t_imag[t_imag < 0.0] = 0.0

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(t_imag)
ax.set_axis_off()

### [Gaussian Mixture (Pixel) Separation](https://scipy-lectures.org/advanced/image_processing/auto_examples/plot_GMM.html)

#### Pixel Intensity를 두 개의 Gaussian Distribution으로 Fitting 하여 분리합니다

In [None]:
#
from sklearn import mixture

#
mixture_buffer = np.zeros((t_imag.flatten().shape[0], 2), dtype=t_imag.dtype)
mixture_buffer[:, 0] = t_imag.flatten()

gm_segment = np.zeros(shape=t_imag.shape, dtype=t_imag.dtype).flatten()


# random process ... something like random forest 
# because the algorithm result is strongly depened on initial guess

numIter = 1
for iter in range(numIter):
    
    # fit !!
    dpgmm = mixture.GaussianMixture(n_components=2, covariance_type='full', max_iter=10).fit(mixture_buffer)

    # prediction
    prob_map = dpgmm.predict_proba(mixture_buffer)

    # convert to the distribution
    prob_map = np.transpose(np.divide(np.transpose(prob_map), np.sum(prob_map, axis=1)))
    prob_map[prob_map < 0.001] = 0.0

    if dpgmm.means_[0, 0] < dpgmm.means_[1, 0]:
        gm_segment += prob_map[:, 0]
    else:
        gm_segment += prob_map[:, 1]

# revert to the original shape
gm_segment = np.divide(gm_segment.reshape(t_imag.shape), numIter)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(gm_segment)
ax.set_axis_off()

### [Morphological Process](https://scikit-image.org/docs/dev/api/skimage.morphology.html)

In [None]:
# call module
from skimage.morphology import erosion, dilation, remove_small_holes, disk

filled_segment = remove_small_holes(gm_segment.astype(np.bool), area_threshold=128*128)

# additional process
# filled_segment = remove_small_objects(filled_segment.astype(np.bool))

filled_segment = filled_segment.astype(np.float32)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))    

ax = axes.flatten()

ax[0].imshow(t_imag)
ax[0].set_axis_off()
ax[0].set_title('original image')

ax[1].imshow(gm_segment)
ax[1].set_axis_off()
ax[1].set_title('gaussian mixture segmentation')

ax[2].imshow(filled_segment)
ax[2].set_axis_off()
ax[2].set_title('with mophological process')

fig.tight_layout()

### [Active Contour Segmentation](https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_morphsnakes.html#sphx-glr-auto-examples-segmentation-plot-morphsnakes-py)

In [None]:
from skimage.segmentation import (morphological_geodesic_active_contour, 
                                  inverse_gaussian_gradient,
                                  circle_level_set)

In [None]:
init_lvs = circle_level_set(filled_segment.shape, center=(256, 144), radius=4)\
         + circle_level_set(filled_segment.shape, center=(256, 368), radius=4)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(filled_segment)
ax.contour(init_lvs, [0.5], colors='r')
ax.set_axis_off()

In [None]:
grad_imag = inverse_gaussian_gradient(filled_segment)

In [None]:
def store_evolution_in(lst):
    """Returns a callback function to store the evolution of the level sets in
    the given list.
    """
    def _store(x):
        lst.append(np.copy(x))
    return _store

In [None]:
evolution = []
callback = store_evolution_in(evolution)
final_lvs = morphological_geodesic_active_contour(grad_imag, 100, init_lvs, 
                                                  balloon=2.0, smoothing=1,
                                                  iter_callback=callback)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(filled_segment)

contour=ax.contour(init_lvs, [0.5], colors=[plt.cm.tab10(0)])
contour.collections[0].set_label("Initial Levelset")

contour=ax.contour(evolution[10], [0.5], colors=[plt.cm.tab10(1)])
contour.collections[0].set_label("After 10 Iteration")

contour=ax.contour(evolution[50], [0.5], colors=[plt.cm.tab10(2)])
contour.collections[0].set_label("After 50 Iteration")

contour=ax.contour(final_lvs, [0.5], colors='r')
contour.collections[0].set_label("Final Results")

ax.legend(loc="upper right")
ax.set_axis_off()

### If we apply the algorithm to original image ?

In [None]:
grad_imag = inverse_gaussian_gradient(t_imag)
evolution = []
callback = store_evolution_in(evolution)
final_lvs = morphological_geodesic_active_contour(grad_imag, 100, init_lvs, 
                                                  balloon=1.0, smoothing=1,
                                                  iter_callback=callback)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6)) 
ax.imshow(t_imag)

contour=ax.contour(init_lvs, [0.5], colors=[plt.cm.tab10(0)])
contour.collections[0].set_label("Initial Levelset")

contour=ax.contour(evolution[10], [0.5], colors=[plt.cm.tab10(1)])
contour.collections[0].set_label("After 10 Iteration")

contour=ax.contour(evolution[50], [0.5], colors=[plt.cm.tab10(2)])
contour.collections[0].set_label("After 50 Iteration")

contour=ax.contour(final_lvs, [0.5], colors='r')
contour.collections[0].set_label("Final Results")

ax.legend(loc="upper right")
ax.set_axis_off()