Skip to content

Commit

Permalink
Add grouped binary convolution support (1/3): the converter.
Browse files Browse the repository at this point in the history
Add a validation check to the converter to ensure that grouped
convolutions have a group size that is a multiple of 32, and report
an error otherwise.
  • Loading branch information
AdamHillier committed Nov 5, 2020
1 parent c1aa3db commit 1ba74aa
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 10 deletions.
1 change: 1 addition & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cc_library(
],
deps = [
":larq_compute_engine",
"//larq_compute_engine/core:types",
"@llvm-project//mlir:StandardOps",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/bitpack-weights.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -tfl-lce-bitpack-weights | FileCheck %s
// RUN: lce-tf-opt %s -tfl-lce-bitpack-weights -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @bitpack_bconv2d_filters
func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/legalize-lce.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -tfl-legalize-lce | FileCheck %s
// RUN: lce-tf-opt %s -tfl-legalize-lce -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @legalize_bconv2d
func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/op-removal.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -lce-op-removal-tf | FileCheck %s
// RUN: lce-tf-opt %s -lce-op-removal-tf -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @snapshot
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/optimize.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -tfl-optimize-lce | FileCheck %s
// RUN: lce-tf-opt %s -tfl-optimize-lce -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @fuse_add_into_bconv2d
func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
Expand Down
31 changes: 30 additions & 1 deletion larq_compute_engine/mlir/tests/prepare-tf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -tfl-prepare-lce | FileCheck %s
// RUN: lce-tf-opt %s -tfl-prepare-lce -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @fuse_bsign
func @fuse_bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
Expand Down Expand Up @@ -29,6 +29,35 @@ func @fuse_bconv2d(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> {
// CHECK-NEXT: return %[[conv]]
}

// CHECK-LABEL: @fuse_bconv2d_grouped_convolution
func @fuse_bconv2d_grouped_convolution(%arg0: tensor<1x112x112x4xi32>) -> tensor<1x112x112x16xf32> {
// A 3x3 filter with 128 input channels (64 per-group) and 16 output channels (8 per-group).
%cst = "tf.Const"() { value = dense<1.0> : tensor<3x3x64x16xf32>} : () -> tensor<3x3x64x16xf32>
%0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x4xi32>) -> tensor<1x112x112x128xf32>
%1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x128xf32>, tensor<3x3x64x16xf32>) -> tensor<1x112x112x16xf32>
return %1 : tensor<1x112x112x16xf32>

// CHECK: %cst = constant
// CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<16xf32>
// CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[output_threshold:.*]] = constant unit
// CHECK: %[[transpose:.*]] = "tf.Transpose"
// CHECK-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 128 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x4xi32>, tensor<16x3x3x64xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<1x112x112x16xf32>
// CHECK-NEXT: return %[[conv]]
}

// CHECK-LABEL: @do_not_fuse_bconv2d_grouped_convolution_group_size_not_mul_32
func @do_not_fuse_bconv2d_grouped_convolution_group_size_not_mul_32(%arg0: tensor<1x56x56x4xi32>) -> tensor<1x56x56x128xf32> {
// A 3x3 filter with 128 input channels (4 per-group) and 128 output channels
// (4 per-group). We expect an error to be raised:
//
// expected-error @+1 {{Invalid binary grouped convolution: the number of input channels per-group must be a multiple of 32, but is 4}}
%cst = "tf.Const"() { value = dense<1.0> : tensor<3x3x4x128xf32>} : () -> tensor<3x3x4x128xf32>
%0 = "lq.Dequantize"(%arg0) : (tensor<1x56x56x4xi32>) -> tensor<1x56x56x128xf32>
%1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x56x56x128xf32>, tensor<3x3x4x128xf32>) -> tensor<1x56x56x128xf32>
return %1 : tensor<1x56x56x128xf32>
}

// CHECK-LABEL: @fuse_scaled_bconv2d
func @fuse_scaled_bconv2d(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> {
%cst = constant dense<[[[[0.3, -0.1], [0.3, 0.1]], [[-0.3, 0.1], [-0.3, 0.1]]]]> : tensor<1x2x2x2xf32>
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/quantize.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: lce-tf-opt %s -lce-quantize | FileCheck %s
// RUN: lce-tf-opt %s -lce-quantize -verify-diagnostics | FileCheck %s

// CHECK-LABEL: quantize_bconv2d
func @quantize_bconv2d(%arg0: tensor<1x224x224x1xi32>, %arg1: tensor<32x3x3x1xi32>, %arg2: none) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
Expand Down
11 changes: 7 additions & 4 deletions larq_compute_engine/mlir/transforms/prepare_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ class GetConstantVector<string val> : NativeCodeCall<"GetConstantVector($0, " #
def BinaryFilter : Constraint<CPred<"IsBinaryFilter($0)">>;
def GetScaleVector : NativeCodeCall<"GetScaleVector($0)">;
def GetNumChannels : NativeCodeCall<"GetNumChannels($_builder, $0)">;
def ValidFilterShape : Constraint<CPred<"HasValidFilterShape($0, $1)">>;
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
def CreateNoneAttrValue : NativeCodeCall<"$_builder.getUnitAttr()">;

def : Pat<(TF_Conv2DOp
(LQ_DequantizeOp: $dequantized_input $input),
(ConstantOp $filter),
(ConstantOp: $filter_op $filter),
IsIntList1XY1:$strides,
$use_cudnn,
$padding,
Expand All @@ -67,7 +68,8 @@ def : Pat<(TF_Conv2DOp
$padding,
ExtractI32At<1>:$strides,
ExtractI32At<2>:$strides),
[(BinaryFilter $filter)],
[(BinaryFilter $filter),
(ValidFilterShape $dequantized_input, $filter_op)],
(addBenefit 90)>;

def ConstFloatValueIsOne : Constraint<
Expand All @@ -82,7 +84,7 @@ def : Pat<(TF_Conv2DOp:$output
(LQ_DequantizeOp: $dequantized_input $input),
(ConstantOp $paddings),
(ConstantOp $pad_values)),
(ConstantOp $filter),
(ConstantOp: $filter_op $filter),
IsIntList1XY1:$strides,
$use_cudnn,
ConstantAttr<StrAttr, "VALID">,
Expand All @@ -108,5 +110,6 @@ def : Pat<(TF_Conv2DOp:$output
ExtractI32At<2>:$strides),
[(BinaryFilter $filter),
(ConstFloatValueIsOne $pad_values),
(SamePadding $paddings, $input, $output, $strides)],
(SamePadding $paddings, $input, $output, $strides),
(ValidFilterShape $dequantized_input, $filter_op)],
(addBenefit 90)>;
34 changes: 34 additions & 0 deletions larq_compute_engine/mlir/transforms/prepare_tf.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "larq_compute_engine/core/types.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -12,6 +13,8 @@ namespace TFL {

namespace {

using compute_engine::core::bitpacking_bitwidth;

// Prepare LCE operations in functions for subsequent legalization.
struct PrepareLCE : public PassWrapper<PrepareLCE, FunctionPass> {
void runOnFunction() override;
Expand Down Expand Up @@ -101,6 +104,37 @@ bool IsSamePadding(Attribute paddings_attr, Value input, Value output,
paddings.getValue<int>({3, 1}) == 0;
}

// Verify that the filter shape is compatible with the input shape. Will fail if
// any other type is passed. Will emit an error and return false if the two
// shapes are incompatible (specifically, if the shapes imply a grouped
// convolution with a group-shape that is not a multiple of 32).
bool HasValidFilterShape(Value input_val, Value filter_val) {
auto input_type = input_val.getType().cast<ShapedType>();
auto input_shape_vector = input_type.getShape();
auto total_input_channels = input_shape_vector[input_shape_vector.size() - 1];
auto filter_type = filter_val.getType().cast<ShapedType>();
auto filter_shape_vector = filter_type.getShape();
auto filter_input_channels =
filter_shape_vector[filter_shape_vector.size() - 2];
if (total_input_channels % filter_input_channels != 0) {
mlir::emitError(filter_val.getLoc())
<< "Filter dimensions invalid: the number of filter input channels "
<< filter_input_channels
<< " does not divide the total number of input channels "
<< total_input_channels << "\n";
return false;
}
auto num_groups = total_input_channels / filter_input_channels;
if (num_groups > 1 && filter_input_channels % bitpacking_bitwidth != 0) {
mlir::emitError(filter_val.getLoc())
<< "Invalid binary grouped convolution: the number of input channels "
"per-group must be a multiple of "
<< bitpacking_bitwidth << ", but is " << filter_input_channels << "\n";
return false;
}
return true;
}

// Returns the number of channels of a shaped tensor. Will fail if any other
// type is passed.
IntegerAttr GetNumChannels(Builder& b, Value output_val) {
Expand Down

0 comments on commit 1ba74aa

Please sign in to comment.