Skip to content

Commit

Permalink
Enable Conv1d channels last and Conv1d+gelu fusion in jit path (#657)
Browse files Browse the repository at this point in the history
* Enable Conv1d channels last and conv+gelu fusion in JIT mode

* Add weight prepack for conv1d

* Fix UT error
  • Loading branch information
yanbing-j committed Mar 31, 2022
1 parent e10d5e5 commit a0c063b
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 43 deletions.
29 changes: 25 additions & 4 deletions intel_extension_for_pytorch/csrc/aten/cpu/Conv.cpp
Expand Up @@ -48,7 +48,18 @@ void convolution_kernel_output(
(IS_CONTIGUOUS_ANY(input)) && (IS_CONTIGUOUS_ANY(output)),
"input and output are need contiguous tensor for "
"convolution_kernel_output");
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
const ideep::tensor mkldnn_input_ = itensor_view_from_dense(input);
ideep::tensor mkldnn_input = mkldnn_input_;
// The following code forces the 3D input to channels last, which is a
// temporary workaround before channels last 1D is formally supported in
// PyTorch.
if (mkldnn_input_.ndims() == 3 &&
!mkldnn_input_.get_desc().is_channels_last()) {
ideep::tensor mkldnn_input_conv1d{
mkldnn_input_.get_desc().to_format(ideep::format_tag::nwc)};
mkldnn_input_conv1d.feed_from(mkldnn_input_);
mkldnn_input = mkldnn_input_conv1d;
}
auto output_sizes = output.sizes();

ideep::tensor mkldnn_output = itensor_view_from_dense(output);
Expand Down Expand Up @@ -109,9 +120,19 @@ at::Tensor convolution_kernel(
std::vector<int64_t> output_sizes =
calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);

auto output = at::empty(
output_sizes,
input.options().memory_format(input.suggest_memory_format()));
at::Tensor output;
if (input.dim() != 3) {
output = at::empty(
output_sizes,
input.options().memory_format(input.suggest_memory_format()));
} else {
// This a temporary workaround before channels last 1D is formally supported
// in PyTorch. We will force to return nwc output.
std::vector<int64_t> output_strides = {
(output_sizes[1] * output_sizes[2]), 1, output_sizes[1]};
output = at::empty_strided(output_sizes, output_strides, input.options());
}

convolution_kernel_output(
input,
mkldnn_weight,
Expand Down
10 changes: 6 additions & 4 deletions intel_extension_for_pytorch/csrc/aten/cpu/ParamUtils.h
Expand Up @@ -35,8 +35,11 @@ inline std::vector<int64_t> gen_dummy_input_size_for(
std::vector<int64_t> kernel_size;
if (5 == input_dim) {
kernel_size.push_back(weight_sizes[input_dim - 3]);
kernel_size.push_back(weight_sizes[input_dim - 2]);
}
if (4 == input_dim) {
kernel_size.push_back(weight_sizes[input_dim - 2]);
}
kernel_size.push_back(weight_sizes[input_dim - 2]);
kernel_size.push_back(weight_sizes[input_dim - 1]);
std::vector<int64_t> input_sizes;
auto grouped = groups > 1;
Expand All @@ -46,11 +49,10 @@ inline std::vector<int64_t> gen_dummy_input_size_for(
auto ic = groups * weights_dims_g[1 + grouped];
input_sizes.push_back(32);
input_sizes.push_back(ic);
input_sizes.push_back(14 * kernel_size[0]);
if (4 == input_dim) {
input_sizes.push_back(14 * kernel_size[0]);
input_sizes.push_back(14 * kernel_size[1]);
} else {
input_sizes.push_back(14 * kernel_size[0]);
} else if (5 == input_dim) {
input_sizes.push_back(14 * kernel_size[1]);
input_sizes.push_back(14 * kernel_size[2]);
}
Expand Down
38 changes: 28 additions & 10 deletions intel_extension_for_pytorch/csrc/cpu/ideep/ideep/operators/conv.hpp
Expand Up @@ -225,8 +225,8 @@ struct convolution_forward
const dims& src_dims = dims(),
const attr_t& attr = attr_t(),
const engine& aengine = engine::cpu_engine()) {
auto src_size =
weights_dims.size(); // weights_dims is 4 for conv2d and 5 for conv3d
auto src_size = weights_dims.size(); // weights_dims is 3 for conv1d, 4 for
// conv2d and 5 for conv3d
auto grouped = groups > 1;
auto weights_dims_g =
grouped ? utils::group_dims(weights_dims, groups) : weights_dims;
Expand All @@ -244,8 +244,11 @@ struct convolution_forward
auto oc = groups * dims_in[0 + grouped];
if (5 == src_size) {
kernel_size.push_back(dims_in[ndims - 3]);
kernel_size.push_back(dims_in[ndims - 2]);
}
if (4 == src_size) {
kernel_size.push_back(dims_in[ndims - 2]);
}
kernel_size.push_back(dims_in[ndims - 2]);
kernel_size.push_back(dims_in[ndims - 1]);
if (src_dims.empty()) {
// Construct a dummy case, those shapes are from resnet50 model,
Expand All @@ -255,11 +258,10 @@ struct convolution_forward
x_dims.push_back(ic);
y_dims.push_back(32);
y_dims.push_back(oc);
x_dims.push_back(14 * kernel_size[0]);
if (4 == src_size) {
x_dims.push_back(14 * kernel_size[0]);
x_dims.push_back(14 * kernel_size[1]);
} else {
x_dims.push_back(14 * kernel_size[0]);
} else if (5 == src_size) {
x_dims.push_back(14 * kernel_size[1]);
x_dims.push_back(14 * kernel_size[2]);
}
Expand All @@ -286,8 +288,17 @@ struct convolution_forward
auto src_query = src_desc;
auto dst_query = dst_desc;
if (channels_last) {
src_query = src_desc.to_format(5 == src_size ? tag::ndhwc : tag::nhwc);
dst_query = dst_desc.to_format(5 == src_size ? tag::ndhwc : tag::nhwc);
if (4 == src_size) {
src_query = src_desc.to_format(tag::nhwc);
dst_query = dst_desc.to_format(tag::nhwc);
} else if (5 == src_size) {
src_query = src_desc.to_format(tag::ndhwc);
dst_query = dst_desc.to_format(tag::ndhwc);
}
}
if (3 == src_size) {
src_query = src_desc.to_format(tag::nwc);
dst_query = dst_desc.to_format(tag::nwc);
}

// FIXME: workaroud winograd format issue in inference
Expand Down Expand Up @@ -345,6 +356,7 @@ struct convolution_forward
auto weights_desc_query = weights_desc;
auto bias_desc_query = with_bias ? bias_desc : tensor::desc();
auto dst_desc_query = dst_desc;
auto src_is_channels_last = src_desc.is_channels_last();
if (!keep_format) {
src_desc_query = src_desc.to_format_any();
weights_desc_query = weights_desc.to_format_any();
Expand All @@ -355,9 +367,15 @@ struct convolution_forward
// For nhwc / ndhwc path, weight uses format_tag::any,
// while activation uses format_tag::nhwc / format_tag::ndhwc.
bool channels_last =
src_desc.is_channels_last() || weights_desc.is_channels_last();
src_is_channels_last || weights_desc.is_channels_last();
if (channels_last) {
auto memory_format = src_desc.get_ndims() == 4 ? tag::nhwc : tag::ndhwc;
const auto dim = src_desc.get_ndims();
auto memory_format = tag::nhwc;
if (dim == 3) {
memory_format = tag::nwc;
} else if (dim == 5) {
memory_format = tag::ndhwc;
}
src_desc_query = src_desc.to_format(memory_format);
weights_desc_query = weights_desc.to_format_any();
bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc();
Expand Down
19 changes: 15 additions & 4 deletions intel_extension_for_pytorch/csrc/cpu/ideep/ideep/tensor.hpp
Expand Up @@ -175,7 +175,8 @@ class tensor : public memory {
};

inline bool is_channels_last() const {
if (!is_plain() || !(data.ndims != 4 || data.ndims != 5))
if (!is_plain() ||
!(data.ndims == 4 || data.ndims == 5 || data.ndims == 3))
return false;
const auto& dims = data.dims;
const auto& strides = blocking_strides();
Expand All @@ -184,12 +185,16 @@ class tensor : public memory {
return strides[n] == dims[h] * dims[w] * dims[c] &&
strides[h] == dims[w] * dims[c] && strides[w] == dims[c] &&
strides[c] == 1;
} else {
} else if (data.ndims == 5) {
const auto n = 0, c = 1, d = 2, h = 3, w = 4;
return strides[n] == dims[d] * dims[h] * dims[w] * dims[c] &&
strides[d] == dims[h] * dims[w] * dims[c] &&
strides[h] == dims[w] * dims[c] && strides[w] == dims[c] &&
strides[c] == 1;
} else {
const auto n = 0, c = 1, w = 2;
return strides[n] == dims[w] * dims[c] && strides[w] == dims[c] &&
strides[c] == 1;
}
};

Expand Down Expand Up @@ -808,8 +813,14 @@ class tensor : public memory {
auto channels_last = old_desc.is_channels_last();
if (channels_last) {
// goihw (abcde) => gohwi (abdec) or goidhw (abcdef) => gohwi (abdefc)
grouped_desc = grouped_desc.to_format(
old_desc.get_ndims() == 4 ? format_tag::abdec : format_tag::abdefc);
auto memory_format = format_tag::abdec;
auto dim = old_desc.get_ndims();
if (dim == 5) {
memory_format = format_tag::abdefc;
} else if (dim == 3) {
memory_format = format_tag::abdc;
}
grouped_desc = grouped_desc.to_format(memory_format);
}
}

Expand Down
59 changes: 50 additions & 9 deletions intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.cpp
Expand Up @@ -3,6 +3,7 @@
#include "csrc/aten/cpu/Conv.h"
#include "csrc/aten/cpu/ParamUtils.h"
#include "csrc/aten/cpu/WeightPack.h"
#include "csrc/aten/cpu/utils/utils.h"
#include "csrc/cpu/ideep/IDeepConversions.h"
#include "csrc/cpu/ideep/ideep.hpp"
#include "csrc/cpu/ideep/ideep/utils.hpp"
Expand Down Expand Up @@ -112,6 +113,29 @@ at::Tensor convolution_swish_run(
return op_context->run(input, ideep::attr_t::fuse_swish());
}

at::Tensor convolution_gelu_run(
const at::Tensor& input,
const c10::string_view approximate,
const c10::intrusive_ptr<ConvolutionOpContext>& op_context) {
IPEX_RECORD_FUNCTION(
"ipex_prepack::convolution_gelu_run", std::vector<c10::IValue>({}));
// https://github.com/pytorch/pytorch/pull/61439
// at::gelu can support tanh approximate now and OneDNN also support it
// by changing algorithm If there is other type of approximate are added to
// pytorch while OneDNN not support it, we might need a fallback path here.
dnnl::algorithm gelu_type;
if (approximate == "none") {
gelu_type = dnnl::algorithm::eltwise_gelu_erf;
} else if (approximate == "tanh") {
gelu_type = dnnl::algorithm::eltwise_gelu_tanh;
} else {
TORCH_CHECK(
false, "ipex::linear_gelu_run only support tanh approximate now");
}
return op_context->run(
input, ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type));
}

at::Tensor convolution_add_run(
const at::Tensor& input,
at::Tensor& accumu,
Expand Down Expand Up @@ -320,13 +344,17 @@ ContextConvolution create(
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;

auto memory_format = at::MemoryFormat::Contiguous;
auto format_tag = input_size.size() == 4 ? ideep::format_tag::nchw
: ideep::format_tag::ncdhw;
auto format_tag = ideep::format_tag::nchw;
if (input_size.size() == 5) {
format_tag = ideep::format_tag::ncdhw;
} else if (input_size.size() == 3) {
format_tag = ideep::format_tag::nwc;
}
if (weight_is_channels_last_) {
if (input_size.size() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
format_tag = ideep::format_tag::nhwc;
} else {
} else if (input_size.size() == 5) {
memory_format = at::MemoryFormat::ChannelsLast3d;
format_tag = ideep::format_tag::ndhwc;
}
Expand Down Expand Up @@ -451,17 +479,27 @@ at::Tensor run(
if (use_channels_last) {
if (input.dim() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
} else if (input.dim() == 5) {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto input_ = input.contiguous(memory_format);
auto input_ = input;
if (!is_channels_last_1d(input)) {
input_ = input.contiguous(memory_format);
}
if (input_.sizes().vec() == context.conv_params_.pd.src_desc().dims() &&
attr == context.conv_params_.op_attr &&
omp_get_max_threads() == context.conv_params_.pd_use_threads) {
auto output_sizes = context.conv_params_.pd.dst_desc().dims();
auto output = at::empty(
context.conv_params_.pd.dst_desc().dims(),
output_sizes,
input_.options().memory_format(input_.suggest_memory_format()));
if (input.dim() == 3) {
std::vector<int64_t> output_strides = {
(output_sizes[1] * output_sizes[2]), 1, output_sizes[1]};
output =
at::empty_strided(output_sizes, output_strides, input_.options());
}
const ideep::tensor mkldnn_input = itensor_view_from_dense(input_);
ideep::tensor mkldnn_output = itensor_view_from_dense(output);
if (context.bias_.is_empty()) {
Expand Down Expand Up @@ -507,11 +545,14 @@ at::Tensor& run(
if (use_channels_last) {
if (input.dim() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
} else if (input.dim() == 5) {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto input_ = input.contiguous(memory_format);
auto input_ = input;
if (!is_channels_last_1d(input)) {
input_ = input.contiguous(memory_format);
}
// always align accumu format with inputs' format.
accumu = accumu.contiguous(memory_format);
if (input_.sizes().vec() == context.conv_params_.pd.src_desc().dims() &&
Expand Down Expand Up @@ -608,7 +649,7 @@ at::Tensor unpack(ContextConvolution& context, const at::Tensor& tensor) {
if (context.weight_is_channels_last_) {
if (context.original_desc_.get_ndims() == 4) {
result = result.to(at::MemoryFormat::ChannelsLast);
} else {
} else if (context.original_desc_.get_ndims() == 5) {
result = result.to(at::MemoryFormat::ChannelsLast3d);
}
}
Expand Down
5 changes: 5 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.h
Expand Up @@ -57,6 +57,11 @@ at::Tensor convolution_swish_run(
const at::Tensor& input,
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);

at::Tensor convolution_gelu_run(
const at::Tensor& input,
c10::string_view approximate,
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);

at::Tensor convolution_add_run(
const at::Tensor& input,
at::Tensor& accumu,
Expand Down

0 comments on commit a0c063b

Please sign in to comment.