From 0ebcdb9d9b556a392fa38c90f67d5df6e0d67834 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 19:54:45 +0800 Subject: [PATCH 1/4] Adds arch input parameter to build workflow Enables targeting specific compute capabilities during the build process by adding an optional arch parameter to the workflow inputs. This provides more granular control over the build configuration while maintaining backward compatibility with the existing default behavior. --- .github/workflows/build.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 25ea5e8..af3d00b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,6 +33,10 @@ on: description: "Upload wheel to this release" required: false type: string + arch: + description: "Target a single compute capability. Leave empty to use project default" + required: false + type: string jobs: build-wheels: @@ -45,3 +49,4 @@ jobs: cxx11_abi: ${{ inputs.cxx11_abi }} upload-to-release: ${{ inputs.upload-to-release }} release-version: ${{ inputs.release-version }} + arch: ${{ inputs.arch }} From 9609d4febd6ada3f72b7afc1b424c32383662e82 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 19:55:50 +0800 Subject: [PATCH 2/4] Adds GPU architecture matrix to publish workflow Enables building packages for multiple GPU architectures (80, 86, 89, 90, 100, 120) by adding an architecture matrix parameter to the publish workflow. Expands compatibility across different NVIDIA GPU generations and ensures optimal performance for each target architecture. --- .github/workflows/publish.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 000d723..38bdce0 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -52,6 +52,7 @@ jobs: # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ["FALSE", "TRUE"] + arch: ["80", "86", "89", "90", "100", "120"] include: - torch-version: "2.9.0.dev20250904" cuda-version: "13.0" @@ -70,6 +71,7 @@ jobs: cuda-version: ${{ matrix.cuda-version }} torch-version: ${{ matrix.torch-version }} cxx11_abi: ${{ matrix.cxx11_abi }} + arch: ${{ matrix.arch }} release-version: ${{ needs.setup_release.outputs.release-version }} upload-to-release: true From 206598aae8a15175714e4c15fa61434812690d3f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 19:59:31 +0800 Subject: [PATCH 3/4] Adds configurable arch input for targeted builds Enables building for a single compute capability to reduce build time when targeting specific GPU architectures. Updates wheel naming convention to include arch identifier when specified, ensuring proper artifact identification for architecture-specific builds. --- .github/workflows/_build.yml | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index bb13eb6..67d0cb7 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -32,6 +32,10 @@ on: description: "Upload wheel to this release" required: false type: string + arch: + description: "Target a single compute capability. Leave empty to build default archs" + required: false + type: string defaults: run: @@ -59,6 +63,7 @@ jobs: echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_ARCH=${{ inputs.arch }}" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} @@ -170,12 +175,21 @@ jobs: export FLASH_DMATTN_FORCE_BUILD="TRUE" export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + # If specified, limit to a single compute capability to speed up build + if [ -n "${MATRIX_ARCH}" ]; then + export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}" + fi + # 5h timeout since GH allows max 6h and we want some buffer EXIT_CODE=0 timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? if [ $EXIT_CODE -eq 0 ]; then - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + if [ -n "${MATRIX_ARCH}" ]; then + tmpname=sm${MATRIX_ARCH}cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + else + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + fi wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV From 8a8a456fba624a8f2924a485d941e6818370f242 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 20:01:39 +0800 Subject: [PATCH 4/4] Adds SM architecture detection for wheel naming Introduces automatic detection of the preferred SM (Streaming Multiprocessor) architecture from the current CUDA device to improve wheel filename specificity. The detection function safely handles cases where CUDA is unavailable or detection fails by returning None. This enhancement allows for more precise wheel identification based on the actual hardware capabilities rather than relying solely on CUDA version information. Removes unused imports to clean up the codebase. --- setup.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index cd325d8..8c847ac 100644 --- a/setup.py +++ b/setup.py @@ -7,12 +7,12 @@ import re import ast import glob -import shutil from pathlib import Path from packaging.version import parse, Version import platform +from typing import Optional -from setuptools import setup, find_packages +from setuptools import setup import subprocess import urllib.request @@ -22,7 +22,6 @@ import torch from torch.utils.cpp_extension import ( BuildExtension, - CppExtension, CUDAExtension, CUDA_HOME, ) @@ -83,6 +82,20 @@ def cuda_archs(): return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;86;89;90;100;120").split(";") +def detect_preferred_sm_arch() -> Optional[str]: + """Detect the preferred SM arch from the current CUDA device. + Returns None if CUDA is unavailable or detection fails. + """ + try: + if torch.cuda.is_available(): + idx = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(idx) + return f"{major}{minor}" + except Exception: + pass + return None + + def get_platform(): """ Returns the platform name as used in wheel filenames. @@ -237,6 +250,7 @@ def get_package_version(): def get_wheel_url(): + sm_arch = detect_preferred_sm_arch() torch_version_raw = parse(torch.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() @@ -255,7 +269,7 @@ def get_wheel_url(): cuda_version = f"{torch_cuda_version.major}" # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_filename = f"{PACKAGE_NAME}-{flash_version}+sm{sm_arch}cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)