Skip to content

Commit

Permalink
CanAccelerateConv4x4: Check output shift
Browse files Browse the repository at this point in the history
Signed-off-by: David Lattimore <dml@google.com>
  • Loading branch information
davidlattimore committed Jul 6, 2022
1 parent ca93804 commit da08bcf
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ inline void ConvPerChannel(
bool accelerated = false;
#ifdef ACCEL_CONV
if (CanAccelerateConv4x4(params, input_shape, filter_shape, output_shape,
bias_data)) {
output_shift, bias_data)) {
ConvPerChannel4x4(params, output_multiplier, output_shift, input_shape,
input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ bool CanAccelerateConv4x4(const ConvParams& params,
const RuntimeShape& input_shape,
const RuntimeShape& filter_shape,
const RuntimeShape& output_shape,
const int32_t* output_shift,
const int32_t* bias_data) {
// No padding allowed
if (params.padding_type != PaddingType::kValid) return false;
Expand Down Expand Up @@ -259,6 +260,15 @@ bool CanAccelerateConv4x4(const ConvParams& params,
if (output_depth % 4 != 0) return false;
}

// RoundingDivideByPowerOfTwo only supports certain shifts. See
// CFU-Playground/proj/hps_accel/gateware/gen2/post_process.py
for (int i = 0; i < output_depth; i++) {
int32_t shift = output_shift[i];
if (shift < -12 || shift > -2) {
return false;
}
}

// Must be 4x4
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ bool CanAccelerateConv4x4(const ConvParams& params,
const RuntimeShape& input_shape,
const RuntimeShape& filter_shape,
const RuntimeShape& output_shape,
const int32_t* output_shift,
const int32_t* bias_data);

// Accelerated version of ConvPerChannel() specialised for:
Expand Down

0 comments on commit da08bcf

Please sign in to comment.