From 634b765eaff7925df0153c233526e2a8016a2d9d Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 27 May 2020 23:17:35 +0100 Subject: [PATCH] Remove filter_format attribute from IR and bconv kernel --- larq_compute_engine/mlir/ir/lce_ops.td | 1 - larq_compute_engine/mlir/tests/optimize.mlir | 78 +++++++++---------- .../mlir/tests/prepare-tf.mlir | 6 +- larq_compute_engine/mlir/tests/quantize.mlir | 2 +- .../mlir/transforms/op_removal_patterns.td | 2 +- .../mlir/transforms/optimize_patterns.td | 24 +++--- .../mlir/transforms/prepare_patterns.td | 2 - .../mlir/transforms/quantize_patterns.td | 4 +- larq_compute_engine/tflite/kernels/bconv2d.cc | 29 ++----- .../tflite/tests/bconv2d_op_model.h | 1 - 10 files changed, 65 insertions(+), 84 deletions(-) diff --git a/larq_compute_engine/mlir/ir/lce_ops.td b/larq_compute_engine/mlir/ir/lce_ops.td index 5513f4013..5bdafd054 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.td +++ b/larq_compute_engine/mlir/ir/lce_ops.td @@ -63,7 +63,6 @@ TODO I32Attr:$channels_in, I32Attr:$dilation_height_factor, I32Attr:$dilation_width_factor, - DefaultValuedAttr, "OHWI">:$filter_format, TFL_AFAttr:$fused_activation_function, DefaultValuedAttr:$pad_values, TFL_PaddingAttr:$padding, diff --git a/larq_compute_engine/mlir/tests/optimize.mlir b/larq_compute_engine/mlir/tests/optimize.mlir index 4d007c731..b6c59503f 100644 --- a/larq_compute_engine/mlir/tests/optimize.mlir +++ b/larq_compute_engine/mlir/tests/optimize.mlir @@ -65,7 +65,7 @@ func @fuse_relu_into_bconv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "RELU", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "RELU", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: return %0 } @@ -77,7 +77,7 @@ func @fuse_relu6_into_bconv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3 %1 = "tfl.relu6"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "RELU6", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "RELU6", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: return %0 } @@ -89,7 +89,7 @@ func @fuse_relu1_into_bconv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3 %1 = "tfl.relu_n1_to_1"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "RELU_N1_TO_1", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "RELU_N1_TO_1", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: return %0 } @@ -101,7 +101,7 @@ func @fuse_relu_into_bconv2d_padding_same(%arg0: tensor<256x32x32x3xf32>, %arg1: %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> return %1 : tensor<256x32x32x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "RELU", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "RELU", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: return %0 } @@ -109,11 +109,11 @@ func @fuse_relu_into_bconv2d_padding_same(%arg0: tensor<256x32x32x3xf32>, %arg1: func @do_not_fuse_relu_into_bconv2d_padding_same(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x32x32x16xf32> { %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> %post_activation_bias = constant dense<0.0> : tensor<16xf32> - %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> return %1 : tensor<256x32x32x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: %1 = "tfl.relu"(%0) // CHECK-NEXT: return %1 } @@ -122,11 +122,11 @@ func @do_not_fuse_relu_into_bconv2d_padding_same(%arg0: tensor<256x32x32x3xf32>, func @do_not_fuse_relu_into_bconv2d_no_post_activation_bias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> %post_activation_bias = constant dense<5.0> : tensor<16xf32> - %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: %1 = "tfl.relu"(%0) // CHECK-NEXT: return %1 } @@ -135,11 +135,11 @@ func @do_not_fuse_relu_into_bconv2d_no_post_activation_bias(%arg0: tensor<256x32 func @do_not_fuse_relu_into_bconv2d_no_post_activation_multiplier(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { %post_activation_multiplier = constant dense<0.8> : tensor<16xf32> %post_activation_bias = constant dense<0.0> : tensor<16xf32> - %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} // CHECK-NEXT: %1 = "tfl.relu"(%0) // CHECK-NEXT: return %1 } @@ -177,62 +177,62 @@ func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16xf return %0 : tensor<256x30x30x16xf32> // CHECK: %cst = constant dense<0> : tensor<16x3x3x1xi32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI_PACKED", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %0 } // CHECK-LABEL: @bitpack_activations_between_two_bconv2ds_valid_padding func @bitpack_activations_between_two_bconv2ds_valid_padding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>, %arg4: tensor<8x3x3x65xf32>, %arg5: tensor<8xf32>, %arg6: tensor<8xf32>) -> tensor<256x28x28x8xf32> { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x30x30x65xf32> - %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x30x30x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x28x28x8xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x30x30x65xf32> + %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x30x30x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x28x28x8xf32> return %1 : tensor<256x28x28x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x30x30x3xi32> - // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x30x30x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x28x28x8xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x30x30x3xi32> + // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x30x30x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x28x28x8xf32> // CHECK-NEXT: return %1 } // CHECK-LABEL: @bitpack_activations_between_two_bconv2ds_same_one_padding func @bitpack_activations_between_two_bconv2ds_same_one_padding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>, %arg4: tensor<8x3x3x65xf32>, %arg5: tensor<8xf32>, %arg6: tensor<8xf32>) -> tensor<256x30x30x8xf32> { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> - %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> return %1 : tensor<256x30x30x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x3xi32> - // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x3xi32> + // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> // CHECK-NEXT: return %1 } // CHECK-LABEL: @do_not_bitpack_activations_between_two_bconv2ds_same_zero_padding func @do_not_bitpack_activations_between_two_bconv2ds_same_zero_padding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>, %arg4: tensor<8x3x3x65xf32>, %arg5: tensor<8xf32>, %arg6: tensor<8xf32>) -> tensor<256x30x30x8xf32> { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> - %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> return %1 : tensor<256x30x30x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> - // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> // CHECK-NEXT: return %1 } // CHECK-LABEL: @do_not_bitpack_activations_between_two_bconv2ds_same_one_padding_multiple_uses func @do_not_bitpack_activations_between_two_bconv2ds_same_one_padding_multiple_uses(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>, %arg4: tensor<8x3x3x65xf32>, %arg5: tensor<8xf32>, %arg6: tensor<8xf32>) -> (tensor<256x32x32x65xf32>, tensor<256x30x30x8xf32>) { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> - %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> return %0, %1: tensor<256x32x32x65xf32>, tensor<256x30x30x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> - // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x30x30x8xf32> // CHECK-NEXT: return %0, %1 } // CHECK-LABEL: @bitpack_activations_between_binary_maxpool_and_bconv func @bitpack_activations_between_binary_maxpool_and_bconv(%arg0: tensor<256x32x32x65xf32>, %arg1: tensor<8x3x3x65xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<256x14x14x8xf32> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> - %1 = "tf.LceBconv2d"(%0, %arg1, %arg2, %arg3) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + %1 = "tf.LceBconv2d"(%0, %arg1, %arg2, %arg3) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> return %1 : tensor<256x14x14x8xf32> // CHECK: %0 = "tf.LceBMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x3xi32> - // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg1, %arg2, %arg3) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + // CHECK-NEXT: %1 = "tf.LceBconv2d"(%0, %arg1, %arg2, %arg3) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> // CHECK-NEXT: return %1 } @@ -240,38 +240,38 @@ func @bitpack_activations_between_binary_maxpool_and_bconv(%arg0: tensor<256x32x func @bitpack_activations_bconv_relu_maxpool_bconv(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<8x3x3x65xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<256x14x14x8xf32> { %post_activation_multiplier = constant dense<1.0> : tensor<65xf32> %post_activation_bias = constant dense<0.0> : tensor<65xf32> - %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x65xf32>) -> tensor<256x32x32x65xf32> %2 = "tfl.max_pool_2d"(%1) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> - %3 = "tf.LceBconv2d"(%2, %arg2, %arg3, %arg4) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + %3 = "tf.LceBconv2d"(%2, %arg2, %arg3, %arg4) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> return %3 : tensor<256x14x14x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "RELU", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x3xi32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %cst, %cst_0) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "RELU", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x3xi32> // CHECK: %1 = "tf.LceBMaxPool2d"(%0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> - // CHECK-NEXT: %2 = "tf.LceBconv2d"(%1, %arg2, %arg3, %arg4) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + // CHECK-NEXT: %2 = "tf.LceBconv2d"(%1, %arg2, %arg3, %arg4) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x3xi32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> // CHECK-NEXT: return %2 } // CHECK-LABEL: @do_not_bitpack_activations_with_intermediate_binary_maxpool_multiple_uses func @do_not_bitpack_activations_with_intermediate_binary_maxpool_multiple_uses(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>, %arg4: tensor<8x3x3x65xf32>, %arg5: tensor<8xf32>, %arg6: tensor<8xf32>) -> (tensor<256x16x16x65xf32>, tensor<256x14x14x8xf32>) { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> %1 = "tfl.max_pool_2d"(%0) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> - %2 = "tf.LceBconv2d"(%1, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + %2 = "tf.LceBconv2d"(%1, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> return %1, %2 : tensor<256x16x16x65xf32>, tensor<256x14x14x8xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> // CHECK: %1 = "tfl.max_pool_2d"(%0) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> - // CHECK-NEXT: %2 = "tf.LceBconv2d"(%1, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> + // CHECK-NEXT: %2 = "tf.LceBconv2d"(%1, %arg4, %arg5, %arg6) {channels_in = 65 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x16x16x65xf32>, tensor<8x3x3x65xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<256x14x14x8xf32> // CHECK-NEXT: return %1, %2 } // CHECK-LABEL: @do_not_allow_binary_maxpool_without_final_bconv func @do_not_allow_binary_maxpool_without_final_bconv(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<65x3x3x3xf32>, %arg2: tensor<65xf32>, %arg3: tensor<65xf32>) -> tensor<256x16x16x65xf32> { - %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> %1 = "tfl.max_pool_2d"(%0) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> return %1 : tensor<256x16x16x65xf32> - // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> + // CHECK: %0 = "tf.LceBconv2d"(%arg0, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>) -> tensor<256x32x32x65xf32> // CHECK-NEXT: %1 = "tfl.max_pool_2d"(%0) {filter_height = 2 : i32, filter_width = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<256x32x32x65xf32>) -> tensor<256x16x16x65xf32> // CHECK-NEXT: return %1 } diff --git a/larq_compute_engine/mlir/tests/prepare-tf.mlir b/larq_compute_engine/mlir/tests/prepare-tf.mlir index 6ce8e62d7..558b815fd 100644 --- a/larq_compute_engine/mlir/tests/prepare-tf.mlir +++ b/larq_compute_engine/mlir/tests/prepare-tf.mlir @@ -23,7 +23,7 @@ func @fuse_bconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> { // CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<2xf32> // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32> // CHECK: %[[transpose:.*]] = "tf.Transpose" - // CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + // CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK-NEXT: return %[[conv]] } @@ -38,7 +38,7 @@ func @fuse_scaled_bconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2x // CHECK: %[[post_activation_multiplier:.*]] = constant dense<[3.000000e-01, 1.000000e-01]> : tensor<2xf32> // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32> // CHECK: %[[transpose:.*]] = "tf.Transpose" - // CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + // CHECK-NEXT: %[[conv:.*]] = "tf.LceBconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK-NEXT: return %[[conv]] } @@ -81,7 +81,7 @@ func @fuse_bconv2d_padding(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x16x16x1 // CHECK: %[[CST1:.*]] = constant dense<1.000000e+00> : tensor<16xf32> // CHECK: %[[CST2:.*]] = constant dense<0.000000e+00> : tensor<16xf32> // CHECK: %[[TRP:.*]] = "tf.Transpose" - // CHECK: %[[CONV:.*]] = "tf.LceBconv2d"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]]) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, filter_format = "OHWI", fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x16x16x16xf32> + // CHECK: %[[CONV:.*]] = "tf.LceBconv2d"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]]) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x16x16x16xf32> } // CHECK-LABEL: @do_not_fuse_bconv2d_padding_same diff --git a/larq_compute_engine/mlir/tests/quantize.mlir b/larq_compute_engine/mlir/tests/quantize.mlir index 5ccb14ac1..1ef175ff0 100644 --- a/larq_compute_engine/mlir/tests/quantize.mlir +++ b/larq_compute_engine/mlir/tests/quantize.mlir @@ -297,7 +297,7 @@ func @QuantizeBConv2D(tensor<1x224x224x3x!quant.uniform : tensor<32xf32> %4 = "tfl.quantize"(%cst1) {qtype = tensor<32x!quant.uniform>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> %5 = "tfl.dequantize"(%4) : (tensor<32x!quant.uniform>) -> tensor<32xf32> - %6 = "tf.LceBconv2d"(%1, %arg1, %3, %5) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, filter_format = "OHWI_PACKED", fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tf.LceBconv2d"(%1, %arg1, %3, %5) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> return %7 : tensor<1x112x112x32x!quant.uniform> diff --git a/larq_compute_engine/mlir/transforms/op_removal_patterns.td b/larq_compute_engine/mlir/transforms/op_removal_patterns.td index b1a77545f..e27cc1de1 100644 --- a/larq_compute_engine/mlir/transforms/op_removal_patterns.td +++ b/larq_compute_engine/mlir/transforms/op_removal_patterns.td @@ -31,7 +31,7 @@ def : Pat<(TF_LceBsignOp:$op $x), (replaceWithValue $x), [(HasNoUseOf:$op)]>; def : Pat<(TF_LceBconv2dOp:$op $input, $filter, $post_activation_multiplier, $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, $pad_values, $padding, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), (replaceWithValue $input), [(HasNoUseOf:$op)]>; def : Pat<(TF_LceBMaxPool2dOp:$op $input, $padding, $stride_width, diff --git a/larq_compute_engine/mlir/transforms/optimize_patterns.td b/larq_compute_engine/mlir/transforms/optimize_patterns.td index a5098b3c1..b739ee98b 100644 --- a/larq_compute_engine/mlir/transforms/optimize_patterns.td +++ b/larq_compute_engine/mlir/transforms/optimize_patterns.td @@ -15,13 +15,13 @@ multiclass FuseAddOrSubWithBConv2D { def : Pat<(binaryOp (TF_LceBconv2dOp:$output $input, $filter, $post_activation_multiplier, (ConstantOp F32ElementsAttr:$post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), (ConstantOp F32ElementsAttr:$value), TFL_AF_None), (TF_LceBconv2dOp $input, $filter, $post_activation_multiplier, (binaryOp (ConstantOp $post_activation_bias), (ConstantOp $value), TFL_AF_None), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), [(HasOneUse $output)], (addBenefit 100)>; } @@ -34,7 +34,7 @@ multiclass FuseMulOrDivWithBConv2D { (ConstantOp F32ElementsAttr:$post_activation_multiplier), (ConstantOp F32ElementsAttr:$post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), (ConstantOp F32ElementsAttr:$value), TFL_AF_None), (TF_LceBconv2dOp $input, $filter, @@ -43,7 +43,7 @@ multiclass FuseMulOrDivWithBConv2D { (binaryOp (ConstantOp $post_activation_bias), (ConstantOp $value), TFL_AF_None), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), [(HasOneUse $conv_output)], (addBenefit 100)>; } @@ -55,14 +55,14 @@ def : Pat<(TF_LceBconv2dOp $input, (ConstantOp F32ElementsAttr:$filter), (ConstantOp HasNegativeValues:$post_activation_multiplier), $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, TFL_AF_None, $pad_values, $padding, + TFL_AF_None, $pad_values, $padding, $stride_height, $stride_width), (TF_LceBconv2dOp $input, (TFL_MulOp (ConstantOp $filter), (ConstantOp (ComputeBSignAndExpandTo4D $post_activation_multiplier)), TFL_AF_None), (TFL_AbsOp (ConstantOp $post_activation_multiplier)), $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, TFL_AF_None, $pad_values, $padding, + TFL_AF_None, $pad_values, $padding, $stride_height, $stride_width), [], (addBenefit 90)>; @@ -77,21 +77,21 @@ multiclass FuseActFnIntoConvOpPat { (ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), (ConstantOp ConstantValue<"0.0f">:$post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, TFL_AF_None, + TFL_AF_None, $pad_values, ConstantAttr:$padding, $stride_height, $stride_width)), (TF_LceBconv2dOp $input, $filter, (ConstantOp $post_activation_multiplier), (ConstantOp $post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, ActFnAttr, + ActFnAttr, $pad_values, $padding, $stride_height, $stride_width), [(HasOneUse $conv_output)]>; def : Pat<(ActFnOp (TF_LceBconv2dOp:$conv_output $input, $filter, (ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), (ConstantOp ConstantValue<"0.0f">:$post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, TFL_AF_None, + TFL_AF_None, ConstantAttr:$pad_values, ConstantAttr:$padding, $stride_height, $stride_width)), @@ -99,7 +99,7 @@ multiclass FuseActFnIntoConvOpPat { (ConstantOp $post_activation_multiplier), (ConstantOp $post_activation_bias), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, ActFnAttr, + ActFnAttr, $pad_values, $padding, $stride_height, $stride_width), [(HasOneUse $conv_output)]>; } @@ -117,10 +117,10 @@ def Bitpack : NativeCodeCall<"Bitpack($_builder, $0)">; def : Pat<(TF_LceBconv2dOp $input, (ConstantOp Conv2DFilter:$filter), $post_activation_multiplier, $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - ConstantAttr, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), (TF_LceBconv2dOp $input, (ConstantOp (Bitpack $filter)), $post_activation_multiplier, $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - ConstantAttr, $fused_activation_function, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width)>; diff --git a/larq_compute_engine/mlir/transforms/prepare_patterns.td b/larq_compute_engine/mlir/transforms/prepare_patterns.td index 66df4e27a..b7f630e26 100644 --- a/larq_compute_engine/mlir/transforms/prepare_patterns.td +++ b/larq_compute_engine/mlir/transforms/prepare_patterns.td @@ -48,7 +48,6 @@ def : Pat<(TF_Conv2DOp (TF_LceBsignOp $input), (ConstantOp $filter), IsIntList1X (GetNumChannels $input), ExtractI32At<1>:$dilations, ExtractI32At<2>:$dilations, - ConstantAttr, TFL_AF_None, ConstantAttr, $padding, @@ -75,7 +74,6 @@ def : Pat<(TF_Conv2DOp: $output (GetNumChannels $input), ExtractI32At<1>:$dilations, ExtractI32At<2>:$dilations, - ConstantAttr, TFL_AF_None, ConstantAttr, ConstantAttr, diff --git a/larq_compute_engine/mlir/transforms/quantize_patterns.td b/larq_compute_engine/mlir/transforms/quantize_patterns.td index f1d378711..a85670be8 100644 --- a/larq_compute_engine/mlir/transforms/quantize_patterns.td +++ b/larq_compute_engine/mlir/transforms/quantize_patterns.td @@ -8,10 +8,10 @@ def : Pat<(TF_LceBconv2dOp $input, $filter, (TFL_DequantizeOp (TFL_QuantizeOp $post_activation_multiplier, $qtype1)), (TFL_DequantizeOp (TFL_QuantizeOp $post_activation_bias, $qtype2)), $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, $pad_values, $padding, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width), (TF_LceBconv2dOp $input, $filter, $post_activation_multiplier, $post_activation_bias, $channels_in, $dilation_height_factor, $dilation_width_factor, - $filter_format, $fused_activation_function, $pad_values, $padding, + $fused_activation_function, $pad_values, $padding, $stride_height, $stride_width)>; diff --git a/larq_compute_engine/tflite/kernels/bconv2d.cc b/larq_compute_engine/tflite/kernels/bconv2d.cc index 156c82b5a..8a4193272 100644 --- a/larq_compute_engine/tflite/kernels/bconv2d.cc +++ b/larq_compute_engine/tflite/kernels/bconv2d.cc @@ -67,18 +67,6 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) { const std::uint8_t* buffer_t = reinterpret_cast(buffer); const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - // Later we can change this so that we only allow "OHWI_PACKED" (prepacked) - if (m["filter_format"].IsNull() || m["filter_format"].ToString() == "HWIO") { - conv_params->filter_format = ce::core::FilterFormat::HWIO; - } else if (m["filter_format"].ToString() == "OHWI") { - conv_params->filter_format = ce::core::FilterFormat::OHWI; - } else if (m["filter_format"].ToString() == "OHWI_PACKED") { - conv_params->filter_format = ce::core::FilterFormat::OHWI_PACKED; - } else { - context->ReportError(context, "Invalid filter format."); - return conv_params; - } - // reading the op's input arguments into the "conv_params" struct LCE_ENSURE_PARAM(conv_params, context, !m["stride_height"].IsNull()); LCE_ENSURE_PARAM(conv_params, context, !m["stride_width"].IsNull()); @@ -254,21 +242,18 @@ TfLiteStatus Prepare(KernelType kernel_type, } // reading the filter dimensions - // only OHWI layout is supported for filters - TF_LITE_ENSURE( - context, - conv_params->filter_format == ce::core::FilterFormat::OHWI || - conv_params->filter_format == ce::core::FilterFormat::OHWI_PACKED); - conv_params->channels_out = filter->dims->data[0]; conv_params->filter_height = filter->dims->data[1]; conv_params->filter_width = filter->dims->data[2]; - if (conv_params->filter_format == ce::core::FilterFormat::OHWI) { - TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteFloat32); + + if (filter->type == kTfLiteFloat32) { + conv_params->filter_format = ce::core::FilterFormat::OHWI; TF_LITE_ENSURE_EQ(context, conv_params->channels_in, filter->dims->data[3]); + } else if (filter->type == kTfLiteInt32) { + conv_params->filter_format = ce::core::FilterFormat::OHWI_PACKED; } else { - // TF Lite does not support the unsigned int32 type so we use int32 here - TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt32); + context->ReportError(context, "Invalid filter format."); + return kTfLiteError; } TF_LITE_ENSURE_EQ(context, post_activation_multiplier->dims->data[0], diff --git a/larq_compute_engine/tflite/tests/bconv2d_op_model.h b/larq_compute_engine/tflite/tests/bconv2d_op_model.h index a7a93ae2c..17211e221 100644 --- a/larq_compute_engine/tflite/tests/bconv2d_op_model.h +++ b/larq_compute_engine/tflite/tests/bconv2d_op_model.h @@ -52,7 +52,6 @@ class BaseBConv2DOpModel : public SingleOpModel { fbb.Int("stride_width", stride_width); fbb.Int("dilation_height_factor", dilation_height_factor); fbb.Int("dilation_width_factor", dilation_width_factor); - fbb.String("filter_format", "OHWI_PACKED"); fbb.String("padding", GetPaddingName(padding)); fbb.Int("pad_values", pad_values); fbb.String("fused_activation_function", getActivationString(activation));