SimpleITK based Segmentation

In [1]:
import SimpleITK as sitk
import numpy as np
from helpers import *

sample_mri_path = 'MRNet-v1.0/train/coronal/0002.npy'
mri_array = np.load(sample_mri_path)

In [2]:
mri_image = sitk.GetImageFromArray(mri_array)

In [3]:
mri_array.shape

(32, 256, 256)

In [4]:
explore_3D_array(mri_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [5]:
seg = mri_image > 200
thresh_img = sitk.LabelOverlay(mri_image, seg)
thresh_array = sitk.GetArrayFromImage(thresh_img)

In [6]:
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [7]:
seg = sitk.BinaryThreshold(mri_image, lowerThreshold=100, upperThreshold=400, insideValue=1, outsideValue=0)
thresh_img = sitk.LabelOverlay(mri_image, seg)
thresh_array = sitk.GetArrayFromImage(thresh_img)

In [8]:
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [9]:
otsu_filter = sitk.OtsuThresholdImageFilter()
otsu_filter.SetInsideValue(0)
otsu_filter.SetOutsideValue(1)
seg = otsu_filter.Execute(mri_image)
thresh_img = sitk.LabelOverlay(mri_image, seg)
thresh_array = sitk.GetArrayFromImage(thresh_img)
explore_3D_array(thresh_array)
print(otsu_filter.GetThreshold() )

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

85.0


In [10]:
print(mri_array.shape)

(32, 256, 256)


In [11]:
seed = (128,128,18)
seg = sitk.Image(mri_image.GetSize(), sitk.sitkUInt8)
seg.CopyInformation(mri_image)
seg[seed] = 1
seg = sitk.BinaryDilate(seg, (9,9,9))
thresh_img = sitk.LabelOverlay(mri_image, seg)
thresh_array = sitk.GetArrayFromImage(thresh_img)
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [12]:
seed = (128,150,18)
seg = sitk.ConfidenceConnected(mri_image, seedList=[seed],
                                   numberOfIterations=1,
                                   multiplier=2.5,
                                   initialNeighborhoodRadius=1,
                                   replaceValue=1)
thresh_img = sitk.LabelOverlay(mri_image, seg)
thresh_array = sitk.GetArrayFromImage(thresh_img)
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [13]:
feature_img = sitk.GradientMagnitudeRecursiveGaussian(mri_image, sigma=.5)
speed_img = sitk.BoundedReciprocal(feature_img) # This is parameter free unlike the Sigmoid
thresh_array = sitk.GetArrayFromImage(speed_img)
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [14]:
fm_filter = sitk.FastMarchingBaseImageFilter()
fm_filter.SetTrialPoints([seed])
fm_filter.SetStoppingValue(1000)
fm_img = fm_filter.Execute(speed_img)
thresh_img = sitk.Threshold(fm_img,
                    lower=0.0,
                    upper=fm_filter.GetStoppingValue(),
                    outsideValue=fm_filter.GetStoppingValue()+1)
thresh_array = sitk.GetArrayFromImage(thresh_img)
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…

In [15]:
seed = (128,128,18)
seg = sitk.Image(mri_image.GetSize(), sitk.sitkUInt8)
seg.CopyInformation(mri_image)
seg[seed] = 1
seg = sitk.BinaryDilate(seg, (3,3,3))

stats = sitk.LabelStatisticsImageFilter()
stats.Execute(mri_image, seg)

factor = 3.5
lower_threshold = stats.GetMean(1)-factor*stats.GetSigma(1)
upper_threshold = stats.GetMean(1)+factor*stats.GetSigma(1)
print(lower_threshold,upper_threshold)

78.5598429607918 225.5295425140685


In [16]:
init_ls = sitk.SignedMaurerDistanceMap(seg, insideIsPositive=True, useImageSpacing=True)


In [17]:
lsFilter = sitk.ThresholdSegmentationLevelSetImageFilter()
lsFilter.SetLowerThreshold(lower_threshold)
lsFilter.SetUpperThreshold(upper_threshold)
lsFilter.SetMaximumRMSError(0.02)
lsFilter.SetNumberOfIterations(1000)
lsFilter.SetCurvatureScaling(.5)
lsFilter.SetPropagationScaling(1)
lsFilter.ReverseExpansionDirectionOn()
ls = lsFilter.Execute(init_ls, sitk.Cast(mri_image, sitk.sitkFloat32))
print(lsFilter)

itk::simple::ThresholdSegmentationLevelSetImageFilter
  LowerThreshold: 78.5598
  UpperThreshold: 225.53
  MaximumRMSError: 0.02
  PropagationScaling: 1
  CurvatureScaling: 0.5
  NumberOfIterations: 1000
  ReverseExpansionDirection: 1
  ElapsedIterations: 940
  RMSChange: 0.0198485
  Debug: 0
  NumberOfThreads: 8
  NumberOfWorkUnits: 0
  Commands: (none)
  ProgressMeasurement: 0.94
  ActiveProcess: (none)



In [18]:
thresh_img = sitk.LabelOverlay(mri_image, ls>0)
thresh_array = sitk.GetArrayFromImage(thresh_img)
explore_3D_array(thresh_array)

interactive(children=(IntSlider(value=15, description='SLICE', max=31), Output()), _dom_classes=('widget-inter…