Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 287 additions & 86 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,297 @@
# Copyright (c) 2025, Jingze Shi.

import sys
import functools
import warnings
import os
import re
import ast
import glob
import shutil
Comment on lines +7 to +10
Copy link

Copilot AI Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There are several imports (e.g., re, ast, glob, shutil) that do not appear to be used; consider removing any unused imports to improve code clarity.

Suggested change
import re
import ast
import glob
import shutil

Copilot uses AI. Check for mistakes.
from pathlib import Path
from packaging.version import parse, Version
import platform

from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

# 获取CUDA主目录
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# 尝试标准位置
if os.path.exists('/usr/local/cuda'):
CUDA_HOME = '/usr/local/cuda'
elif platform.system() == 'Windows':
# Windows上尝试默认位置
for cuda_version in range(12, 9, -1): # 尝试CUDA 12至10
cuda_path = f"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v{cuda_version}.0"
if os.path.exists(cuda_path):
CUDA_HOME = cuda_path
break

# 获取当前目录
current_dir = os.path.dirname(os.path.abspath(__file__))

# 定义所有包含路径
include_dirs = [
os.path.join(CUDA_HOME, 'include'),
os.path.join(current_dir, 'csrc'), # 项目源目录
os.path.join(current_dir, 'csrc/cutlass/include'), # CUTLASS头文件
os.path.join(current_dir, 'csrc/cub/cub'), # CUB头文件
os.path.join(current_dir, 'csrc/src'), # 项目源代码子目录
# os.path.join(current_dir, 'fcsrc'), # Flash attention 源目录
# os.path.join(current_dir, 'fcsrc/src'), # Flash attention 源代码子目录
]

# 禁用警告的编译标志
extra_compile_args = {
'cxx': ['-O3'],
'nvcc': [
'-O3',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_75,code=sm_75',
'-gencode=arch=compute_80,code=sm_80',
'-gencode=arch=compute_86,code=sm_86',
'-gencode=arch=compute_86,code=compute_86',
'--use_fast_math',
'--expt-relaxed-constexpr', # 允许在constexpr中使用更多功能
'--extended-lambda', # 支持更高级的lambda功能
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_BFLOAT16_OPERATORS__',
'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-U__CUDA_NO_BFLOAT162_OPERATORS__',
'-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
# 抑制特定警告
'-Xcudafe', '--diag_suppress=177',
'-Xcudafe', '--diag_suppress=550',
]
}

# 源文件列表
sources = [
# 'csrc/apply_dynamic_mask_api.cpp',
# 'csrc/apply_dynamic_mask_kernel.cu',
'csrc/apply_dynamic_mask_attention_api.cpp',
'csrc/apply_dynamic_mask_attention_kernel.cu',
# 'fcsrc/apply_attention_api.cpp',
# 'fcsrc/apply_attention_kernel.cu',
]

# 创建扩展
ext_modules = [
CUDAExtension(
name='flash_dma_cpp',
sources=sources,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
import subprocess

import urllib.request
import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

import torch
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
)


with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

PACKAGE_NAME = "flash_dma"

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_DMA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_DMA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_DMA_FORCE_CXX11_ABI", "FALSE") == "TRUE"

@functools.lru_cache(maxsize=None)
def cuda_archs():
# return os.getenv("FLASH_DMA_CUDA_ARCHS", "80;90;100;120").split(";")
return os.getenv("FLASH_DMA_CUDA_ARCHS", "80").split(";")


def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith("linux"):
return f'linux_{platform.uname().machine}'
elif sys.platform == "darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
return f"macosx_{mac_version}_x86_64"
elif sys.platform == "win32":
return "win_amd64"
else:
raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
]

# 设置包

def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]


cmdclass = {}
ext_modules = []

# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
if os.path.isdir(".git"):
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True)
else:
assert (
os.path.exists("csrc/cutlass/include/cutlass/cutlass.h")
), "csrc/cutlass is missing, please use source distribution or git clone"

if not SKIP_CUDA_BUILD:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])

check_if_cuda_home_none("flash_dma")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.7"):
raise RuntimeError(
"Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)

if "80" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8") and "90" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
if bare_metal_version >= Version("12.8") and "100" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100")
if bare_metal_version >= Version("12.8") and "120" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_120,code=sm_120")

# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True

ext_modules.append(
CUDAExtension(
name="flash_dma_cuda",
sources=[
"csrc/flash_api.cpp",
# Forward kernels - regular
"csrc/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/src/flash_fwd_hdim256_bf16_sm80.cu",
# Forward kernels - causal
"csrc/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
# Forward kernels - split
"csrc/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"csrc/src/flash_fwd_split_hdim96_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"csrc/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"csrc/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"csrc/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/src/flash_fwd_split_hdim256_bf16_sm80.cu",
# Forward kernels - split causal
"csrc/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
# "-lineinfo",
"-DFLASHATTENTION_DISABLE_BACKWARD", # Only forward pass
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
]
+ cc_flag
),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)


def get_package_version():
return "0.1.0"


class NinjaBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
# do not override env MAX_JOBS if already exists
if not os.environ.get("MAX_JOBS"):
import psutil

# calculate the maximum allowed NUM_JOBS based on cores
max_num_jobs_cores = max(1, (os.cpu_count() or 1) // 2)

# calculate the maximum allowed NUM_JOBS based on free memory
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4

# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
os.environ["MAX_JOBS"] = str(max_jobs)

super().__init__(*args, **kwargs)


setup(
name='flash_dma',
version='0.1',
description='Dynamic Mask Attention and Standard Attention for PyTorch',
author='AI Assistant',
author_email='example@example.com',
packages=find_packages(),
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=(
"build",
"csrc",
"include",
"tests",
"dist",
"docs",
"benchmarks",
"flash_dma.egg-info",
)
),
author="Jingze Shi",
author_email="losercheems@gmail.com",
description="Flash Dynamic Mask Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/SmallDoge/flash-dmattn",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules,
cmdclass={
'build_ext': BuildExtension
},
cmdclass={"build_ext": NinjaBuildExtension}
if ext_modules
else {},
python_requires=">=3.9",
install_requires=[
'torch>=1.10.0',
"torch",
"einops",
],
setup_requires=[
"packaging",
"psutil",
"ninja",
],
python_requires='>=3.7',
)