Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/core_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pre-commit
- pyarrow>=5
- pydata-sphinx-theme
- pynvml
- pytest
- python>=3.7
- setuptools>=60
Expand Down
22 changes: 21 additions & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,26 @@ def install(
)
dump_json_config(cuda_config, cuda_dir)

arch_config = os.path.join(legate_core_dir, ".arch.json")
if arch is None:
arch = load_json_config(arch_config)
if arch is None:
try:
import pynvml

pynvml.nvmlInit()
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(
pynvml.nvmlDeviceGetHandleByIndex(0)
)
arch = f"{major}{minor}"
pynvml.nvmlShutdown()
except Exception as exc:
raise Exception(
"Could not auto-detect CUDA GPU architecture, please "
"specify the target architecture using --arch"
) from exc
dump_json_config(arch_config, arch)

nccl_config = os.path.join(legate_core_dir, ".nccl.json")
if nccl_dir is None:
nccl_dir = load_json_config(nccl_config)
Expand Down Expand Up @@ -866,7 +886,7 @@ def driver():
dest="arch",
action="store",
required=False,
default="volta",
default=None,
help="Specify the target GPU architecture.",
)
parser.add_argument(
Expand Down