Skip to content

Commit

Permalink
[Engine] release weight after onednn ip reorder (#981)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu committed Jun 9, 2023
1 parent 2be4a11 commit 3f6b473
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class InnerProductOperator : public Operator {
memory dst_m_;
memory gelu_m_;
memory binary_m_;
memory any_src1_m_last_;
memory any_bias_m_last_;

Tensor* src0_ = nullptr;
Tensor* src1_ = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,8 @@ void InnerProductOperator::PrepareDense(const vector<Tensor*>& input, const vect
any_src1_md_ = memory::desc(src1_shape, type2mem[src1_->dtype()], memory::format_tag::any);
src1_md_ = memory::desc(src1_shape, type2mem[src1_->dtype()], src1_stride);
src1_m_ = memory(src1_md_, eng_, src1_->mutable_data());
// for weight cache
any_src1_m_last_ = memory(src1_md_, eng_, src1_->mutable_data());
if (!src1_perm_.empty() && src1_perm_ == vector<int64_t>{1, 0}) src1_->set_transpose();
if (!src0_perm_.empty() && src0_perm_ == vector<int64_t>{1, 0}) src0_->set_transpose();

Expand All @@ -1162,6 +1164,7 @@ void InnerProductOperator::PrepareDense(const vector<Tensor*>& input, const vect
bias_md_ = memory::desc(bias_shape, type2mem[bias_->dtype()], bias_stride);
any_bias_md_ = memory::desc(bias_shape, type2mem[bias_->dtype()], memory::format_tag::any);
bias_m_ = memory(bias_md_, eng_, bias_->mutable_data());
any_bias_m_last_ = memory(bias_md_, eng_, bias_->mutable_data());
} else {
bias_md_ = memory::desc(bias_shape, dnnl::memory::data_type::f32, bias_stride);
any_bias_md_ = memory::desc(bias_shape, dnnl::memory::data_type::f32, memory::format_tag::any);
Expand Down Expand Up @@ -1337,11 +1340,20 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
: dnnl::inner_product_forward::desc(prop_kind::forward_inference, src0_md, src1_md_, dst_md);

if (format_any_) {
inner_product_d =
has_bias_ || (is_dynamic_ && src0_->dtype() == "u8")
? dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md, any_src1_md_, any_bias_md_,
any_dst_md)
: dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md, any_src1_md_, any_dst_md);
if (MemoryAllocator::CheckMemory(src1_->mutable_data()) == -1) {
inner_product_d =
has_bias_ || (is_dynamic_ && src0_->dtype() == "u8")
? dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md, any_src1_md_, any_bias_md_,
any_dst_md)
: dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md, any_src1_md_, any_dst_md);
} else {
inner_product_d =
has_bias_ || (is_dynamic_ && src0_->dtype() == "u8")
? dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md,
any_src1_m_last_.get_desc(), any_bias_m_last_.get_desc(), any_dst_md)
: dnnl::inner_product_forward::desc(prop_kind::forward_inference, any_src0_md,
any_src1_m_last_.get_desc(), any_dst_md);
}
}

if (gelu_erf_ && gelu_split_) {
Expand Down Expand Up @@ -1389,29 +1401,69 @@ void InnerProductOperator::ReshapeDense(const vector<Tensor*>& input, const vect
src0_m_ = memory(src0_md, eng_);
dst_m_ = memory(dst_md, eng_);
if (!weight_cached_) {
memory any_src1_m = src1_m_;
if (inner_product_pd_.weights_desc() != src1_m_.get_desc()) {
memory any_src1_m = any_src1_m_last_;
if (inner_product_pd_.weights_desc() != any_src1_m_last_.get_desc()) {
void* cached_w_ptr;
any_src1_m = memory(inner_product_pd_.weights_desc(), eng_);
if (src1_->is_shared()) {
int64_t weight_size = any_src1_m.get_desc().get_size();
void* weight_shm_ptr =
MemoryAllocator::ManagedShm().find_or_construct<char>(src1_->name().c_str())[weight_size](0);
any_src1_m.set_data_handle(weight_shm_ptr);
cached_w_ptr = weight_shm_ptr;
} else {
cached_w_ptr = MemoryAllocator::get().GetMemory(any_src1_m.get_desc().get_size(), 1);
any_src1_m.set_data_handle(cached_w_ptr);
}
dnnl::reorder(any_src1_m_last_, any_src1_m).execute(eng_stream_, any_src1_m_last_, any_src1_m);
if (src1_->is_shared() && execution_options_ptr_->execution_mode == ExecutionMode::INFERENCE &&
src1_->life() <= 1) {
MemoryAllocator::ManagedShm().destroy_ptr(src1_->mutable_data());
src1_->set_shm_handle(MemoryAllocator::ManagedShm().get_handle_from_address(cached_w_ptr));
} else {
if (execution_options_ptr_->execution_mode == ExecutionMode::INFERENCE && src1_->life() <= 1) {
if (MemoryAllocator::CheckMemory(src1_->mutable_data()) == -1) {
aligned_free(src1_->mutable_data());
} else {
MemoryAllocator::UnrefMemory(src1_->mutable_data());
}
src1_->set_data(cached_w_ptr);
}
any_src1_m_last_ = memory(inner_product_pd_.weights_desc(), eng_, cached_w_ptr);
}
dnnl::reorder(src1_m_, any_src1_m).execute(eng_stream_, src1_m_, any_src1_m);
}
memory_args_[DNNL_ARG_WEIGHTS] = any_src1_m;
if (!is_dynamic_ && has_bias_) {
memory any_bias_m = bias_m_;
if (inner_product_pd_.bias_desc() != bias_m_.get_desc()) {
memory any_bias_m = any_bias_m_last_;
if (inner_product_pd_.bias_desc() != any_bias_m_last_.get_desc()) {
void* cached_b_ptr;
any_bias_m = memory(inner_product_pd_.bias_desc(), eng_);
if (bias_->is_shared()) {
int64_t bias_size = bias_m_.get_desc().get_size();
void* bias_shm_ptr =
MemoryAllocator::ManagedShm().find_or_construct<char>(bias_->name().c_str())[bias_size](0);
any_bias_m.set_data_handle(bias_shm_ptr);
cached_b_ptr = bias_shm_ptr;
} else {
cached_b_ptr = MemoryAllocator::get().GetMemory(bias_m_.get_desc().get_size(), 1);
any_bias_m.set_data_handle(cached_b_ptr);
}
dnnl::reorder(any_bias_m_last_, any_bias_m).execute(eng_stream_, any_bias_m_last_, any_bias_m);
if (bias_->is_shared() && execution_options_ptr_->execution_mode == ExecutionMode::INFERENCE &&
bias_->life() <= 1) {
MemoryAllocator::ManagedShm().destroy_ptr(bias_->mutable_data());
bias_->set_shm_handle(MemoryAllocator::ManagedShm().get_handle_from_address(cached_b_ptr));
} else {
if (execution_options_ptr_->execution_mode == ExecutionMode::INFERENCE && bias_->life() <= 1) {
if (MemoryAllocator::CheckMemory(bias_->mutable_data()) == -1) {
aligned_free(bias_->mutable_data());
} else {
MemoryAllocator::UnrefMemory(bias_->mutable_data());
}
bias_->set_data(cached_b_ptr);
}
any_bias_m_last_ = memory(inner_product_pd_.bias_desc(), eng_, cached_b_ptr);
}
dnnl::reorder(bias_m_, any_bias_m).execute(eng_stream_, bias_m_, any_bias_m);
}
memory_args_[DNNL_ARG_BIAS] = any_bias_m;
}
Expand Down

0 comments on commit 3f6b473

Please sign in to comment.