Skip to content

Commit

Permalink
Merge pull request #8012 from kmaehashi/fix-cutensor-wheel-discovery
Browse files Browse the repository at this point in the history
Avoid using `pkg_resources` for cuTENSOR wheel discovery
  • Loading branch information
emcastillo committed Dec 4, 2023
2 parents 27cafa0 + 09b42a5 commit 83d0917
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions cupy/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import json
import os
import os.path
import re
import shutil
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import warnings


Expand Down Expand Up @@ -379,30 +380,46 @@ def _preload_library(lib):
_log(f'Library {lib} could not be preloaded: {e}')


def _parse_version(version: str) -> Tuple[int, int, int]:
parts = re.split(r'[^\d]', version, maxsplit=3)
major = int(parts[0])
minor = int(parts[1]) if len(parts) >= 2 else 0
patch = int(parts[2]) if len(parts) >= 3 else 0
return major, minor, patch


def _get_cutensor_from_wheel(version: str, cuda: str) -> List[str]:
"""
Returns the list of shared library path candidates for cuTENSOR
installed via Pip (cutensor-cuXX package).
"""
cuda_major_ver, _ = cuda.split('.')
import pkg_resources # defer import # NOQA
cutensor_pkg = f'cutensor-cu{cuda_major_ver}'
try:
# load any compatible version (ex: version=1.6.2, load >=1.6.2,<1.7)
dist = pkg_resources.get_distribution(
f'cutensor-cu{cuda_major_ver}~={version}')
except pkg_resources.ResolutionError as e:
_log(f'cuTENSOR wheel could not be loaded: {type(e).__name__}: {e}')
cutensor_dist = importlib.metadata.distribution(cutensor_pkg)
except importlib.metadata.PackageNotFoundError:
_log(f'cuTENSOR wheel package not installed: {cutensor_pkg}')
return []

actual = _parse_version(cutensor_dist.version)
expected = _parse_version(version)
is_compatible = (
actual[0] == expected[0] and
actual[1] >= expected[1] and
actual[2] >= expected[2]
)
if not is_compatible:
_log('cuTENSOR wheel incompatible: '
f'expected {version}, found {cutensor_dist.version}')
return []

if sys.platform == 'linux':
shared_lib = os.path.join(
dist.module_path, 'cutensor', 'lib',
f'libcutensor.so.{version.split(".")[0]}'
shared_lib = cutensor_dist.locate_file(
f'cutensor/lib/libcutensor.so.{version.split(".")[0]}'
)
else:
shared_lib = os.path.join(
dist.module_path, 'cutensor', 'bin', 'cutensor.dll'
)
return [shared_lib]
shared_lib = cutensor_dist.locate_file('cutensor\\bin\\cutensor.dll')
return [str(shared_lib)]


def _preload_warning(lib, exc):
Expand Down

0 comments on commit 83d0917

Please sign in to comment.