Skip to content

Commit

Permalink
[PJRT C API] Change build wheel script to build a separate package fo…
Browse files Browse the repository at this point in the history
…r cuda kernels.

With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
  • Loading branch information
jyingl3 authored and jax authors committed Nov 6, 2023
1 parent 79ca40e commit 462ef16
Show file tree
Hide file tree
Showing 16 changed files with 347 additions and 51 deletions.
17 changes: 14 additions & 3 deletions build/build.py
Expand Up @@ -582,16 +582,27 @@ def main():
shell(command)

if args.build_gpu_plugin:
build_plugin_command = ([bazel_path] + args.bazel_startup_options +
build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
print(" ".join(build_cuda_kernels_command))
shell(build_cuda_kernels_command)

build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
print(" ".join(build_plugin_command))
shell(build_plugin_command)
print(" ".join(build_pjrt_plugin_command))
shell(build_pjrt_plugin_command)

shell([bazel_path] + args.bazel_startup_options + ["shutdown"])

Expand Down
11 changes: 11 additions & 0 deletions jax/tools/build_utils.py
Expand Up @@ -84,3 +84,14 @@ def build_editable(
)
shutil.rmtree(output_path, ignore_errors=True)
shutil.copytree(sources_path, output_path)


def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
src_file = file_dir / "setup.py"
with open(src_file, "r") as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
)
with open(src_file, "w") as f:
f.write(content)
16 changes: 12 additions & 4 deletions jaxlib/gpu_linalg.py
Expand Up @@ -14,6 +14,7 @@

import functools
from functools import partial
import importlib
import operator

import jaxlib.mlir.ir as ir
Expand All @@ -23,12 +24,19 @@

from jaxlib import xla_client

try:
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"
)
except ImportError:
_cuda_linalg = None
else:
break

if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_linalg = None

try:
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error
Expand Down
16 changes: 12 additions & 4 deletions jaxlib/gpu_prng.py
Expand Up @@ -15,6 +15,7 @@

import functools
from functools import partial
import importlib
import itertools
import operator
from typing import Optional, Union
Expand All @@ -26,12 +27,19 @@
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError

try:
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_prng = importlib.import_module(
f"{cuda_module_name}._prng", package="jaxlib"
)
except ImportError:
_cuda_prng = None
else:
break

if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_prng = None

try:
from .rocm import _prng as _hip_prng # pytype: disable=import-error
Expand Down
17 changes: 11 additions & 6 deletions jaxlib/gpu_rnn.py
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo

Expand All @@ -20,14 +22,17 @@
from jaxlib import xla_client
from .gpu_common_utils import GpuLibNotLinkedError

try:
from .cuda import _rnn # pytype: disable=import-error
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
except ImportError:
_rnn = None
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
except ImportError:
_rnn = None
else:
break

if _rnn:
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes


Expand Down
28 changes: 22 additions & 6 deletions jaxlib/gpu_solver.py
Expand Up @@ -14,6 +14,7 @@

from collections.abc import Sequence
from functools import partial
import importlib
import math

import jaxlib.mlir.ir as ir
Expand All @@ -31,17 +32,32 @@

try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
except ImportError:
_cublas = None
else:
break

if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cublas = None

try:
from .cuda import _solver as _cusolver # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cusolver = importlib.import_module(
f"{cuda_module_name}._solver", package="jaxlib"
)
except ImportError:
_cusolver = None
else:
break

if _cusolver:
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cusolver = None


try:
Expand Down
17 changes: 12 additions & 5 deletions jaxlib/gpu_sparse.py
Expand Up @@ -17,6 +17,7 @@

import math
from functools import partial
import importlib

import jaxlib.mlir.ir as ir

Expand All @@ -26,11 +27,17 @@

from .hlo_helpers import custom_call, mk_result_types_and_shapes

try:
from .cuda import _sparse as _cusparse # pytype: disable=import-error
except ImportError:
_cusparse = None
else:
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cusparse = importlib.import_module(
f"{cuda_module_name}._sparse", package="jaxlib"
)
except ImportError:
_cusparse = None
else:
break

if _cusparse:
for _name, _value in _cusparse.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")

Expand Down
16 changes: 12 additions & 4 deletions jaxlib/gpu_triton.py
Expand Up @@ -11,11 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib

from jaxlib import xla_client

try:
from .cuda import _triton as _cuda_triton # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_triton = importlib.import_module(
f"{cuda_module_name}._triton", package="jaxlib"
)
except ImportError:
_cuda_triton = None
else:
break

if _cuda_triton:
xla_client.register_custom_call_target(
"triton_kernel_call", _cuda_triton.get_custom_call(),
platform='CUDA')
Expand All @@ -27,8 +37,6 @@
get_compute_capability = _cuda_triton.get_compute_capability
get_custom_call = _cuda_triton.get_custom_call
get_serialized_metadata = _cuda_triton.get_serialized_metadata
except ImportError:
_cuda_triton = None

try:
from .rocm import _triton as _hip_triton # pytype: disable=import-error
Expand Down
20 changes: 20 additions & 0 deletions jaxlib/tools/BUILD.bazel
Expand Up @@ -35,6 +35,7 @@ py_binary(
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
# TODO(jieying): move it out from jaxlib
"//jaxlib:cuda_plugin_extension",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
Expand All @@ -53,6 +54,7 @@ py_binary(
"LICENSE.txt",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so",
] + if_cuda([
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//plugins/cuda:pyproject.toml",
"//plugins/cuda:setup.py",
Expand All @@ -64,3 +66,21 @@ py_binary(
"@bazel_tools//tools/python/runfiles"
],
)

py_binary(
name = "build_cuda_kernels_wheel",
srcs = ["build_cuda_kernels_wheel.py"],
data = [
"LICENSE.txt",
] + if_cuda([
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//plugins/cuda:plugin_pyproject.toml",
"//plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm",
]),
deps = [
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles"
],
)

0 comments on commit 462ef16

Please sign in to comment.