Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use custom larq MLIR dialect for our ops #384

Merged
merged 13 commits into from
Jun 9, 2020
41 changes: 37 additions & 4 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ gentbl(
tbl_outs = [
("-gen-op-decls", "ir/lce_ops.h.inc"),
("-gen-op-defs", "ir/lce_ops.cc.inc"),
("-gen-op-doc", "g3doc/lce_ops.md"),
("-gen-dialect-decls -dialect=lq", "ir/lce_dialect.h.inc"),
("-gen-dialect-doc", "g3doc/lce_ops.md"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/lce_ops.td",
td_srcs = [
"@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
Expand Down Expand Up @@ -58,8 +60,6 @@ gentbl(
td_file = "transforms/optimize_patterns.td",
td_srcs = [
"ir/lce_ops.td",
"transforms/op_removal_patterns.td",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
],
Expand Down Expand Up @@ -99,6 +99,7 @@ gentbl(
cc_library(
name = "larq_compute_engine",
srcs = [
"ir/lce_dialect.h.inc",
"ir/lce_ops.cc",
"ir/lce_ops.cc.inc",
"ir/lce_ops.h.inc",
Expand All @@ -108,11 +109,26 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@flatbuffers",
"@llvm-project//mlir:QuantOps",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
],
alwayslink = 1,
)

# Library with tensorflow Lite dialect static initialization.
cc_library(
name = "larq_dialect_registration",
srcs = [
"ir/dialect_registration.cc",
],
deps = [
":larq_compute_engine",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)

cc_library(
name = "larq_compute_engine_op_removal",
srcs = [
Expand All @@ -123,7 +139,6 @@ cc_library(
"transforms/passes.h",
],
deps = [
":larq_compute_engine",
"@llvm-project//mlir:StandardOps",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
],
lgeiger marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -185,6 +200,21 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "larq_compute_engine_legalize_tflite",
srcs = [
"transforms/legalize_tflite.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":larq_compute_engine",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
],
alwayslink = 1,
)

cc_library(
name = "larq_compute_engine_quantize",
srcs = [
Expand All @@ -209,6 +239,7 @@ cc_library(
],
deps = [
":larq_compute_engine_bitpack_weights",
":larq_compute_engine_legalize_tflite",
":larq_compute_engine_op_removal",
":larq_compute_engine_optimize",
":larq_compute_engine_prepare",
Expand All @@ -233,6 +264,7 @@ cc_library(
cc_binary(
name = "lce-tf-opt",
deps = [
":larq_dialect_registration",
":lce_tfl_passes",
"@llvm-project//mlir:MlirOptMain",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
Expand All @@ -245,6 +277,7 @@ pybind_extension(
srcs = ["python/graphdef_tfl_flatbuffer.cc"],
module_name = "graphdef_tfl_flatbuffer",
deps = [
":larq_dialect_registration",
":lce_tfl_passes",
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export",
Expand Down
4 changes: 4 additions & 0 deletions larq_compute_engine/mlir/ir/dialect_registration.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"

// Static initialization for Larq Compute Engine op registration.
static mlir::DialectRegistration<mlir::TF::LarqDialect> lce_ops;
41 changes: 41 additions & 0 deletions larq_compute_engine/mlir/ir/lce_ops.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,51 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"

#include "flatbuffers/flexbuffers.h"

namespace mlir {
namespace TF {

#define GET_OP_CLASSES
#include "larq_compute_engine/mlir/ir/lce_ops.cc.inc"

std::vector<uint8_t> BsignOp::buildCustomOptions() { return {}; }

std::vector<uint8_t> Bconv2dOp::buildCustomOptions() {
AdamHillier marked this conversation as resolved.
Show resolved Hide resolved
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Int("channels_in", channels_in().getSExtValue());
fbb.Int("dilation_height_factor", dilation_height_factor().getSExtValue());
fbb.Int("dilation_width_factor", dilation_width_factor().getSExtValue());
fbb.String("fused_activation_function",
std::string(fused_activation_function()));
fbb.Int("pad_values", pad_values().getSExtValue());
fbb.String("padding", std::string(padding()));
fbb.Int("stride_height", stride_height().getSExtValue());
fbb.Int("stride_width", stride_width().getSExtValue());
});
fbb.Finish();
return fbb.GetBuffer();
}

std::vector<uint8_t> BMaxPool2dOp::buildCustomOptions() {
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.String("padding", std::string(padding()));
fbb.Int("stride_width", stride_width().getSExtValue());
fbb.Int("stride_height", stride_height().getSExtValue());
fbb.Int("filter_width", filter_width().getSExtValue());
fbb.Int("filter_height", filter_height().getSExtValue());
});
fbb.Finish();
return fbb.GetBuffer();
}

LarqDialect::LarqDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "larq_compute_engine/mlir/ir/lce_ops.cc.inc"
>();
}
} // namespace TF
} // namespace mlir
3 changes: 3 additions & 0 deletions larq_compute_engine/mlir/ir/lce_ops.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#ifndef LARQ_COMPUTE_ENGINE_MLIR_IR_LCE_OPS_H_
#define LARQ_COMPUTE_ENGINE_MLIR_IR_LCE_OPS_H_

#include "mlir/Dialect/Quant/QuantTypes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir {
namespace TF {

#include "larq_compute_engine/mlir/ir/lce_dialect.h.inc"

#define GET_OP_CLASSES
#include "larq_compute_engine/mlir/ir/lce_ops.h.inc"

Expand Down
62 changes: 46 additions & 16 deletions larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// This is the operation definition file for Larq Compute engine ops.

// We extend the TensorFlow dialect in order to allow for easy generation of
// of the TFLite flatbuffer using TensorFlows infrastructure.
//===----------------------------------------------------------------------===//
//
// This is the operation definition file for Larq dialect operations.
//
//===----------------------------------------------------------------------===//
lgeiger marked this conversation as resolved.
Show resolved Hide resolved

include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"

#ifndef TFL_OPS
def TFL_AF_None : StrEnumAttrCase<"NONE">;
Expand All @@ -24,13 +29,44 @@ def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
]>;
#endif

#ifndef LCE_OPS
#define LCE_OPS
//===----------------------------------------------------------------------===//
// Larq dialect definitions
//===----------------------------------------------------------------------===//

#ifndef LARQ_DIALECT
#define LARQ_DIALECT

def LarqDialect : Dialect {
let name = "lq";

let summary = "Types and operations for Larq dialect";
let description = [{
This dialect contains operations for Larq. This dialect will be used in
conjunction with the TensorFlow dialects for converting & optimizing
TF graphs to be deployed on Larq Compute Engine.
}];

let cppNamespace = "TF";
lgeiger marked this conversation as resolved.
Show resolved Hide resolved
}

//===----------------------------------------------------------------------===//
// Larq op definitions
//===----------------------------------------------------------------------===//

// Base class for the operation in this dialect
class LQ_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LarqDialect, mnemonic, traits> {

let extraClassDeclaration = [{
std::vector<uint8_t> buildCustomOptions();
}];
}


class TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;

def TF_LceBsignOp : TF_Op<"LceBsign", [NoSideEffect, SameOperandsAndResultType]> {
def LQ_BsignOp : LQ_Op<"Bsign", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns an element-wise indication of the binary sign of a number.";

let description = [{
Expand All @@ -44,11 +80,9 @@ def TF_LceBsignOp : TF_Op<"LceBsign", [NoSideEffect, SameOperandsAndResultType]>
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_LceBconv2dOp : TF_Op<"LceBconv2d", [NoSideEffect]> {
def LQ_Bconv2dOp : LQ_Op<"Bconv2d", [NoSideEffect]> {
let summary = [{
Computes a 2D binary convolution by binarizing and bitpacking the input and filter.
}];
Expand All @@ -58,7 +92,7 @@ TODO
}];

let arguments = (ins
TensorOf<[F32]>:$input,
TensorOf<[F32, I32, QI8]>:$input,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correctness of this definition was actually never checked before.

TensorOf<[F32, I32]>:$filter,
TensorOfOrNone<[F32]>:$post_activation_multiplier,
TensorOfOrNone<[F32]>:$post_activation_bias,
Expand All @@ -75,13 +109,11 @@ TODO
);

let results = (outs
TensorOf<[F32]>:$output
TensorOf<[F32, I32, QI8]>:$output
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_LceBMaxPool2dOp : TF_Op<"LceBMaxPool2d", [NoSideEffect]> {
def LQ_BMaxPool2dOp : LQ_Op<"BMaxPool2d", [NoSideEffect]> {
let summary = [{
Binary MaxPool2D op.
}];
Expand All @@ -102,8 +134,6 @@ Computes a MaxPool2D operation and outputs bitpacked binary values, for consumpt
let results = (outs
TensorOf<[I32]>:$output
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

#endif // LCE_OPS
#endif // LARQ_DIALECT
11 changes: 11 additions & 0 deletions larq_compute_engine/mlir/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,14 @@ lce_lit_test_suite(
"@llvm-project//llvm:FileCheck",
],
)

cc_test(
name = "lce_ops_options_test",
srcs = ["lce_ops_options_test.cc"],
deps = [
"//larq_compute_engine/mlir:larq_compute_engine",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
"@llvm-project//mlir:IR",
],
)
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/tests/bitpack-weights.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
// CHECK-LABEL: @bitpack_bconv2d_filters
func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.0> : tensor<16x3x3x3xf32>
%0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
%0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>

// CHECK: %cst = constant dense<0> : tensor<16x3x3x1xi32>
// CHECK: %0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
// CHECK: %0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: return %0
}
Loading