Skip to content

Commit

Permalink
Fixed BuildOpInfoWithoutDevice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 165653933
  • Loading branch information
benoitsteiner authored and tensorflower-gardener committed Aug 18, 2017
1 parent d7e425f commit 513def0
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 18 deletions.
20 changes: 19 additions & 1 deletion tensorflow/core/grappler/costs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ tf_cuda_library(
],
)

cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

cc_library(
name = "cost_estimator",
hdrs = ["cost_estimator.h"],
Expand Down Expand Up @@ -170,7 +188,7 @@ cc_test(
srcs = ["virtual_placer_test.cc"],
deps = [
":virtual_placer",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
Expand Down
25 changes: 8 additions & 17 deletions tensorflow/core/grappler/costs/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
return tensors;
}

// Annotate the op_info inputs with extra information when possible (e.g. the
// input value if it's known statically).
static void ExtractExtraProperties(
const NodeDef& node,
const std::unordered_map<string, const NodeDef*>& name_to_node,
std::vector<OpInfo::TensorProperties>* extra_inputs,
protobuf::Map<string, AttrValue>* attr_map) {
OpInfo* op_info) {
OpRegistry* op_registry = OpRegistry::Global();
const OpDef* op_def = nullptr;
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
Expand Down Expand Up @@ -102,11 +103,8 @@ static void ExtractExtraProperties(
if (tensors.empty()) continue;

const TensorProto& t = tensors[0];
OpInfo::TensorProperties input;
input.set_dtype(t.dtype());
*(input.mutable_shape()) = t.tensor_shape();
*(input.mutable_value()) = t;
extra_inputs->push_back(input);
OpInfo::TensorProperties* input = op_info->mutable_inputs(i);
*(input->mutable_value()) = t;

// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
Expand All @@ -129,7 +127,7 @@ static void ExtractExtraProperties(
AttrValue attr;
attr.set_i(stat.length);
string attr_key = strings::StrCat("input_", i, "_filesize");
(*attr_map)[attr_key] = attr;
(*op_info->mutable_attr())[attr_key] = attr;
}
}

Expand All @@ -140,7 +138,7 @@ static void ExtractExtraProperties(
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
(*attr_map)[new_key] = attr;
(*op_info->mutable_attr())[new_key] = attr;
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
// and attributes when necessary.
}
Expand Down Expand Up @@ -212,14 +210,7 @@ OpInfo BuildOpInfoWithoutDevice(
for (auto& input : inputs) {
*op_info.add_inputs() = input;
}

std::vector<OpInfo::TensorProperties> extra_inputs;
ExtractExtraProperties(node, name_to_node, &extra_inputs,
op_info.mutable_attr());
for (auto& input : extra_inputs) {
*op_info.add_inputs() = input;
}

ExtractExtraProperties(node, name_to_node, &op_info);
return op_info;
}

Expand Down
150 changes: 150 additions & 0 deletions tensorflow/core/grappler/costs/utils_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {
namespace grappler {

class UtilsTest : public ::testing::Test {
public:
void CreateConstOp(const string& name, std::initializer_list<int64> dims,
NodeDef* node) {
Tensor tensor(DT_FLOAT, TensorShape(dims));
for (int64 i = 0; i < tensor.NumElements(); ++i) {
tensor.flat<float>()(i) = i / 10.0f;
}
TF_CHECK_OK(NodeDefBuilder(name, "Const")
.Attr("dtype", DT_FLOAT)
.Attr("value", tensor)
.Finalize(node));
}

void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
NodeDef* node) {
TensorShape shape;
shape.AddDim(sizes.size());
Tensor tensor(DT_INT32, shape);
for (int64 i = 0; i < tensor.NumElements(); ++i) {
tensor.flat<int32>()(i) = sizes[i];
}
TF_CHECK_OK(NodeDefBuilder(name, "Const")
.Attr("dtype", DT_INT32)
.Attr("value", tensor)
.Finalize(node));
}
};

TEST_F(UtilsTest, ConvOpInfo) {
int batch = 32;
int rows = 7;
int cols = 9;
int filter_rows = 3;
int filter_cols = 3;
int out_rows = 7;
int out_cols = 9;
int in_depth = 3;
int out_depth = 5;
int stride = 1;

std::unordered_map<string, const NodeDef*> name_to_node;
GraphDef graph;
NodeDef* input = graph.add_node();
name_to_node["input"] = input;
CreateConstOp("input", {batch, rows, cols, in_depth}, input);
NodeDef* filter = graph.add_node();
name_to_node["filter"] = filter;
CreateConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth},
filter);
NodeDef* output_backprop = graph.add_node();
name_to_node["output_backprop"] = output_backprop;
CreateConstOp("output_backprop", {batch, out_rows, out_cols, out_depth},
output_backprop);
NodeDef* input_sizes = graph.add_node();
name_to_node["input_sizes"] = input;
CreateConstSizesOp("input_sizes",
std::vector<int32>({batch, rows, cols, in_depth}),
input_sizes);
NodeDef* filter_sizes = graph.add_node();
name_to_node["filter_sizes"] = filter_sizes;
CreateConstSizesOp(
"filter_sizes",
std::vector<int32>({filter_rows, filter_cols, in_depth, out_depth}),
filter_sizes);

TensorShape paddings_shape({4, 2});
Tensor paddings_tensor(DT_INT32, paddings_shape);
for (int64 i = 0; i < paddings_tensor.NumElements(); ++i) {
paddings_tensor.flat<int32>()(i) = 0;
}
TF_CHECK_OK(NodeDefBuilder("paddings", "Const")
.Attr("dtype", DT_INT32)
.Attr("value", paddings_tensor)
.Finalize(graph.add_node()));

// Now add the convolution op
NodeDef* conv = graph.add_node();
TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D")
.Input("input", 0, DT_FLOAT)
.Input("filter", 0, DT_FLOAT)
.Attr("strides", {1, stride, stride, 1})
.Attr("padding", "SAME")
.Finalize(conv));

NodeDef* conv_bp_in = graph.add_node();
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_in", "Conv2DBackpropInput")
.Input("input_sizes", 0, DT_INT32)
.Input("filter", 0, DT_FLOAT)
.Input("output_backprop", 0, DT_FLOAT)
.Attr("strides", {1, stride, stride, 1})
.Attr("padding", "SAME")
.Finalize(conv_bp_in));

NodeDef* conv_bp_filter = graph.add_node();
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_filter", "Conv2DBackpropFilter")
.Input("input", 0, DT_FLOAT)
.Input("filter_sizes", 0, DT_INT32)
.Input("output_backprop", 0, DT_FLOAT)
.Attr("strides", {1, stride, stride, 1})
.Attr("padding", "SAME")
.Finalize(conv_bp_filter));

for (const auto& node : graph.node()) {
if (node.name().find("conv2d") != 0) {
continue;
}
std::vector<OpInfo::TensorProperties> inputs;
inputs.resize(node.input_size());
OpInfo info = BuildOpInfoWithoutDevice(node, name_to_node, inputs);
if (node.name() == "conv2d") {
EXPECT_EQ(2, info.inputs_size());
} else if (node.name() == "conv2dbp_in") {
EXPECT_EQ(3, info.inputs_size());
} else if (node.name() == "conv2d_bp_filter") {
EXPECT_EQ(3, info.inputs_size());
}
}
}

} // end namespace grappler
} // end namespace tensorflow

0 comments on commit 513def0

Please sign in to comment.