Skip to content

Commit

Permalink
Add grouped binary convolution support (2/3): reference kernel.
Browse files Browse the repository at this point in the history
Add support for grouped binary convolutions to the reference kernel.
Extend the kernel tests to include grouped convolutions, through
emulation with multiple convolutions since the built-in Conv2D op
doesn't support groups.
  • Loading branch information
AdamHillier committed Nov 6, 2020
1 parent 4a80d66 commit b15e227
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 56 deletions.
2 changes: 1 addition & 1 deletion larq_compute_engine/core/bconv2d/optimized_bgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ inline void BConv2DOptimizedBGEMM(
// output tensor with zeroes in advance so that the BGEMM doesn't have to
// worry about doing the padding.
if (std::is_same<DstScalar, TBitpacked>::value &&
output_shape.Dims(3) % 32 != 0) {
output_shape.Dims(3) % bitpacking_bitwidth != 0) {
std::fill(
output_data,
output_data + FlatSizeSkipDim(output_shape, 3) *
Expand Down
1 change: 1 addition & 0 deletions larq_compute_engine/core/bconv2d/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct BConv2DParams {
std::int32_t filter_height;
std::int32_t channels_in;
std::int32_t channels_out;
std::int32_t groups;

// Strides
std::int32_t stride_height;
Expand Down
16 changes: 12 additions & 4 deletions larq_compute_engine/core/bconv2d/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,21 @@ inline void BConv2DReference(
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);

const int batches = MatchingDim(packed_input_shape, 0, output_shape, 0);
const int input_depth =
MatchingDim(packed_input_shape, 3, packed_filter_shape, 3);
const int input_depth_per_group = packed_filter_shape.Dims(3);
const int output_depth = packed_filter_shape.Dims(0);
const int output_depth_per_group = output_depth / bconv2d_params->groups;
const int input_height = packed_input_shape.Dims(1);
const int input_width = packed_input_shape.Dims(2);
const int filter_height = packed_filter_shape.Dims(1);
const int filter_width = packed_filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);

TFLITE_DCHECK_EQ(input_depth_per_group * bconv2d_params->groups,
packed_input_shape.Dims(3));
TFLITE_DCHECK_EQ(output_depth_per_group * bconv2d_params->groups,
output_depth);

for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
Expand All @@ -75,10 +80,12 @@ inline void BConv2DReference(
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
const int group = out_channel / output_depth_per_group;
AccumScalar accum = AccumScalar(0);
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
for (int in_channel = 0; in_channel < input_depth_per_group;
++in_channel) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
const int in_y =
in_y_origin + dilation_height_factor * filter_y;
Expand All @@ -88,7 +95,8 @@ inline void BConv2DReference(
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
input_value = packed_input_data[Offset(
packed_input_shape, batch, in_y, in_x, in_channel)];
packed_input_shape, batch, in_y, in_x,
group * input_depth_per_group + in_channel)];
}
TBitpacked filter_value =
packed_filter_data[Offset(packed_filter_shape, out_channel,
Expand Down
34 changes: 29 additions & 5 deletions larq_compute_engine/tflite/kernels/bconv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bconv2d_params->filter_height = SizeOfDimension(filter, 1);
bconv2d_params->filter_width = SizeOfDimension(filter, 2);

if (SizeOfDimension(filter, 3) ==
GetBitpackedSize(bconv2d_params->channels_in)) {
bconv2d_params->groups = 1;
} else {
TF_LITE_ENSURE_MSG(
context, kernel_type == KernelType::kReference,
"Grouped binary convolutions are not supported with this kernel.");
TF_LITE_ENSURE_EQ(context,
GetBitpackedSize(bconv2d_params->channels_in) %
SizeOfDimension(filter, 3),
0);
const std::int32_t groups = GetBitpackedSize(bconv2d_params->channels_in) /
SizeOfDimension(filter, 3);
const std::int32_t group_size = bconv2d_params->channels_in / groups;
TF_LITE_ENSURE_EQ(context, group_size % core::bitpacking_bitwidth, 0);
TF_LITE_ENSURE_EQ(context, bconv2d_params->channels_out % groups, 0);
bconv2d_params->groups = groups;
}

// Compute the padding and output values (height, width)
int out_width, out_height;
bconv2d_params->padding_values = ComputePaddingHeightWidth(
Expand Down Expand Up @@ -273,7 +292,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

// Resize the im2col tensor
int bitpacked_channels_in = GetBitpackedSize(bconv2d_params->channels_in);
const std::int32_t bitpacked_channels_in =
GetBitpackedSize(bconv2d_params->channels_in);
TfLiteIntArray* im2col_size = TfLiteIntArrayCopy(output_shape);
im2col_size->data[3] = bitpacked_channels_in *
bconv2d_params->filter_height *
Expand Down Expand Up @@ -321,6 +341,11 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) {
const auto* post_activation_bias = GetInput(context, node, 3);
const auto* output = GetOutput(context, node, 0);

// Division is safe because at this point we know that channels_in is a
// multiple of the number of groups.
const std::int32_t channels_in_per_group =
bconv2d_params->channels_in / bconv2d_params->groups;

// For 'same-zero' padding, compute the padding-correction.
if (bconv2d_params->padding_type == kTfLitePaddingSame &&
bconv2d_params->pad_value == 0) {
Expand All @@ -331,7 +356,7 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) {
zero_padding_correction::CacheCorrectionValues(
GetTensorData<TBitpacked>(filter), bconv2d_params->filter_height,
bconv2d_params->filter_width, bconv2d_params->channels_out,
bconv2d_params->channels_in, bconv2d_params->dilation_height_factor,
channels_in_per_group, bconv2d_params->dilation_height_factor,
bconv2d_params->dilation_width_factor,
GetTensorData<float>(post_activation_multiplier),
op_data->padding_buffer.data());
Expand All @@ -346,9 +371,8 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) {
LCE_EXTRA_BYTES / sizeof(float));

const auto filter_shape = GetTensorShape(GetInput(context, node, 1));
const std::int32_t backtransform_add = filter_shape.Dims(1) *
filter_shape.Dims(2) *
bconv2d_params->channels_in;
const std::int32_t backtransform_add =
filter_shape.Dims(1) * filter_shape.Dims(2) * channels_in_per_group;
const double output_scale =
output->type == kTfLiteInt8 ? output->params.scale : 1.0f;
const double output_zero_point =
Expand Down
3 changes: 0 additions & 3 deletions larq_compute_engine/tflite/tests/bconv2d_op_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ class BaseBConv2DOpModel : public SingleOpModel {

flexbuffers::Builder fbb;
fbb.Map([&]() {
// This attribute is necessary because if the filters are bitpacked and
// we're reading bitpacked input then we don't have access to the original
// 'true' number of input channels.
fbb.Int("channels_in", channels_in);
fbb.Int("stride_height", stride_height);
fbb.Int("stride_width", stride_width);
Expand Down
Loading

0 comments on commit b15e227

Please sign in to comment.