Skip to content

Commit

Permalink
Merge pull request tensorflow#12 from ROCmSoftwarePlatform/develop-up…
Browse files Browse the repository at this point in the history
…stream-xla

Enable XLA on TensorFlow ROCm port
  • Loading branch information
whchung committed Jun 7, 2018
2 parents d563339 + a4652d5 commit ecc7295
Show file tree
Hide file tree
Showing 43 changed files with 1,899 additions and 210 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,7 @@ def main():
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', False, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
False, 'xla')
True, 'xla')
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
False, 'gdr')
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
Expand Down
14 changes: 12 additions & 2 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")

# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
Expand All @@ -40,6 +42,9 @@ cc_library(
] + if_cuda_is_configured([
":xla_gpu_device",
":xla_gpu_jit",
]) + if_rocm_is_configured([
":xla_gpu_device",
":xla_gpu_jit",
]),
alwayslink = 1,
)
Expand All @@ -59,12 +64,17 @@ cc_library(
cc_library(
name = "xla_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda([
deps = if_cuda_is_configured(if_cuda([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
])) + if_rocm_is_configured(if_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
])),
alwayslink = 1,
)

Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/jit/kernels/xla_launch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
platform_id_ = se::cuda::kCudaPlatformId;
// XXX FIXME devise a way to cope with multiple platforms
//platform_id_ = se::cuda::kCudaPlatformId;
platform_id_ = se::rocm::kROCmPlatformId;
} else {
platform_id_ = nullptr;
}
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/jit/xla_gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/

// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.

#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
Expand Down Expand Up @@ -46,6 +46,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,

std::unique_ptr<XlaDevice> device;
Status status =
// XXX FIXME devise a way to cope with multiple platforms
//XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false, &device);
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package(
)

load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")

cc_library(
Expand Down Expand Up @@ -141,6 +142,8 @@ cc_library(
"xla_cpu_backend.cc",
] + if_cuda_is_configured([
"xla_gpu_backend.cc",
]) + if_rocm_is_configured([
"xla_gpu_backend.cc",
]),
hdrs = [
"const_analysis.h",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
],
)

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/computation_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ static bool InitModule() {
stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(
stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(
stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
return true;
}
static bool module_initialized = InitModule();
19 changes: 15 additions & 4 deletions tensorflow/compiler/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ filegroup(
)

load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")

cc_library(
name = "gpu_constants",
Expand Down Expand Up @@ -190,6 +194,7 @@ cc_library(
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
Expand Down Expand Up @@ -246,7 +251,8 @@ cc_library(
"thunk_schedule.cc",
"tuple_thunk.cc",
"while_thunk.cc",
],
] + if_cuda_is_configured(if_cuda(["nvptx_executable.cc"])) +
if_rocm_is_configured(if_rocm(["amdgpu_executable.cc"])),
hdrs = [
"conditional_thunk.h",
"convolution_thunk.h",
Expand All @@ -264,7 +270,8 @@ cc_library(
"thunk_schedule.h",
"tuple_thunk.h",
"while_thunk.h",
],
] + if_cuda_is_configured(if_cuda(["nvptx_executable.h"])) +
if_rocm_is_configured(if_rocm(["amdgpu_executable.h"])),
deps = [
":buffer_allocations",
":cudnn_convolution_runner",
Expand Down Expand Up @@ -296,6 +303,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
"//tensorflow/stream_executor",
],
)
Expand Down Expand Up @@ -490,8 +498,10 @@ cc_library(

cc_library(
name = "gpu_compiler",
srcs = ["gpu_compiler.cc"],
hdrs = ["gpu_compiler.h"],
srcs = if_cuda_is_configured(if_cuda(["nvptx_compiler.cc"])) +
if_rocm_is_configured(if_rocm(["amdgpu_compiler.cc"])),
hdrs = if_cuda_is_configured(if_cuda(["nvptx_compiler.h"])) +
if_rocm_is_configured(if_rocm(["amdgpu_compiler.h"])),
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
Expand Down Expand Up @@ -545,6 +555,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:cuda_libdevice_path",
"//tensorflow/core:rocm_rocdl_path",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
Expand Down

0 comments on commit ecc7295

Please sign in to comment.