Skip to content

Commit

Permalink
Migrate 'jaxlib' CPU custom-calls to the status-returning API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438165260
  • Loading branch information
agrue authored and jax authors committed Mar 30, 2022
1 parent b31cf89 commit 8884ce5
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 43 deletions.
2 changes: 2 additions & 0 deletions jaxlib/BUILD
Expand Up @@ -145,6 +145,7 @@ cc_library(
srcs = ["lapack_kernels.cc"],
hdrs = ["lapack_kernels.h"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/base:dynamic_annotations",
],
)
Expand Down Expand Up @@ -198,6 +199,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":pocketfft_flatbuffers_cc",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@flatbuffers//:runtime_cc",
"@pocketfft",
],
Expand Down
36 changes: 27 additions & 9 deletions jaxlib/lapack.py
Expand Up @@ -94,7 +94,9 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
Shape.array_shape(dtype, (), ()),
Shape.array_shape(dtype, a_shape.dimensions(), layout),
Shape.array_shape(dtype, b_shape.dimensions(), layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
jax_trsm = trsm

# # ?getrf: LU decomposition
Expand Down Expand Up @@ -149,7 +151,9 @@ def getrf(c, a):
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))

# # ?geqrf: QR decomposition
Expand Down Expand Up @@ -212,7 +216,9 @@ def geqrf(c, a):
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))

# # ?orgqr: product of elementary Householder reflectors:
Expand Down Expand Up @@ -282,7 +288,9 @@ def orgqr(c, a, tau):
dtype,
batch_dims + (k,),
tuple(range(num_bd, -1, -1))),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return tuple(_ops.GetTupleElement(out, i) for i in range(2))


Expand Down Expand Up @@ -326,7 +334,9 @@ def potrf(c, a, lower=False):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return tuple(_ops.GetTupleElement(out, i) for i in range(2))


Expand Down Expand Up @@ -420,7 +430,9 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2),
_ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4))

Expand Down Expand Up @@ -491,7 +503,9 @@ def syevd(c, a, lower=False):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))

Expand Down Expand Up @@ -575,7 +589,9 @@ def geev(c, a, jobvl=True, jobvr=True):
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(dtype, dims, layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
if real:
return (_ops.Complex(_ops.GetTupleElement(out, 3),
_ops.GetTupleElement(out, 4)),
Expand Down Expand Up @@ -653,7 +669,9 @@ def gees(c, a, jobvs=True, sort=False, select=None):
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(dtype, dims, layout),
))
),
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
if sort == ord('S'):
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 3),
_ops.GetTupleElement(out, 4), _ops.GetTupleElement(out, 5))
Expand Down
30 changes: 17 additions & 13 deletions jaxlib/lapack_kernels.cc
Expand Up @@ -30,7 +30,7 @@ template <typename T>
typename Trsm<T>::FnType* Trsm<T>::fn = nullptr;

template <typename T>
void Trsm<T>::Kernel(void* out, void** data) {
void Trsm<T>::Kernel(void* out, void** data, XlaCustomCallStatus*) {
int32_t left_side = *reinterpret_cast<int32_t*>(data[0]);
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
int32_t trans_a = *reinterpret_cast<int32_t*>(data[2]);
Expand Down Expand Up @@ -82,7 +82,7 @@ template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;

template <typename T>
void Getrf<T>::Kernel(void* out_tuple, void** data) {
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int m = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -116,7 +116,7 @@ template <typename T>
typename Geqrf<T>::FnType* Geqrf<T>::fn = nullptr;

template <typename T>
void Geqrf<T>::Kernel(void* out_tuple, void** data) {
void Geqrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int m = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -163,7 +163,7 @@ template <typename T>
typename Orgqr<T>::FnType* Orgqr<T>::fn = nullptr;

template <typename T>
void Orgqr<T>::Kernel(void* out_tuple, void** data) {
void Orgqr<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int m = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -211,7 +211,7 @@ template <typename T>
typename Potrf<T>::FnType* Potrf<T>::fn = nullptr;

template <typename T>
void Potrf<T>::Kernel(void* out_tuple, void** data) {
void Potrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
int b = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -260,7 +260,7 @@ template <typename T>
typename RealGesdd<T>::FnType* RealGesdd<T>::fn = nullptr;

template <typename T>
void RealGesdd<T>::Kernel(void* out_tuple, void** data) {
void RealGesdd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
int b = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -332,7 +332,8 @@ template <typename T>
typename ComplexGesdd<T>::FnType* ComplexGesdd<T>::fn = nullptr;

template <typename T>
void ComplexGesdd<T>::Kernel(void* out_tuple, void** data) {
void ComplexGesdd<T>::Kernel(void* out_tuple, void** data,
XlaCustomCallStatus*) {
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
int b = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -411,7 +412,7 @@ template <typename T>
typename RealSyevd<T>::FnType* RealSyevd<T>::fn = nullptr;

template <typename T>
void RealSyevd<T>::Kernel(void* out_tuple, void** data) {
void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
int b = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -459,7 +460,8 @@ template <typename T>
typename ComplexHeevd<T>::FnType* ComplexHeevd<T>::fn = nullptr;

template <typename T>
void ComplexHeevd<T>::Kernel(void* out_tuple, void** data) {
void ComplexHeevd<T>::Kernel(void* out_tuple, void** data,
XlaCustomCallStatus*) {
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
int b = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
Expand Down Expand Up @@ -531,7 +533,7 @@ template <typename T>
typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;

template <typename T>
void RealGeev<T>::Kernel(void* out_tuple, void** data) {
void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
Expand Down Expand Up @@ -590,7 +592,8 @@ template <typename T>
typename ComplexGeev<T>::FnType* ComplexGeev<T>::fn = nullptr;

template <typename T>
void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
Expand Down Expand Up @@ -648,7 +651,7 @@ template <typename T>
typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;

template <typename T>
void RealGees<T>::Kernel(void* out_tuple, void** data) {
void RealGees<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
Expand Down Expand Up @@ -708,7 +711,8 @@ template <typename T>
typename ComplexGees<T>::FnType* ComplexGees<T>::fn = nullptr;

template <typename T>
void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
void ComplexGees<T>::Kernel(void* out_tuple, void** data,
XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
Expand Down
27 changes: 14 additions & 13 deletions jaxlib/lapack_kernels.h
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <complex>
#include <cstdint>
#include "tensorflow/compiler/xla/service/custom_call_status.h"

// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
// by the pybind wrapper that links them to an existing SciPy lapack instance,
Expand All @@ -35,7 +36,7 @@ struct Trsm {
lapack_int* lda, T* b, lapack_int* ldb);

static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -44,7 +45,7 @@ struct Getrf {
lapack_int* ipiv, lapack_int* info);

static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -53,7 +54,7 @@ struct Geqrf {
T* tau, T* work, lapack_int* lwork, lapack_int* info);

static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);

static int64_t Workspace(lapack_int m, lapack_int n);
};
Expand All @@ -64,7 +65,7 @@ struct Orgqr {
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k);
};

Expand All @@ -73,7 +74,7 @@ struct Potrf {
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

lapack_int GesddIworkSize(int64_t m, int64_t n);
Expand All @@ -85,7 +86,7 @@ struct RealGesdd {
lapack_int* ldvt, T* work, lapack_int* lwork,
lapack_int* iwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);

static int64_t Workspace(lapack_int m, lapack_int n,
bool job_opt_compute_uv, bool job_opt_full_matrices);
Expand All @@ -101,7 +102,7 @@ struct ComplexGesdd {
lapack_int* lwork, typename T::value_type* rwork,
lapack_int* iwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);

static int64_t Workspace(lapack_int m, lapack_int n,
bool job_opt_compute_uv, bool job_opt_full_matrices);
Expand All @@ -117,7 +118,7 @@ struct RealSyevd {
lapack_int* lda, T* w, T* work, lapack_int* lwork,
lapack_int* iwork, lapack_int* liwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

lapack_int HeevdWorkSize(int64_t n);
Expand All @@ -131,7 +132,7 @@ struct ComplexHeevd {
lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -141,7 +142,7 @@ struct RealGeev {
T* vr, lapack_int* ldvr, T* work, lapack_int* lwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -151,7 +152,7 @@ struct ComplexGeev {
lapack_int* ldvr, T* work, lapack_int* lwork,
typename T::value_type* rwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -161,7 +162,7 @@ struct RealGees {
T* wr, T* wi, T* vs, lapack_int* ldvs, T* work,
lapack_int* lwork, bool* bwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

template <typename T>
Expand All @@ -172,7 +173,7 @@ struct ComplexGees {
typename T::value_type* rwork, bool* bwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data);
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};

} // namespace jax
Expand Down

0 comments on commit 8884ce5

Please sign in to comment.