From 628e1208ae1516f1a2a2a2df37ea4900cead4c7b Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 27 Jun 2025 11:54:51 +0800 Subject: [PATCH] Refactors setup.py for production-ready package build Transforms the basic CUDA extension setup into a comprehensive package configuration with proper metadata, dependencies, and build optimization. Adds copyright header, comprehensive package metadata including author information, description, and PyPI classifiers for better discoverability. Implements dynamic CUDA architecture detection, version checking, and proper error handling for unsupported CUDA versions (requires 11.7+). Introduces NinjaBuildExtension with intelligent job allocation based on available CPU cores and memory to prevent OOM during compilation. Expands source file coverage to include comprehensive flash attention kernels for multiple head dimensions, data types, and attention variants (regular, causal, split). Adds build environment controls through environment variables for forced builds, CUDA skipping, and ABI compatibility. --- setup.py | 373 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 287 insertions(+), 86 deletions(-) diff --git a/setup.py b/setup.py index 6264f50..2c2c19e 100644 --- a/setup.py +++ b/setup.py @@ -1,96 +1,297 @@ +# Copyright (c) 2025, Jingze Shi. + +import sys +import functools +import warnings import os +import re +import ast +import glob +import shutil +from pathlib import Path +from packaging.version import parse, Version import platform + from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -# 获取CUDA主目录 -CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda') -if not os.path.exists(CUDA_HOME): - # 尝试标准位置 - if os.path.exists('/usr/local/cuda'): - CUDA_HOME = '/usr/local/cuda' - elif platform.system() == 'Windows': - # Windows上尝试默认位置 - for cuda_version in range(12, 9, -1): # 尝试CUDA 12至10 - cuda_path = f"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v{cuda_version}.0" - if os.path.exists(cuda_path): - CUDA_HOME = cuda_path - break - -# 获取当前目录 -current_dir = os.path.dirname(os.path.abspath(__file__)) - -# 定义所有包含路径 -include_dirs = [ - os.path.join(CUDA_HOME, 'include'), - os.path.join(current_dir, 'csrc'), # 项目源目录 - os.path.join(current_dir, 'csrc/cutlass/include'), # CUTLASS头文件 - os.path.join(current_dir, 'csrc/cub/cub'), # CUB头文件 - os.path.join(current_dir, 'csrc/src'), # 项目源代码子目录 - # os.path.join(current_dir, 'fcsrc'), # Flash attention 源目录 - # os.path.join(current_dir, 'fcsrc/src'), # Flash attention 源代码子目录 -] - -# 禁用警告的编译标志 -extra_compile_args = { - 'cxx': ['-O3'], - 'nvcc': [ - '-O3', - '-gencode=arch=compute_60,code=sm_60', - '-gencode=arch=compute_70,code=sm_70', - '-gencode=arch=compute_75,code=sm_75', - '-gencode=arch=compute_80,code=sm_80', - '-gencode=arch=compute_86,code=sm_86', - '-gencode=arch=compute_86,code=compute_86', - '--use_fast_math', - '--expt-relaxed-constexpr', # 允许在constexpr中使用更多功能 - '--extended-lambda', # 支持更高级的lambda功能 - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_BFLOAT16_OPERATORS__', - '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', - '-U__CUDA_NO_BFLOAT162_OPERATORS__', - '-U__CUDA_NO_BFLOAT162_CONVERSIONS__', - # 抑制特定警告 - '-Xcudafe', '--diag_suppress=177', - '-Xcudafe', '--diag_suppress=550', - ] -} - -# 源文件列表 -sources = [ - # 'csrc/apply_dynamic_mask_api.cpp', - # 'csrc/apply_dynamic_mask_kernel.cu', - 'csrc/apply_dynamic_mask_attention_api.cpp', - 'csrc/apply_dynamic_mask_attention_kernel.cu', - # 'fcsrc/apply_attention_api.cpp', - # 'fcsrc/apply_attention_kernel.cu', -] - -# 创建扩展 -ext_modules = [ - CUDAExtension( - name='flash_dma_cpp', - sources=sources, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "flash_dma" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_DMA_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_DMA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_DMA_FORCE_CXX11_ABI", "FALSE") == "TRUE" + +@functools.lru_cache(maxsize=None) +def cuda_archs(): + # return os.getenv("FLASH_DMA_CUDA_ARCHS", "80;90;100;120").split(";") + return os.getenv("FLASH_DMA_CUDA_ARCHS", "80").split(";") + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f'linux_{platform.uname().machine}' + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." ) -] -# 设置包 + +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "4" + return nvcc_extra_args + ["--threads", nvcc_threads] + + +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) +else: + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none("flash_dma") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.7"): + raise RuntimeError( + "Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + if "80" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if CUDA_HOME is not None: + if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_100,code=sm_100") + if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_120,code=sm_120") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + ext_modules.append( + CUDAExtension( + name="flash_dma_cuda", + sources=[ + "csrc/flash_api.cpp", + # Forward kernels - regular + "csrc/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/src/flash_fwd_hdim256_bf16_sm80.cu", + # Forward kernels - causal + "csrc/src/flash_fwd_hdim32_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim32_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim64_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim64_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim96_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim96_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim128_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim128_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim192_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim192_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim256_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_hdim256_bf16_causal_sm80.cu", + # Forward kernels - split + "csrc/src/flash_fwd_split_hdim32_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim32_bf16_sm80.cu", + "csrc/src/flash_fwd_split_hdim64_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim64_bf16_sm80.cu", + "csrc/src/flash_fwd_split_hdim96_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim96_bf16_sm80.cu", + "csrc/src/flash_fwd_split_hdim128_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim128_bf16_sm80.cu", + "csrc/src/flash_fwd_split_hdim192_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim192_bf16_sm80.cu", + "csrc/src/flash_fwd_split_hdim256_fp16_sm80.cu", + "csrc/src/flash_fwd_split_hdim256_bf16_sm80.cu", + # Forward kernels - split causal + "csrc/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", + "csrc/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + "-DFLASHATTENTION_DISABLE_BACKWARD", # Only forward pass + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + ] + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "src", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) + ) + + +def get_package_version(): + return "0.1.0" + + +class NinjaBuildExtension(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + import psutil + + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, (os.cpu_count() or 1) // 2) + + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB + max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 + + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + os.environ["MAX_JOBS"] = str(max_jobs) + + super().__init__(*args, **kwargs) + + setup( - name='flash_dma', - version='0.1', - description='Dynamic Mask Attention and Standard Attention for PyTorch', - author='AI Assistant', - author_email='example@example.com', - packages=find_packages(), + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "flash_dma.egg-info", + ) + ), + author="Jingze Shi", + author_email="losercheems@gmail.com", + description="Flash Dynamic Mask Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/SmallDoge/flash-dmattn", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], ext_modules=ext_modules, - cmdclass={ - 'build_ext': BuildExtension - }, + cmdclass={"build_ext": NinjaBuildExtension} + if ext_modules + else {}, + python_requires=">=3.9", install_requires=[ - 'torch>=1.10.0', + "torch", + "einops", + ], + setup_requires=[ + "packaging", + "psutil", + "ninja", ], - python_requires='>=3.7', ) \ No newline at end of file