Skip to content

Commit

Permalink
Remove filter_format attribute from IR and bconv kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed May 29, 2020
1 parent 2383d37 commit 634b765
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 84 deletions.
1 change: 0 additions & 1 deletion larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ TODO
I32Attr:$channels_in,
I32Attr:$dilation_height_factor,
I32Attr:$dilation_width_factor,
DefaultValuedAttr<TF_AnyStrAttrOf<["OHWI", "OHWI_PACKED"]>, "OHWI">:$filter_format,
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<I32Attr, "0">:$pad_values,
TFL_PaddingAttr:$padding,
Expand Down
78 changes: 39 additions & 39 deletions larq_compute_engine/mlir/tests/optimize.mlir

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions larq_compute_engine/mlir/tests/prepare-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func @fuse_bconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
// CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
// CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32>
// CHECK: %[[transpose:.*]] = "tf.Transpose"
// CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK-NEXT: return %[[conv]]
}

Expand All @@ -38,7 +38,7 @@ func @fuse_scaled_bconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2x
// CHECK: %[[post_activation_multiplier:.*]] = constant dense<[3.000000e-01, 1.000000e-01]> : tensor<2xf32>
// CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32>
// CHECK: %[[transpose:.*]] = "tf.Transpose"
// CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK-NEXT: return %[[conv]]
}

Expand Down Expand Up @@ -81,7 +81,7 @@ func @fuse_bconv2d_padding(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x16x16x1
// CHECK: %[[CST1:.*]] = constant dense<1.000000e+00> : tensor<16xf32>
// CHECK: %[[CST2:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[TRP:.*]] = "tf.Transpose"
// CHECK: %[[CONV:.*]] = "tf.LceBconv2d"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]]) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x16x16x16xf32>
// CHECK: %[[CONV:.*]] = "tf.LceBconv2d"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]]) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x16x16x16xf32>
}

// CHECK-LABEL: @do_not_fuse_bconv2d_padding_same
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/quantize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func @QuantizeBConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128
%cst1 = constant dense<1.10976315> : tensor<32xf32>
%4 = "tfl.quantize"(%cst1) {qtype = tensor<32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<32xf32>) -> tensor<32x!quant.uniform<u8:f32, 0.023528476789885875>>
%5 = "tfl.dequantize"(%4) : (tensor<32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<32xf32>
%6 = "tf.LceBconv2d"(%1, %arg1, %3, %5) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, filter_format = "OHWI_PACKED", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%6 = "tf.LceBconv2d"(%1, %arg1, %3, %5) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>

Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/transforms/op_removal_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def : Pat<(TF_LceBsignOp:$op $x), (replaceWithValue $x), [(HasNoUseOf:$op)]>;
def : Pat<(TF_LceBconv2dOp:$op $input, $filter, $post_activation_multiplier,
$post_activation_bias, $channels_in,
$dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function, $pad_values, $padding,
$fused_activation_function, $pad_values, $padding,
$stride_height, $stride_width),
(replaceWithValue $input), [(HasNoUseOf:$op)]>;
def : Pat<(TF_LceBMaxPool2dOp:$op $input, $padding, $stride_width,
Expand Down
24 changes: 12 additions & 12 deletions larq_compute_engine/mlir/transforms/optimize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ multiclass FuseAddOrSubWithBConv2D<dag binaryOp> {
def : Pat<(binaryOp (TF_LceBconv2dOp:$output $input, $filter,
$post_activation_multiplier, (ConstantOp F32ElementsAttr:$post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width),
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
(TF_LceBconv2dOp $input, $filter, $post_activation_multiplier,
(binaryOp (ConstantOp $post_activation_bias), (ConstantOp $value), TFL_AF_None),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width),
[(HasOneUse $output)], (addBenefit 100)>;
}
Expand All @@ -34,7 +34,7 @@ multiclass FuseMulOrDivWithBConv2D<dag binaryOp> {
(ConstantOp F32ElementsAttr:$post_activation_multiplier),
(ConstantOp F32ElementsAttr:$post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width),
(ConstantOp F32ElementsAttr:$value), TFL_AF_None),
(TF_LceBconv2dOp $input, $filter,
Expand All @@ -43,7 +43,7 @@ multiclass FuseMulOrDivWithBConv2D<dag binaryOp> {
(binaryOp (ConstantOp $post_activation_bias),
(ConstantOp $value), TFL_AF_None),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
}
Expand All @@ -55,14 +55,14 @@ def : Pat<(TF_LceBconv2dOp $input, (ConstantOp F32ElementsAttr:$filter),
(ConstantOp HasNegativeValues:$post_activation_multiplier),
$post_activation_bias, $channels_in,
$dilation_height_factor, $dilation_width_factor,
$filter_format, TFL_AF_None, $pad_values, $padding,
TFL_AF_None, $pad_values, $padding,
$stride_height, $stride_width),
(TF_LceBconv2dOp $input,
(TFL_MulOp (ConstantOp $filter), (ConstantOp (ComputeBSignAndExpandTo4D $post_activation_multiplier)), TFL_AF_None),
(TFL_AbsOp (ConstantOp $post_activation_multiplier)),
$post_activation_bias, $channels_in,
$dilation_height_factor, $dilation_width_factor,
$filter_format, TFL_AF_None, $pad_values, $padding,
TFL_AF_None, $pad_values, $padding,
$stride_height, $stride_width),
[], (addBenefit 90)>;

Expand All @@ -77,29 +77,29 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
(ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier),
(ConstantOp ConstantValue<"0.0f">:$post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, TFL_AF_None,
TFL_AF_None,
$pad_values, ConstantAttr<StrAttr, "VALID">:$padding,
$stride_height, $stride_width)),
(TF_LceBconv2dOp $input, $filter,
(ConstantOp $post_activation_multiplier),
(ConstantOp $post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, ActFnAttr,
ActFnAttr,
$pad_values, $padding, $stride_height, $stride_width),
[(HasOneUse $conv_output)]>;
def : Pat<(ActFnOp (TF_LceBconv2dOp:$conv_output $input, $filter,
(ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier),
(ConstantOp ConstantValue<"0.0f">:$post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, TFL_AF_None,
TFL_AF_None,
ConstantAttr<I32Attr, "1">:$pad_values,
ConstantAttr<StrAttr, "SAME">:$padding,
$stride_height, $stride_width)),
(TF_LceBconv2dOp $input, $filter,
(ConstantOp $post_activation_multiplier),
(ConstantOp $post_activation_bias),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, ActFnAttr,
ActFnAttr,
$pad_values, $padding, $stride_height, $stride_width),
[(HasOneUse $conv_output)]>;
}
Expand All @@ -117,10 +117,10 @@ def Bitpack : NativeCodeCall<"Bitpack($_builder, $0)">;
def : Pat<(TF_LceBconv2dOp $input, (ConstantOp Conv2DFilter:$filter),
$post_activation_multiplier, $post_activation_bias,
$channels_in, $dilation_height_factor, $dilation_width_factor,
ConstantAttr<StrAttr, "OHWI">, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width),
(TF_LceBconv2dOp $input, (ConstantOp (Bitpack $filter)),
$post_activation_multiplier, $post_activation_bias,
$channels_in, $dilation_height_factor, $dilation_width_factor,
ConstantAttr<StrAttr, "OHWI_PACKED">, $fused_activation_function,
$fused_activation_function,
$pad_values, $padding, $stride_height, $stride_width)>;
2 changes: 0 additions & 2 deletions larq_compute_engine/mlir/transforms/prepare_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def : Pat<(TF_Conv2DOp (TF_LceBsignOp $input), (ConstantOp $filter), IsIntList1X
(GetNumChannels $input),
ExtractI32At<1>:$dilations,
ExtractI32At<2>:$dilations,
ConstantAttr<StrAttr, "OHWI">,
TFL_AF_None,
ConstantAttr<I32Attr, "0">,
$padding,
Expand All @@ -75,7 +74,6 @@ def : Pat<(TF_Conv2DOp: $output
(GetNumChannels $input),
ExtractI32At<1>:$dilations,
ExtractI32At<2>:$dilations,
ConstantAttr<StrAttr, "OHWI">,
TFL_AF_None,
ConstantAttr<I32Attr, "1">,
ConstantAttr<StrAttr, "SAME">,
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/transforms/quantize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ def : Pat<(TF_LceBconv2dOp $input, $filter,
(TFL_DequantizeOp (TFL_QuantizeOp $post_activation_multiplier, $qtype1)),
(TFL_DequantizeOp (TFL_QuantizeOp $post_activation_bias, $qtype2)),
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function, $pad_values, $padding,
$fused_activation_function, $pad_values, $padding,
$stride_height, $stride_width),
(TF_LceBconv2dOp $input, $filter,
$post_activation_multiplier, $post_activation_bias,
$channels_in, $dilation_height_factor, $dilation_width_factor,
$filter_format, $fused_activation_function, $pad_values, $padding,
$fused_activation_function, $pad_values, $padding,
$stride_height, $stride_width)>;
29 changes: 7 additions & 22 deletions larq_compute_engine/tflite/kernels/bconv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,6 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) {
const std::uint8_t* buffer_t = reinterpret_cast<const std::uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();

// Later we can change this so that we only allow "OHWI_PACKED" (prepacked)
if (m["filter_format"].IsNull() || m["filter_format"].ToString() == "HWIO") {
conv_params->filter_format = ce::core::FilterFormat::HWIO;
} else if (m["filter_format"].ToString() == "OHWI") {
conv_params->filter_format = ce::core::FilterFormat::OHWI;
} else if (m["filter_format"].ToString() == "OHWI_PACKED") {
conv_params->filter_format = ce::core::FilterFormat::OHWI_PACKED;
} else {
context->ReportError(context, "Invalid filter format.");
return conv_params;
}

// reading the op's input arguments into the "conv_params" struct
LCE_ENSURE_PARAM(conv_params, context, !m["stride_height"].IsNull());
LCE_ENSURE_PARAM(conv_params, context, !m["stride_width"].IsNull());
Expand Down Expand Up @@ -254,21 +242,18 @@ TfLiteStatus Prepare(KernelType kernel_type,
}

// reading the filter dimensions
// only OHWI layout is supported for filters
TF_LITE_ENSURE(
context,
conv_params->filter_format == ce::core::FilterFormat::OHWI ||
conv_params->filter_format == ce::core::FilterFormat::OHWI_PACKED);

conv_params->channels_out = filter->dims->data[0];
conv_params->filter_height = filter->dims->data[1];
conv_params->filter_width = filter->dims->data[2];
if (conv_params->filter_format == ce::core::FilterFormat::OHWI) {
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteFloat32);

if (filter->type == kTfLiteFloat32) {
conv_params->filter_format = ce::core::FilterFormat::OHWI;
TF_LITE_ENSURE_EQ(context, conv_params->channels_in, filter->dims->data[3]);
} else if (filter->type == kTfLiteInt32) {
conv_params->filter_format = ce::core::FilterFormat::OHWI_PACKED;
} else {
// TF Lite does not support the unsigned int32 type so we use int32 here
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt32);
context->ReportError(context, "Invalid filter format.");
return kTfLiteError;
}

TF_LITE_ENSURE_EQ(context, post_activation_multiplier->dims->data[0],
Expand Down
1 change: 0 additions & 1 deletion larq_compute_engine/tflite/tests/bconv2d_op_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class BaseBConv2DOpModel : public SingleOpModel {
fbb.Int("stride_width", stride_width);
fbb.Int("dilation_height_factor", dilation_height_factor);
fbb.Int("dilation_width_factor", dilation_width_factor);
fbb.String("filter_format", "OHWI_PACKED");
fbb.String("padding", GetPaddingName(padding));
fbb.Int("pad_values", pad_values);
fbb.String("fused_activation_function", getActivationString(activation));
Expand Down

0 comments on commit 634b765

Please sign in to comment.