Skip to content

Commit

Permalink
Allocate persistent buffer over scratch buffer for CMSIS-NN kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
janjongboom committed Aug 30, 2020
1 parent addbc56 commit 5986dd0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 58 deletions.
22 changes: 9 additions & 13 deletions tensorflow/lite/micro/kernels/cmsis-nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ struct OpData {
int32_t output_activation_min;
int32_t output_activation_max;

// Index to buffer for optimizations if applicable.
int buffer_idx;
// Scratch buffer for optimizations if applicable.
void* scratch_buffer;
};

inline PaddingType RuntimePaddingType(TfLitePadding padding) {
Expand Down Expand Up @@ -196,10 +196,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
data->scratch_buffer = context->AllocatePersistentBuffer(
context, buf_size);
} else {
data->buffer_idx = -1;
data->scratch_buffer = nullptr;
}
#endif
return kTfLiteOk;
Expand Down Expand Up @@ -318,16 +318,12 @@ TfLiteStatus EvalQuantizedPerChannel(

// Initialize cmsis-nn context
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.buf = data.scratch_buffer;
// Note: ctx.size is currently not used in cmsis-nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
ctx.size = 0;

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
// Note: ctx.size is currently not used in cmsis-nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
}

// arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with
// the parameters passed
arm_status status = arm_convolve_wrapper_s8(
Expand Down
24 changes: 8 additions & 16 deletions tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ struct OpData {
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// Index to buffer for optimizations if applicable.
int buffer_idx;

// Scratch buffer for optimizations if applicable.
void* scratch_buffer;
};

TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
Expand All @@ -77,8 +78,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);

int unused_output_height, unused_output_width;
// Set buffer index to a reset value
data->buffer_idx = -1;
data->scratch_buffer = nullptr;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width, 1, 1, height, width,
filter_height, filter_width, params->padding, &unused_output_height,
Expand Down Expand Up @@ -130,10 +130,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int filter_height = SizeOfDimension(filter, 1);

if (input->type == kTfLiteInt8) {
// Allocate memory for per-channel quantization parameters
const int num_channels =
filter->dims->data[kDepthwiseConvQuantizedDimension];

TF_LITE_ENSURE_EQ(context, filter->quantization.type,
kTfLiteAffineQuantization);

Expand Down Expand Up @@ -210,10 +206,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
&dw_conv_params, &input_dims, &filter_dims, &output_dims);

if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
data->scratch_buffer = context->AllocatePersistentBuffer(
context, buf_size);
} else {
data->buffer_idx = -1;
data->scratch_buffer = nullptr;
}
}
return kTfLiteOk;
Expand Down Expand Up @@ -318,14 +314,10 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
output_dims.c = output_depth;

cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.buf = reinterpret_cast<int16_t*>(data->scratch_buffer);;
/* 'size' is unused */
ctx.size = 0;

if (data->buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data->buffer_idx);
}

TFLITE_DCHECK_EQ(
arm_depthwise_conv_wrapper_s8(
&ctx, &dw_conv_params, &quant_params, &input_dims,
Expand Down
20 changes: 8 additions & 12 deletions tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ struct OpData {
int32_t output_activation_max;
// The index of the temporary tensor where the quantized inputs are cached.
int input_quantized_index;
// Index to buffer for optimizations if applicable.
int buffer_idx;
// Scratch buffer for optimizations if applicable.
void* scratch_buffer;

// Cached tensor zero point values for quantized operations.
int32_t input_zero_point;
Expand All @@ -63,8 +63,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
const TfLiteTensor* bias, TfLiteTensor* output,
OpData* data) {
TfLiteStatus status = kTfLiteOk;
// Set buffer index to a reset value
data->buffer_idx = -1;
// Set scratch buffer to empty
data->scratch_buffer = nullptr;
if (data_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
Expand Down Expand Up @@ -124,10 +124,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
arm_fully_connected_s8_get_buffer_size(&filter_dims);

if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
data->scratch_buffer = context->AllocatePersistentBuffer(
context, buf_size);
} else {
data->buffer_idx = -1;
data->scratch_buffer = nullptr;
}
}
return kTfLiteOk;
Expand Down Expand Up @@ -187,13 +187,9 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
output_dims.c = output_depth;

cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.buf = reinterpret_cast<int16_t*>(data.scratch_buffer);
ctx.size = 0;

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}

TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s8(
Expand Down
28 changes: 11 additions & 17 deletions tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ constexpr int kOutputTensor = 0;

struct OpData {
TfLitePaddingValues padding;
// Index to buffer for optimizations if applicable.
int buffer_idx;
// Buffer for optimizations if applicable.
void* scratch_buffer;

int32_t activation_min;
int32_t activation_max;
Expand Down Expand Up @@ -65,8 +65,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
TFLITE_DCHECK_LE(data->activation_min, data->activation_max);
}

// Set buffer index to a reset value
data->buffer_idx = -1;
// Set scratch buffer to empty
data->scratch_buffer = nullptr;

return kTfLiteOk;
}
Expand Down Expand Up @@ -150,11 +150,8 @@ void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
filter_dims.c = 1;

cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.buf = static_cast<int16_t*>(data.scratch_buffer);
ctx.size = 0;
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}

TFLITE_DCHECK_EQ(
arm_avgpool_s8(&ctx, &pool_params, &input_dims,
Expand Down Expand Up @@ -240,11 +237,8 @@ TfLiteStatus MaxEvalInt8(TfLiteContext* context, const TfLiteNode* node,
filter_dims.c = 1;

cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.buf = static_cast<int16_t*>(data.scratch_buffer);
ctx.size = 0;
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}

TFLITE_DCHECK_EQ(
arm_max_pool_s8(&ctx, &pool_params, &input_dims,
Expand Down Expand Up @@ -300,14 +294,14 @@ TfLiteStatus AveragePrepare(TfLiteContext* context, TfLiteNode* node) {
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int output_width = output_shape.Dims(2);

const int32_t buffer_size =
const int32_t buf_size =
arm_avgpool_s8_get_buffer_size(output_width, depth);

if (buffer_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buffer_size, &data->buffer_idx));
if (buf_size > 0) {
data->scratch_buffer = context->AllocatePersistentBuffer(
context, buf_size);
} else {
data->buffer_idx = -1;
data->scratch_buffer = nullptr;
}
}
return kTfLiteOk;
Expand Down

0 comments on commit 5986dd0

Please sign in to comment.