Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

PACKAGE_NAME = "flash_dmattn"

BASE_WHEEL_URL = (
"https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}"
)

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
# Also useful when user only wants Triton/Flex backends without CUDA compilation
Expand Down Expand Up @@ -307,6 +311,67 @@ def get_package_version():
return str(public_version)


def get_wheel_url():
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}"

# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)

return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""

def run(self):
if FORCE_BUILD:
return super().run()

wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
Copy link

Copilot AI Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print statements for logging in production code is not ideal. Consider using the logging module for better control over log levels and output formatting.

Copilot uses AI. Check for mistakes.
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
Comment on lines +354 to +355
Copy link

Copilot AI Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using urllib.request.urlretrieve without SSL verification or timeout could pose security risks. Consider adding timeout parameter and validating the downloaded file integrity with checksums.

Copilot uses AI. Check for mistakes.

# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)

impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
Copy link

Copilot AI Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print statements for logging in production code is not ideal. Consider using the logging module for better control over log levels and output formatting.

Copilot uses AI. Check for mistakes.
os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError):
Copy link

Copilot AI Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is too generic and doesn't provide useful debugging information. Consider logging the specific HTTP status code or error details to help users understand why the wheel download failed.

Suggested change
except (urllib.error.HTTPError, urllib.error.URLError):
except (urllib.error.HTTPError, urllib.error.URLError) as e:
print(f"Failed to download precompiled wheel from {wheel_url}.")
print(f"Error type: {type(e).__name__}")
if hasattr(e, 'code'):
print(f"HTTP status code: {getattr(e, 'code', 'N/A')}")
if hasattr(e, 'reason'):
print(f"Reason: {getattr(e, 'reason', 'N/A')}")

Copilot uses AI. Check for mistakes.
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()


class NinjaBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
# do not override env MAX_JOBS if already exists
Expand All @@ -329,7 +394,9 @@ def __init__(self, *args, **kwargs) -> None:

setup(
ext_modules=ext_modules,
cmdclass={"build_ext": NinjaBuildExtension}
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
if ext_modules
else {},
else {
"bdist_wheel": CachedWheelsCommand,
},
)