diff --git a/larq_compute_engine/mlir/BUILD b/larq_compute_engine/mlir/BUILD index c7a09979..2a324c8a 100644 --- a/larq_compute_engine/mlir/BUILD +++ b/larq_compute_engine/mlir/BUILD @@ -9,6 +9,8 @@ package( gentbl( name = "lce_ops_inc_gen", tbl_outs = [ + ("-gen-enum-decls", "ir/lce_enum.h.inc"), + ("-gen-enum-defs", "ir/lce_enum.cc.inc"), ("-gen-op-decls", "ir/lce_ops.h.inc"), ("-gen-op-defs", "ir/lce_ops.cc.inc"), ("-gen-dialect-decls -dialect=lq", "ir/lce_dialect.h.inc"), @@ -153,6 +155,8 @@ cc_library( name = "larq_compute_engine", srcs = [ "ir/lce_dialect.h.inc", + "ir/lce_enum.cc.inc", + "ir/lce_enum.h.inc", "ir/lce_ops.cc", "ir/lce_ops.cc.inc", "ir/lce_ops.h.inc", @@ -263,6 +267,21 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "larq_compute_engine_translate_tflite", + srcs = [ + "transforms/translate_tflite.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + deps = [ + ":larq_compute_engine", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite", + ], + alwayslink = 1, +) + cc_library( name = "larq_compute_engine_quantize", srcs = [ @@ -306,6 +325,7 @@ cc_library( ":larq_compute_engine_optimize", ":larq_compute_engine_prepare", ":larq_compute_engine_quantize", + ":larq_compute_engine_translate_tflite", ":set_batch_size", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/larq_compute_engine/mlir/ir/lce_ops.cc b/larq_compute_engine/mlir/ir/lce_ops.cc index ee024d63..f9cc1d1f 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.cc +++ b/larq_compute_engine/mlir/ir/lce_ops.cc @@ -5,22 +5,8 @@ #include "larq_compute_engine/mlir/transforms/bitpack.h" #include "tensorflow/lite/schema/schema_generated.h" -static tflite::Padding ConvertPaddingAttr(llvm::StringRef str) { - return llvm::StringSwitch(str) - .Case("SAME", tflite::Padding_SAME) - .Case("VALID", tflite::Padding_VALID); -} - -static tflite::ActivationFunctionType ConvertActivationAttr( - llvm::StringRef str) { - return llvm::StringSwitch(str) - .Case("NONE", tflite::ActivationFunctionType_NONE) - .Case("RELU", tflite::ActivationFunctionType_RELU) - .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1) - .Case("RELU6", tflite::ActivationFunctionType_RELU6); -} - #define GET_OP_CLASSES +#include "larq_compute_engine/mlir/ir/lce_enum.cc.inc" #include "larq_compute_engine/mlir/ir/lce_ops.cc.inc" namespace mlir { @@ -36,9 +22,10 @@ std::vector Bconv2dOp::buildCustomOptions() { fbb.Int("dilation_height_factor", dilation_height_factor()); fbb.Int("dilation_width_factor", dilation_width_factor()); fbb.Int("fused_activation_function", - (int)ConvertActivationAttr(fused_activation_function())); + (int)symbolizeActivationFunctionType(fused_activation_function()) + .getValue()); fbb.Int("pad_values", pad_values()); - fbb.Int("padding", (int)ConvertPaddingAttr(padding())); + fbb.Int("padding", (int)symbolizePadding(padding()).getValue()); fbb.Int("stride_height", stride_height()); fbb.Int("stride_width", stride_width()); }); @@ -49,7 +36,7 @@ std::vector Bconv2dOp::buildCustomOptions() { std::vector BMaxPool2dOp::buildCustomOptions() { flexbuffers::Builder fbb; fbb.Map([&]() { - fbb.Int("padding", (int)ConvertPaddingAttr(padding())); + fbb.Int("padding", (int)symbolizePadding(padding()).getValue()); fbb.Int("stride_width", stride_width()); fbb.Int("stride_height", stride_height()); fbb.Int("filter_width", filter_width()); diff --git a/larq_compute_engine/mlir/ir/lce_ops.h b/larq_compute_engine/mlir/ir/lce_ops.h index f19dd81b..9becf0a5 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.h +++ b/larq_compute_engine/mlir/ir/lce_ops.h @@ -8,6 +8,8 @@ #include "larq_compute_engine/mlir/ir/lce_dialect.h.inc" // clang-format on +#include "larq_compute_engine/mlir/ir/lce_enum.h.inc" + #define GET_OP_CLASSES #include "larq_compute_engine/mlir/ir/lce_ops.h.inc" diff --git a/larq_compute_engine/mlir/tests/lce_ops_options_test.cc b/larq_compute_engine/mlir/tests/lce_ops_options_test.cc index 381435b1..26b40432 100644 --- a/larq_compute_engine/mlir/tests/lce_ops_options_test.cc +++ b/larq_compute_engine/mlir/tests/lce_ops_options_test.cc @@ -60,9 +60,9 @@ TEST(LCEOpsSerializationTest, BConv2dTest) { ASSERT_EQ(m["stride_height"].AsInt32(), 1); ASSERT_EQ(m["stride_width"].AsInt32(), 2); ASSERT_EQ(m["pad_values"].AsInt32(), 1); - ASSERT_EQ((ActivationFunctionType)m["fused_activation_function"].AsInt32(), - ActivationFunctionType_RELU); - ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME); + ASSERT_EQ((::ActivationFunctionType)m["fused_activation_function"].AsInt32(), + ::ActivationFunctionType::RELU); + ASSERT_EQ((::Padding)m["padding"].AsInt32(), ::Padding::SAME); } TEST(LCEOpsSerializationTest, BMaxPool2dTest) { @@ -82,7 +82,7 @@ TEST(LCEOpsSerializationTest, BMaxPool2dTest) { std::vector v = cast(op).buildCustomOptions(); const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap(); - ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME); + ASSERT_EQ((::Padding)m["padding"].AsInt32(), ::Padding::SAME); ASSERT_EQ(m["stride_width"].AsInt32(), 2); ASSERT_EQ(m["stride_height"].AsInt32(), 1); ASSERT_EQ(m["filter_width"].AsInt32(), 3); diff --git a/larq_compute_engine/mlir/tests/legalize-lce.mlir b/larq_compute_engine/mlir/tests/legalize-lce.mlir index 4739d725..3230cbe8 100644 --- a/larq_compute_engine/mlir/tests/legalize-lce.mlir +++ b/larq_compute_engine/mlir/tests/legalize-lce.mlir @@ -1,4 +1,5 @@ // RUN: lce-tf-opt %s -tfl-legalize-lce -verify-diagnostics | FileCheck %s +// RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE // CHECK-LABEL: @legalize_bconv2d func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> { @@ -7,6 +8,9 @@ func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf3 // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) {custom_code = "LceBconv2d", custom_option = opaque<"lq", "0x6368616E6E656C735F696E0064696C6174696F6E5F6865696768745F666163746F720064696C6174696F6E5F77696474685F666163746F720066757365645F61637469766174696F6E5F66756E6374696F6E007061645F76616C7565730070616464696E67007374726964655F686569676874007374726964655F776964746800088277614C3329221508010803010100000101010404040404040404102401"> : tensor<160xi8>} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %0 + + // TRANSLATE: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> + // TRANSLATE-NEXT: return %0 : tensor<256x30x30x16xf32> } // CHECK-LABEL: @legalize_bmax_pool2d @@ -16,6 +20,9 @@ func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3 // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceBMaxPool2d", custom_option = opaque<"lq", "0x70616464696E67007374726964655F7769647468007374726964655F6865696768740066696C7465725F77696474680066696C7465725F68656967687400050F1D412D3B050105020200020204040404040A2401"> : tensor<84xi8>} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> // CHECK-NEXT: return %0 + + // TRANSLATE: %0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> + // TRANSLATE-NEXT: return %0 : tensor<256x16x16x3xi32> } // CHECK-LABEL: @legalize_quantize @@ -25,6 +32,9 @@ func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceQuantize", custom_option = opaque<"lq", "0x"> : tensor<0xi8>} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> // CHECK-NEXT: return %0 + + // TRANSLATE: %0 = "lq.Quantize"(%arg0) : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> + // TRANSLATE-NEXT: return %0 : tensor<256x32x32x2xi32> } // CHECK-LABEL: @legalize_dequantize @@ -34,4 +44,7 @@ func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64 // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceDequantize", custom_option = opaque<"lq", "0x"> : tensor<0xi8>} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> // CHECK-NEXT: return %0 + + // TRANSLATE: %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> + // TRANSLATE-NEXT: return %0 : tensor<256x32x32x64xf32> } diff --git a/larq_compute_engine/mlir/transforms/passes.h b/larq_compute_engine/mlir/transforms/passes.h index 62237cb0..f251cfca 100644 --- a/larq_compute_engine/mlir/transforms/passes.h +++ b/larq_compute_engine/mlir/transforms/passes.h @@ -27,6 +27,9 @@ std::unique_ptr> CreateLCEQuantizePass(); // Creates an instance of LegalizeLCE pass. std::unique_ptr> CreateLegalizeLCEPass(); +// Creates an instance of TranslateToLCE pass. +std::unique_ptr> CreateTranslateToLCEPass(); + } // namespace TFL // Creates an instance of the TensorFlow dialect SetBatchSize pass diff --git a/larq_compute_engine/mlir/transforms/translate_tflite.cc b/larq_compute_engine/mlir/transforms/translate_tflite.cc new file mode 100644 index 00000000..1fe799b7 --- /dev/null +++ b/larq_compute_engine/mlir/transforms/translate_tflite.cc @@ -0,0 +1,83 @@ +#include "flatbuffers/flexbuffers.h" +#include "larq_compute_engine/mlir/ir/lce_ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +struct TranslateToLCE : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; +}; + +struct TranslateToLCEPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::CustomOp custom_op, + PatternRewriter& rewriter) const override { + auto stringData = custom_op.custom_option().getValue(); + + // Replace CustomOp with relevant LarqOp + if (custom_op.custom_code() == "LceQuantize") { + rewriter.replaceOpWithNewOp( + custom_op, custom_op->getResultTypes(), custom_op->getOperands()); + } else if (custom_op.custom_code() == "LceDequantize") { + rewriter.replaceOpWithNewOp( + custom_op, custom_op->getResultTypes(), custom_op->getOperands()); + } else if (custom_op.custom_code() == "LceBMaxPool2d") { + auto map = + flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size()) + .AsMap(); + rewriter.replaceOpWithNewOp( + custom_op, custom_op->getResultTypes(), custom_op->getOperand(0), + stringifyPadding(static_cast(map["padding"].AsInt32())), + map["stride_width"].AsInt32(), map["stride_height"].AsInt32(), + map["filter_width"].AsInt32(), map["filter_height"].AsInt32()); + } else if (custom_op.custom_code() == "LceBconv2d") { + auto map = + flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size()) + .AsMap(); + rewriter.replaceOpWithNewOp( + custom_op, custom_op->getResultTypes(), custom_op->getOperand(0), + custom_op->getOperand(1), custom_op->getOperand(2), + custom_op->getOperand(3), custom_op->getOperand(4), + map["channels_in"].AsInt32(), map["dilation_height_factor"].AsInt32(), + map["dilation_width_factor"].AsInt32(), + stringifyActivationFunctionType(static_cast( + map["fused_activation_function"].AsInt32())), + map["pad_values"].AsInt32(), + stringifyPadding(static_cast(map["padding"].AsInt32())), + map["stride_height"].AsInt32(), map["stride_width"].AsInt32()); + } + + return success(); + } +}; + +void TranslateToLCE::runOnFunction() { + OwningRewritePatternList patterns(&getContext()); + auto* ctx = &getContext(); + auto func = getFunction(); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // namespace + +// Creates an instance of the TranslateToLCE pass. +std::unique_ptr> CreateTranslateToLCEPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "lce-translate-tfl", "Translate TFL custom ops to LCE ops"); + +} // namespace TFL +} // namespace mlir