Skip to content

Commit

Permalink
Enable onednn.QConv FP32/BF16 output
Browse files Browse the repository at this point in the history
ghstack-source-id: af28fba9e9a2d6a03f9efe18496a9c173143ecd2
Pull Request resolved: pytorch#112010
  • Loading branch information
leslie-fang-intel committed Nov 3, 2023
1 parent 0d95378 commit efacc46
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 84 deletions.
61 changes: 41 additions & 20 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ static at::Tensor _quantized_convolution_onednn(
c10::optional<at::Tensor> accum, // accum to fused with conv add
double accum_scale,
int64_t accum_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::optional<c10::string_view> binary_attr,
c10::optional<at::Scalar> binary_alpha,
c10::optional<c10::string_view> unary_attr,
Expand All @@ -1402,13 +1402,15 @@ static at::Tensor _quantized_convolution_onednn(
// inv_scale = 1.0 / scale will be folded.
// So, we can only get inv_scale from quant node which is used as
// output_scale of this op.
if (fp32_output) {
// When fp32_output, oneDNN expects op_attr doesn't set_scales and set_zero_points.
bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
bool bfloat16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
if (fp32_output || bfloat16_output) {
// When fp32 or bf16 output, oneDNN expects op_attr doesn't set_scales and set_zero_points.
// So, we will use default inv_output_scale as 1.0 and output_zero_point as 0, since
// when inv_output_scale is 1.0, we will skip invoking of op_attr.set_scales in ideep;
// when output_zero_point is 0, we will skip invoking of op_attr.set_zero_points in ideep.
TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 output, inv_output_scale must be 1.0.");
TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 output, output_zero_point must be 0");
TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 or bf16 output, inv_output_scale must be 1.0.");
TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 or bf16 output, output_zero_point must be 0");
}

int kSpatialDim = act.dim() - 2;
Expand All @@ -1417,7 +1419,14 @@ static at::Tensor _quantized_convolution_onednn(
bool has_binary_post_op = binary_attr.has_value() && binary_attr.value() != "none";
bool has_unary_post_op = unary_attr.has_value() && unary_attr.value() != "none";
// has_accum_postop_sum: extra input besides the conv to do conv add fusion with post op sum.
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add" && !fp32_output;
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add";

if (has_accum_postop_sum && (fp32_output || bfloat16_output)) {
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
}

std::string func_name = "quantized::packed_weights_conv";
func_name += std::to_string(kSpatialDim) + "d";
if (has_binary_post_op) {
Expand Down Expand Up @@ -1523,14 +1532,17 @@ static at::Tensor _quantized_convolution_onednn(
ideep::tensor onednn_bias;
const int output_channels = weight.size(0);
bool with_bias = bias.has_value();

at::Tensor bias_val_float;
if (with_bias) {
at::Tensor bias_val = bias.value();
TORCH_CHECK(bias_val.dim() == 1, "bias should be a vector (1D Tensor)");
// For int8-mixed-bf16, we will also use float32 bias
bias_val_float = bias.value().to(at::kFloat);
TORCH_CHECK(bias_val_float.dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
bias_val.size(0) == output_channels,
bias_val_float.size(0) == output_channels,
"bias should have K elements: " + std::to_string(output_channels));
auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
onednn_bias.init(bias_desc, bias.value().data_ptr());
auto bias_desc = ideep::tensor::desc(bias_val_float.sizes().vec(), dnnl::memory::data_type::f32);
onednn_bias.init(bias_desc, bias_val_float.data_ptr());
}

const auto& expected_bias = with_bias ? onednn_bias : ideep::tensor();
Expand All @@ -1556,11 +1568,11 @@ static at::Tensor _quantized_convolution_onednn(
ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
// Output is not a quantized tensor but data type is uint8
at::Tensor output;
if (fp32_output) {
if (fp32_output || bfloat16_output) {
output = at::empty(
dst_dims,
device(c10::kCPU)
.dtype(c10::kFloat)
.dtype(fp32_output ? c10::kFloat : c10::kBFloat16)
.memory_format(kSpatialDim == 2 ?
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d),
Expand All @@ -1581,16 +1593,25 @@ static at::Tensor _quantized_convolution_onednn(
ideep::tensor dst;
at::Tensor accum_contig;
if (has_accum_postop_sum) {
auto dst_desc = ideep::tensor::desc(dst_dims, src_data_type,
auto dst_desc = ideep::tensor::desc(dst_dims, fp32_output ? ideep::tensor::data_type::f32 : (
bfloat16_output ? ideep::tensor::data_type::bf16 : src_data_type),
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
if (fp32_output || bfloat16_output) {
TORCH_CHECK((output.scalar_type() == c10::kFloat) || (output.scalar_type() == c10::kBFloat16), "The output tensor should be KFloat or KBFloat.");
if (accum_contig.scalar_type() != output.scalar_type()) {
// accum_contig is KFloat32 and we expect a kBFloat16 output
// or accum_contig is kBFloat16 and we expect a KFloat32 output
accum_contig = accum_contig.to(output.scalar_type());
}
}
TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
// When fused with sum, the dst tensor will share the data ptr as the accum tensor.
dst.init(dst_desc, accum_contig.data_ptr());
} else {
if (fp32_output) {
if (fp32_output || bfloat16_output) {
// Conv without add: int8-in, fp32-output
dst = ideep::tensor({dst_dims, ideep::tensor::data_type::f32, {output.strides().cbegin(), output.strides().cend()}},
dst = ideep::tensor({dst_dims, fp32_output ? ideep::tensor::data_type::f32 : ideep::tensor::data_type::bf16, {output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
} else {
dst = ideep::tensor({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
Expand Down Expand Up @@ -1782,7 +1803,7 @@ class QConvoneDNN final {
int64_t groups,
double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::string_view attr,
torch::List<c10::optional<at::Scalar>> scalars,
c10::optional<c10::string_view> algorithm) {
Expand Down Expand Up @@ -1810,7 +1831,7 @@ class QConvoneDNN final {
bias, stride, padding, dilation, /*transposed*/false,
groups, inv_output_scale, output_zero_point,
/*accum*/c10::nullopt, /*accum_scale*/0.0, /*accum_zero_point*/0,
/*fp32_output*/fp32_output, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt,
/*output_dtype*/output_dtype, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt,
/*unary_attr*/attr, /*unary_scalars*/scalars, /*unary_algorithm*/algorithm
);
#else
Expand All @@ -1834,7 +1855,7 @@ class QConvoneDNN final {
int64_t groups,
double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::string_view binary_attr,
c10::optional<at::Scalar> alpha,
c10::optional<c10::string_view> unary_attr,
Expand Down Expand Up @@ -1862,7 +1883,7 @@ class QConvoneDNN final {
bias, stride, padding, dilation, /*transposed*/false,
groups, inv_output_scale, output_zero_point,
accum, accum_scale, accum_zero_point,
/*fp32_output*/false, binary_attr, alpha,
/*output_dtype*/output_dtype, binary_attr, alpha,
unary_attr, unary_scalars, unary_algorithm
);
#else
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,12 @@ TORCH_LIBRARY(onednn, m) {
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_prepack(Tensor weight, Tensor w_scales, float x_scale, int x_zp, int[] stride, int[] padding, int[] dilation, int groups, int[]? x_shape=None) -> Tensor"));

// Conv1D/2D/3D with unary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));

// Conv2D with binary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor"));

// Linear prepack
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv2d_pointwise.binary", datetime.date(2023, 12, 31)),
("onednn::qlinear_pointwise", datetime.date(2023, 12, 31)),
]

Expand Down
Loading

0 comments on commit efacc46

Please sign in to comment.