Skip to content

Commit

Permalink
Create a partial key for output_scale to improve key creation perform…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
ShengYang1 committed Dec 20, 2019
1 parent 85f4643 commit 5024be7
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions tensorflow/core/kernels/mkl_conv_ops.cc
Expand Up @@ -24,8 +24,8 @@ limitations under the License.
#include <map>
#include <vector>

#include "mkldnn.hpp"
#include "absl/strings/str_join.h"
#include "mkldnn.hpp"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand Down Expand Up @@ -156,6 +156,7 @@ struct MklConvFwdParams {
string name;
mkldnn::algorithm alg;
std::vector<float> param;
std::string param_key;
};
std::vector<PostOpParam> post_op_params;

Expand Down Expand Up @@ -488,17 +489,22 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> {

// Generate keys for post-ops
for (auto const& post_op_param : convFwdDims.post_op_params) {
key_creator.AddAsKey(post_op_param.name);
if (post_op_param.name == "activation") {
DCHECK_EQ(post_op_param.param.size(), 3);
for (auto& param : post_op_param.param) {
key_creator.AddAsKey(param);
}
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
for (auto& param : post_op_param.param) {
key_creator.AddAsKey(param);
}
} else if (post_op_param.name == "output_scale") {
key_creator.AddAsKey(post_op_param.param_key);
} else if (post_op_param.name != "output_scale") {
return string("not_a_key");
}
key_creator.AddAsKey(post_op_param.name);
for (auto& param : post_op_param.param) {
key_creator.AddAsKey(param);
}
}

return key_creator.GetKey();
Expand Down Expand Up @@ -570,17 +576,15 @@ class MklConvOp : public OpKernel {
OP_REQUIRES(context, dilations_.size() == 5,
errors::InvalidArgument("Dilation rates field must "
"specify 5 dimensions"));
OP_REQUIRES(context,
(GetTensorDim(dilations_, data_format_, 'N') == 1 &&
GetTensorDim(dilations_, data_format_, 'C') == 1),
OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
GetTensorDim(dilations_, data_format_, 'C') == 1),
errors::InvalidArgument(
"Current implementation does not yet support "
"dilations rates in the batch and depth dimensions."));
OP_REQUIRES(
context,
(GetTensorDim(dilations_, data_format_, '0') > 0 &&
GetTensorDim(dilations_, data_format_, '1') > 0 &&
GetTensorDim(dilations_, data_format_, '2') > 0),
context, (GetTensorDim(dilations_, data_format_, '0') > 0 &&
GetTensorDim(dilations_, data_format_, '1') > 0 &&
GetTensorDim(dilations_, data_format_, '2') > 0),
errors::InvalidArgument("Dilated rates should be larger than 0."));
}
}
Expand Down Expand Up @@ -946,11 +950,11 @@ class MklConvOp : public OpKernel {
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking `fuse_biasadd_` flag.
if (fuse_add_) {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
}
if (fuse_activation_) {
params.post_op_params.push_back(
{"activation", activation_alg_, {1.0, relu_up_bound_, 0.0}});
{"activation", activation_alg_, {1.0, relu_up_bound_, 0.0}, ""});
}
}

Expand Down Expand Up @@ -1564,8 +1568,19 @@ class MklQuantizedConv2DOp
scales[i] = int_output_limit * float_input_range * float_filter_range /
(int_const_scale_limit * float_output_range);
}
// we are creating a partial key here to use with primitive key caching to
// improve key creation performance. Instead of using actual values we are
// using the pointers for min/max_filter_vector, and this works since the
// filter vector here is a constant.
FactoryKeyCreator param_key;
param_key.AddAsKey<float>(min_input);
param_key.AddAsKey<float>(max_input);
param_key.AddAsKey<float>(min_freezed_output);
param_key.AddAsKey<float>(max_freezed_output);
param_key.AddAsKey<const Tensor*>(&min_filter_vector);
param_key.AddAsKey<const Tensor*>(&max_filter_vector);
params.post_op_params.push_back(
{"output_scale", ALGORITHM_UNDEF, scales});
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()});
}
}

Expand Down Expand Up @@ -1745,7 +1760,7 @@ class MklQuantizedConv2DReluOp
bias_enabled,
is_depthwise>::ExtendConvFwdParams(context, params);
params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
}
};

Expand Down Expand Up @@ -1793,17 +1808,18 @@ class MklQuantizedConv2DSumReluOp
// If it is not then it is DT_INT8 and is scaled appropriately.
if (summand_type == DT_QUINT8)
params.post_op_params.push_back(
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}});
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""});
else
params.post_op_params.push_back(
{"sum",
ALGORITHM_UNDEF,
{255.0f * scale_summand / (scale_output * 127.0f)}});
{255.0f * scale_summand / (scale_output * 127.0f)},
""});
} else {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
}
params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
}

void AllocateOutputTensor(OpKernelContext* context,
Expand Down

0 comments on commit 5024be7

Please sign in to comment.