Skip to content

Commit

Permalink
Add abstract module and use SlicerPyTorch
Browse files Browse the repository at this point in the history
Resolves #7.
  • Loading branch information
fepegar committed Aug 31, 2021
1 parent aca9dec commit ec1274a
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 168 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include(${Slicer_USE_FILE})
#-----------------------------------------------------------------------------
# Extension modules
add_subdirectory(TorchIOTransforms)
add_subdirectory(TorchIOModule)
## NEXT_MODULE

#-----------------------------------------------------------------------------
Expand Down
30 changes: 30 additions & 0 deletions TorchIOModule/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#-----------------------------------------------------------------------------
set(MODULE_NAME TorchIOModule)

#-----------------------------------------------------------------------------
set(MODULE_PYTHON_SCRIPTS
${MODULE_NAME}.py
)

set(MODULE_PYTHON_RESOURCES
Resources/Icons/${MODULE_NAME}.png
)

#-----------------------------------------------------------------------------
slicerMacroBuildScriptedModule(
NAME ${MODULE_NAME}
SCRIPTS ${MODULE_PYTHON_SCRIPTS}
RESOURCES ${MODULE_PYTHON_RESOURCES}
WITH_GENERIC_TESTS
)

#-----------------------------------------------------------------------------
if(BUILD_TESTING)

# Register the unittest subclass in the main script as a ctest.
# Note that the test will also be available at runtime.
slicer_add_python_unittest(SCRIPT ${MODULE_NAME}.py)

# Additional build-time testing
add_subdirectory(Testing)
endif()
Binary file added TorchIOModule/Resources/Icons/TorchIOModule.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions TorchIOModule/Testing/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Python)
2 changes: 2 additions & 0 deletions TorchIOModule/Testing/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

#slicer_add_python_unittest(SCRIPT ${MODULE_NAME}ModuleTest.py)
158 changes: 158 additions & 0 deletions TorchIOModule/TorchIOModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import logging
from pathlib import Path
from contextlib import contextmanager

import numpy as np
import SimpleITK as sitk

import qt, slicer
import sitkUtils as su
from slicer.ScriptedLoadableModule import (
ScriptedLoadableModule,
ScriptedLoadableModuleLogic,
)

import PyTorchUtils


MRML_LABEL = 'vtkMRMLLabelMapVolumeNode'
MRML_SCALAR = 'vtkMRMLScalarVolumeNode'


class TorchIOModule(ScriptedLoadableModule):

def __init__(self, parent):
ScriptedLoadableModule.__init__(self, parent)
self.parent.title = 'TorchIO Abstract Module'
self.parent.categories = []
self.parent.dependencies = []
self.parent.contributors = [
"Fernando Perez-Garcia (University College London and King's College London)"
]
self.parent.helpText = (
'This module can be used to quickly visualize the effect of each'
' transform parameter. That way, users can have an intuitive feeling of'
' what the output of a transform looks like without any coding at all.\n\n'
)
self.parent.helpText += self.getDefaultModuleDocumentationLink()
self.parent.acknowledgementText = (
'This work was was funded by the Engineering and Physical Sciences'
' Research Council (EPSRC) and supported by the UCL Centre for Doctoral'
' Training in Intelligent, Integrated Imaging in Healthcare, the UCL'
' Wellcome / EPSRC Centre for Interventional and Surgical Sciences (WEISS),'
' and the School of Biomedical Engineering & Imaging Sciences (BMEIS)'
" of King's College London."
)

def getDefaultModuleDocumentationLink(self):
docsUrl = 'https://torchio.readthedocs.io/slicer.html'
linkText = f'See <a href="{docsUrl}">the documentation</a> for more information.'
return linkText


class TorchIOModuleLogic(ScriptedLoadableModuleLogic):
def __init__(self):
self._torchio = None
self.torchLogic = PyTorchUtils.PyTorchUtilsLogic()

@property
def torchio(self):
if self._torchio is None:
logging.info('Importing torchio...')
self._torchio = self.importTorchIO()
return self._torchio

def importTorchIO(self):
if not self.torchLogic.torchInstalled():
logging.info('PyTorch module not found')
torch = self.torchLogic.installTorch(askConfirmation=True)
if torch is None:
slicer.util.errorDisplay(
'PyTorch needs to be installed to use the TorchIO extension.'
' Please reload this module to install PyTorch.'
)
return
try:
import torchio
except ModuleNotFoundError:
with self.showWaitCursor(), self.peakPythonConsole():
torchio = self.installTorchIO()
logging.info(f'TorchIO {torchio.__version__} imported correctly')
return torchio

@staticmethod
def installTorchIO(confirm=True):
if confirm:
install = slicer.util.confirmOkCancelDisplay(
'TorchIO will be downloaded and installed now. The process might take some minutes.'
)
if not install:
logging.info('Installation of TorchIO aborted by user')
return None
slicer.util.pip_install('torchio')
import torchio
logging.info(f'TorchIO {torchio.__version__} installed correctly')
return torchio

def getTorchIOImageFromVolumeNode(self, volumeNode):
image = su.PullVolumeFromSlicer(volumeNode)
tio = self.torchio
if volumeNode.IsA('vtkMRMLScalarVolumeNode'):
image = sitk.Cast(image, sitk.sitkFloat32)
class_ = tio.ScalarImage
elif volumeNode.IsA('vtkMRMLLabelMapVolumeNode'):
class_ = tio.LabelMap
tensor, affine = tio.io.sitk_to_nib(image)
return class_(tensor=tensor, affine=affine)

def getVolumeNodeFromTorchIOImage(self, image, outputVolumeNode=None):
tio = self.torchio
kwargs = {}
if outputVolumeNode is None:
kwargs = {'className': MRML_LABEL if isinstance(image, tio.LabelMap) else MRML_SCALAR}
else:
kwargs = {'targetNode': outputVolumeNode}
su.PushVolumeToSlicer(image.as_sitk(), **kwargs)
return outputVolumeNode

def getPythonConsoleWidget(self):
return slicer.util.mainWindow().pythonConsole().parent()

@contextmanager
def peakPythonConsole(self, show=True):
if show:
console = self.getPythonConsoleWidget()
pythonVisible = console.visible
console.setVisible(True)
yield
if show:
console.setVisible(pythonVisible)

@contextmanager
def showWaitCursor(self, show=True):
if show:
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
yield
if show:
qt.QApplication.restoreOverrideCursor()

def getNodesFromSubject(self, subject):
nodes = {}
for name, image in subject.get_images_dict(intensity_only=False).items():
nodes[name] = self.getVolumeNodeFromTorchIOImage(image, name=name)
return nodes

def getColin(self, version=1998):
colin = self.torchio.datasets.Colin27(version=version)
nodes = self.getNodesFromSubject(colin)
if version == 1998:
slicer.util.setSliceViewerLayers(
background=nodes['t1'],
label=nodes['brain'],
)
elif version == 2008:
slicer.util.setSliceViewerLayers(
background=nodes['t1'],
foreground=nodes['t2'],
label=nodes['cls'],
)

0 comments on commit ec1274a

Please sign in to comment.