Skip to content

Commit

Permalink
Roll cl/517070827 forward with fix for ROCm.
Browse files Browse the repository at this point in the history
Custom kernel for sum reductions that is intended to run faster than NCCL for small buffers.

Disabled by default, enable with `XLA_FLAGS=--xla_gpu_allow_all_reduce_kernel=true` on sm60+ GPUs.

Reverts 460d208

PiperOrigin-RevId: 517521090
  • Loading branch information
chsigg authored and Copybara-Service committed Mar 17, 2023
1 parent 6061922 commit 16e953a
Show file tree
Hide file tree
Showing 12 changed files with 590 additions and 40 deletions.
8 changes: 8 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);

opts.set_xla_gpu_allow_all_reduce_kernel(false);

return opts;
}

Expand Down Expand Up @@ -886,6 +889,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_triton_gemm_any(),
"Use Triton-based matrix multiplication for any GEMM it "
"supports without filtering only faster ones."));
flag_list->push_back(tsl::Flag(
"xla_gpu_allow_all_reduce_kernel",
bool_setter_for(&DebugOptions::set_xla_gpu_allow_all_reduce_kernel),
debug_options->xla_gpu_allow_all_reduce_kernel(),
"Mark all reduce ops to use costum kernel if feasible."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
17 changes: 17 additions & 0 deletions xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "xla/mlir/backends/gpu/transforms/uid_generator.h"
#include "xla/mlir/runtime/utils/custom_calls.h"
Expand Down Expand Up @@ -699,6 +700,22 @@ class CollectiveOpLowering : public OpRewritePattern<CollectiveOp> {
return success();
}

static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, AllReduceOp op,
func::CallOp call) {
auto attr = op->getAttrOfType<BoolAttr>("allow_all_reduce_kernel");
call->setAttr(b.getStringAttr("allow_all_reduce_kernel"),
attr ? attr : b.getBoolAttr(false));
return SetSpecificAttrs<AllReduceOp>(b, op, call);
}
static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b,
AllReduceStartOp op,
func::CallOp call) {
auto attr = op->getAttrOfType<BoolAttr>("allow_all_reduce_kernel");
call->setAttr(b.getStringAttr("allow_all_reduce_kernel"),
attr ? attr : b.getBoolAttr(false));
return SetSpecificAttrs<AllReduceStartOp>(b, op, call);
}

template <typename OpT>
static typename std::enable_if_t<is_any<OpT, AllGatherOp, AllGatherStartOp>,
LogicalResult>
Expand Down
35 changes: 25 additions & 10 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# GPU-specific components in XLA service implementation.

load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
load(
"@tsl//tsl/platform:build_config.bzl",
"tf_proto_library",
Expand Down Expand Up @@ -604,7 +605,17 @@ tsl_gpu_library(
":ir_emission_utils",
":nccl_utils",
":thunk",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@llvm-project//mlir:IR",
"//xla:shape_util",
"//xla:status",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -613,22 +624,26 @@ tsl_gpu_library(
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_activation",
"//xla/stream_executor/gpu:gpu_activation_header",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/translate/hlo_to_mhlo:hlo_utils",
"//xla/translate/mhlo_to_hlo:attribute_exporter",
"//xla/translate/mhlo_to_hlo:type_to_shape",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:logging",
],
] + if_nccl([":all_reduce_kernel"]),
)

cuda_library(
name = "all_reduce_kernel",
srcs = if_cuda_is_configured(["all_reduce_kernel.cu.cc"]),
hdrs = if_cuda_is_configured(["all_reduce_kernel.h"]),
compatible_with = [],
tags = ["manual"], # only builds if_nccl.
deps = if_cuda_is_configured(["@local_config_nccl//:nccl"]),
alwayslink = 1,
)

# Empty library to implement nested dependency conditions.
Expand Down
254 changes: 254 additions & 0 deletions xla/service/gpu/all_reduce_kernel.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/all_reduce_kernel.h"

#include <cassert>
#include <cstdint>

namespace {

using xla::gpu::kLaunchBounds;
using xla::gpu::kMaxBuffers;
using xla::gpu::kMaxNumGpus;
using xla::gpu::SyncFlag;

// Like std::array<T, kMaxNumGpus>, without the need for `relaxed-constexpr`.
template <typename T>
struct Array {
__device__ constexpr const T& operator[](int i) const { return data[i]; }

private:
T data[kMaxNumGpus];
};

struct float2 {
__device__ explicit float2(__nv_bfloat162 value)
: x(__bfloat162float(value.x)), y(__bfloat162float(value.y)) {}
__device__ operator __nv_bfloat162() const {
__nv_bfloat162 result;
result.x = __float2bfloat16_rn(x);
result.y = __float2bfloat16_rn(y);
return result;
}
__device__ float2& operator+=(const float2& rhs) {
x += rhs.x;
y += rhs.y;
return *this;
}

private:
float x, y;
};

template <typename T>
struct MathType {
using type = T;
};
template <>
struct MathType<__nv_bfloat16> {
using type = float;
};
template <>
struct MathType<__nv_bfloat162> {
using type = float2;
};
} // namespace

static __device__ uint32_t atomic_inc_release_system(uint32_t* ptr,
uint32_t value) {
#if __CUDA_ARCH__ >= 700
uint32_t result = 0;
asm volatile("atom.inc.release.sys.u32 %0, [%1], %2;"
: "=r"(result)
: "l"(ptr), "r"(value)
: "memory");
return result;
#elif __CUDA_ARCH__ >= 600
return atomicInc_system(ptr, value);
#else
return __trap(), 0; // Unsupported.
#endif
}

static __device__ uint32_t atomic_load_acquire_system(uint32_t* ptr) {
uint32_t result = 0;
#if __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.b32 %0, [%1];"
: "=r"(result)
: "l"(ptr)
: "memory");
#else
asm volatile("ld.volatile.b32 %0, [%1];"
: "=r"(result)
: "l"(ptr)
: "memory");
#endif
return result;
}

static __global__ void SyncKernel(uint32_t* counter) {
atomic_inc_release_system(counter, kMaxNumGpus);
while (atomic_load_acquire_system(counter) != 0) {
}
}

template <typename T>
static __global__ void __launch_bounds__(kLaunchBounds)
AllReduceKernel(int num_gpus, Array<const T* __restrict> send_buffers,
Array<T* __restrict> recv_buffers, int64_t num_elements,
uint32_t* counter, SyncFlag sync_flag) {
if (sync_flag & SyncFlag::SYNC_START) {
if (threadIdx.x == 0) {
while (atomic_load_acquire_system(counter) != num_gpus - 1) {
}
}
__syncthreads();
}

T vals[kMaxNumGpus];
for (int tid = blockDim.x * blockIdx.x + threadIdx.x; tid < num_elements;
tid += blockDim.x * gridDim.x) {
// Static loop bounds is required to store 'vals' in registers.
for (int i = 0; i < kMaxNumGpus; ++i) {
if (i >= num_gpus) break;
vals[i] = send_buffers[i][tid];
}
using MathType = typename MathType<T>::type;
MathType result = static_cast<MathType>(vals[0]);
for (int i = 1; i < kMaxNumGpus; ++i) {
if (i >= num_gpus) break;
result += static_cast<MathType>(vals[i]);
}
for (int i = 0; i < kMaxNumGpus; ++i) {
if (i >= num_gpus) break;
recv_buffers[i][tid] = result;
}
}

if (sync_flag & SyncFlag::SYNC_END) {
__syncthreads();
if (threadIdx.x == 0) {
atomic_inc_release_system(counter, num_gpus + gridDim.x - 2);
}
}
}

// bfloat16x2 kernel for sm80+ that requires num_elements to be multiple of 32.
static __global__ void __launch_bounds__(kLaunchBounds)
AllReduceKernelAsync(int num_gpus,
Array<const __nv_bfloat162* __restrict> send_buffers,
Array<__nv_bfloat162* __restrict> recv_buffers,
int64_t num_elements, uint32_t* counter,
SyncFlag sync_flag) {
assert(num_elements % 32 == 0);

if (sync_flag & SyncFlag::SYNC_START) {
if (threadIdx.x == 0) {
while (atomic_load_acquire_system(counter) != num_gpus - 1) {
}
}
__syncthreads();
}

#if __CUDA_ARCH__ >= 800
__shared__ __nv_bfloat162 data[kMaxNumGpus][kLaunchBounds];

// Groups of 4 consecutive threads load 4*bfloat16x2 (16B) each from 4
// different GPUs at a time. That is, thread 4*k+i loads
// elements [4*k, 4*k+1, 4*k+2, 4*k+3] from GPUs [i, i+4, i+8, i+12].
int start_gpu = threadIdx.x & 0x3;
int start_offset = threadIdx.x & ~0x3;
uint32_t start_shared =
__cvta_generic_to_shared(data[start_gpu] + start_offset);

for (int offset = blockDim.x * blockIdx.x + start_offset;
offset < num_elements; offset += blockDim.x * gridDim.x) {
uint32_t shared = start_shared;
for (int i = start_gpu; i < kMaxNumGpus; i += 4) {
if (i >= num_gpus) break;
asm volatile(
"cp.async.ca.shared.global [%0], [%1], 16, 16;" ::"r"(shared),
"l"(send_buffers[i] + offset)
: "memory");
shared += 4 * kLaunchBounds * sizeof(__nv_bfloat162);
}
asm volatile("cp.async.wait_all;" ::: "memory");
__syncwarp();

const __nv_bfloat162* ptr = data[0] + threadIdx.x;
auto f32x2 = __bfloat1622float2(*ptr);
for (int i = 1; i < kMaxNumGpus; ++i) {
if (i >= num_gpus) break;
ptr += kLaunchBounds;
auto tmp = __bfloat1622float2(*ptr);
f32x2.x += tmp.x;
f32x2.y += tmp.y;
}
__nv_bfloat162 bf16x2 = __floats2bfloat162_rn(f32x2.x, f32x2.y);
unsigned result = reinterpret_cast<const unsigned&>(bf16x2);
uint4 results = {
__shfl_sync(~0u, result, 0, 4), // x
__shfl_sync(~0u, result, 1, 4), // y
__shfl_sync(~0u, result, 2, 4), // z
__shfl_sync(~0u, result, 3, 4) // w
};

for (int i = start_gpu; i < kMaxNumGpus; i += 4) {
if (i >= num_gpus) break;
*reinterpret_cast<uint4* __restrict>(recv_buffers[i] + offset) = results;
}
}
#else
__trap(); // Unsupported.
#endif

if (sync_flag & SyncFlag::SYNC_END) {
__syncthreads();
if (threadIdx.x == 0) {
atomic_inc_release_system(counter, num_gpus + gridDim.x - 2);
}
}
}

const void* xla::gpu::GetSyncKernel() {
return reinterpret_cast<const void*>(&SyncKernel);
}

const void* xla::gpu::GetAllReduceKernel(ncclDataType_t dtype,
int64_t* num_elements, int cc_major) {
// Clang crashes if not wrapped in a IFEE.
return [&]() -> const void* {
switch (dtype) {
case ncclBfloat16:
if (cc_major >= 8 && *num_elements % 64 == 0) {
*num_elements /= 2;
return reinterpret_cast<const void*>(&AllReduceKernelAsync);
}
if (*num_elements % 2 == 0) {
*num_elements /= 2;
return reinterpret_cast<const void*>(
&AllReduceKernel<__nv_bfloat162>);
}
return reinterpret_cast<const void*>(&AllReduceKernel<__nv_bfloat16>);
case ncclFloat32:
return reinterpret_cast<const void*>(&AllReduceKernel<float>);
case ncclInt32:
return reinterpret_cast<const void*>(&AllReduceKernel<int32_t>);
default:
return nullptr;
}
}();
}

0 comments on commit 16e953a

Please sign in to comment.