diff --git a/numba/cuda/cudadrv/nvvm.py b/numba/cuda/cudadrv/nvvm.py index 24569f99b43..505e6797f9d 100644 --- a/numba/cuda/cudadrv/nvvm.py +++ b/numba/cuda/cudadrv/nvvm.py @@ -272,29 +272,39 @@ def get_log(self): default_data_layout = data_layout[tuple.__itemsize__ * 8] +_supported_cc = None -try: - from numba.cuda.cudadrv.runtime import runtime - cudart_version_major = runtime.get_version()[0] -except: - # The CUDA Runtime may not be present - cudart_version_major = 0 - -# List of supported compute capability in sorted order -if cudart_version_major == 0: - SUPPORTED_CC = (), -elif cudart_version_major < 9: - # CUDA 8.x - SUPPORTED_CC = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2) -elif cudart_version_major < 10: - # CUDA 9.x - SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0) -elif cudart_version_major < 11: - # CUDA 10.x - SUPPORTED_CC = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5) -else: - # CUDA 11.0 and later - SUPPORTED_CC = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0) + +def get_supported_ccs(): + global _supported_cc + + if _supported_cc: + return _supported_cc + + try: + from numba.cuda.cudadrv.runtime import runtime + cudart_version_major = runtime.get_version()[0] + except: + # The CUDA Runtime may not be present + cudart_version_major = 0 + + # List of supported compute capability in sorted order + if cudart_version_major == 0: + _supported_cc = (), + elif cudart_version_major < 9: + # CUDA 8.x + _supported_cc = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2) + elif cudart_version_major < 10: + # CUDA 9.x + _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0) + elif cudart_version_major < 11: + # CUDA 10.x + _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5) + else: + # CUDA 11.0 and later + _supported_cc = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0) + + return _supported_cc def find_closest_arch(mycc): @@ -305,7 +315,9 @@ def find_closest_arch(mycc): :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)`` :return: Closest supported CC as a tuple ``(MAJOR, MINOR)`` """ - for i, cc in enumerate(SUPPORTED_CC): + supported_cc = get_supported_ccs() + + for i, cc in enumerate(supported_cc): if cc == mycc: # Matches return cc @@ -317,10 +329,10 @@ def find_closest_arch(mycc): "not supported (requires >=%d.%d)" % (mycc + cc)) else: # return the previous CC - return SUPPORTED_CC[i - 1] + return supported_cc[i - 1] # CC higher than supported - return SUPPORTED_CC[-1] # Choose the highest + return supported_cc[-1] # Choose the highest def get_arch_option(major, minor): diff --git a/numba/cuda/tests/cudadrv/test_nvvm_driver.py b/numba/cuda/tests/cudadrv/test_nvvm_driver.py index 0e40a906d68..ec6d575b534 100644 --- a/numba/cuda/tests/cudadrv/test_nvvm_driver.py +++ b/numba/cuda/tests/cudadrv/test_nvvm_driver.py @@ -1,7 +1,7 @@ from llvmlite.llvmpy.core import Module, Type, Builder from numba.cuda.cudadrv.nvvm import (NVVM, CompilationUnit, llvm_to_ptx, set_cuda_kernel, fix_data_layout, - get_arch_option, SUPPORTED_CC) + get_arch_option, get_supported_ccs) from ctypes import c_size_t, c_uint64, sizeof from numba.cuda.testing import unittest from numba.cuda.cudadrv.nvvm import LibDevice, NvvmError @@ -54,7 +54,7 @@ def _test_nvvm_support(self, arch): def test_nvvm_support(self): """Test supported CC by NVVM """ - for arch in SUPPORTED_CC: + for arch in get_supported_ccs(): self._test_nvvm_support(arch=arch) @unittest.skipIf(True, "No new CC unknown to NVVM yet") @@ -80,10 +80,11 @@ def test_get_arch_option(self): self.assertEqual(get_arch_option(5, 1), 'compute_50') self.assertEqual(get_arch_option(3, 7), 'compute_35') # Test known arch. - for arch in SUPPORTED_CC: + supported_cc = get_supported_ccs() + for arch in supported_cc: self.assertEqual(get_arch_option(*arch), 'compute_%d%d' % arch) self.assertEqual(get_arch_option(1000, 0), - 'compute_%d%d' % SUPPORTED_CC[-1]) + 'compute_%d%d' % supported_cc[-1]) @skip_on_cudasim('NVVM Driver unsupported in the simulator') diff --git a/numba/cuda/tests/cudadrv/test_runtime.py b/numba/cuda/tests/cudadrv/test_runtime.py index 9973020f050..83e49d16a66 100644 --- a/numba/cuda/tests/cudadrv/test_runtime.py +++ b/numba/cuda/tests/cudadrv/test_runtime.py @@ -1,6 +1,20 @@ +import multiprocessing +import os from numba.core import config from numba.cuda.cudadrv.runtime import runtime -from numba.cuda.testing import unittest +from numba.cuda.testing import unittest, SerialMixin + + +def set_visible_devices_and_check(q): + try: + from numba import cuda + import os + + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + q.put(len(cuda.gpus.lst)) + except: # noqa: E722 + # Sentinel value for error executing test code + q.put(-1) class TestRuntime(unittest.TestCase): @@ -13,5 +27,43 @@ def test_get_version(self): self.assertIn(runtime.get_version(), supported_versions) +class TestVisibleDevices(unittest.TestCase, SerialMixin): + def test_visible_devices_set_after_import(self): + # See Issue #6149. This test checks that we can set + # CUDA_VISIBLE_DEVICES after importing Numba and have the value + # reflected in the available list of GPUs. Prior to the fix for this + # issue, Numba made a call to runtime.get_version() on import that + # initialized the driver and froze the list of available devices before + # CUDA_VISIBLE_DEVICES could be set by the user. + + # Avoid importing cuda at the top level so that + # set_visible_devices_and_check gets to import it first in its process + from numba import cuda + + if len(cuda.gpus.lst) in (0, 1): + self.skipTest('This test requires multiple GPUs') + + if os.environ.get('CUDA_VISIBLE_DEVICES'): + msg = 'Cannot test when CUDA_VISIBLE_DEVICES already set' + self.skipTest(msg) + + ctx = multiprocessing.get_context('spawn') + q = ctx.Queue() + p = ctx.Process(target=set_visible_devices_and_check, args=(q,)) + p.start() + try: + visible_gpu_count = q.get() + finally: + p.join() + + # Make an obvious distinction between an error running the test code + # and an incorrect number of GPUs in the list + msg = 'Error running set_visible_devices_and_check' + self.assertNotEqual(visible_gpu_count, -1, msg=msg) + + # The actual check that we see only one GPU + self.assertEqual(visible_gpu_count, 1) + + if __name__ == '__main__': unittest.main()