From 794250514ef7ed4ad287288ce4415c2a8e1a3f54 Mon Sep 17 00:00:00 2001 From: "Chen, Jian Ping" Date: Tue, 9 Jun 2020 22:40:02 +0800 Subject: [PATCH] Fix EmbeddingBag unit test failure --- .../csrc/cpu/aten/operators/embedding_bag.cpp | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp index bab6988f1..ece001005 100755 --- a/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp +++ b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp @@ -180,12 +180,19 @@ count_and_map_uniq(const at::TensorAccessor& indices_accessor, int64 } template -static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor grad, const at::Tensor indices, const at::Tensor offsets, const at::Tensor offset2bag, int num_weights, int mode) { +static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor grad, const at::Tensor indices, const at::Tensor offsets, int num_weights, int mode) { - assert((mode == MODE_SUM) && (grad.stride(1) == 1)); - - auto offset_numel = offsets.numel(); int64_t indices_numel = indices.numel(); + assert((mode == MODE_SUM) && (grad.stride(1) == 1) && (indices_numel > 0)); + auto offset_numel = offsets.numel(); + at::Tensor offset2bag_ ; + if (offset_numel != indices_numel) { + offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options()); + make_offset2bag(offsets, indices, offset2bag_); + offset2bag_.resize_({indices.sizes()[0]}); + } else { + offset2bag_ = offsets; + } auto indices_accessor = indices.accessor(); std::vector indices_to_index(num_weights, -1ull); std::vector index_to_indices; @@ -220,7 +227,7 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor float* temp_output = temp_grad_weight.data(); zero_ker(temp_output, unique_indices * ddim); - int64_t* offset2bag_data = offset2bag.data_ptr(); + auto offset2bag_accessor = offset2bag_.accessor(); T* grad_data = grad.data_ptr(); at::parallel_for(0, max_threads, 0, [&](int64_t start, int64_t end) { for(int k = start; k < end; k++) { @@ -230,7 +237,7 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor int64_t indices_num = indices_accessor[mb]; int64_t index = indices_to_index[indices_num]; if (index >= chunk_start && index < chunk_end) { - auto s = offset2bag_data[mb]; + auto s = offset2bag_accessor[mb]; add_ker((float*)(temp_output + index * ddim), (T*)(grad_data + s * ddim), ddim); } } @@ -244,10 +251,11 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor return index_grad_weight; } -static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, const at::Tensor indices, const at::Tensor per_sample_weights, bool scale_grad_by_freq, int64_t mode) { +static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, const at::Tensor indices, const at::Tensor offset2bag, const at::Tensor per_sample_weights, bool scale_grad_by_freq, int64_t mode) { if ((grad.scalar_type() != at::kFloat) && (grad.scalar_type() != at::kBFloat16)) return false; if ((mode != MODE_SUM) || (grad.stride(1) != 1)) return false; + if ((indices.numel() == 0) || (offset2bag.numel() != 0)) return false; if (per_sample_weights.defined() || scale_grad_by_freq) return false; return true; @@ -256,17 +264,12 @@ static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, c static inline at::Tensor embedding_bag_get_offset2bag(const at::Tensor indices, const at::Tensor & offsets, const at::Tensor & offset2bag) { - auto offset_numel = offsets.numel(); int64_t indices_numel = indices.numel(); at::Tensor offset2bag_ ; if (indices_numel != 0 && offset2bag.numel() == 0) { - if (indices_numel != offset_numel) { - offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options()); - make_offset2bag(offsets, indices, offset2bag_); - offset2bag_.resize_({indices.sizes()[0]}); - } else { - offset2bag_ = offsets.contiguous(); - } + offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options()); + make_offset2bag(offsets, indices, offset2bag_); + offset2bag_.resize_({indices.sizes()[0]}); } else { offset2bag_ = offset2bag; } @@ -278,7 +281,7 @@ at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const at::Tensor & per_sample_weights) { if (sparse) { - if (embedding_bag_backward_fast_path_sum(grad, indices, per_sample_weights, scale_grad_by_freq, mode)) { + if (embedding_bag_backward_fast_path_sum(grad, indices, offset2bag, per_sample_weights, scale_grad_by_freq, mode)) { if (is_bfloat16_tensor(grad)) { return embedding_bag_sparse_backward_sum_fast(grad, indices, offsets, num_weights, mode); } else { @@ -291,16 +294,16 @@ at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights); } } else { - at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag); auto grad_c = grad.contiguous(); - if (embedding_bag_backward_fast_path_sum(grad_c, indices, per_sample_weights, scale_grad_by_freq, mode)) { + if (embedding_bag_backward_fast_path_sum(grad_c, indices, offset2bag, per_sample_weights, scale_grad_by_freq, mode)) { if (is_bfloat16_tensor(grad)) { - return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, offset2bag_, num_weights, mode); + return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, num_weights, mode); } else { - return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, offset2bag_, num_weights, mode); + return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, num_weights, mode); } } else { //May need full support for Bfloat16 + at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag); return at::_embedding_bag_dense_backward(grad_c, indices, offsets, offset2bag_, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights); }