diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 3fe7dc81ac92..16d1d1e67f59 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -145,6 +145,7 @@ cc_library( srcs = ["lapack_kernels.cc"], hdrs = ["lapack_kernels.h"], deps = [ + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -198,6 +199,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":pocketfft_flatbuffers_cc", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@flatbuffers//:runtime_cc", "@pocketfft", ], diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 2bebe0228603..0326f0cb3e67 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -94,7 +94,9 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False, Shape.array_shape(dtype, (), ()), Shape.array_shape(dtype, a_shape.dimensions(), layout), Shape.array_shape(dtype, b_shape.dimensions(), layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) jax_trsm = trsm # # ?getrf: LU decomposition @@ -149,7 +151,9 @@ def getrf(c, a): dtype, batch_dims + (m, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return tuple(_ops.GetTupleElement(out, i) for i in range(3)) # # ?geqrf: QR decomposition @@ -212,7 +216,9 @@ def geqrf(c, a): dtype, batch_dims + (m, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return tuple(_ops.GetTupleElement(out, i) for i in range(3)) # # ?orgqr: product of elementary Householder reflectors: @@ -282,7 +288,9 @@ def orgqr(c, a, tau): dtype, batch_dims + (k,), tuple(range(num_bd, -1, -1))), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return tuple(_ops.GetTupleElement(out, i) for i in range(2)) @@ -326,7 +334,9 @@ def potrf(c, a, lower=False): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, dims, layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return tuple(_ops.GetTupleElement(out, i) for i in range(2)) @@ -420,7 +430,9 @@ def gesdd(c, a, full_matrices=True, compute_uv=True): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2), _ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4)) @@ -491,7 +503,9 @@ def syevd(c, a, lower=False): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, dims, layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2)) @@ -575,7 +589,9 @@ def geev(c, a, jobvl=True, jobvr=True): Shape.array_shape(np.dtype(np.uint8), (), ()), Shape.array_shape(np.dtype(np.uint8), (), ()), Shape.array_shape(dtype, dims, layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) if real: return (_ops.Complex(_ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4)), @@ -653,7 +669,9 @@ def gees(c, a, jobvs=True, sort=False, select=None): Shape.array_shape(np.dtype(np.uint8), (), ()), Shape.array_shape(np.dtype(np.uint8), (), ()), Shape.array_shape(dtype, dims, layout), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) if sort == ord('S'): return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4), _ops.GetTupleElement(out, 5)) diff --git a/jaxlib/lapack_kernels.cc b/jaxlib/lapack_kernels.cc index 670498eccacc..b07532bfdfe4 100644 --- a/jaxlib/lapack_kernels.cc +++ b/jaxlib/lapack_kernels.cc @@ -30,7 +30,7 @@ template typename Trsm::FnType* Trsm::fn = nullptr; template -void Trsm::Kernel(void* out, void** data) { +void Trsm::Kernel(void* out, void** data, XlaCustomCallStatus*) { int32_t left_side = *reinterpret_cast(data[0]); int32_t lower = *reinterpret_cast(data[1]); int32_t trans_a = *reinterpret_cast(data[2]); @@ -82,7 +82,7 @@ template typename Getrf::FnType* Getrf::fn = nullptr; template -void Getrf::Kernel(void* out_tuple, void** data) { +void Getrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int m = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -116,7 +116,7 @@ template typename Geqrf::FnType* Geqrf::fn = nullptr; template -void Geqrf::Kernel(void* out_tuple, void** data) { +void Geqrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int m = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -163,7 +163,7 @@ template typename Orgqr::FnType* Orgqr::fn = nullptr; template -void Orgqr::Kernel(void* out_tuple, void** data) { +void Orgqr::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int m = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -211,7 +211,7 @@ template typename Potrf::FnType* Potrf::fn = nullptr; template -void Potrf::Kernel(void* out_tuple, void** data) { +void Potrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int32_t lower = *(reinterpret_cast(data[0])); int b = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -260,7 +260,7 @@ template typename RealGesdd::FnType* RealGesdd::fn = nullptr; template -void RealGesdd::Kernel(void* out_tuple, void** data) { +void RealGesdd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); int b = *(reinterpret_cast(data[2])); @@ -332,7 +332,8 @@ template typename ComplexGesdd::FnType* ComplexGesdd::fn = nullptr; template -void ComplexGesdd::Kernel(void* out_tuple, void** data) { +void ComplexGesdd::Kernel(void* out_tuple, void** data, + XlaCustomCallStatus*) { int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); int b = *(reinterpret_cast(data[2])); @@ -411,7 +412,7 @@ template typename RealSyevd::FnType* RealSyevd::fn = nullptr; template -void RealSyevd::Kernel(void* out_tuple, void** data) { +void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int32_t lower = *(reinterpret_cast(data[0])); int b = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -459,7 +460,8 @@ template typename ComplexHeevd::FnType* ComplexHeevd::fn = nullptr; template -void ComplexHeevd::Kernel(void* out_tuple, void** data) { +void ComplexHeevd::Kernel(void* out_tuple, void** data, + XlaCustomCallStatus*) { int32_t lower = *(reinterpret_cast(data[0])); int b = *(reinterpret_cast(data[1])); int n = *(reinterpret_cast(data[2])); @@ -531,7 +533,7 @@ template typename RealGeev::FnType* RealGeev::fn = nullptr; template -void RealGeev::Kernel(void* out_tuple, void** data) { +void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int n_int = *(reinterpret_cast(data[1])); int64_t n = n_int; @@ -590,7 +592,8 @@ template typename ComplexGeev::FnType* ComplexGeev::fn = nullptr; template -void ComplexGeev::Kernel(void* out_tuple, void** data) { +void ComplexGeev::Kernel(void* out_tuple, void** data, + XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int n_int = *(reinterpret_cast(data[1])); int64_t n = n_int; @@ -648,7 +651,7 @@ template typename RealGees::FnType* RealGees::fn = nullptr; template -void RealGees::Kernel(void* out_tuple, void** data) { +void RealGees::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int n_int = *(reinterpret_cast(data[1])); int64_t n = n_int; @@ -708,7 +711,8 @@ template typename ComplexGees::FnType* ComplexGees::fn = nullptr; template -void ComplexGees::Kernel(void* out_tuple, void** data) { +void ComplexGees::Kernel(void* out_tuple, void** data, + XlaCustomCallStatus*) { int b = *(reinterpret_cast(data[0])); int n_int = *(reinterpret_cast(data[1])); int64_t n = n_int; diff --git a/jaxlib/lapack_kernels.h b/jaxlib/lapack_kernels.h index 22dc4bd45820..03cf73cc8283 100644 --- a/jaxlib/lapack_kernels.h +++ b/jaxlib/lapack_kernels.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/custom_call_status.h" // Underlying function pointers (e.g., Trsm::Fn) are initialized either // by the pybind wrapper that links them to an existing SciPy lapack instance, @@ -35,7 +36,7 @@ struct Trsm { lapack_int* lda, T* b, lapack_int* ldb); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -44,7 +45,7 @@ struct Getrf { lapack_int* ipiv, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -53,7 +54,7 @@ struct Geqrf { T* tau, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); static int64_t Workspace(lapack_int m, lapack_int n); }; @@ -64,7 +65,7 @@ struct Orgqr { lapack_int* lda, T* tau, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); }; @@ -73,7 +74,7 @@ struct Potrf { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; lapack_int GesddIworkSize(int64_t m, int64_t n); @@ -85,7 +86,7 @@ struct RealGesdd { lapack_int* ldvt, T* work, lapack_int* lwork, lapack_int* iwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, bool job_opt_full_matrices); @@ -101,7 +102,7 @@ struct ComplexGesdd { lapack_int* lwork, typename T::value_type* rwork, lapack_int* iwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, bool job_opt_full_matrices); @@ -117,7 +118,7 @@ struct RealSyevd { lapack_int* lda, T* w, T* work, lapack_int* lwork, lapack_int* iwork, lapack_int* liwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; lapack_int HeevdWorkSize(int64_t n); @@ -131,7 +132,7 @@ struct ComplexHeevd { lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -141,7 +142,7 @@ struct RealGeev { T* vr, lapack_int* ldvr, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -151,7 +152,7 @@ struct ComplexGeev { lapack_int* ldvr, T* work, lapack_int* lwork, typename T::value_type* rwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -161,7 +162,7 @@ struct RealGees { T* wr, T* wi, T* vs, lapack_int* ldvs, T* work, lapack_int* lwork, bool* bwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; template @@ -172,7 +173,7 @@ struct ComplexGees { typename T::value_type* rwork, bool* bwork, lapack_int* info); static FnType* fn; - static void Kernel(void* out, void** data); + static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; } // namespace jax diff --git a/jaxlib/pocketfft.py b/jaxlib/pocketfft.py index 1f12fac4a040..d6516aedbc42 100644 --- a/jaxlib/pocketfft.py +++ b/jaxlib/pocketfft.py @@ -20,7 +20,6 @@ from . import pocketfft_flatbuffers_py_generated as pd import numpy as np - import flatbuffers from jaxlib import xla_client @@ -53,8 +52,9 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]): pd.PocketFftDtype.COMPLEX64 if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128) - assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, ( - shape, fft_lengths) + assert list( + shape.dimensions())[-len(fft_lengths):] == fft_lengths, (shape, + fft_lengths) out_shape = list(shape.dimensions()) out_shape[-1] = out_shape[-1] // 2 + 1 @@ -80,8 +80,9 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]): pd.PocketFftDtype.COMPLEX64 if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128) - assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, ( - shape, fft_lengths) + assert list( + shape.dimensions())[-len(fft_lengths):] == fft_lengths, (shape, + fft_lengths) out_shape = shape.dimensions() # PocketFft does not allow size 0 dimensions. @@ -156,4 +157,6 @@ def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]): np.dtype(np.uint8), (len(descriptor_bytes),), (0,)), xla_client.Shape.array_shape(dtype, shape.dimensions(), tuple(range(n - 1, -1, -1))), - )) + ), + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) diff --git a/jaxlib/pocketfft_kernels.cc b/jaxlib/pocketfft_kernels.cc index ce8cc3e89ddd..aad63a222397 100644 --- a/jaxlib/pocketfft_kernels.cc +++ b/jaxlib/pocketfft_kernels.cc @@ -18,10 +18,11 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" #include "pocketfft/pocketfft_hdronly.h" #include "jaxlib/pocketfft_generated.h" +#include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { -void PocketFft(void* out, void** in) { +void PocketFft(void* out, void** in, XlaCustomCallStatus*) { const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]); pocketfft::shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); diff --git a/jaxlib/pocketfft_kernels.h b/jaxlib/pocketfft_kernels.h index 2f09012f20db..7804ad1ded67 100644 --- a/jaxlib/pocketfft_kernels.h +++ b/jaxlib/pocketfft_kernels.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/custom_call_status.h" + namespace jax { -void PocketFft(void* out, void** in); +void PocketFft(void* out, void** in, XlaCustomCallStatus*); } // namespace jax