Skip to content

Commit

Permalink
CUDA BFloat embedding (pytorch#44848)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#44848

Reviewed By: izdeby

Differential Revision: D25574204

Pulled By: ngimel

fbshipit-source-id: b35f7253a6ad2b83f7b6b06862a5ab77295373e0
  • Loading branch information
zasdfgbnm authored and hwangdeyu committed Dec 23, 2020
1 parent 9745362 commit bf3d1b4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 91 deletions.
50 changes: 23 additions & 27 deletions aten/src/ATen/native/cuda/Embedding.cu
Expand Up @@ -249,23 +249,21 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
"embedding_backward",
[&]
{
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
<<<grid,
block,
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
stream>>>
(indices_contig.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
static_cast<int>(num_indices),
static_cast<int64_t>(stride),
static_cast<int>(padding_idx));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
using accscalar_t = acc_type<scalar_t, true>;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
<<<grid,
block,
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
stream>>>
(indices_contig.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
static_cast<int>(num_indices),
static_cast<int64_t>(stride),
static_cast<int>(padding_idx));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return grad_weight;
}
Expand Down Expand Up @@ -362,16 +360,14 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
int dim = self.stride(0);

AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
self.data_ptr<scalar_t>(),
unique_indices.data_ptr<index_t>(),
static_cast<accscalar_t>(max_norm),
static_cast<accscalar_t>(norm_type),
dim, self.stride(0), self.stride(1));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
using accscalar_t = acc_type<scalar_t, true>;
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
self.data_ptr<scalar_t>(),
unique_indices.data_ptr<index_t>(),
static_cast<accscalar_t>(max_norm),
static_cast<accscalar_t>(norm_type),
dim, self.stride(0), self.stride(1));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return self;
Expand Down
100 changes: 49 additions & 51 deletions aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Expand Up @@ -272,59 +272,57 @@ Tensor embedding_backward_cuda_kernel(

AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] {
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr, numel, stride,
mode_mean, bag_size.data_ptr<index_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<index_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr, numel, stride,
mode_mean, bag_size.data_ptr<index_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
padding_idx,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<index_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
padding_idx,
stride_warped);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return grad_weight;
Expand Down
22 changes: 10 additions & 12 deletions aten/src/ATen/native/cuda/EmbeddingBag.cu
Expand Up @@ -325,18 +325,16 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
#endif
int grid = 1024;
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
EmbeddingBag_updateOutputKernel<scalar_t, index_t><<<grid, block, 0, stream>>>(
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<index_t>(),
mode == MODE_MAX ? max_indices.data_ptr<index_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
EmbeddingBag_updateOutputKernel<scalar_t, index_t><<<grid, block, 0, stream>>>(
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<index_t>(),
mode == MODE_MAX ? max_indices.data_ptr<index_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});

Expand Down
1 change: 0 additions & 1 deletion test/test_nn.py
Expand Up @@ -12250,7 +12250,6 @@ def test_embedding_bag_non_contiguous_weight(self, device, dtypes):


@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.int, torch.long)
def test_embedding_bag_bfloat16(self, device, dtype):
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
Expand Down

0 comments on commit bf3d1b4

Please sign in to comment.