Skip to content

Commit

Permalink
Reduce binary size - no debug info for binaries
Browse files Browse the repository at this point in the history
ghstack-source-id: 4759ee88b471d7e10e40c72d9b9d4a441c82aeaf
Pull Request resolved: #549
  • Loading branch information
danthe3rd committed Nov 29, 2022
1 parent dc23392 commit 2eb133d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
FORCE_CUDA: 1
MAX_JOBS: 1 # will crash otherwise
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"

jobs:
build_wheels:
Expand Down
1 change: 1 addition & 0 deletions packaging/conda/build_conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _set_env_for_build(self):
os.environ["PYTORCH_VERSION"] = self.pytorch_version
os.environ["CU_VERSION"] = self.cuda_version
os.environ["SOURCE_ROOT_DIR"] = str(SOURCE_ROOT_DIR)
os.environ["XFORMERS_BUILD_TYPE"] = "Release"
cuda_constraint = version_constraint(self.cuda_version)
pytorch_version_tuple = tuple(int(v) for v in self.pytorch_version.split("."))
if pytorch_version_tuple < (1, 13):
Expand Down
17 changes: 13 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
this_dir = os.path.dirname(os.path.abspath(__file__))


def get_extra_nvcc_flags_for_build_type() -> List[str]:
build_type = os.environ.get("XFORMERS_BUILD_TYPE", "RelWithDebInfo").lower()
if build_type == "relwithdebinfo":
return ["--generate-line-info"]
elif build_type == "release":
return []
else:
raise ValueError(f"Unknown build type: {build_type}")


def fetch_requirements():
with open("requirements.txt") as f:
reqs = f.read().strip().split("\n")
Expand Down Expand Up @@ -137,10 +147,10 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
+ nvcc_platform_dependant_args
+ nvcc_archs_flags,
+ nvcc_archs_flags
+ get_extra_nvcc_flags_for_build_type(),
},
include_dirs=[
Path(flash_root) / "csrc" / "flash_attn",
Expand Down Expand Up @@ -199,11 +209,10 @@ def get_extensions():
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
]
] + get_extra_nvcc_flags_for_build_type()
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down

0 comments on commit 2eb133d

Please sign in to comment.