diff --git a/pyproject.toml b/pyproject.toml index b8c414b..981a53f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,9 @@ requires = [ "wheel", "packaging", "psutil", - "ninja" + "ninja", + "scikit-build-core>=0.10", + "torch>=2.8.0", ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 1e0f075..f8af57c 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ def should_skip_cuda_build(): @functools.lru_cache(maxsize=None) def cuda_archs(): - return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86,89,90;100;120").split(";") + return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";") def get_platform(): @@ -151,6 +151,7 @@ def append_nvcc_threads(nvcc_extra_args): if "80" in cuda_archs(): cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") + if CUDA_HOME is not None: if bare_metal_version >= Version("11.8") and "86" in cuda_archs(): cc_flag.append("-gencode")