Skip to content

Commit

Permalink
use legacy unrolled kernel for non-trivial offset calc cases (#71710)
Browse files Browse the repository at this point in the history
Summary:
This leads to across the board improvements on Pascals, big perf improvements for some broadcasting patterns and datatypes on V100 (along with some 3-5% regressions for some other patterns). The most common improving pattern on V100 is half-precision x+bias, that improves by ~5%. Full V100 results in https://docs.google.com/spreadsheets/d/1K67x-6_TPT9Yt6533NfECEhUyfbqBxLH9M5Z3gymzXE/edit#gid=1218963246, benchmarking script in https://gist.github.com/ngimel/986ee84a1dd234a0485e99544e0fc8b6
Most importantly, it reduces context size by 40 MB.

Pull Request resolved: pytorch/pytorch#71710

Reviewed By: mruberry

Differential Revision: D33769330

Pulled By: ngimel

fbshipit-source-id: 5a7942261e06003ca79bfa3b071106aab1a8a4bc
(cherry picked from commit f9b51b4)
  • Loading branch information
ngimel authored and cyyever committed Feb 3, 2022
1 parent dd2fa71 commit d1a8018
Showing 1 changed file with 85 additions and 14 deletions.
99 changes: 85 additions & 14 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Expand Up @@ -406,6 +406,67 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}

template<int nt, int vt, typename func_t>
static void launch_legacy_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
auto stream = at::cuda::getCurrentCUDAStream();
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename traits, typename func_t, typename index_t, size_t... INDEX>
C10_HOST_DEVICE typename traits::result_type
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
std::index_sequence<INDEX...>) {
(void)strides;
(void)i;
return f(*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
}

template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
C10_HOST_DEVICE typename traits::result_type
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
using Indices = std::make_index_sequence<traits::arity>;
return invoke_impl<traits>(f, data, strides, i, Indices{});
}

template <typename traits, typename func_t, typename index_t, size_t... I>
C10_HOST_DEVICE typename traits::result_type
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
std::index_sequence<I...>) {
(void)strides;
(void)i;
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
}

template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
C10_HOST_DEVICE typename traits::result_type
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
using Indices = std::make_index_sequence<traits::arity>;
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
}


template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
Expand All @@ -430,27 +491,37 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
if (contiguous) {
launch_vectorized_kernel(numel, f, data);
} else {
auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
launch_legacy_kernel<128,unroll_factor>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
} else {
at::detail::Array<ScalarType, traits::arity> dtypes;
for (int i = 0; i < traits::arity; i++) {
dtypes[i] = iter.dtype(i + 1);
}
auto loader = memory::LoadWithCast<traits::arity>(dtypes);
auto storer = memory::StoreWithCast(iter.dtype(0));
if (contiguous) {
at::detail::Array<ScalarType, traits::arity> dtypes;
for (int i = 0; i < traits::arity; i++) {
dtypes[i] = iter.dtype(i + 1);
}
auto loader = memory::LoadWithCast<traits::arity>(dtypes);
auto storer = memory::StoreWithCast(iter.dtype(0));
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
} else {
auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
at::detail::Array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.dtype(i);
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
launch_legacy_kernel<128, 4>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
}
}
}
Expand Down

0 comments on commit d1a8018

Please sign in to comment.