Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ package(
gentbl(
name = "lce_ops_inc_gen",
tbl_outs = [
("-gen-enum-decls", "ir/lce_enum.h.inc"),
("-gen-enum-defs", "ir/lce_enum.cc.inc"),
("-gen-op-decls", "ir/lce_ops.h.inc"),
("-gen-op-defs", "ir/lce_ops.cc.inc"),
("-gen-dialect-decls -dialect=lq", "ir/lce_dialect.h.inc"),
Expand Down Expand Up @@ -153,6 +155,8 @@ cc_library(
name = "larq_compute_engine",
srcs = [
"ir/lce_dialect.h.inc",
"ir/lce_enum.cc.inc",
"ir/lce_enum.h.inc",
"ir/lce_ops.cc",
"ir/lce_ops.cc.inc",
"ir/lce_ops.h.inc",
Expand Down Expand Up @@ -263,6 +267,21 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "larq_compute_engine_translate_tflite",
srcs = [
"transforms/translate_tflite.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":larq_compute_engine",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
],
alwayslink = 1,
)

cc_library(
name = "larq_compute_engine_quantize",
srcs = [
Expand Down Expand Up @@ -306,6 +325,7 @@ cc_library(
":larq_compute_engine_optimize",
":larq_compute_engine_prepare",
":larq_compute_engine_quantize",
":larq_compute_engine_translate_tflite",
":set_batch_size",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
23 changes: 5 additions & 18 deletions larq_compute_engine/mlir/ir/lce_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,8 @@
#include "larq_compute_engine/mlir/transforms/bitpack.h"
#include "tensorflow/lite/schema/schema_generated.h"

static tflite::Padding ConvertPaddingAttr(llvm::StringRef str) {
return llvm::StringSwitch<tflite::Padding>(str)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID);
}

static tflite::ActivationFunctionType ConvertActivationAttr(
llvm::StringRef str) {
return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
.Case("NONE", tflite::ActivationFunctionType_NONE)
.Case("RELU", tflite::ActivationFunctionType_RELU)
.Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
.Case("RELU6", tflite::ActivationFunctionType_RELU6);
}

#define GET_OP_CLASSES
#include "larq_compute_engine/mlir/ir/lce_enum.cc.inc"
#include "larq_compute_engine/mlir/ir/lce_ops.cc.inc"

namespace mlir {
Expand All @@ -36,9 +22,10 @@ std::vector<uint8_t> Bconv2dOp::buildCustomOptions() {
fbb.Int("dilation_height_factor", dilation_height_factor());
fbb.Int("dilation_width_factor", dilation_width_factor());
fbb.Int("fused_activation_function",
(int)ConvertActivationAttr(fused_activation_function()));
(int)symbolizeActivationFunctionType(fused_activation_function())
.getValue());
fbb.Int("pad_values", pad_values());
fbb.Int("padding", (int)ConvertPaddingAttr(padding()));
fbb.Int("padding", (int)symbolizePadding(padding()).getValue());
fbb.Int("stride_height", stride_height());
fbb.Int("stride_width", stride_width());
});
Expand All @@ -49,7 +36,7 @@ std::vector<uint8_t> Bconv2dOp::buildCustomOptions() {
std::vector<uint8_t> BMaxPool2dOp::buildCustomOptions() {
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Int("padding", (int)ConvertPaddingAttr(padding()));
fbb.Int("padding", (int)symbolizePadding(padding()).getValue());
fbb.Int("stride_width", stride_width());
fbb.Int("stride_height", stride_height());
fbb.Int("filter_width", filter_width());
Expand Down
2 changes: 2 additions & 0 deletions larq_compute_engine/mlir/ir/lce_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "larq_compute_engine/mlir/ir/lce_dialect.h.inc"
// clang-format on

#include "larq_compute_engine/mlir/ir/lce_enum.h.inc"

#define GET_OP_CLASSES
#include "larq_compute_engine/mlir/ir/lce_ops.h.inc"

Expand Down
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/tests/lce_ops_options_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ TEST(LCEOpsSerializationTest, BConv2dTest) {
ASSERT_EQ(m["stride_height"].AsInt32(), 1);
ASSERT_EQ(m["stride_width"].AsInt32(), 2);
ASSERT_EQ(m["pad_values"].AsInt32(), 1);
ASSERT_EQ((ActivationFunctionType)m["fused_activation_function"].AsInt32(),
ActivationFunctionType_RELU);
ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME);
ASSERT_EQ((::ActivationFunctionType)m["fused_activation_function"].AsInt32(),
::ActivationFunctionType::RELU);
ASSERT_EQ((::Padding)m["padding"].AsInt32(), ::Padding::SAME);
}

TEST(LCEOpsSerializationTest, BMaxPool2dTest) {
Expand All @@ -82,7 +82,7 @@ TEST(LCEOpsSerializationTest, BMaxPool2dTest) {
std::vector<uint8_t> v = cast<lq::BMaxPool2dOp>(op).buildCustomOptions();
const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap();

ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME);
ASSERT_EQ((::Padding)m["padding"].AsInt32(), ::Padding::SAME);
ASSERT_EQ(m["stride_width"].AsInt32(), 2);
ASSERT_EQ(m["stride_height"].AsInt32(), 1);
ASSERT_EQ(m["filter_width"].AsInt32(), 3);
Expand Down
13 changes: 13 additions & 0 deletions larq_compute_engine/mlir/tests/legalize-lce.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: lce-tf-opt %s -tfl-legalize-lce -verify-diagnostics | FileCheck %s
// RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE

// CHECK-LABEL: @legalize_bconv2d
func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
Expand All @@ -7,6 +8,9 @@ func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf3

// CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) {custom_code = "LceBconv2d", custom_option = opaque<"lq", "0x6368616E6E656C735F696E0064696C6174696F6E5F6865696768745F666163746F720064696C6174696F6E5F77696474685F666163746F720066757365645F61637469766174696F6E5F66756E6374696F6E007061645F76616C7565730070616464696E67007374726964655F686569676874007374726964655F776964746800088277614C3329221508010803010100000101010404040404040404102401"> : tensor<160xi8>} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: return %0

// TRANSLATE: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
// TRANSLATE-NEXT: return %0 : tensor<256x30x30x16xf32>
}

// CHECK-LABEL: @legalize_bmax_pool2d
Expand All @@ -16,6 +20,9 @@ func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3

// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceBMaxPool2d", custom_option = opaque<"lq", "0x70616464696E67007374726964655F7769647468007374726964655F6865696768740066696C7465725F77696474680066696C7465725F68656967687400050F1D412D3B050105020200020204040404040A2401"> : tensor<84xi8>} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32>
// CHECK-NEXT: return %0

// TRANSLATE: %0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32>
// TRANSLATE-NEXT: return %0 : tensor<256x16x16x3xi32>
}

// CHECK-LABEL: @legalize_quantize
Expand All @@ -25,6 +32,9 @@ func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi

// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceQuantize", custom_option = opaque<"lq", "0x"> : tensor<0xi8>} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32>
// CHECK-NEXT: return %0

// TRANSLATE: %0 = "lq.Quantize"(%arg0) : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32>
// TRANSLATE-NEXT: return %0 : tensor<256x32x32x2xi32>
}

// CHECK-LABEL: @legalize_dequantize
Expand All @@ -34,4 +44,7 @@ func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64

// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceDequantize", custom_option = opaque<"lq", "0x"> : tensor<0xi8>} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32>
// CHECK-NEXT: return %0

// TRANSLATE: %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32>
// TRANSLATE-NEXT: return %0 : tensor<256x32x32x64xf32>
}
3 changes: 3 additions & 0 deletions larq_compute_engine/mlir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLCEQuantizePass();
// Creates an instance of LegalizeLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeLCEPass();

// Creates an instance of TranslateToLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateTranslateToLCEPass();

} // namespace TFL

// Creates an instance of the TensorFlow dialect SetBatchSize pass
Expand Down
83 changes: 83 additions & 0 deletions larq_compute_engine/mlir/transforms/translate_tflite.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "flatbuffers/flexbuffers.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"

namespace mlir {
namespace TFL {

namespace {

struct TranslateToLCE : public PassWrapper<TranslateToLCE, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::TFL::TensorFlowLiteDialect, mlir::lq::LarqDialect>();
}
void runOnFunction() override;
};

struct TranslateToLCEPattern : public OpRewritePattern<TFL::CustomOp> {
using OpRewritePattern<TFL::CustomOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TFL::CustomOp custom_op,
PatternRewriter& rewriter) const override {
auto stringData = custom_op.custom_option().getValue();

// Replace CustomOp with relevant LarqOp
if (custom_op.custom_code() == "LceQuantize") {
rewriter.replaceOpWithNewOp<lq::QuantizeOp>(
custom_op, custom_op->getResultTypes(), custom_op->getOperands());
} else if (custom_op.custom_code() == "LceDequantize") {
rewriter.replaceOpWithNewOp<lq::DequantizeOp>(
custom_op, custom_op->getResultTypes(), custom_op->getOperands());
} else if (custom_op.custom_code() == "LceBMaxPool2d") {
auto map =
flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size())
.AsMap();
rewriter.replaceOpWithNewOp<lq::BMaxPool2dOp>(
custom_op, custom_op->getResultTypes(), custom_op->getOperand(0),
stringifyPadding(static_cast<Padding>(map["padding"].AsInt32())),
map["stride_width"].AsInt32(), map["stride_height"].AsInt32(),
map["filter_width"].AsInt32(), map["filter_height"].AsInt32());
} else if (custom_op.custom_code() == "LceBconv2d") {
auto map =
flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size())
.AsMap();
rewriter.replaceOpWithNewOp<lq::Bconv2dOp>(
custom_op, custom_op->getResultTypes(), custom_op->getOperand(0),
custom_op->getOperand(1), custom_op->getOperand(2),
custom_op->getOperand(3), custom_op->getOperand(4),
map["channels_in"].AsInt32(), map["dilation_height_factor"].AsInt32(),
map["dilation_width_factor"].AsInt32(),
stringifyActivationFunctionType(static_cast<ActivationFunctionType>(
map["fused_activation_function"].AsInt32())),
map["pad_values"].AsInt32(),
stringifyPadding(static_cast<Padding>(map["padding"].AsInt32())),
map["stride_height"].AsInt32(), map["stride_width"].AsInt32());
}

return success();
}
};

void TranslateToLCE::runOnFunction() {
OwningRewritePatternList patterns(&getContext());
auto* ctx = &getContext();
auto func = getFunction();
patterns.insert<TranslateToLCEPattern>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TranslateToLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateTranslateToLCEPass() {
return std::make_unique<TranslateToLCE>();
}

static PassRegistration<TranslateToLCE> pass(
"lce-translate-tfl", "Translate TFL custom ops to LCE ops");

} // namespace TFL
} // namespace mlir