Skip to content

Commit

Permalink
Merge pull request #6831 from chainer-ci/bp-6819-v10-infer-version-by…
Browse files Browse the repository at this point in the history
…-nvrtc

[backport] cupy-wheel: Use NVRTC to infer the toolkit version
  • Loading branch information
kmaehashi committed Jun 29, 2022
2 parents 4a24ef4 + 38c4619 commit bf54a5b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
77 changes: 65 additions & 12 deletions install/universal_pkg/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def _log(msg: str) -> None:


def _get_version_from_library(
libnames: List[str], funcname: str) -> Optional[int]:
libnames: List[str],
funcname: str,
nvrtc: bool = False,
) -> Optional[int]:
"""Returns the library version from list of candidate libraries."""

for libname in libnames:
Expand All @@ -70,30 +73,69 @@ def _get_version_from_library(
if func is None:
raise AutoDetectionFailed(
f'{libname}: {func} could not be found')
func.restype = ctypes.c_int # cudaError_t
func.argtypes = [ctypes.POINTER(ctypes.c_int)]
version_ptr = ctypes.c_int()
retval = func(version_ptr)
if retval != 0:
func.restype = ctypes.c_int

if nvrtc:
# nvrtcVersion
func.argtypes = [
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
]
major = ctypes.c_int()
minor = ctypes.c_int()
retval = func(major, minor)
version = major.value * 1000 + minor.value * 10
else:
# cudaRuntimeGetVersion
func.argtypes = [
ctypes.POINTER(ctypes.c_int),
]
version_ref = ctypes.c_int()
retval = func(version_ref)
version = version_ref.value

if retval != 0: # NVRTC_SUCCESS or cudaSuccess
raise AutoDetectionFailed(
f'{libname}: {func} returned error: {retval}')
version = version_ptr.value
_log(f'Detected version: {version}')
return version


def _setup_win32_dll_directory() -> None:
if not hasattr(os, 'add_dll_directory'):
# Python 3.7 or earlier.
return
cuda_path = os.environ.get('CUDA_PATH', None)
if cuda_path is None:
_log('CUDA_PATH is not set.'
'cupy-wheel may not be able to discover NVRTC to probe version')
return
os.add_dll_directory(os.path.join(cuda_path, 'bin')) # type: ignore[attr-defined] # NOQA


def _get_cuda_version() -> Optional[int]:
"""Returns the detected CUDA version or None."""

if sys.platform == 'linux':
libnames = ['libcudart.so']
libnames = [
'libnvrtc.so.11.2',
'libnvrtc.so.11.1',
'libnvrtc.so.11.0',
'libnvrtc.so.10.2',
]
elif sys.platform == 'win32':
libnames = ['cudart64_110.dll', 'cudart64_102.dll']
libnames = [
'nvrtc64_112_0.dll',
'nvrtc64_111_0.dll',
'nvrtc64_110_0.dll',
'nvrtc64_102_0.dll',
]
_setup_win32_dll_directory()
else:
_log(f'CUDA detection unsupported on platform: {sys.platform}')
return None
_log(f'Trying to detect CUDA version from libraries: {libnames}')
version = _get_version_from_library(libnames, 'cudaRuntimeGetVersion')
version = _get_version_from_library(libnames, 'nvrtcVersion', True)
return version


Expand Down Expand Up @@ -159,10 +201,21 @@ def _cuda_version_to_package(ver: int) -> str:


def _rocm_version_to_package(ver: int) -> str:
if 400 <= ver < 410:
"""
ROCm 4.0.x = 3212
ROCm 4.1.x = 3241
ROCm 4.2.0 = 3275
ROCm 4.3.0 = 40321300
ROCm 4.3.1 = 40321331
ROCm 4.5.0 = 40421401
ROCm 4.5.1 = 40421432
ROCm 5.0.0 = 50013601
ROCm 5.1.0 = 50120531
"""
if ver == 3212:
# ROCm 4.0
suffix = '4-0'
elif 420 <= ver < 430:
elif ver == 3275:
# ROCm 4.2
suffix = '4-2'
elif 4_03_00000 <= ver < 4_04_00000:
Expand Down
2 changes: 1 addition & 1 deletion tests/install_tests/test_universal_pkg/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_cuda_version_to_package():
def test_rocm_version_to_package():
with pytest.raises(setup.AutoDetectionFailed):
assert setup._rocm_version_to_package(399)
assert setup._rocm_version_to_package(400) == 'cupy-rocm-4-0'
assert setup._rocm_version_to_package(3212) == 'cupy-rocm-4-0'
assert setup._rocm_version_to_package(5_00_13601) == 'cupy-rocm-5-0'
with pytest.raises(setup.AutoDetectionFailed):
assert setup._rocm_version_to_package(9_00_00000)
Expand Down

0 comments on commit bf54a5b

Please sign in to comment.