Skip to content

Commit

Permalink
[Engine] Forbidden some oneDNN memory classes to allocate underlying …
Browse files Browse the repository at this point in the history
…buffers in operators (#1030)
  • Loading branch information
zhentaoyu committed Jun 15, 2023
1 parent 5acaaf0 commit 5f75df3
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
#define NEURALENGINE_API_
#endif // GTEST_API_

// Special pointer value that indicates that a oneDNN memory object should not have
// an underlying buffer.
#define DNNL_MEMORY_NONE (NULL)

namespace executor {

using std::max;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ dnnl::binary::desc BinaryAddOperator::PrepareBroadcastBinaryDesc(const memory::d
memory::desc user_dst_md(jit_pass_src0_shape, type2mem[output[0]->dtype()], dt_tag);

// Prepare memory objects (cached)
user_src0_m_ = memory(user_src0_md, eng_);
user_src1_m_ = memory(user_src1_md, eng_);
user_dst_m_ = memory(user_dst_md, eng_);
user_src0_m_ = memory(user_src0_md, eng_, DNNL_MEMORY_NONE);
user_src1_m_ = memory(user_src1_md, eng_, DNNL_MEMORY_NONE);
user_dst_m_ = memory(user_dst_md, eng_, DNNL_MEMORY_NONE);

dnnl::binary::desc binary_d(algo_, user_src0_md, user_src1_md, user_dst_md);
return binary_d;
Expand Down Expand Up @@ -238,9 +238,9 @@ dnnl::binary::desc BinaryAddOperator::PrepareStrideBinaryDesc(const memory::dims
memory::desc any_dst_md(user_dst_md.dims(), user_dst_md.data_type(), memory::format_tag::any);

// 5. Prepare memory objects (cached)
user_src0_m_ = memory(user_src0_md, eng_);
user_src1_m_ = memory(user_src1_md, eng_);
user_dst_m_ = memory(user_dst_md, eng_);
user_src0_m_ = memory(user_src0_md, eng_, DNNL_MEMORY_NONE);
user_src1_m_ = memory(user_src1_md, eng_, DNNL_MEMORY_NONE);
user_dst_m_ = memory(user_dst_md, eng_, DNNL_MEMORY_NONE);

// 6. Prepare op descriptors
dnnl::binary::desc binary_d(algorithm::binary_add, user_src0_md, user_src1_md, any_dst_md);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ void BinaryOpOperator::Reshape(const vector<Tensor*>& input, const vector<Tensor
auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr, eng_);

// Create src memory objects.
src_0_mem_ = memory(src_0_md, eng_);
src_1_mem_ = memory(src_1_md, eng_);
dst_mem_ = memory(dst_md, eng_);
src_0_mem_ = memory(src_0_md, eng_, DNNL_MEMORY_NONE);
src_1_mem_ = memory(src_1_md, eng_, DNNL_MEMORY_NONE);
dst_mem_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);

// Create the primitive.
binary_prim_ = dnnl::binary(binary_pd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ void ConvolutionOperator::Reshape(const vector<Tensor*>& input, const vector<Ten
// scale_shape.push_back(1);
scale_shape.push_back(weight_min_->size());
memory::desc scale_md_ = memory::desc(scale_shape, memory::data_type::f32, GetStrides(scale_shape));
scale_f32_mem_ = memory(scale_md_, eng_);
scale_f32_mem_ = memory(scale_md_, eng_, DNNL_MEMORY_NONE);
// if (!(dst_->dtype() == "u8" || dst_->dtype() == "s8")) {
// dnnl::post_ops po;
// po.append_binary(algorithm::binary_mul, scale_md_);
Expand All @@ -393,7 +393,8 @@ void ConvolutionOperator::Reshape(const vector<Tensor*>& input, const vector<Ten
if (src_->dtype() == "u8") {
mask = src_min_->size() > 1 ? 2 : 0;
zp_src0_mem_ = memory(
{{static_cast<dnnl_dim_t>(src_min_->size())}, memory::data_type::s32, GetStrides(scale_shape)}, eng_);
{{static_cast<dnnl_dim_t>(src_min_->size())}, memory::data_type::s32, GetStrides(scale_shape)}, eng_,
DNNL_MEMORY_NONE);
attr_.set_zero_points(DNNL_ARG_SRC, mask, {DNNL_RUNTIME_S32_VAL});
}
}
Expand Down Expand Up @@ -516,15 +517,15 @@ void ConvolutionOperator::Reshape(const vector<Tensor*>& input, const vector<Ten
dnnl::eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_gelu_erf, gelu_md, 0.f, 0.f);
gelu_pd_ = dnnl::eltwise_forward::primitive_desc(gelu_d, gelu_eng_);
gelu_p_ = dnnl::eltwise_forward(gelu_pd_);
gelu_m_ = memory(gelu_md, gelu_eng_);
gelu_m_ = memory(gelu_md, gelu_eng_, DNNL_MEMORY_NONE);
}
if (gelu_tanh_ && gelu_split_) {
memory::desc gelu_md = memory::desc(dst_shape_origin, type2mem[dst_->dtype()], dst_stride);
auto gelu_d =
dnnl::eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_gelu_tanh, gelu_md, 0.f, 0.f);
gelu_pd_ = dnnl::eltwise_forward::primitive_desc(gelu_d, gelu_eng_);
gelu_p_ = dnnl::eltwise_forward(gelu_pd_);
gelu_m_ = memory(gelu_md, gelu_eng_);
gelu_m_ = memory(gelu_md, gelu_eng_, DNNL_MEMORY_NONE);
}
if (binary_add_) {
// The binary primitive requires all source and destination tensors to have the same number of dimensions.
Expand All @@ -541,7 +542,7 @@ void ConvolutionOperator::Reshape(const vector<Tensor*>& input, const vector<Ten
memory::desc binary_md = memory::desc(post_shape, type2mem[post_->dtype()], post_stride);
po.append_binary(algorithm::binary_add, binary_md);
attr.set_post_ops(po);
binary_m_ = memory(binary_md, eng_);
binary_m_ = memory(binary_md, eng_, DNNL_MEMORY_NONE);
attr_ = attr;
}

Expand All @@ -558,8 +559,8 @@ void ConvolutionOperator::Reshape(const vector<Tensor*>& input, const vector<Ten
memory_args_[DNNL_ARG_SCRATCHPAD] = scratchpad_m;

// 2.4 Prepare memory objects (cached)
src_m_ = memory(src_md, eng_);
dst_m_ = memory(dst_md, eng_);
src_m_ = memory(src_md, eng_, DNNL_MEMORY_NONE);
dst_m_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);
if (!weight_cached_) {
memory any_weight_m = weight_m_;
if (convolution_pd_.weights_desc() != weight_m_.get_desc()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ void ExpOperator::Reshape(const vector<Tensor*>& input, const vector<Tensor*>& o
exp_p_ = dnnl::eltwise_forward(exp_pd);

// 2.5 Prepare memory objects (cached)
src_m_ = memory(src_md, eng_);
dst_m_ = memory(dst_md, eng_);
src_m_ = memory(src_md, eng_, DNNL_MEMORY_NONE);
dst_m_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ void GeluOperator::ReshapeWithOnednn(const vector<Tensor*>& input, const vector<
memory::desc dst_md(dst_shape, type2mem[output[0]->dtype()], dst_stride);

// 1.5 Prepare memory objects (cached)
src_m_ = memory(src_md, eng_);
dst_m_ = memory(dst_md, eng_);
src_m_ = memory(src_md, eng_, DNNL_MEMORY_NONE);
dst_m_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);

//// Part2: Prepare primitive
// 2.1 Prepare op descriptors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
attr_.set_output_scales(ic_dim, {DNNL_RUNTIME_F32_VAL});
vector<int64_t> scale_shape;
scale_shape.push_back(src1_min_->size());
scale_f32_mem_ = memory({scale_shape, memory::data_type::f32, GetStrides(scale_shape)}, eng_);
scale_f32_mem_ = memory({scale_shape, memory::data_type::f32, GetStrides(scale_shape)}, eng_, DNNL_MEMORY_NONE);
// need zero point when src0 is u8
if (src0_->dtype() == "u8") {
vector<int64_t> zero_point_shape(src1_shape);
Expand Down Expand Up @@ -1362,22 +1362,22 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
dnnl::eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_gelu_erf, gelu_md, 0.f, 0.f);
gelu_pd_ = dnnl::eltwise_forward::primitive_desc(gelu_d, gelu_eng_);
gelu_p_ = dnnl::eltwise_forward(gelu_pd_);
gelu_m_ = memory(gelu_md, gelu_eng_);
gelu_m_ = memory(gelu_md, gelu_eng_, DNNL_MEMORY_NONE);
}
if (gelu_tanh_ && gelu_split_) {
memory::desc gelu_md = memory::desc(dst_shape_origin, type2mem[dst_->dtype()], dst_stride);
auto gelu_d =
dnnl::eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_gelu_tanh, gelu_md, 0.f, 0.f);
gelu_pd_ = dnnl::eltwise_forward::primitive_desc(gelu_d, gelu_eng_);
gelu_p_ = dnnl::eltwise_forward(gelu_pd_);
gelu_m_ = memory(gelu_md, gelu_eng_);
gelu_m_ = memory(gelu_md, gelu_eng_, DNNL_MEMORY_NONE);
}
if (binary_add_) {
vector<int64_t> post_shape = post_->shape();
vector<int64_t> post_stride = GetStrides(post_shape);
memory::desc binary_md = memory::desc(post_shape, type2mem[post_->dtype()], post_stride);
po.append_binary(algorithm::binary_add, binary_md);
binary_m_ = memory(binary_md, eng_);
binary_m_ = memory(binary_md, eng_, DNNL_MEMORY_NONE);
}
if (append_eltwise_ || append_sum_ || binary_add_ || is_dynamic_) attr_.set_post_ops(po);

Expand All @@ -1398,13 +1398,13 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
memory_args_[DNNL_ARG_SCRATCHPAD] = scratchpad_m;

// 2.4 Prepare memory objects (cached)
src0_m_ = memory(src0_md, eng_);
dst_m_ = memory(dst_md, eng_);
src0_m_ = memory(src0_md, eng_, DNNL_MEMORY_NONE);
dst_m_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);
if (!weight_cached_) {
memory any_src1_m = any_src1_m_last_;
if (inner_product_pd_.weights_desc() != any_src1_m_last_.get_desc()) {
void* cached_w_ptr;
any_src1_m = memory(inner_product_pd_.weights_desc(), eng_);
any_src1_m = memory(inner_product_pd_.weights_desc(), eng_, DNNL_MEMORY_NONE);
if (src1_->is_shared()) {
int64_t weight_size = any_src1_m.get_desc().get_size();
void* weight_shm_ptr =
Expand Down Expand Up @@ -1437,7 +1437,7 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
memory any_bias_m = any_bias_m_last_;
if (inner_product_pd_.bias_desc() != any_bias_m_last_.get_desc()) {
void* cached_b_ptr;
any_bias_m = memory(inner_product_pd_.bias_desc(), eng_);
any_bias_m = memory(inner_product_pd_.bias_desc(), eng_, DNNL_MEMORY_NONE);
if (bias_->is_shared()) {
int64_t bias_size = bias_m_.get_desc().get_size();
void* bias_shm_ptr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ void LayerNormOperator::ReshapewithOnednn(const vector<Tensor*>& input, const ve
lnorm_p_ = dnnl::layer_normalization_forward(lnorm_pd);

// 2.5 Prepare memory objects (cached)
src_m_ = memory(src_md, eng_);
dst_m_ = memory(dst_md, eng_);
src_m_ = memory(src_md, eng_, DNNL_MEMORY_NONE);
dst_m_ = memory(dst_md, eng_, DNNL_MEMORY_NONE);
memory mean_m(lnorm_pd.mean_desc(), eng_);
memory variance_m(lnorm_pd.variance_desc(), eng_);

Expand Down

0 comments on commit 5f75df3

Please sign in to comment.