Skip to content

Commit

Permalink
Add serialization unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Jun 8, 2020
1 parent f900164 commit 19e79a3
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
11 changes: 11 additions & 0 deletions larq_compute_engine/mlir/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,14 @@ lce_lit_test_suite(
"@llvm-project//llvm:FileCheck",
],
)

cc_test(
name = "lce_ops_options_test",
srcs = ["lce_ops_options_test.cc"],
deps = [
"//larq_compute_engine/mlir:larq_compute_engine",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
"@llvm-project//mlir:IR",
],
)
74 changes: 74 additions & 0 deletions larq_compute_engine/mlir/tests/lce_ops_options_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <gtest/gtest.h>

#include "flatbuffers/flexbuffers.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OperationSupport.h"

using namespace mlir;

IntegerAttr getIntegerAttr(Builder builder, int value) {
return builder.getIntegerAttr(builder.getIntegerType(32), value);
}

TEST(LCEOpsSerializationTest, BsignTest) {
MLIRContext context;
auto* op = Operation::create(UnknownLoc::get(&context),
OperationName("lq.Bsign", &context), llvm::None,
llvm::None, llvm::None, llvm::None, 0);

ASSERT_EQ(cast<TF::BsignOp>(op).buildCustomOptions().size(), 0);
}

TEST(LCEOpsSerializationTest, BConv2dTest) {
MLIRContext context;
Builder builder(&context);
auto op = Operation::create(UnknownLoc::get(&context),
OperationName("lq.Bconv2d", &context), llvm::None,
llvm::None, llvm::None, llvm::None, 0);

op->setAttr("channels_in", getIntegerAttr(builder, 64));
op->setAttr("dilation_height_factor", getIntegerAttr(builder, 3));
op->setAttr("dilation_width_factor", getIntegerAttr(builder, 4));
op->setAttr("stride_height", getIntegerAttr(builder, 1));
op->setAttr("stride_width", getIntegerAttr(builder, 2));
op->setAttr("pad_values", getIntegerAttr(builder, 1));

op->setAttr("fused_activation_function", builder.getStringAttr("RELU"));
op->setAttr("padding", builder.getStringAttr("SAME"));

std::vector<uint8_t> v = cast<TF::Bconv2dOp>(op).buildCustomOptions();
const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap();

ASSERT_EQ(m["channels_in"].AsInt32(), 64);
ASSERT_EQ(m["dilation_height_factor"].AsInt32(), 3);
ASSERT_EQ(m["dilation_width_factor"].AsInt32(), 4);
ASSERT_EQ(m["stride_height"].AsInt32(), 1);
ASSERT_EQ(m["stride_width"].AsInt32(), 2);
ASSERT_EQ(m["pad_values"].AsInt32(), 1);
ASSERT_EQ(m["fused_activation_function"].ToString(), "RELU");
ASSERT_EQ(m["padding"].ToString(), "SAME");
}

TEST(LCEOpsSerializationTest, BMaxPool2dTest) {
MLIRContext context;
Builder builder(&context);
auto op = Operation::create(
UnknownLoc::get(&context), OperationName("lq.BMaxPool2d", &context),
llvm::None, llvm::None, llvm::None, llvm::None, 0);

op->setAttr("padding", builder.getStringAttr("SAME"));
op->setAttr("stride_width", getIntegerAttr(builder, 2));
op->setAttr("stride_height", getIntegerAttr(builder, 1));
op->setAttr("filter_width", getIntegerAttr(builder, 3));
op->setAttr("filter_height", getIntegerAttr(builder, 4));

std::vector<uint8_t> v = cast<TF::BMaxPool2dOp>(op).buildCustomOptions();
const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap();

ASSERT_EQ(m["padding"].ToString(), "SAME");
ASSERT_EQ(m["stride_width"].AsInt32(), 2);
ASSERT_EQ(m["stride_height"].AsInt32(), 1);
ASSERT_EQ(m["filter_width"].AsInt32(), 3);
ASSERT_EQ(m["filter_height"].AsInt32(), 4);
}

0 comments on commit 19e79a3

Please sign in to comment.