/
PyTorchUtils.py
266 lines (220 loc) · 10.1 KB
/
PyTorchUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import qt
import logging
import slicer
from slicer.ScriptedLoadableModule import (
ScriptedLoadableModule,
ScriptedLoadableModuleWidget,
ScriptedLoadableModuleLogic,
ScriptedLoadableModuleTest,
)
class PyTorchUtils(ScriptedLoadableModule):
def __init__(self, parent):
ScriptedLoadableModule.__init__(self, parent)
self.parent.title = "PyTorch Utils"
self.parent.categories = ['Utilities']
self.parent.dependencies = []
self.parent.contributors = ["Fernando Perez-Garcia (University College London)"]
self.parent.helpText = 'This hidden module containing some tools to work with PyTorch inside Slicer.'
self.parent.acknowledgementText = (
'This work was funded by the Engineering and Physical Sciences'
' Research Council (EPSRC) and supported by the UCL Centre for Doctoral'
' Training in Intelligent, Integrated Imaging in Healthcare, the UCL'
' Wellcome / EPSRC Centre for Interventional and Surgical Sciences (WEISS),'
' and the School of Biomedical Engineering & Imaging Sciences (BMEIS)'
" of King's College London."
)
class PyTorchUtilsWidget(ScriptedLoadableModuleWidget):
def setup(self):
super().setup()
self.logic = PyTorchUtilsLogic()
# Load widget from .ui file (created by Qt Designer).
# Additional widgets can be instantiated manually and added to self.layout.
uiWidget = slicer.util.loadUI(self.resourcePath('UI/PyTorchUtils.ui'))
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
self.ui.detectPushButton.clicked.connect(self.onDetect)
self.ui.installPushButton.clicked.connect(self.onInstallTorch)
self.ui.uninstallPushButton.clicked.connect(self.onUninstallTorch)
self.ui.restartPushButton.clicked.connect(self.onApplicationRestart)
self.updateVersionInformation()
def onDetect(self):
with slicer.util.tryWithErrorDisplay("Failed to detect compatible computation backends.", waitCursor=True):
backends = self.logic.getCompatibleComputationBackends()
currentBackend = self.ui.backendComboBox.currentText
self.ui.backendComboBox.clear()
self.ui.backendComboBox.addItem("automatic")
for backend in backends:
self.ui.backendComboBox.addItem(backend)
self.ui.backendComboBox.currentText = currentBackend
self.ui.backendComboBox.showPopup()
self.updateVersionInformation()
def onInstallTorch(self):
with slicer.util.tryWithErrorDisplay("Failed to install PyTorch. Some PyTorch files may be in use or corrupted. Please restart the application, uninstall PyTorch, and try installing again.", waitCursor=True):
if self.logic.torchInstalled():
torch = self.logic.torch
slicer.util.delayDisplay(f'PyTorch {torch.__version__} is already installed, using {self.logic.getDevice()}.', autoCloseMsec=2000)
else:
backend = self.ui.backendComboBox.currentText
automaticBackend = (backend == "automatic")
askConfirmation = automaticBackend
torch = self.logic.installTorch(askConfirmation, None if automaticBackend else backend)
if torch is not None:
slicer.util.delayDisplay(f'PyTorch {torch.__version__} installed successfully using {self.logic.getDevice()}.', autoCloseMsec=2000)
self.updateVersionInformation()
def onUninstallTorch(self):
with slicer.util.tryWithErrorDisplay("Failed to uninstall PyTorch. Probably PyTorch is already in use. Please restart the application and try again.", waitCursor=True):
self.logic.uninstallTorch()
slicer.util.delayDisplay(f'PyTorch uninstalled successfully.', autoCloseMsec=2000)
self.updateVersionInformation()
def updateVersionInformation(self):
try:
self.ui.torchVersionInformation.text = self.logic.torchVersionInformation
except Exception as e:
logging.error(str(e))
self.ui.torchVersionInformation.text = "unknown (corrupted installation?)"
try:
info = self.logic.nvidiaDriverVersionInformation
self.ui.nvidiaVersionInformation.text = info if info else "unknown"
except Exception as e:
logging.error(str(e))
self.ui.nvidiaVersionInformation.text = "unknown"
def onApplicationRestart(self):
slicer.util.restart()
class PyTorchUtilsLogic(ScriptedLoadableModuleLogic):
def __init__(self):
self._torch = None
@property
def nvidiaDriverVersionInformation(self):
"""Get NVIDIA driver version information as a string that can be displayed to the user.
If light-the-torch is not installed yet then empty string is returned.
"""
try:
import light_the_torch._cb as computationBackend
return f"installed version {str(computationBackend._detect_nvidia_driver_version())}"
except Exception as e:
# Don't install light-the-torch just for getting the NVIDIA driver version
return ""
@property
def torchVersionInformation(self):
"""Get PyTorch version information as a string that can be displayed to the user.
"""
if not self.torchInstalled():
return "not installed"
import torch
return f"installed version {torch.__version__}"
@property
def torch(self):
"""``torch`` Python module. it will be installed if necessary."""
if self._torch is None:
logging.info('Importing torch...')
self._torch = self.importTorch()
return self._torch
@staticmethod
def torchInstalled():
# Attempt to import torch could load some files, which could prevent uninstalling a corrupted pytorch install
import importlib.metadata
try:
metadataPath = [p for p in importlib.metadata.files('torch') if 'METADATA' in str(p)][0]
except importlib.metadata.PackageNotFoundError as e:
return False
try:
import torch
installed = True
except ModuleNotFoundError:
installed = False
return installed
def importTorch(self):
"""Import the ``torch`` Python module, installing it if necessary."""
if self.torchInstalled():
import torch
else:
torch = self.installTorch()
if torch is None:
logging.warning('PyTorch was not installed')
else:
logging.info(f'PyTorch {torch.__version__} imported successfully')
logging.info(f'CUDA available: {torch.cuda.is_available()}')
return torch
def installTorch(self, askConfirmation=False, forceComputationBackend=None):
"""Install PyTorch and return the ``torch`` Python module.
:param forceComputationBackend: optional parameter to set computation backend (cpu, cu116, cu117, ...)
If computation backend is not specified then the ``light-the-torch`` Python package is used to get the most recent version of
PyTorch compatible with the installed NVIDIA drivers. If CUDA-compatible device is not found, a version compiled for CPU will be installed.
"""
args = PyTorchUtilsLogic._getPipInstallArguments(forceComputationBackend)
if askConfirmation and not slicer.app.commandOptions().testingEnabled:
install = slicer.util.confirmOkCancelDisplay(
f'PyTorch will be downloaded and installed using light-the-torch (ltt {" ".join(args)}).'
' The process might take some minutes.'
)
if not install:
logging.info('Installation of PyTorch aborted by user')
return None
try:
import light_the_torch._patch
except:
PyTorchUtilsLogic._installLightTheTorch()
import light_the_torch._patch
slicer.util._executePythonModule('light_the_torch', args)
import torch
logging.info(f'PyTorch {torch.__version__} installed successfully.')
return torch
def uninstallTorch(self, askConfirmation=False, forceComputationBackend=None):
"""Uninstall PyTorch"""
slicer.util.pip_uninstall('torch')
logging.info(f'PyTorch uninstalled successfully.')
@staticmethod
def _getPipInstallArguments(forceComputationBackend=None):
args = ["install","torch"]
if forceComputationBackend is not None:
args.append(f"--pytorch-computation-backend={forceComputationBackend}")
return args
@staticmethod
def _installLightTheTorch():
slicer.util.pip_install('light-the-torch>=0.5')
@staticmethod
def getCompatibleComputationBackends(forceComputationBackend=None):
"""Get the list of computation backends compatible with the available hardware.
:param forceComputationBackend: optional parameter to set computation backend (cpu, cu116, cu117, ...)
If computation backend is not specified then the ``light-the-torch`` is used to get the most recent version of
PyTorch compatible with the installed NVIDIA drivers.
"""
try:
import light_the_torch._patch
except:
PyTorchUtilsLogic._installLightTheTorch()
import light_the_torch._patch
args = PyTorchUtilsLogic._getPipInstallArguments(forceComputationBackend)
try:
backends = sorted(light_the_torch._patch.LttOptions.from_pip_argv(args).computation_backends)
except Exception as e:
logging.warning(str(e))
raise ValueError(f"Failed to get computation backend. Requested computation backend: `{forceComputationBackend}`.")
return backends
def getPyTorchHubModel(self, repoOwner, repoName, modelName, addPretrainedKwarg=True, *args, **kwargs):
"""Use PyTorch Hub to download a PyTorch model, typically pre-trained.
More information can be found at https://pytorch.org/hub/.
"""
repo = f'{repoOwner}/{repoName}'
if addPretrainedKwarg:
kwargs['pretrained'] = True
model = self.torch.hub.load(repo, modelName, *args, **kwargs)
return model
def getDevice(self):
"""Get CUDA device if available and CPU otherwise."""
return self.torch.device('cuda') if self.torch.cuda.is_available() else 'cpu'
@property
def cuda(self):
"""Return True if a CUDA-compatible device is available."""
return self.getDevice() != 'cpu'
class PyTorchUtilsTest(ScriptedLoadableModuleTest):
def runTest(self):
self.test_PyTorchUtils()
def _delayDisplay(self, message):
if not slicer.app.testingEnabled():
self.delayDisplay(message)
def test_PyTorchUtils(self):
self._delayDisplay('Starting the test')
logic = PyTorchUtilsLogic()
self._delayDisplay(f'CUDA available: {logic.torch.cuda.is_available()}')
self._delayDisplay('Test passed!')