-
Notifications
You must be signed in to change notification settings - Fork 39
Refactors setup.py for production-ready package build #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,96 +1,297 @@ | ||
| # Copyright (c) 2025, Jingze Shi. | ||
|
|
||
| import sys | ||
| import functools | ||
| import warnings | ||
| import os | ||
| import re | ||
| import ast | ||
| import glob | ||
| import shutil | ||
| from pathlib import Path | ||
| from packaging.version import parse, Version | ||
| import platform | ||
|
|
||
| from setuptools import setup, find_packages | ||
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
|
||
| # 获取CUDA主目录 | ||
| CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda') | ||
| if not os.path.exists(CUDA_HOME): | ||
| # 尝试标准位置 | ||
| if os.path.exists('/usr/local/cuda'): | ||
| CUDA_HOME = '/usr/local/cuda' | ||
| elif platform.system() == 'Windows': | ||
| # Windows上尝试默认位置 | ||
| for cuda_version in range(12, 9, -1): # 尝试CUDA 12至10 | ||
| cuda_path = f"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v{cuda_version}.0" | ||
| if os.path.exists(cuda_path): | ||
| CUDA_HOME = cuda_path | ||
| break | ||
|
|
||
| # 获取当前目录 | ||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
|
||
| # 定义所有包含路径 | ||
| include_dirs = [ | ||
| os.path.join(CUDA_HOME, 'include'), | ||
| os.path.join(current_dir, 'csrc'), # 项目源目录 | ||
| os.path.join(current_dir, 'csrc/cutlass/include'), # CUTLASS头文件 | ||
| os.path.join(current_dir, 'csrc/cub/cub'), # CUB头文件 | ||
| os.path.join(current_dir, 'csrc/src'), # 项目源代码子目录 | ||
| # os.path.join(current_dir, 'fcsrc'), # Flash attention 源目录 | ||
| # os.path.join(current_dir, 'fcsrc/src'), # Flash attention 源代码子目录 | ||
| ] | ||
|
|
||
| # 禁用警告的编译标志 | ||
| extra_compile_args = { | ||
| 'cxx': ['-O3'], | ||
| 'nvcc': [ | ||
| '-O3', | ||
| '-gencode=arch=compute_60,code=sm_60', | ||
| '-gencode=arch=compute_70,code=sm_70', | ||
| '-gencode=arch=compute_75,code=sm_75', | ||
| '-gencode=arch=compute_80,code=sm_80', | ||
| '-gencode=arch=compute_86,code=sm_86', | ||
| '-gencode=arch=compute_86,code=compute_86', | ||
| '--use_fast_math', | ||
| '--expt-relaxed-constexpr', # 允许在constexpr中使用更多功能 | ||
| '--extended-lambda', # 支持更高级的lambda功能 | ||
| '-U__CUDA_NO_HALF_OPERATORS__', | ||
| '-U__CUDA_NO_HALF_CONVERSIONS__', | ||
| '-U__CUDA_NO_BFLOAT16_OPERATORS__', | ||
| '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', | ||
| '-U__CUDA_NO_BFLOAT162_OPERATORS__', | ||
| '-U__CUDA_NO_BFLOAT162_CONVERSIONS__', | ||
| # 抑制特定警告 | ||
| '-Xcudafe', '--diag_suppress=177', | ||
| '-Xcudafe', '--diag_suppress=550', | ||
| ] | ||
| } | ||
|
|
||
| # 源文件列表 | ||
| sources = [ | ||
| # 'csrc/apply_dynamic_mask_api.cpp', | ||
| # 'csrc/apply_dynamic_mask_kernel.cu', | ||
| 'csrc/apply_dynamic_mask_attention_api.cpp', | ||
| 'csrc/apply_dynamic_mask_attention_kernel.cu', | ||
| # 'fcsrc/apply_attention_api.cpp', | ||
| # 'fcsrc/apply_attention_kernel.cu', | ||
| ] | ||
|
|
||
| # 创建扩展 | ||
| ext_modules = [ | ||
| CUDAExtension( | ||
| name='flash_dma_cpp', | ||
| sources=sources, | ||
| include_dirs=include_dirs, | ||
| extra_compile_args=extra_compile_args, | ||
| import subprocess | ||
|
|
||
| import urllib.request | ||
| import urllib.error | ||
| from wheel.bdist_wheel import bdist_wheel as _bdist_wheel | ||
|
|
||
| import torch | ||
| from torch.utils.cpp_extension import ( | ||
| BuildExtension, | ||
| CppExtension, | ||
| CUDAExtension, | ||
| CUDA_HOME, | ||
| ) | ||
|
|
||
|
|
||
| with open("README.md", "r", encoding="utf-8") as fh: | ||
| long_description = fh.read() | ||
|
|
||
|
|
||
| # ninja build does not work unless include_dirs are abs path | ||
| this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
|
||
| PACKAGE_NAME = "flash_dma" | ||
|
|
||
| # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels | ||
| # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation | ||
| FORCE_BUILD = os.getenv("FLASH_DMA_FORCE_BUILD", "FALSE") == "TRUE" | ||
| SKIP_CUDA_BUILD = os.getenv("FLASH_DMA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" | ||
| # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI | ||
| FORCE_CXX11_ABI = os.getenv("FLASH_DMA_FORCE_CXX11_ABI", "FALSE") == "TRUE" | ||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def cuda_archs(): | ||
| # return os.getenv("FLASH_DMA_CUDA_ARCHS", "80;90;100;120").split(";") | ||
| return os.getenv("FLASH_DMA_CUDA_ARCHS", "80").split(";") | ||
|
|
||
|
|
||
| def get_platform(): | ||
| """ | ||
| Returns the platform name as used in wheel filenames. | ||
| """ | ||
| if sys.platform.startswith("linux"): | ||
| return f'linux_{platform.uname().machine}' | ||
| elif sys.platform == "darwin": | ||
| mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) | ||
| return f"macosx_{mac_version}_x86_64" | ||
| elif sys.platform == "win32": | ||
| return "win_amd64" | ||
| else: | ||
| raise ValueError("Unsupported platform: {}".format(sys.platform)) | ||
|
|
||
|
|
||
| def get_cuda_bare_metal_version(cuda_dir): | ||
| raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | ||
| output = raw_output.split() | ||
| release_idx = output.index("release") + 1 | ||
| bare_metal_version = parse(output[release_idx].split(",")[0]) | ||
|
|
||
| return raw_output, bare_metal_version | ||
|
|
||
|
|
||
| def check_if_cuda_home_none(global_option: str) -> None: | ||
| if CUDA_HOME is not None: | ||
| return | ||
| # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary | ||
| # in that case. | ||
| warnings.warn( | ||
| f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " | ||
| "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " | ||
| "only images whose names contain 'devel' will provide nvcc." | ||
| ) | ||
| ] | ||
|
|
||
| # 设置包 | ||
|
|
||
| def append_nvcc_threads(nvcc_extra_args): | ||
| nvcc_threads = os.getenv("NVCC_THREADS") or "4" | ||
| return nvcc_extra_args + ["--threads", nvcc_threads] | ||
|
|
||
|
|
||
| cmdclass = {} | ||
| ext_modules = [] | ||
|
|
||
| # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp | ||
| # files included in the source distribution, in case the user compiles from source. | ||
| if os.path.isdir(".git"): | ||
| subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) | ||
| else: | ||
| assert ( | ||
| os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") | ||
| ), "csrc/cutlass is missing, please use source distribution or git clone" | ||
|
|
||
| if not SKIP_CUDA_BUILD: | ||
| print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) | ||
| TORCH_MAJOR = int(torch.__version__.split(".")[0]) | ||
| TORCH_MINOR = int(torch.__version__.split(".")[1]) | ||
|
|
||
| check_if_cuda_home_none("flash_dma") | ||
| # Check, if CUDA11 is installed for compute capability 8.0 | ||
| cc_flag = [] | ||
| if CUDA_HOME is not None: | ||
| _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) | ||
| if bare_metal_version < Version("11.7"): | ||
| raise RuntimeError( | ||
| "Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. " | ||
| "Note: make sure nvcc has a supported version by running nvcc -V." | ||
| ) | ||
|
|
||
| if "80" in cuda_archs(): | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_80,code=sm_80") | ||
| if CUDA_HOME is not None: | ||
| if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_90,code=sm_90") | ||
| if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_100,code=sm_100") | ||
| if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_120,code=sm_120") | ||
|
|
||
| # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as | ||
| # torch._C._GLIBCXX_USE_CXX11_ABI | ||
| # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 | ||
| if FORCE_CXX11_ABI: | ||
| torch._C._GLIBCXX_USE_CXX11_ABI = True | ||
|
|
||
| ext_modules.append( | ||
| CUDAExtension( | ||
| name="flash_dma_cuda", | ||
| sources=[ | ||
| "csrc/flash_api.cpp", | ||
| # Forward kernels - regular | ||
| "csrc/src/flash_fwd_hdim32_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim32_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim64_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim64_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim96_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim96_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim128_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim128_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim192_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim192_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim256_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim256_bf16_sm80.cu", | ||
| # Forward kernels - causal | ||
| "csrc/src/flash_fwd_hdim32_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim32_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim64_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim64_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim96_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim96_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim128_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim128_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim192_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim192_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim256_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_hdim256_bf16_causal_sm80.cu", | ||
| # Forward kernels - split | ||
| "csrc/src/flash_fwd_split_hdim32_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim32_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim64_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim64_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim96_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim96_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim128_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim128_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim192_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim192_bf16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim256_fp16_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim256_bf16_sm80.cu", | ||
| # Forward kernels - split causal | ||
| "csrc/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", | ||
| "csrc/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", | ||
| ], | ||
| extra_compile_args={ | ||
| "cxx": ["-O3", "-std=c++17"], | ||
| "nvcc": append_nvcc_threads( | ||
| [ | ||
| "-O3", | ||
| "-std=c++17", | ||
| "-U__CUDA_NO_HALF_OPERATORS__", | ||
| "-U__CUDA_NO_HALF_CONVERSIONS__", | ||
| "-U__CUDA_NO_HALF2_OPERATORS__", | ||
| "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
| "--expt-relaxed-constexpr", | ||
| "--expt-extended-lambda", | ||
| "--use_fast_math", | ||
| # "--ptxas-options=-v", | ||
| # "--ptxas-options=-O2", | ||
| # "-lineinfo", | ||
| "-DFLASHATTENTION_DISABLE_BACKWARD", # Only forward pass | ||
| # "-DFLASHATTENTION_DISABLE_DROPOUT", | ||
| # "-DFLASHATTENTION_DISABLE_SOFTCAP", | ||
| # "-DFLASHATTENTION_DISABLE_UNEVEN_K", | ||
| ] | ||
| + cc_flag | ||
| ), | ||
| }, | ||
| include_dirs=[ | ||
| Path(this_dir) / "csrc", | ||
| Path(this_dir) / "csrc" / "src", | ||
| Path(this_dir) / "csrc" / "cutlass" / "include", | ||
| ], | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| def get_package_version(): | ||
| return "0.1.0" | ||
|
|
||
|
|
||
| class NinjaBuildExtension(BuildExtension): | ||
| def __init__(self, *args, **kwargs) -> None: | ||
| # do not override env MAX_JOBS if already exists | ||
| if not os.environ.get("MAX_JOBS"): | ||
| import psutil | ||
|
|
||
| # calculate the maximum allowed NUM_JOBS based on cores | ||
| max_num_jobs_cores = max(1, (os.cpu_count() or 1) // 2) | ||
|
|
||
| # calculate the maximum allowed NUM_JOBS based on free memory | ||
| free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB | ||
| max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 | ||
|
|
||
| # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation | ||
| max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) | ||
| os.environ["MAX_JOBS"] = str(max_jobs) | ||
|
|
||
| super().__init__(*args, **kwargs) | ||
|
|
||
|
|
||
| setup( | ||
| name='flash_dma', | ||
| version='0.1', | ||
| description='Dynamic Mask Attention and Standard Attention for PyTorch', | ||
| author='AI Assistant', | ||
| author_email='example@example.com', | ||
| packages=find_packages(), | ||
| name=PACKAGE_NAME, | ||
| version=get_package_version(), | ||
| packages=find_packages( | ||
| exclude=( | ||
| "build", | ||
| "csrc", | ||
| "include", | ||
| "tests", | ||
| "dist", | ||
| "docs", | ||
| "benchmarks", | ||
| "flash_dma.egg-info", | ||
| ) | ||
| ), | ||
| author="Jingze Shi", | ||
| author_email="losercheems@gmail.com", | ||
| description="Flash Dynamic Mask Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention", | ||
| long_description=long_description, | ||
| long_description_content_type="text/markdown", | ||
| url="https://github.com/SmallDoge/flash-dmattn", | ||
| classifiers=[ | ||
| "Programming Language :: Python :: 3", | ||
| "License :: OSI Approved :: BSD License", | ||
| "Operating System :: Unix", | ||
| ], | ||
| ext_modules=ext_modules, | ||
| cmdclass={ | ||
| 'build_ext': BuildExtension | ||
| }, | ||
| cmdclass={"build_ext": NinjaBuildExtension} | ||
| if ext_modules | ||
| else {}, | ||
| python_requires=">=3.9", | ||
| install_requires=[ | ||
| 'torch>=1.10.0', | ||
| "torch", | ||
| "einops", | ||
| ], | ||
| setup_requires=[ | ||
| "packaging", | ||
| "psutil", | ||
| "ninja", | ||
| ], | ||
| python_requires='>=3.7', | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] There are several imports (e.g., re, ast, glob, shutil) that do not appear to be used; consider removing any unused imports to improve code clarity.