diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index acf7e696ad00..aff975280a44 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -216,10 +216,7 @@ def make_cpu_client() -> xla_client.Client: def _check_cuda_versions(): - # TODO(phawkins): remove the test for None cuda_versions after jaxlib 0.4.17 - # is the minimum. - if cuda_versions is None: - return + assert cuda_versions is not None def _version_check(name, get_version, get_build_version, scale_for_comparison=1): @@ -256,9 +253,14 @@ def _version_check(name, get_version, get_build_version, scale_for_comparison=100) _version_check("cuPTI", cuda_versions.cupti_get_version, cuda_versions.cupti_build_version) - # TODO(phawkins): ideally we'd check cublas and cusparse here also, but their - # "get version" APIs require initializing those libraries, which we don't want - # to do here. + _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100) def make_gpu_client( diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index b62d070e3f32..5c819ce13cc7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -458,6 +458,30 @@ pybind_extension( ], ) +cc_library( + name = "versions_helpers", + srcs = ["versions_helpers.cc"], + hdrs = ["versions_helpers.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:absl_status_casters", + "//jaxlib:kernel_nanobind_helpers", + "@tsl//tsl/cuda:cublas", + "@tsl//tsl/cuda:cudart", + "@tsl//tsl/cuda:cudnn", + "@tsl//tsl/cuda:cufft", + "@tsl//tsl/cuda:cupti", + "@tsl//tsl/cuda:cusolver", + "@tsl//tsl/cuda:cusparse", + ], +) + pybind_extension( name = "_versions", srcs = ["versions.cc"], @@ -482,6 +506,7 @@ pybind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + ":versions_helpers", "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@tsl//tsl/cuda:cublas", diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index ce422abf9bbd..a99cd683a120 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/cuda/versions_helpers.h" + #include "nanobind/nanobind.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { @@ -22,40 +23,6 @@ namespace { namespace nb = nanobind; -#if CUDA_VERSION < 11080 -#error "JAX requires CUDA 11.8 or newer." -#endif // CUDA_VERSION < 11080 - -int CudaRuntimeGetVersion() { - int version; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version))); - return version; -} - -int CudaDriverGetVersion() { - int version; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version))); - return version; -} - -uint32_t CuptiGetVersion() { - uint32_t version; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version))); - return version; -} - -int CufftGetVersion() { - int version; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version))); - return version; -} - -int CusolverGetVersion() { - int version; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version))); - return version; -} - NB_MODULE(_versions, m) { // Nanobind's leak checking sometimes returns false positives for this file. // The problem appears related to forming a closure of a nanobind function. @@ -70,14 +37,14 @@ NB_MODULE(_versions, m) { m.def("cusolver_build_version", []() { return CUSOLVER_VERSION; }); m.def("cusparse_build_version", []() { return CUSPARSE_VERSION; }); - // TODO(phawkins): annoyingly cublas and cusparse have "get version" APIs that - // require the library to be initialized. m.def("cuda_runtime_get_version", &CudaRuntimeGetVersion); m.def("cuda_driver_get_version", &CudaDriverGetVersion); m.def("cudnn_get_version", &cudnnGetVersion); m.def("cupti_get_version", &CuptiGetVersion); m.def("cufft_get_version", &CufftGetVersion); m.def("cusolver_get_version", &CusolverGetVersion); + m.def("cublas_get_version", &CublasGetVersion); + m.def("cusparse_get_version", &CusparseGetVersion); } } // namespace diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc new file mode 100644 index 000000000000..52906258e6a5 --- /dev/null +++ b/jaxlib/cuda/versions_helpers.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/cuda/versions_helpers.h" + +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax::cuda { + +#if CUDA_VERSION < 11080 +#error "JAX requires CUDA 11.8 or newer." +#endif // CUDA_VERSION < 11080 + +int CudaRuntimeGetVersion() { + int version; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version))); + return version; +} + +int CudaDriverGetVersion() { + int version; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version))); + return version; +} + +uint32_t CuptiGetVersion() { + uint32_t version; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version))); + return version; +} + +int CufftGetVersion() { + int version; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version))); + return version; +} + +int CusolverGetVersion() { + int version; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version))); + return version; +} + +int CublasGetVersion() { + int version; + // NVIDIA promise that it's safe to parse nullptr as the handle to this + // function. + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version))); + return version; +} + +int CusparseGetVersion() { + // cusparseGetVersion is unhappy if passed a null library handle. But + // cusparseGetProperty doesn't require one. + int major, minor, patch; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch))); + return major * 1000 + minor * 100 + patch; +} + +} // namespace jax::cuda \ No newline at end of file diff --git a/jaxlib/cuda/versions_helpers.h b/jaxlib/cuda/versions_helpers.h new file mode 100644 index 000000000000..af890f91202f --- /dev/null +++ b/jaxlib/cuda/versions_helpers.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CUDA_VERSIONS_HELPERS_H_ +#define JAXLIB_CUDA_VERSIONS_HELPERS_H_ + +#include + +namespace jax::cuda { + +int CudaRuntimeGetVersion(); +int CudaDriverGetVersion(); +uint32_t CuptiGetVersion(); +int CufftGetVersion(); +int CusolverGetVersion(); +int CublasGetVersion(); +int CusparseGetVersion(); + +} // namespace jax::cuda + +#endif // JAXLIB_CUDA_VERSIONS_HELPERS_H_ \ No newline at end of file