Skip to content

Commit

Permalink
[Cherry-pick] support load_state_dict after ipex.optimize (#1326) (#1…
Browse files Browse the repository at this point in the history
…338)

* support load_state_dict after ipex.optimize (#1326)

* support load_state_dict after ipex.optimize

* remove un-used import

* fix inference test for NotEqual

* add inplace test && remove unneccessary check

* improve comments

* Update _optimizer_utils.py

cherry-pick intel-innersource/frameworks.ai.pytorch.ipex-cpu#1277

* Update _optimizer_utils.py

---------

Co-authored-by: zhuhaozhe <haozhe.zhu@intel.com>
  • Loading branch information
jianan-gu and zhuhaozhe committed Dec 29, 2022
1 parent 4c29927 commit 0bdf4b2
Show file tree
Hide file tree
Showing 14 changed files with 579 additions and 105 deletions.
4 changes: 2 additions & 2 deletions csrc/jit/cpu/kernels/ContextConvTranspose.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct ContextConvTranspose final {
// at_weight will share same memory with weight_packed_
// at_weight is used for autograd and optimizer update
at::Tensor at_weight_;
c10::optional<at::Tensor> bias_;
c10::optional<at::Tensor> at_bias_;
// paddings_, strided_, dilation_, output_padding_ here are expanded and
// might different with those stored on ConvTransposeOpContext.
// For example, aten deconv2d can accept padding = 2, but onednn deconv2d need
Expand Down Expand Up @@ -48,7 +48,7 @@ struct ContextConvTranspose final {
: original_desc_(std::move(original_desc)),
weight_packed_(std::move(weight_packed)),
at_weight_(std::move(at_weight)),
bias_(std::move(bias)),
at_bias_(std::move(bias)),
padding_(padding),
output_padding_(output_padding),
stride_(stride),
Expand Down
4 changes: 2 additions & 2 deletions csrc/jit/cpu/kernels/ContextLinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct ContextLinear final {
// at_weight will share same memory with weight_packed_
// at_weight is used for autograd and optimizer update
at::Tensor at_weight_;
c10::optional<at::Tensor> bias_;
c10::optional<at::Tensor> at_bias_;

ContextLinear() = delete;

Expand All @@ -25,7 +25,7 @@ struct ContextLinear final {
: original_desc_(std::move(original_desc)),
weight_packed_(std::move(weight_packed)),
at_weight_(std::move(at_weight)),
bias_(std::move(bias)) {}
at_bias_(std::move(bias)) {}

ContextLinear(ContextLinear&&) = default;
ContextLinear& operator=(ContextLinear&&) = default;
Expand Down
10 changes: 5 additions & 5 deletions csrc/jit/cpu/kernels/ContextLinearMKL.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace cpu {
namespace detail {
struct ContextLinearMKL final {
std::vector<int64_t> sgemm_sizes_ = {0, 0, 0};
at::Tensor mkl_weight_;
at::Tensor ori_weight_;
c10::optional<at::Tensor> bias_;
at::Tensor at_weight_; // packed at weight
at::Tensor ori_weight_; // non-packed at weight
c10::optional<at::Tensor> at_bias_;

ContextLinearMKL() = delete;

Expand All @@ -21,9 +21,9 @@ struct ContextLinearMKL final {
at::Tensor&& ori_weight,
c10::optional<at::Tensor>&& bias)
: sgemm_sizes_(std::move(sgemm_sizes)),
mkl_weight_(std::move(mkl_weight)),
at_weight_(std::move(mkl_weight)),
ori_weight_(std::move(ori_weight)),
bias_(std::move(bias)) {}
at_bias_(std::move(bias)) {}

ContextLinearMKL(ContextLinearMKL&&) = default;
ContextLinearMKL& operator=(ContextLinearMKL&&) = default;
Expand Down
8 changes: 4 additions & 4 deletions csrc/jit/cpu/kernels/ConvTransposePacked.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ at::Tensor run(
check_shape_forward(
input_.sizes(),
context.origin_weight_dims_,
context.bias_,
context.at_bias_,
context.padding_,
context.stride_,
context.dilation_,
Expand All @@ -361,7 +361,7 @@ at::Tensor run(
return conv_transpose_kernel_impl(
input_,
context.weight_packed_,
context.bias_,
context.at_bias_,
context.stride_,
context.padding_,
context.output_padding_,
Expand Down Expand Up @@ -397,7 +397,7 @@ at::Tensor& run(
check_shape_forward(
input_.sizes(),
context.origin_weight_dims_,
context.bias_,
context.at_bias_,
context.padding_,
context.stride_,
context.dilation_,
Expand All @@ -406,7 +406,7 @@ at::Tensor& run(
conv_transpose_out_kernel_impl(
input_,
context.weight_packed_,
context.bias_,
context.at_bias_,
accumu,
context.stride_,
context.padding_,
Expand Down
8 changes: 4 additions & 4 deletions csrc/jit/cpu/kernels/LinearMKLPacked.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ at::Tensor run(ContextLinearMKL& context, const at::Tensor& input) {
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
auto input_ = input.contiguous();
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
at::borrow_from_optional_tensor(context.bias_);
at::borrow_from_optional_tensor(context.at_bias_);
const at::Tensor& bias = *bias_maybe_owned;
int64_t input_batch = (int64_t)(input_.numel() / K);

Expand All @@ -71,7 +71,7 @@ at::Tensor run(ContextLinearMKL& context, const at::Tensor& input) {
if (input_batch != context.sgemm_sizes_[0])
return mkl_sgemm_kernel(input_, context.ori_weight_, bias);
return mkl_prepack_sgemm_kernel(
input_, context.mkl_weight_, bias, context.sgemm_sizes_[2]);
input_, context.at_weight_, bias, context.sgemm_sizes_[2]);
}

at::Tensor& run(
Expand All @@ -84,14 +84,14 @@ at::Tensor& run(
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
auto input_ = input.contiguous();
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
at::borrow_from_optional_tensor(context.bias_);
at::borrow_from_optional_tensor(context.at_bias_);
const at::Tensor& bias = *bias_maybe_owned;
int64_t input_batch = (int64_t)(input_.numel() / K);
if (input_batch != context.sgemm_sizes_[0]) {
mkl_sgemm_kernel_output(input_, context.ori_weight_, bias, accumu);
} else {
mkl_prepack_sgemm_kernel_output(
input_, context.mkl_weight_, bias, context.sgemm_sizes_[2], accumu);
input_, context.at_weight_, bias, context.sgemm_sizes_[2], accumu);
}
return accumu;
}
Expand Down
36 changes: 33 additions & 3 deletions csrc/jit/cpu/kernels/LinearPacked.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ at::Tensor run(
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
auto input_ = input.contiguous();
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
at::borrow_from_optional_tensor(context.bias_);
at::borrow_from_optional_tensor(context.at_bias_);
const at::Tensor& bias = *bias_maybe_owned;
return linear_kernel(input_, context.weight_packed_, bias, attr);
}
Expand All @@ -233,12 +233,42 @@ at::Tensor& run(
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
auto input_ = input.contiguous();
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
at::borrow_from_optional_tensor(context.bias_);
at::borrow_from_optional_tensor(context.at_bias_);
const at::Tensor& bias = *bias_maybe_owned;
linear_kernel_output(input_, context.weight_packed_, bias, accumu, attr);
return accumu;
}

void run_core(
const ContextLinear& context,
const at::Tensor& input,
at::Tensor& accumu,
const ideep::attr_t attr) {
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
ideep::tensor mkldnn_output = itensor_view_from_dense(accumu);
ideep::inner_product_forward_params param;
TORCH_CHECK(
input.size(input.dim() - 1) == context.weight_packed_.get_dims()[1],
"Check the shapes of mat1 and mat2, they cannot be multiplied!");
if (context.at_bias_) {
auto mkl_bias = itensor_view_from_dense(*context.at_bias_);
ideep::inner_product_forward::prepare(
param,
mkldnn_input,
context.weight_packed_,
mkl_bias,
mkldnn_output,
attr);
ideep::inner_product_forward::compute<true, false>(
param, mkldnn_input, context.weight_packed_, mkl_bias, mkldnn_output);
} else {
ideep::inner_product_forward::prepare(
param, mkldnn_input, context.weight_packed_, mkldnn_output, attr);
ideep::inner_product_forward::compute<true, false>(
param, mkldnn_input, context.weight_packed_, mkldnn_output);
}
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> run_backward(
ContextLinear& context,
const at::Tensor& input,
Expand All @@ -250,7 +280,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> run_backward(
context.at_weight_,
output_mask,
context.weight_packed_,
context.bias_);
context.at_bias_);
}

at::Tensor pack(ContextLinear& context, const at::Tensor& tensor) {
Expand Down
50 changes: 48 additions & 2 deletions csrc/jit/cpu/kernels/OpContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
namespace torch_ipex {
namespace cpu {

template <typename T1, typename T2>
void load_from_ctx_template(T1* self, c10::intrusive_ptr<T2> other) {
auto& other_ctx_ = other->get_context();
auto loaded_weight = other_ctx_.at_weight_;
auto loaded_bias = other_ctx_.at_bias_;
self->get_context().at_weight_.copy_(loaded_weight);
if (loaded_bias.has_value()) {
self->get_context().at_bias_.value().copy_(loaded_bias.value());
}
return;
}

template <>
void load_from_ctx_template<IpexLinearMKLOpContext, MKLOpContext>(
IpexLinearMKLOpContext* self,
c10::intrusive_ptr<MKLOpContext> other) {
auto& other_ctx_ = other->get_context();
auto loaded_weight = other_ctx_.at_weight_;
auto loaded_bias = other_ctx_.at_bias_;
self->get_context().at_weight_.copy_(loaded_weight);
if (loaded_bias.has_value()) {
self->get_context().at_bias_.value().copy_(loaded_bias.value());
}
self->get_context().ori_weight_.copy_(other->get_context().ori_weight_);
return;
}
c10::intrusive_ptr<ConvolutionOpContext> IpexConvolutionOpContext::
create_context(
at::Tensor&& weight,
Expand Down Expand Up @@ -99,6 +125,11 @@ at::Tensor IpexConvolutionOpContext::get_data_handle() {
return ptr;
}

void IpexConvolutionOpContext::load_from_ctx(
c10::intrusive_ptr<ConvolutionOpContext> other) {
load_from_ctx_template(this, other);
}

c10::intrusive_ptr<LinearOpContext> IpexLinearOpContext::create_context(
at::Tensor&& weight,
c10::optional<at::Tensor>&& bias,
Expand Down Expand Up @@ -153,6 +184,11 @@ at::Tensor IpexLinearOpContext::to_public(const at::Tensor& tensor) {
return torch_ipex::cpu::detail::linear::unpack(op_context_, tensor);
}

void IpexLinearOpContext::load_from_ctx(
c10::intrusive_ptr<LinearOpContext> other) {
load_from_ctx_template(this, other);
}

c10::intrusive_ptr<ConvTransposeOpContext> IpexConvTransposeOpContext::
create_context(
at::Tensor&& weight,
Expand Down Expand Up @@ -194,7 +230,7 @@ c10::intrusive_ptr<MKLOpContext> IpexLinearMKLOpContext::create_context(
}

at::Tensor IpexLinearMKLOpContext::get_at_packed_weight() {
return op_context_.mkl_weight_;
return op_context_.at_weight_;
}

at::Tensor IpexLinearMKLOpContext::get_data_handle() {
Expand All @@ -221,7 +257,7 @@ at::Tensor IpexLinearMKLOpContext::to_public(const at::Tensor& tensor) {
return op_context_.ori_weight_.clone();
}

detail::ContextLinearMKL& IpexLinearMKLOpContext::get_mkl_context() {
detail::ContextLinearMKL& IpexLinearMKLOpContext::get_context() {
return op_context_;
}

Expand All @@ -233,6 +269,11 @@ int64_t IpexLinearMKLOpContext::get_in_features() {
return op_context_.sgemm_sizes_[1];
}

void IpexLinearMKLOpContext::load_from_ctx(
c10::intrusive_ptr<MKLOpContext> other) {
load_from_ctx_template(this, other);
}

at::Tensor IpexConvTransposeOpContext::run(
const at::Tensor& input,
const ideep::attr_t& attr) {
Expand Down Expand Up @@ -288,5 +329,10 @@ detail::ContextConvTranspose& IpexConvTransposeOpContext::get_context() {
return op_context_;
}

void IpexConvTransposeOpContext::load_from_ctx(
c10::intrusive_ptr<ConvTransposeOpContext> other) {
load_from_ctx_template(this, other);
}

} // namespace cpu
} // namespace torch_ipex
Loading

0 comments on commit 0bdf4b2

Please sign in to comment.