Skip to content

Commit

Permalink
Add grouped binary convolution support (2/3): reference kernel.
Browse files Browse the repository at this point in the history
Add support for grouped binary convolutions to the reference kernel.
Extend the kernel tests to include grouped convolutions, through
emulation with multiple convolutions since the built-in Conv2D op
doesn't support groups.
  • Loading branch information
AdamHillier committed Oct 22, 2020
1 parent 9c0acf3 commit 9f32b05
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 79 deletions.
2 changes: 1 addition & 1 deletion larq_compute_engine/core/bconv2d/optimized_bgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ inline void BConv2DOptimizedBGEMM(
// output tensor with zeroes in advance so that the BGEMM doesn't have to
// worry about doing the padding.
if (std::is_same<DstScalar, TBitpacked>::value &&
output_shape.Dims(3) % 32 != 0) {
output_shape.Dims(3) % bitpacking_bitwidth != 0) {
std::fill(
output_data,
output_data + FlatSizeSkipDim(output_shape, 3) *
Expand Down
34 changes: 22 additions & 12 deletions larq_compute_engine/core/bconv2d/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ limitations under the License.
#define COMPUTE_ENGINE_CORE_BCONV2D_REFERENCE_H_

#include "larq_compute_engine/core/bconv2d/output_transform.h"
#include "larq_compute_engine/tflite/kernels/bconv2d_params.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"

using namespace tflite;

namespace compute_engine {
namespace core {
namespace bconv2d {

using namespace tflite;
using compute_engine::tflite::bconv2d::TfLiteBConv2DParams;

template <typename AccumScalar, typename DstScalar,
OutputTransformDetails details>
inline void BConv2DReference(
const ConvParams& params, const RuntimeShape& packed_input_shape,
const TfLiteBConv2DParams* params, const RuntimeShape& packed_input_shape,
const TBitpacked* packed_input_data,
const RuntimeShape& packed_filter_shape,
const TBitpacked* packed_filter_data,
Expand All @@ -47,27 +50,31 @@ inline void BConv2DReference(
"The reference implementation supports either float "
"output, bitpacked output or 8-bit quantized output.");

const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
const int pad_width = params.padding_values.width;
const int pad_height = params.padding_values.height;
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;
const int dilation_width_factor = params->dilation_width_factor;
const int dilation_height_factor = params->dilation_height_factor;
const int pad_width = params->padding_values.width;
const int pad_height = params->padding_values.height;

TFLITE_DCHECK_EQ(packed_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(packed_filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);

const int batches = MatchingDim(packed_input_shape, 0, output_shape, 0);
const int input_depth =
MatchingDim(packed_input_shape, 3, packed_filter_shape, 3);
const int input_depth_per_group = packed_filter_shape.Dims(3);
const int output_depth = packed_filter_shape.Dims(0);
const int output_depth_per_group = output_depth / params->groups;
const int input_height = packed_input_shape.Dims(1);
const int input_width = packed_input_shape.Dims(2);
const int filter_height = packed_filter_shape.Dims(1);
const int filter_width = packed_filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);

TFLITE_DCHECK_EQ(input_depth_per_group * params->groups,
packed_input_shape.Dims(3));

for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
Expand All @@ -76,10 +83,12 @@ inline void BConv2DReference(
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
const int group = out_channel / output_depth_per_group;
AccumScalar accum = AccumScalar(0);
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
for (int in_channel = 0; in_channel < input_depth_per_group;
++in_channel) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
const int in_y =
in_y_origin + dilation_height_factor * filter_y;
Expand All @@ -89,7 +98,8 @@ inline void BConv2DReference(
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
input_value = packed_input_data[Offset(
packed_input_shape, batch, in_y, in_x, in_channel)];
packed_input_shape, batch, in_y, in_x,
group * input_depth_per_group + in_channel)];
}
TBitpacked filter_value =
packed_filter_data[Offset(packed_filter_shape, out_channel,
Expand Down
44 changes: 30 additions & 14 deletions larq_compute_engine/tflite/kernels/bconv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) {
const std::uint8_t* buffer_t = reinterpret_cast<const std::uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();

// Read the op's input arguments into the "conv_params" struct
// Read the op's input arguments into the `conv_params` struct

LCE_ENSURE_PARAM(conv_params, context, !m["stride_height"].IsNull());
LCE_ENSURE_PARAM(conv_params, context, !m["stride_width"].IsNull());
Expand Down Expand Up @@ -83,6 +83,12 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) {
// attribute added to the op in the converter.
conv_params->channels_in = m["channels_in"].AsInt32();

if (!m["groups"].IsNull()) {
conv_params->groups = m["groups"].AsInt32();
} else {
conv_params->groups = 1;
}

conv_params->fused_activation_function = ConvertActivation(
(ActivationFunctionType)m["fused_activation_function"].AsInt32());
if (conv_params->padding_type == kTfLitePaddingSame &&
Expand Down Expand Up @@ -136,6 +142,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
conv_params->filter_height = SizeOfDimension(filter, 1);
conv_params->filter_width = SizeOfDimension(filter, 2);

const std::int32_t groups = conv_params->groups;
TF_LITE_ENSURE_EQ(context, conv_params->channels_in % groups, 0);
TF_LITE_ENSURE_EQ(context, conv_params->channels_out % groups, 0);
if (groups > 1) {
TF_LITE_ENSURE_EQ(
context,
(conv_params->channels_in / groups) % core::bitpacking_bitwidth, 0);
TF_LITE_ENSURE_MSG(
context, kernel_type == KernelType::kReference,
"Grouped binary convolutions are not supported with this kernel.");
}

// Compute the padding and output values (height, width)
int out_width, out_height;
conv_params->padding_values = ComputePaddingHeightWidth(
Expand Down Expand Up @@ -232,10 +250,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

// Resize the im2col tensor
int channels_in = GetBitpackedSize(conv_params->channels_in);
int bitpacked_channels_in = GetBitpackedSize(conv_params->channels_in);
TfLiteIntArray* im2col_size = TfLiteIntArrayCopy(output_shape);
im2col_size->data[3] =
channels_in * conv_params->filter_height * conv_params->filter_width;
im2col_size->data[3] = bitpacked_channels_in * conv_params->filter_height *
conv_params->filter_width;
TfLiteTensor* im2col =
GetTemporary(context, node, conv_params->im2col_index);
im2col->type = kTfLiteInt32;
Expand Down Expand Up @@ -277,6 +295,11 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node,
const auto* post_activation_bias = GetInput(context, node, 3);
const auto* output = GetOutput(context, node, 0);

// Division is safe because at this point we know that channels_in is a
// multiple of the number of groups.
const std::int32_t channels_in_per_group =
params->channels_in / params->groups;

// For 'same-zero' padding, compute the padding-correction.
if (params->padding_type == kTfLitePaddingSame && params->pad_value == 0) {
params->padding_buffer.resize(
Expand All @@ -285,7 +308,7 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node,
params->dilation_height_factor, params->dilation_width_factor));
core::bconv2d::zero_padding_correction::CacheCorrectionValues(
GetTensorData<TBitpacked>(filter), params->filter_height,
params->filter_width, params->channels_out, params->channels_in,
params->filter_width, params->channels_out, channels_in_per_group,
params->dilation_height_factor, params->dilation_width_factor,
GetTensorData<float>(post_activation_multiplier),
params->padding_buffer.data());
Expand All @@ -301,7 +324,7 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node,

const auto filter_shape = GetTensorShape(GetInput(context, node, 1));
const std::int32_t backtransform_add =
filter_shape.Dims(1) * filter_shape.Dims(2) * params->channels_in;
filter_shape.Dims(1) * filter_shape.Dims(2) * channels_in_per_group;
const double output_scale =
output->type == kTfLiteInt8 ? output->params.scale : 1.0f;
const double output_zero_point =
Expand Down Expand Up @@ -427,18 +450,11 @@ void EvalRef(TfLiteContext* context, TfLiteNode* node,
const auto* packed_filter = GetInput(context, node, 1);
auto* output = GetOutput(context, node, 0);

// Using the standard TF Lite ConvParams struct.
// This requires extra step of converting the TfLiteBConv2DParams
// but unifies the interface with the default TF lite API for CONV params
// which is used in internal TF lite im2col functions.
ConvParams op_params;
GetConvParamsType(*params, op_params);

OutputTransform<DstScalar> output_transform;
GetOutputTransform(output_transform, context, node, params);

core::bconv2d::BConv2DReference<std::int32_t, DstScalar>(
op_params, GetTensorShape(input), GetTensorData<TBitpacked>(input),
params, GetTensorShape(input), GetTensorData<TBitpacked>(input),
GetTensorShape(packed_filter), GetTensorData<TBitpacked>(packed_filter),
output_transform, GetTensorShape(output),
GetTensorData<DstScalar>(output), params->pad_value);
Expand Down
1 change: 1 addition & 0 deletions larq_compute_engine/tflite/kernels/bconv2d_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct TfLiteBConv2DParams {
std::int32_t filter_height{0};
std::int32_t channels_in{0};
std::int32_t channels_out{0};
std::int32_t groups{1};

// Strides
std::int32_t stride_height{0};
Expand Down
6 changes: 2 additions & 4 deletions larq_compute_engine/tflite/tests/bconv2d_op_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaseBConv2DOpModel : public SingleOpModel {
const TensorData& filter, const TensorData& output,
const TensorData& post_activation_multiplier,
const TensorData& post_activation_bias, const TensorData& thresholds,
int channels_in, int stride_width = 1, int stride_height = 1,
int channels_in, int groups, int stride_width = 1, int stride_height = 1,
enum Padding padding = Padding_VALID, int pad_values = 0,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1,
Expand All @@ -42,10 +42,8 @@ class BaseBConv2DOpModel : public SingleOpModel {

flexbuffers::Builder fbb;
fbb.Map([&]() {
// This attribute is necessary because if the filters are bitpacked and
// we're reading bitpacked input then we don't have access to the original
// 'true' number of input channels.
fbb.Int("channels_in", channels_in);
fbb.Int("groups", groups);
fbb.Int("stride_height", stride_height);
fbb.Int("stride_width", stride_width);
fbb.Int("dilation_height_factor", dilation_height_factor);
Expand Down
Loading

0 comments on commit 9f32b05

Please sign in to comment.