-
Notifications
You must be signed in to change notification settings - Fork 41
Add wheel URL generation and caching for pre-built wheels in setup.py #156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -37,6 +37,10 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| PACKAGE_NAME = "flash_dmattn" | ||||||||||||||||||
|
|
||||||||||||||||||
| BASE_WHEEL_URL = ( | ||||||||||||||||||
| "https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels | ||||||||||||||||||
| # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation | ||||||||||||||||||
| # Also useful when user only wants Triton/Flex backends without CUDA compilation | ||||||||||||||||||
|
|
@@ -307,6 +311,67 @@ def get_package_version(): | |||||||||||||||||
| return str(public_version) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_wheel_url(): | ||||||||||||||||||
| torch_version_raw = parse(torch.__version__) | ||||||||||||||||||
| python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" | ||||||||||||||||||
| platform_name = get_platform() | ||||||||||||||||||
| flash_version = get_package_version() | ||||||||||||||||||
| torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" | ||||||||||||||||||
| cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Determine the version numbers that will be used to determine the correct wheel | ||||||||||||||||||
| # We're using the CUDA version used to build torch, not the one currently installed | ||||||||||||||||||
| # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) | ||||||||||||||||||
| torch_cuda_version = parse(torch.version.cuda) | ||||||||||||||||||
| # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 | ||||||||||||||||||
| # to save CI time. Minor versions should be compatible. | ||||||||||||||||||
| torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") | ||||||||||||||||||
| # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" | ||||||||||||||||||
| cuda_version = f"{torch_cuda_version.major}" | ||||||||||||||||||
|
|
||||||||||||||||||
| # Determine wheel URL based on CUDA version, torch version, python version and OS | ||||||||||||||||||
| wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" | ||||||||||||||||||
|
|
||||||||||||||||||
| wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) | ||||||||||||||||||
|
|
||||||||||||||||||
| return wheel_url, wheel_filename | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class CachedWheelsCommand(_bdist_wheel): | ||||||||||||||||||
| """ | ||||||||||||||||||
| The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot | ||||||||||||||||||
| find an existing wheel (which is currently the case for all flash attention installs). We use | ||||||||||||||||||
| the environment parameters to detect whether there is already a pre-built version of a compatible | ||||||||||||||||||
| wheel available and short-circuits the standard full build pipeline. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def run(self): | ||||||||||||||||||
| if FORCE_BUILD: | ||||||||||||||||||
| return super().run() | ||||||||||||||||||
|
|
||||||||||||||||||
| wheel_url, wheel_filename = get_wheel_url() | ||||||||||||||||||
| print("Guessing wheel URL: ", wheel_url) | ||||||||||||||||||
| try: | ||||||||||||||||||
| urllib.request.urlretrieve(wheel_url, wheel_filename) | ||||||||||||||||||
|
Comment on lines
+354
to
+355
|
||||||||||||||||||
|
|
||||||||||||||||||
| # Make the archive | ||||||||||||||||||
| # Lifted from the root wheel processing command | ||||||||||||||||||
| # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 | ||||||||||||||||||
| if not os.path.exists(self.dist_dir): | ||||||||||||||||||
| os.makedirs(self.dist_dir) | ||||||||||||||||||
|
|
||||||||||||||||||
| impl_tag, abi_tag, plat_tag = self.get_tag() | ||||||||||||||||||
| archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" | ||||||||||||||||||
|
|
||||||||||||||||||
| wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") | ||||||||||||||||||
| print("Raw wheel path", wheel_path) | ||||||||||||||||||
|
||||||||||||||||||
| os.rename(wheel_filename, wheel_path) | ||||||||||||||||||
| except (urllib.error.HTTPError, urllib.error.URLError): | ||||||||||||||||||
|
||||||||||||||||||
| except (urllib.error.HTTPError, urllib.error.URLError): | |
| except (urllib.error.HTTPError, urllib.error.URLError) as e: | |
| print(f"Failed to download precompiled wheel from {wheel_url}.") | |
| print(f"Error type: {type(e).__name__}") | |
| if hasattr(e, 'code'): | |
| print(f"HTTP status code: {getattr(e, 'code', 'N/A')}") | |
| if hasattr(e, 'reason'): | |
| print(f"Reason: {getattr(e, 'reason', 'N/A')}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using print statements for logging in production code is not ideal. Consider using the logging module for better control over log levels and output formatting.