Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,19 @@ count_and_map_uniq(const at::TensorAccessor<int64_t, 1>& indices_accessor, int64
}

template<typename T>
static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor grad, const at::Tensor indices, const at::Tensor offsets, const at::Tensor offset2bag, int num_weights, int mode) {
static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor grad, const at::Tensor indices, const at::Tensor offsets, int num_weights, int mode) {

assert((mode == MODE_SUM) && (grad.stride(1) == 1));

auto offset_numel = offsets.numel();
int64_t indices_numel = indices.numel();
assert((mode == MODE_SUM) && (grad.stride(1) == 1) && (indices_numel > 0));
auto offset_numel = offsets.numel();
at::Tensor offset2bag_ ;
if (offset_numel != indices_numel) {
offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options());
make_offset2bag(offsets, indices, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
offset2bag_ = offsets;
}
auto indices_accessor = indices.accessor<int64_t, 1>();
std::vector<int64_t> indices_to_index(num_weights, -1ull);
std::vector<int64_t> index_to_indices;
Expand Down Expand Up @@ -220,7 +227,7 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor
float* temp_output = temp_grad_weight.data();
zero_ker(temp_output, unique_indices * ddim);

int64_t* offset2bag_data = offset2bag.data_ptr<int64_t>();
auto offset2bag_accessor = offset2bag_.accessor<int64_t, 1>();
T* grad_data = grad.data_ptr<T>();
at::parallel_for(0, max_threads, 0, [&](int64_t start, int64_t end) {
for(int k = start; k < end; k++) {
Expand All @@ -230,7 +237,7 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor
int64_t indices_num = indices_accessor[mb];
int64_t index = indices_to_index[indices_num];
if (index >= chunk_start && index < chunk_end) {
auto s = offset2bag_data[mb];
auto s = offset2bag_accessor[mb];
add_ker((float*)(temp_output + index * ddim), (T*)(grad_data + s * ddim), ddim);
}
}
Expand All @@ -244,10 +251,11 @@ static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor
return index_grad_weight;
}

static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, const at::Tensor indices, const at::Tensor per_sample_weights, bool scale_grad_by_freq, int64_t mode) {
static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, const at::Tensor indices, const at::Tensor offset2bag, const at::Tensor per_sample_weights, bool scale_grad_by_freq, int64_t mode) {

if ((grad.scalar_type() != at::kFloat) && (grad.scalar_type() != at::kBFloat16)) return false;
if ((mode != MODE_SUM) || (grad.stride(1) != 1)) return false;
if ((indices.numel() == 0) || (offset2bag.numel() != 0)) return false;
if (per_sample_weights.defined() || scale_grad_by_freq) return false;

return true;
Expand All @@ -256,17 +264,12 @@ static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, c
static inline at::Tensor
embedding_bag_get_offset2bag(const at::Tensor indices, const at::Tensor & offsets, const at::Tensor & offset2bag)
{
auto offset_numel = offsets.numel();
int64_t indices_numel = indices.numel();
at::Tensor offset2bag_ ;
if (indices_numel != 0 && offset2bag.numel() == 0) {
if (indices_numel != offset_numel) {
offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options());
make_offset2bag(offsets, indices, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
offset2bag_ = offsets.contiguous();
}
offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options());
make_offset2bag(offsets, indices, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
offset2bag_ = offset2bag;
}
Expand All @@ -278,7 +281,7 @@ at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
const at::Tensor & per_sample_weights) {
if (sparse) {
if (embedding_bag_backward_fast_path_sum(grad, indices, per_sample_weights, scale_grad_by_freq, mode)) {
if (embedding_bag_backward_fast_path_sum(grad, indices, offset2bag, per_sample_weights, scale_grad_by_freq, mode)) {
if (is_bfloat16_tensor(grad)) {
return embedding_bag_sparse_backward_sum_fast<at::BFloat16>(grad, indices, offsets, num_weights, mode);
} else {
Expand All @@ -291,16 +294,16 @@ at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor
bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights);
}
} else {
at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag);
auto grad_c = grad.contiguous();
if (embedding_bag_backward_fast_path_sum(grad_c, indices, per_sample_weights, scale_grad_by_freq, mode)) {
if (embedding_bag_backward_fast_path_sum(grad_c, indices, offset2bag, per_sample_weights, scale_grad_by_freq, mode)) {
if (is_bfloat16_tensor(grad)) {
return embedding_bag_dense_backward_sum_fast<at::BFloat16>(grad_c, indices, offsets, offset2bag_, num_weights, mode);
return embedding_bag_dense_backward_sum_fast<at::BFloat16>(grad_c, indices, offsets, num_weights, mode);
} else {
return embedding_bag_dense_backward_sum_fast<float>(grad_c, indices, offsets, offset2bag_, num_weights, mode);
return embedding_bag_dense_backward_sum_fast<float>(grad_c, indices, offsets, num_weights, mode);
}
} else {
//May need full support for Bfloat16
at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag);
return at::_embedding_bag_dense_backward(grad_c, indices, offsets, offset2bag_, bag_size,
maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights);
}
Expand Down