In [9]:
import itk
import vtk
from vtk import vtkCommand
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

IF3         = itk.Image[itk.F, 3]
IUC3        = itk.Image[itk.UC, 3]
IRGBUC3     = itk.Image[itk.RGBPixel[itk.UC], 3]
LMLOUL3     = itk.LabelMap[itk.StatisticsLabelObject[itk.UL, 3]]

class DicomSeries():
  
  def __init__(self, fileNames):
    self.fileNames = fileNames
    self.type = IUC3
    self.generatePipeline(fileNames)

  def generatePipeline(self, fileNames):
    # itk.ImageFileReader -> itk.ImageSeriesReader -> 
    # RescaleIntensityImageFilter -> CastImageFilter -> 
    # itk.ImageToVTKImageFilter -> vtkImageViewer2
    imageSeriesReader = itk.ImageSeriesReader[IF3].New()
    imageSeriesReader.SetFileNames(fileNames)
    rescaler = itk.RescaleIntensityImageFilter[IF3, IF3].New()
    rescaler.SetInput(imageSeriesReader.GetOutput())
    rescaler.SetOutputMaximum(255)
    rescaler.SetOutputMinimum(0)
    self.caster = itk.CastImageFilter[IF3, IUC3].New()
    self.caster.SetInput(rescaler.GetOutput())
    self.caster.Update()

  def GetOutput(self):
    return self.caster.GetOutput()

In [10]:
NII_PATH = "/Users/benjaminhon/Developer/HeadHunter/notebooks/220259.nii"
DICOM_PATH = "/Users/benjaminhon/Developer/HeadHunter/notebooks/220259"

def generateSeries(path):
    generator = itk.GDCMSeriesFileNames.New()
    generator.SetDirectory(path)
    seriesUIDs = generator.GetSeriesUIDs()
    series = { uid: generator.GetFileNames(uid) for uid in generator.GetSeriesUIDs() }
    return (series, seriesUIDs)

In [13]:
(series, seriesUIDs) = generateSeries(DICOM_PATH)
model = DicomSeries(series[seriesUIDs[2]])

imageToVTKImageFilter = itk.ImageToVTKImageFilter[model.type].New()
imageToVTKImageFilter.SetInput(model.GetOutput())

imageToVTKImageFilter.Update()

In [14]:
class SliceViewInteractorStyle(vtk.vtkInteractorStyleUser):
    def __init__(self, parent=None, imageViewer=None):
        self.AddObserver(vtkCommand.MouseWheelForwardEvent, self.mouseWheelForwardEvent)        
        self.AddObserver(vtkCommand.MouseWheelBackwardEvent, self.mouseWheelBackwardEvent)
        if imageViewer:
            self.imageViewer = imageViewer
            self.minSlice = imageViewer.GetSliceMin()
            self.maxSlice = imageViewer.GetSliceMax()
            self.slice = round((self.minSlice + self.maxSlice) / 2)
            print(self.minSlice, self.maxSlice)
        
    def nextSlice(self):
        if self.imageViewer and self.slice < self.maxSlice:
            self.slice = self.slice + 1
            self.imageViewer.SetSlice(self.slice)

    def previousSlice(self):
        if self.imageViewer and self.slice > self.minSlice:
            self.slice = self.slice - 1
            self.imageViewer.SetSlice(self.slice)
    
    def mouseWheelForwardEvent(self, obj, event):
        self.nextSlice()
        return
    
    def mouseWheelBackwardEvent(self, obj, event):
        self.previousSlice()
        return

imageViewer = vtk.vtkImageViewer2()
imageViewer.SetInputData(imageToVTKImageFilter.GetOutput())

renderWindow = imageViewer.GetRenderWindow()
interactor = vtk.vtkRenderWindowInteractor()
sliceInteractorStyle = SliceViewInteractorStyle(imageViewer=imageViewer)
interactor.SetInteractorStyle(sliceInteractorStyle)
interactor.SetRenderWindow(renderWindow)

renderWindow.Render()
interactor.Initialize()
interactor.Start()

0 22
