Skip to content

Commit

Permalink
[jax_triton] Split C++ only parts of Triton custom callback from Pyth…
Browse files Browse the repository at this point in the history
…on parts.

Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
  • Loading branch information
chr1sj0nes authored and jax authors committed Jul 3, 2023
1 parent 658e8ff commit 31b862d
Show file tree
Hide file tree
Showing 9 changed files with 720 additions and 620 deletions.
42 changes: 28 additions & 14 deletions jaxlib/cuda/BUILD
Expand Up @@ -57,6 +57,7 @@ cc_library(
"@xla//xla/stream_executor/cuda:cusolver_lib",
"@xla//xla/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -376,15 +377,40 @@ cc_library(
":cuda_vendor",
":cusolver_kernels",
":cusparse_kernels",
":triton_kernels",
"@xla//xla/service:custom_call_target_registry",
],
alwayslink = 1,
)

cc_library(
name = "triton_kernels",
srcs = ["//jaxlib/gpu:triton_kernels.cc"],
hdrs = ["//jaxlib/gpu:triton_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib/gpu:triton_cc_proto",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cudart_stub",
"@xla//xla/stream_executor/gpu:asm_compiler",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@zlib",
],
)

pybind_extension(
name = "_triton",
srcs = ["//jaxlib/gpu:triton.cc"],
hdrs = ["//jaxlib/gpu:triton.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
Expand All @@ -400,24 +426,12 @@ pybind_extension(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":triton_kernels",
"//jaxlib:kernel_pybind11_helpers",
"//jaxlib/gpu:triton_cc_proto",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cudart_stub",
"@xla//xla/stream_executor/gpu:asm_compiler",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@zlib",
],
)

Expand Down
3 changes: 2 additions & 1 deletion jaxlib/gpu/BUILD
Expand Up @@ -48,7 +48,8 @@ exports_files(srcs = [
"sparse_kernels.cc",
"sparse_kernels.h",
"triton.cc",
"triton.h",
"triton_kernels.cc",
"triton_kernels.h",
"vendor.h",
])

Expand Down
16 changes: 16 additions & 0 deletions jaxlib/gpu/gpu_kernel_helpers.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"

#include "absl/base/optimization.h"
#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
Expand All @@ -28,6 +29,12 @@ std::string ErrorString(gpuError_t error) { return gpuGetErrorString(error); }

#ifdef JAX_GPU_CUDA

std::string ErrorString(CUresult error) {
const char* str;
CHECK_EQ(cuGetErrorName(error, &str), CUDA_SUCCESS);
return str;
}

std::string ErrorString(gpusparseStatus_t status) {
return cusparseGetErrorString(status);
}
Expand Down Expand Up @@ -220,6 +227,15 @@ absl::Status AsStatus(gpublasStatus_t status, const char* file,
return absl::OkStatus();
}

#ifdef JAX_GPU_CUDA
absl::Status AsStatus(CUresult error, const char* file, std::int64_t line,
const char* expr) {
if (ABSL_PREDICT_FALSE(error != CUDA_SUCCESS))
return absl::InternalError(ErrorString(error, file, line, expr));
return absl::OkStatus();
}
#endif

absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
Expand Down
11 changes: 11 additions & 0 deletions jaxlib/gpu/gpu_kernel_helpers.h
Expand Up @@ -39,6 +39,13 @@ limitations under the License.
if (ABSL_PREDICT_FALSE(!s___.ok())) return s___; \
}

#define JAX_ASSIGN_OR_RETURN(lhs, expr) \
auto s___ = (expr); \
if (ABSL_PREDICT_FALSE(!s___.ok())) { \
return s___.status(); \
} \
lhs = (*std::move(s___))

namespace jax {
namespace JAX_GPU_NAMESPACE {

Expand All @@ -51,6 +58,10 @@ absl::Status AsStatus(gpusparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(gpublasStatus_t status, const char* file,
std::int64_t line, const char* expr);
#ifdef JAX_GPU_CUDA
absl::Status AsStatus(CUresult error, const char* file, std::int64_t line,
const char* expr);
#endif

// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/gpu/gpu_kernels.cc
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "jaxlib/gpu/prng_kernels.h"
#include "jaxlib/gpu/solver_kernels.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include "jaxlib/gpu/triton_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_target_registry.h"

Expand Down Expand Up @@ -66,6 +67,9 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32,
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64,
"CUDA");

XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("triton_kernel_call", TritonKernelCall,
"CUDA");

} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

0 comments on commit 31b862d

Please sign in to comment.