diff --git a/build2cmake/src/cuda_supported_archs.json b/build2cmake/src/cuda_supported_archs.json index 851fa7d6..41a56d3f 100644 --- a/build2cmake/src/cuda_supported_archs.json +++ b/build2cmake/src/cuda_supported_archs.json @@ -1,13 +1 @@ -[ - "7.0", - "7.2", - "7.5", - "8.0", - "8.6", - "8.7", - "8.9", - "9.0+PTX", - "10.0", - "10.1", - "12.0+PTX" -] +["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0"] diff --git a/build2cmake/src/templates/cuda/kernel.cmake b/build2cmake/src/templates/cuda/kernel.cmake index 866a4357..8205c959 100644 --- a/build2cmake/src/templates/cuda/kernel.cmake +++ b/build2cmake/src/templates/cuda/kernel.cmake @@ -18,7 +18,7 @@ if(GPU_LANG STREQUAL "CUDA") {% if cuda_capabilities %} cuda_archs_loose_intersection({{kernel_name}}_ARCHS "{{ cuda_capabilities|join(";") }}" "${CUDA_ARCHS}") {% else %} - cuda_archs_loose_intersection({{kernel_name}}_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") + cuda_archs_loose_intersection({{kernel_name}}_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}") {% endif %} message(STATUS "Capabilities for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}") set_gencode_flags_for_srcs(SRCS {{'"${' + kernel_name + '_SRC}"'}} CUDA_ARCHS "{{ '${' + kernel_name + '_ARCHS}'}}") diff --git a/build2cmake/src/templates/cuda/preamble.cmake b/build2cmake/src/templates/cuda/preamble.cmake index 642dc033..62f547c6 100644 --- a/build2cmake/src/templates/cuda/preamble.cmake +++ b/build2cmake/src/templates/cuda/preamble.cmake @@ -34,6 +34,13 @@ if (NOT TARGET_DEVICE STREQUAL "cuda" AND return() endif() +if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX") +else() + set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX") +endif() + if (NOT HIP_FOUND AND CUDA_FOUND) set(GPU_LANG "CUDA") diff --git a/lib/torch-extension/default.nix b/lib/torch-extension/default.nix index 84c708db..adbde6c8 100644 --- a/lib/torch-extension/default.nix +++ b/lib/torch-extension/default.nix @@ -132,9 +132,9 @@ stdenv.mkDerivation (prevAttrs: { CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}"; TORCH_CUDA_ARCH_LIST = if cudaPackages.cudaOlder "12.8" then - "7.0;7.5;8.0;8.6;8.9;9.0+PTX" + "7.0;7.5;8.0;8.6;8.9;9.0" else - "7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0+PTX"; + "7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0"; } // lib.optionalAttrs rocmSupport { PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;