From bf3d1b4660a8117e8f66cf61f1725f8c48e02d7f Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 16 Dec 2020 09:21:45 -0800 Subject: [PATCH] CUDA BFloat embedding (#44848) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44848 Reviewed By: izdeby Differential Revision: D25574204 Pulled By: ngimel fbshipit-source-id: b35f7253a6ad2b83f7b6b06862a5ab77295373e0 --- aten/src/ATen/native/cuda/Embedding.cu | 50 ++++----- .../native/cuda/EmbeddingBackwardKernel.cu | 100 +++++++++--------- aten/src/ATen/native/cuda/EmbeddingBag.cu | 22 ++-- test/test_nn.py | 1 - 4 files changed, 82 insertions(+), 91 deletions(-) diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index fd97d8ab26b6..80a8bfa5a6e8 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -249,23 +249,21 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice "embedding_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { - using accscalar_t = acc_type; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { - embedding_backward_feature_kernel - <<>> - (indices_contig.data_ptr(), - grad.data_ptr(), - grad_weight.data_ptr(), - static_cast(num_indices), - static_cast(stride), - static_cast(padding_idx)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); + using accscalar_t = acc_type; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { + embedding_backward_feature_kernel + <<>> + (indices_contig.data_ptr(), + grad.data_ptr(), + grad_weight.data_ptr(), + static_cast(num_indices), + static_cast(stride), + static_cast(padding_idx)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); return grad_weight; } @@ -362,16 +360,14 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int dim = self.stride(0); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { - using accscalar_t = acc_type; - renorm_kernel<<>>( - self.data_ptr(), - unique_indices.data_ptr(), - static_cast(max_norm), - static_cast(norm_type), - dim, self.stride(0), self.stride(1)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + using accscalar_t = acc_type; + renorm_kernel<<>>( + self.data_ptr(), + unique_indices.data_ptr(), + static_cast(max_norm), + static_cast(norm_type), + dim, self.stride(0), self.stride(1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); return self; diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index 689db4347067..dd0730a38bcb 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -272,59 +272,57 @@ Tensor embedding_backward_cuda_kernel( AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] { - // For numerical stability, the dtype of `grad_weight_per_segment` - // should match `acc_type` - using partial_weight_t = acc_type; - TensorOptions op; - if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) { - op = grad.options().dtype(at::kFloat); - } else { - op = grad.options(); - } - auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op); - // Compute the sum of each partial-segment and handle bags - if (offset2bag.defined()) { - compute_grad_weight_bags<<>>( - orig_indices.data_ptr(), - grad.data_ptr(), - offset2bag.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, numel, stride, - mode_mean, bag_size.data_ptr(), - per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, - partial_segment_offset.data_ptr(), - num_of_partial_segments, grad_weight_per_segment.data_ptr(), - stride_warped); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - compute_grad_weight<<>>( - orig_indices.data_ptr(), - grad.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, - numel, stride, - partial_segment_offset.data_ptr(), - num_of_partial_segments, - grad_weight_per_segment.data_ptr(), - stride_warped); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - - // Finally, we sum all the partial-sums and scatter them - // into `grad_weight`. - const int grid2 = ceil_div(num_of_segments*stride_warped, block); - sum_and_scatter<<>>( - sorted_indices.data_ptr(), - grad_weight.data_ptr(), - stride, - segment_offsets.data_ptr(), - num_of_segments, grad_weight_per_segment.data_ptr(), - partials_per_segment_offset.data_ptr(), + // For numerical stability, the dtype of `grad_weight_per_segment` + // should match `acc_type` + using partial_weight_t = acc_type; + TensorOptions op; + if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) { + op = grad.options().dtype(at::kFloat); + } else { + op = grad.options(); + } + auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op); + // Compute the sum of each partial-segment and handle bags + if (offset2bag.defined()) { + compute_grad_weight_bags<<>>( + orig_indices.data_ptr(), + grad.data_ptr(), + offset2bag.data_ptr(), + count.defined() ? count.data_ptr() : nullptr, numel, stride, + mode_mean, bag_size.data_ptr(), + per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, + partial_segment_offset.data_ptr(), + num_of_partial_segments, grad_weight_per_segment.data_ptr(), + stride_warped); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + compute_grad_weight<<>>( + orig_indices.data_ptr(), + grad.data_ptr(), + count.defined() ? count.data_ptr() : nullptr, + numel, stride, + partial_segment_offset.data_ptr(), num_of_partial_segments, - padding_idx, + grad_weight_per_segment.data_ptr(), stride_warped); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + + // Finally, we sum all the partial-sums and scatter them + // into `grad_weight`. + const int grid2 = ceil_div(num_of_segments*stride_warped, block); + sum_and_scatter<<>>( + sorted_indices.data_ptr(), + grad_weight.data_ptr(), + stride, + segment_offsets.data_ptr(), + num_of_segments, grad_weight_per_segment.data_ptr(), + partials_per_segment_offset.data_ptr(), + num_of_partial_segments, + padding_idx, + stride_warped); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); return grad_weight; diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 651261cf6408..a80de4b45138 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -325,18 +325,16 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, #endif int grid = 1024; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] { - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () { - EmbeddingBag_updateOutputKernel<<>>( - indices.data_ptr(), offsets.data_ptr(), - weight.data_ptr(), output.data_ptr(), - offset2bag.data_ptr(), numIndices, numBags, featureSize, - weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), - mode == MODE_MAX ? max_indices.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () { + EmbeddingBag_updateOutputKernel<<>>( + indices.data_ptr(), offsets.data_ptr(), + weight.data_ptr(), output.data_ptr(), + offset2bag.data_ptr(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), + mode == MODE_MAX ? max_indices.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); diff --git a/test/test_nn.py b/test/test_nn.py index 652b4d85cbed..a3d18bc3e49c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12250,7 +12250,6 @@ def test_embedding_bag_non_contiguous_weight(self, device, dtypes): @onlyCUDA - @skipCUDAIfNotRocm @dtypes(torch.int, torch.long) def test_embedding_bag_bfloat16(self, device, dtype): self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)