Skip to content

Commit

Permalink
Performance fix for torch.cat operator on ROCm (pytorch#46097)
Browse files Browse the repository at this point in the history
Summary:
This pull request is a partial revert of pytorch#44833 for ROCm to fix the performance of the concatenate operator. The changes only affect execution on ROCm and are guarded by the define `__HIP_PLATFORM_HCC__`

Pull Request resolved: pytorch#46097

Test Plan:
Benchmark
`python -m pt.cat_test --tag_filter all --device cuda`

Results on ROCm before the PR:
```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : all

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1,1,1)_N2_dim0_cuda
# Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 10828.314

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(512,512,2)_N2_dim1_cuda
# Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 11888.028

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(128,1024,2)_N2_dim1_cuda
# Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 11898.945

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim0_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 11787.744

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1025,1023,2)_N2_dim1_cuda
# Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 11792.479

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim2_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 11769.718

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f989e5c2510>,111,65]_N5_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f989e5c2510>, 111, 65], N: 5, dim: 0, device: cuda
Forward Execution Time (us) : 11633.882

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[96,<function<lambda>at0x7f989e5c2620>,64]_N5_dim1_cuda
# Input: sizes: [96, <function <lambda> at 0x7f989e5c2620>, 64], N: 5, dim: 1, device: cuda
Forward Execution Time (us) : 11617.768

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[128,64,<function<lambda>at0x7f96eee4df28>]_N5_dim2_cuda
# Input: sizes: [128, 64, <function <lambda> at 0x7f96eee4df28>], N: 5, dim: 2, device: cuda
Forward Execution Time (us) : 11625.143

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f96ef874048>,32,64]_N50_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f96ef874048>, 32, 64], N: 50, dim: 0, device: cuda
Forward Execution Time (us) : 13079.204

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[32,<function<lambda>at0x7f96ef8740d0>,64]_N50_dim1_cuda
# Input: sizes: [32, <function <lambda> at 0x7f96ef8740d0>, 64], N: 50, dim: 1, device: cuda
Forward Execution Time (us) : 13095.620

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[33,65,<function<lambda>at0x7f96ef874158>]_N50_dim2_cuda
# Input: sizes: [33, 65, <function <lambda> at 0x7f96ef874158>], N: 50, dim: 2, device: cuda
Forward Execution Time (us) : 13403.086

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda
# Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 118.704

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda
# Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda
Forward Execution Time (us) : 263.273

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda
# Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda
Forward Execution Time (us) : 463.024

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f96ef8741e0>]_N100_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f96ef8741e0>], N: 100, dim: 0, device: cuda
Forward Execution Time (us) : 23818.032

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f96ef874268>]_N1000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f96ef874268>], N: 1000, dim: 0, device: cuda
Forward Execution Time (us) : 234778.296

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f96ef8742f0>]_N2000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f96ef8742f0>], N: 2000, dim: 0, device: cuda
Forward Execution Time (us) : 470288.132

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f96ef874378>]_N3000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f96ef874378>], N: 3000, dim: 0, device: cuda
Forward Execution Time (us) : 704361.221
```

Results on ROCm after the PR:
```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : all

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1,1,1)_N2_dim0_cuda
# Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 29.292

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(512,512,2)_N2_dim1_cuda
# Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 46.320

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(128,1024,2)_N2_dim1_cuda
# Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 36.969

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim0_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 92.816

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1025,1023,2)_N2_dim1_cuda
# Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 93.943

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim2_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 163.914

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1da3186510>,111,65]_N5_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1da3186510>, 111, 65], N: 5, dim: 0, device: cuda
Forward Execution Time (us) : 75.475

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[96,<function<lambda>at0x7f1da3186620>,64]_N5_dim1_cuda
# Input: sizes: [96, <function <lambda> at 0x7f1da3186620>, 64], N: 5, dim: 1, device: cuda
Forward Execution Time (us) : 68.880

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[128,64,<function<lambda>at0x7f1bf3c50f28>]_N5_dim2_cuda
# Input: sizes: [128, 64, <function <lambda> at 0x7f1bf3c50f28>], N: 5, dim: 2, device: cuda
Forward Execution Time (us) : 85.268

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1bf4669048>,32,64]_N50_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1bf4669048>, 32, 64], N: 50, dim: 0, device: cuda
Forward Execution Time (us) : 111.543

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[32,<function<lambda>at0x7f1bf46690d0>,64]_N50_dim1_cuda
# Input: sizes: [32, <function <lambda> at 0x7f1bf46690d0>, 64], N: 50, dim: 1, device: cuda
Forward Execution Time (us) : 110.644

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[33,65,<function<lambda>at0x7f1bf4669158>]_N50_dim2_cuda
# Input: sizes: [33, 65, <function <lambda> at 0x7f1bf4669158>], N: 50, dim: 2, device: cuda
Forward Execution Time (us) : 116.201

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda
# Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 117.708

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda
# Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda
Forward Execution Time (us) : 264.953

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda
# Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda
Forward Execution Time (us) : 480.304

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1bf46691e0>]_N100_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1bf46691e0>], N: 100, dim: 0, device: cuda
Forward Execution Time (us) : 116.385

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1bf4669268>]_N1000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1bf4669268>], N: 1000, dim: 0, device: cuda
Forward Execution Time (us) : 913.591

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1bf46692f0>]_N2000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1bf46692f0>], N: 2000, dim: 0, device: cuda
Forward Execution Time (us) : 2003.212

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f1bf4669378>]_N3000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f1bf4669378>], N: 3000, dim: 0, device: cuda
Forward Execution Time (us) : 3004.174
```

Reviewed By: bdhirsh

Differential Revision: D24286324

Pulled By: malfet

fbshipit-source-id: 291f3f3f80f9d2f9ba52a455a942f3fb0406e7d2
  • Loading branch information
ashishfarmer authored and facebook-github-bot committed Oct 14, 2020
1 parent 09842a4 commit d5ca53c
Showing 1 changed file with 168 additions and 5 deletions.
173 changes: 168 additions & 5 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
namespace at {
namespace native {

#ifdef __HIP_PLATFORM_HCC__
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
#else
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
#endif
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;

namespace {
Expand Down Expand Up @@ -78,6 +82,46 @@ struct OutputTensorSizeStride {
* The most important assumption made is that the input tensors are contiguous.
*/


// Use pinned memory and and pass the struct by pointer on ROCm
template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};

template <typename T, typename IndexType, int Dims>
C10_LAUNCH_BOUNDS_1(512)
__global__ void HIP_CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs[blockIdx.y].nElements;

if(tid >= nElements) return;

T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType dataOffset = offset * dimStride;

IndexType stride = gridDim.x * blockDim.x;

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

tid += stride;
}
}

// pass meta data directly through kernel argument instead of pin memory
template <typename T, typename IndexType, int n>
struct CatArrInputTensorMetadata {
Expand All @@ -88,9 +132,6 @@ struct CatArrInputTensorMetadata {
};

template <typename T, typename IndexType, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensorMetadata<T, IndexType, CAT_ARRAY_BATCH_SIZE> inputs,
Expand Down Expand Up @@ -141,6 +182,122 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
}
}

template <typename scalar_t>
void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();

// Kernel Parameter
long tensorMetadataSize =
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
auto d_inputs_storage = at::empty(
{tensorMetadataSize}, out.options().dtype(at::kByte));
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());

OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
{
auto stackInputs_storage = at::empty({tensorMetadataSize},
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto stackInputs =
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
stackInputs_storage.data_ptr());
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);

stackInputs[batchCounter].input =
inputs[i+batchCounter].data_ptr<scalar_t>();
stackInputs[batchCounter].offset = offset;
stackInputs[batchCounter].dimSize = dimSize;
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();

// update offset
offset += dimSize;
}
at::native::copy_(d_inputs_storage, stackInputs_storage,
/* non_blocking= */ true);
}

// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
// is based on.
dim3 applyBlock = dim3(32*16);

//Get grid where x dim fills half gpu and y dim is number of tensors.
//This will have cating two tensors fill the entire grid, but prevent
//many threads from needlessly load meta data if their sizes is small.
dim3 catGrid;
getCatGrid(batchCounter, catGrid);

if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
case 0:
break;
case 1:
dimension = nDims - dimension;
break;
default:
dimension--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, param, dimension, param.outputStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;
case 2:
HANDLE_CASE(2);
break;
case 3:
HANDLE_CASE(3);
break;
case 4:
HANDLE_CASE(4);
break;
}
#undef HANDLE_CASE
AT_CUDA_CHECK(cudaGetLastError());
}
}

template <typename scalar_t>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
Expand Down Expand Up @@ -235,7 +392,6 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
AT_CUDA_CHECK(cudaGetLastError());
}
}

} // namespace

Tensor cat_cuda(TensorList inputs, int64_t dimension) {
Expand Down Expand Up @@ -373,12 +529,19 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
all32BitIndexable &&
allSameType) {

#ifdef __HIP_PLATFORM_HCC__
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
#else
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});

#endif
} else {
int64_t offset = 0;
for (int j = 0; j < inputs.size(); j++)
Expand Down

0 comments on commit d5ca53c

Please sign in to comment.