diff --git a/build/build.py b/build/build.py index 4ae9c1c8f872..907bc0849bdb 100755 --- a/build/build.py +++ b/build/build.py @@ -582,7 +582,18 @@ def main(): shell(command) if args.build_gpu_plugin: - build_plugin_command = ([bazel_path] + args.bazel_startup_options + + build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options + + ["run", "--verbose_failures=true"] + + ["//jaxlib/tools:build_cuda_kernels_wheel", "--", + f"--output_path={output_path}", + f"--cpu={wheel_cpu}", + f"--cuda_version={args.gpu_plugin_cuda_version}"]) + if args.editable: + command.append("--editable") + print(" ".join(build_cuda_kernels_command)) + shell(build_cuda_kernels_command) + + build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + ["//jaxlib/tools:build_gpu_plugin_wheel", "--", f"--output_path={output_path}", @@ -590,8 +601,8 @@ def main(): f"--cuda_version={args.gpu_plugin_cuda_version}"]) if args.editable: command.append("--editable") - print(" ".join(build_plugin_command)) - shell(build_plugin_command) + print(" ".join(build_pjrt_plugin_command)) + shell(build_pjrt_plugin_command) shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 563f95ee4a0c..ecf6669c4435 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -84,3 +84,14 @@ def build_editable( ) shutil.rmtree(output_path, ignore_errors=True) shutil.copytree(sources_path, output_path) + + +def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): + src_file = file_dir / "setup.py" + with open(src_file, "r") as f: + content = f.read() + content = content.replace( + "cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}" + ) + with open(src_file, "w") as f: + f.write(content) diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 273a1017f145..1002b66171b7 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -14,6 +14,7 @@ import functools from functools import partial +import importlib import operator import jaxlib.mlir.ir as ir @@ -23,12 +24,19 @@ from jaxlib import xla_client -try: - from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cuda_linalg = importlib.import_module( + f"{cuda_module_name}._linalg", package="jaxlib" + ) + except ImportError: + _cuda_linalg = None + else: + break + +if _cuda_linalg: for _name, _value in _cuda_linalg.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -except ImportError: - _cuda_linalg = None try: from .rocm import _linalg as _hip_linalg # pytype: disable=import-error diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 8a5c28577985..44cdb6b292e8 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -15,6 +15,7 @@ import functools from functools import partial +import importlib import itertools import operator from typing import Optional, Union @@ -26,12 +27,19 @@ from .hlo_helpers import custom_call from .gpu_common_utils import GpuLibNotLinkedError -try: - from .cuda import _prng as _cuda_prng # pytype: disable=import-error +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cuda_prng = importlib.import_module( + f"{cuda_module_name}._prng", package="jaxlib" + ) + except ImportError: + _cuda_prng = None + else: + break + +if _cuda_prng: for _name, _value in _cuda_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -except ImportError: - _cuda_prng = None try: from .rocm import _prng as _hip_prng # pytype: disable=import-error diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index fe69a7c0c6b5..4dde76a4e8ec 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib + import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo @@ -20,14 +22,17 @@ from jaxlib import xla_client from .gpu_common_utils import GpuLibNotLinkedError -try: - from .cuda import _rnn # pytype: disable=import-error - for _name, _value in _rnn.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform='CUDA') -except ImportError: - _rnn = None +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib") + except ImportError: + _rnn = None + else: + break if _rnn: + for _name, _value in _rnn.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform='CUDA') compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 9650b71b2fc4..4f850e6f4acd 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -14,6 +14,7 @@ from collections.abc import Sequence from functools import partial +import importlib import math import jaxlib.mlir.ir as ir @@ -31,17 +32,32 @@ try: from .cuda import _blas as _cublas # pytype: disable=import-error +except ImportError: + for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cublas = importlib.import_module(f"{cuda_module_name}._blas") + except ImportError: + _cublas = None + else: + break + +if _cublas: for _name, _value in _cublas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -except ImportError: - _cublas = None -try: - from .cuda import _solver as _cusolver # pytype: disable=import-error +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cusolver = importlib.import_module( + f"{cuda_module_name}._solver", package="jaxlib" + ) + except ImportError: + _cusolver = None + else: + break + +if _cusolver: for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -except ImportError: - _cusolver = None try: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 40f6bf6c0f21..12120b32e86b 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -17,6 +17,7 @@ import math from functools import partial +import importlib import jaxlib.mlir.ir as ir @@ -26,11 +27,17 @@ from .hlo_helpers import custom_call, mk_result_types_and_shapes -try: - from .cuda import _sparse as _cusparse # pytype: disable=import-error -except ImportError: - _cusparse = None -else: +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cusparse = importlib.import_module( + f"{cuda_module_name}._sparse", package="jaxlib" + ) + except ImportError: + _cusparse = None + else: + break + +if _cusparse: for _name, _value in _cusparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index 5b342c92ea1e..a004954a6164 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -11,11 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib from jaxlib import xla_client -try: - from .cuda import _triton as _cuda_triton # pytype: disable=import-error +for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]: + try: + _cuda_triton = importlib.import_module( + f"{cuda_module_name}._triton", package="jaxlib" + ) + except ImportError: + _cuda_triton = None + else: + break + +if _cuda_triton: xla_client.register_custom_call_target( "triton_kernel_call", _cuda_triton.get_custom_call(), platform='CUDA') @@ -27,8 +37,6 @@ get_compute_capability = _cuda_triton.get_compute_capability get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata -except ImportError: - _cuda_triton = None try: from .rocm import _triton as _hip_triton # pytype: disable=import-error diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index f30d2d0118d5..236c5e09d7eb 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -35,6 +35,7 @@ py_binary( "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ "//jaxlib/cuda:cuda_gpu_support", + # TODO(jieying): move it out from jaxlib "//jaxlib:cuda_plugin_extension", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ @@ -53,6 +54,7 @@ py_binary( "LICENSE.txt", "@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so", ] + if_cuda([ + "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", "//plugins/cuda:pyproject.toml", "//plugins/cuda:setup.py", @@ -64,3 +66,21 @@ py_binary( "@bazel_tools//tools/python/runfiles" ], ) + +py_binary( + name = "build_cuda_kernels_wheel", + srcs = ["build_cuda_kernels_wheel.py"], + data = [ + "LICENSE.txt", + ] + if_cuda([ + "//jaxlib:version", + "//jaxlib/cuda:cuda_gpu_support", + "//plugins/cuda:plugin_pyproject.toml", + "//plugins/cuda:plugin_setup.py", + "@local_config_cuda//cuda:cuda-nvvm", + ]), + deps = [ + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles" + ], +) diff --git a/jaxlib/tools/build_cuda_kernels_wheel.py b/jaxlib/tools/build_cuda_kernels_wheel.py new file mode 100644 index 000000000000..ff5a74640061 --- /dev/null +++ b/jaxlib/tools/build_cuda_kernels_wheel.py @@ -0,0 +1,120 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script that builds a jax-cuda12-plugin wheel for cuda kernels, intended to be +# run via bazel run as part of the jax cuda plugin build process. + +# Most users should not run this script directly; use build.py instead. + +import argparse +import functools +import os +import pathlib +import tempfile + +from bazel_tools.tools.python.runfiles import runfiles +from jax.tools import build_utils + +parser = argparse.ArgumentParser() +parser.add_argument( + "--output_path", + default=None, + required=True, + help="Path to which the output wheel should be written. Required.", +) +parser.add_argument( + "--cpu", default=None, required=True, help="Target CPU architecture. Required." +) +parser.add_argument( + "--cuda_version", + default=None, + required=True, + help="Target CUDA version. Required.", +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax cuda plugin build instead of a wheel.", +) +args = parser.parse_args() + +r = runfiles.Create() +pyext = "pyd" if build_utils.is_windows() else "so" + + +def write_setup_cfg(sources_path, cpu): + tag = build_utils.platform_tag(cpu) + with open(sources_path / "setup.cfg", "w") as f: + f.write(f"""[metadata] +license_files = LICENSE.txt + +[bdist_wheel] +plat-name={tag} +""") + + +def prepare_wheel( + sources_path: pathlib.Path, *, cpu, cuda_version +): + """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + copy_runfiles( + "__main__/plugins/cuda/plugin_pyproject.toml", + dst_dir=sources_path, + dst_filename="pyproject.toml", + ) + copy_runfiles( + "__main__/plugins/cuda/plugin_setup.py", + dst_dir=sources_path, + dst_filename="setup.py", + ) + build_utils.update_setup_with_cuda_version(sources_path, cuda_version) + write_setup_cfg(sources_path, cpu) + + plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" + copy_runfiles( + dst_dir=plugin_dir / "nvvm" / "libdevice", + src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], + ) + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + f"__main__/jaxlib/cuda/_solver.{pyext}", + f"__main__/jaxlib/cuda/_blas.{pyext}", + f"__main__/jaxlib/cuda/_linalg.{pyext}", + f"__main__/jaxlib/cuda/_prng.{pyext}", + f"__main__/jaxlib/cuda/_rnn.{pyext}", + f"__main__/jaxlib/cuda/_sparse.{pyext}", + f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_versions.{pyext}", + "__main__/jaxlib/version.py", + ], + ) + +# Build wheel for cuda kernels +tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") +sources_path = tmpdir.name +try: + os.makedirs(args.output_path, exist_ok=True) + prepare_wheel( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version + ) + package_name = f"jax cuda{args.cuda_version} plugin" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel(sources_path, args.output_path, package_name) +finally: + tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 3ab942a77de9..020bbc922208 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -72,22 +72,11 @@ def write_setup_cfg(sources_path, cpu): ) -def update_setup(file_dir, cuda_version): - src_file = file_dir / "setup.py" - with open(src_file, "r") as f: - content = f.read() - content = content.replace( - "cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}" - ) - with open(src_file, "w") as f: - f.write(content) - - def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): """Assembles a source tree for the wheel in `sources_path`.""" copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) - plugin_dir = sources_path / "jax_plugins" / f"xla_cuda_cu{cuda_version}" + plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" copy_runfiles( dst_dir=sources_path, src_files=[ @@ -95,12 +84,13 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): "__main__/plugins/cuda/setup.py", ], ) - update_setup(sources_path, cuda_version) + build_utils.update_setup_with_cuda_version(sources_path, cuda_version) write_setup_cfg(sources_path, cpu) copy_runfiles( dst_dir=plugin_dir, src_files=[ "__main__/plugins/cuda/__init__.py", + "__main__/jaxlib/version.py", ], ) copy_runfiles( @@ -113,7 +103,7 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): tmpdir = None sources_path = args.sources_path if sources_path is None: - tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin") + tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt") sources_path = tmpdir.name try: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 786d6996141f..9bec96807785 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -216,7 +216,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi ], ) - if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"): + if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension: copy_runfiles( dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice", src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], diff --git a/plugins/cuda/BUILD.bazel b/plugins/cuda/BUILD.bazel index 0d7863738015..0718ea8588be 100644 --- a/plugins/cuda/BUILD.bazel +++ b/plugins/cuda/BUILD.bazel @@ -21,6 +21,8 @@ package( exports_files([ "__init__.py", + "plugin_pyproject.toml", + "plugin_setup.py", "pyproject.toml", "setup.py", ]) diff --git a/plugins/cuda/plugin_pyproject.toml b/plugins/cuda/plugin_pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/plugins/cuda/plugin_pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/plugins/cuda/plugin_setup.py b/plugins/cuda/plugin_setup.py new file mode 100644 index 000000000000..9f5653b345e1 --- /dev/null +++ b/plugins/cuda/plugin_setup.py @@ -0,0 +1,75 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup +from setuptools.dist import Distribution + +__version__ = None +cuda_version = 0 # placeholder +project_name = f"jax-cuda{cuda_version}-plugin" +package_name = f"jax_cuda{cuda_version}_plugin" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(package_name) +__version__ = _version_module._get_version_for_build() +_cmdclass = _version_module._get_cmdclass(package_name) + +cudnn_version = os.environ.get("JAX_CUDNN_VERSION") +if cudnn_version: + __version__ += f"+cudnn{cudnn_version.replace('.', '')}" + +class BinaryDistribution(Distribution): + """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" + + def has_ext_modules(self): + return True + +setup( + name=project_name, + version=__version__, + cmdclass=_cmdclass, + description="JAX Plugin for NVIDIA GPUs", + long_description="", + long_description_content_type="text/markdown", + author="JAX team", + author_email="jax-dev@google.com", + packages=[package_name], + python_requires=">=3.9", + install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + package_data={ + package_name: [ + "*", + "nvvm/libdevice/libdevice*", + ], + }, + zip_safe=False, + distclass=BinaryDistribution, +) diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 563985950037..96ce577fc643 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from setuptools import setup, find_namespace_packages -__version__ = "0.0" +__version__ = None cuda_version = 0 # placeholder -project_name = f"jax-cuda-plugin-cu{cuda_version}" -package_name = f"jax_plugins.xla_cuda_cu{cuda_version}" +project_name = f"jax-cuda{cuda_version}-pjrt" +package_name = f"jax_plugins.xla_cuda{cuda_version}" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(f"jax_plugins/xla_cuda{cuda_version}") +__version__ = _version_module._get_version_for_build() packages = find_namespace_packages( include=[ @@ -48,7 +60,7 @@ zip_safe=False, entry_points={ "jax_plugins": [ - f"xla_cuda_cu{cuda_version} = {package_name}", + f"xla_cuda{cuda_version} = {package_name}", ], }, )