Skip to content

Commit

Permalink
[Grappler]
Browse files Browse the repository at this point in the history
1) Skip dead branch elimination for merge nodes with control inputs, since these can create cycles in the resulting optimized graph.
2) Optimize a few utility functions.
3) Add more verbose VLOGging when topological sorting fails.

PiperOrigin-RevId: 284871268
Change-Id: I36435402d826e4737b709468d88641d7a7fa2a83
  • Loading branch information
tensorflower-gardener committed Dec 11, 2019
1 parent 77b30d9 commit 74229d4
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 63 deletions.
76 changes: 43 additions & 33 deletions tensorflow/core/grappler/optimizers/loop_optimizer.cc
Expand Up @@ -627,7 +627,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
return Status::OK();
}

VLOG(3) << "Try to find a zero iteration while loop:"
VLOG(4) << "Try to find a zero iteration while loop:"
<< " switch_node=" << switch_node.name();

// Find the boolean predicate from a LoopCond node (e.g. Greater).
Expand Down Expand Up @@ -704,7 +704,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
&constant_switch_value));

if (constant_switch_value == false) {
VLOG(4) << "Remove 0 iteration while loop:"
VLOG(3) << "Remove 0 iteration while loop:"
<< " switch_node=" << switch_node.name();
*has_dead_fanout = true;
*dead_fanout = 1;
Expand Down Expand Up @@ -746,8 +746,6 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
}
if (options_.enable_dead_branch_removal) {
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
// optimizer passes.
NodeMap node_map(optimized_graph);
absl::flat_hash_set<string> feed_nodes;
for (const auto& feed : item.feed) {
Expand Down Expand Up @@ -890,43 +888,55 @@ Status LoopOptimizer::RemoveDeadBranches(
// Names of the nodes that were removed from the graph.
absl::flat_hash_set<absl::string_view> dead_node_names;
dead_node_names.reserve(dead_nodes.size());
for (const NodeDef* dead_node : dead_nodes)
for (const NodeDef* dead_node : dead_nodes) {
dead_node_names.insert(dead_node->name());
}

// Remove dead inputs from Merge nodes that were not pruned from the graph.
// Check that the merge nodes are valid.
for (const auto& itr : dead_merge_inputs) {
NodeDef* dead_node = itr.first;
if (dead_nodes.find(dead_node) != dead_nodes.end()) {
// The node has been pruned since all its inputs are dead.
NodeDef* merge_node = itr.first;
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
// The node will be pruned since all its inputs are dead.
continue;
}
// Remove dead data input.
const std::set<int>& dead_inputs = itr.second;
CHECK_LE(dead_inputs.size(), 1);
// (This loop would delete >1 items possibly in the wrong order.)
for (int index : dead_inputs) {
dead_node->mutable_input()->DeleteSubrange(index, 1);
}
// Turn Merge into Identity only if we deleted the other data input.
if (!dead_inputs.empty()) {
const int num_data_inputs = dead_node->attr().at("N").i();
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
dead_node->set_op("Identity");
dead_node->mutable_attr()->erase("N");
}
// Remove control inputs from dead nodes.
int pos = 0;
while (pos < dead_node->input_size()) {
TensorId tensor = ParseTensorName(dead_node->input(pos));
if (tensor.index() == Graph::kControlSlot &&
dead_node_names.contains(tensor.node())) {
auto* inputs = dead_node->mutable_input();
inputs->SwapElements(pos, dead_node->input_size() - 1);
inputs->RemoveLast();
} else {
++pos;
}
const int num_data_inputs = merge_node->attr().at("N").i();
if (merge_node->input_size() != num_data_inputs) {
LOG(WARNING)
<< "Skipping loop optimization for Merge node with control input: "
<< merge_node->name();
return Status::OK();
} else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
LOG(WARNING) << "Skipping loop optimization for Merge node ("
<< merge_node->name()
<< ") with unexpected dead_inputs.size() ("
<< dead_inputs.size() << " or num_data_inputs"
<< num_data_inputs;
return Status::OK();
}
}

// Remove dead inputs from Merge nodes that will not be not
// pruned from the graph.
for (const auto& itr : dead_merge_inputs) {
NodeDef* merge_node = itr.first;
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
// The node will be pruned since all its inputs are dead.
continue;
}
VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
// Remove dead data input.
const std::set<int>& dead_inputs = itr.second;
int index = *dead_inputs.begin();
auto* inputs = merge_node->mutable_input();
inputs->SwapElements(1, index);
inputs->SwapElements(1, merge_node->input_size() - 1);
inputs->RemoveLast();
merge_node->set_op("Identity");
merge_node->mutable_attr()->erase("N");

VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
}

EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
Expand Down
26 changes: 4 additions & 22 deletions tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
Expand Up @@ -777,10 +777,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
{v_in, square1});
ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
{v_in, square1});

ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
Expand Down Expand Up @@ -831,19 +827,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "square1");
EXPECT_EQ(node.input(1), "sqrt2");
} else if (node.name() == "m6") {
// both inputs are alive and the control dependency can get triggered
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
EXPECT_EQ(node.input(2), "^sqrt2");
} else if (node.name() == "m7") {
// removed control input from dead sqrt1
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
} else if (node.name() == "m8") {
// The node is to be preserved because of a fetch
EXPECT_EQ(node.op(), "Merge");
Expand All @@ -859,11 +842,11 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
}
}

auto tensors_expected = EvaluateNodes(item.graph, {"m7", "m8", "m9"});
ASSERT_EQ(tensors_expected.size(), 3);
auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
ASSERT_EQ(tensors_expected.size(), 2);

auto tensors = EvaluateNodes(output, {"m7", "m8", "m9"});
ASSERT_EQ(tensors.size(), 3);
auto tensors = EvaluateNodes(output, {"m8", "m9"});
ASSERT_EQ(tensors.size(), 2);

test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
Expand Down Expand Up @@ -1098,7 +1081,6 @@ node {
op: "Merge"
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "N"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/grappler/optimizers/meta_optimizer.cc
Expand Up @@ -221,7 +221,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.function_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<FunctionOptimizer>(
cfg_.function_optimization(),
/*lower_contorl_flow=*/!IsSingleThreadedExecutor()));
/*lower_control_flow=*/!IsSingleThreadedExecutor()));
}
if (cfg_.debug_stripper() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<DebugStripper>());
Expand Down
28 changes: 21 additions & 7 deletions tensorflow/core/grappler/utils.cc
Expand Up @@ -277,10 +277,22 @@ bool HasRegularInputs(const NodeDef& node) {
}

int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (const string& input : node.input()) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(num_inputs);
if (IsControlInput(input)) {
--num_inputs;
return num_inputs;
}
}
return num_inputs;
}

int NumControlInputs(const NodeDef& node) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(node.input_size() - num_inputs - 1);
if (!IsControlInput(input)) {
return num_inputs;
}
}
return num_inputs;
Expand All @@ -302,8 +314,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {

bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (!IsControlInput(node_as_input)) continue;
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;

TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
Expand All @@ -317,8 +330,9 @@ bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (!IsControlInput(node_as_input)) continue;
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;

TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/grappler/utils.h
Expand Up @@ -221,6 +221,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
// Returns true iff the node has at least one control output.
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);

// Number of connected control inputs.
int NumControlInputs(const NodeDef& node);

// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);

Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/grappler/utils/topological_sort.cc
Expand Up @@ -90,6 +90,16 @@ Status ComputeTopologicalOrder(
}

if (back != graph_view.num_nodes()) {
if (VLOG_IS_ON(1)) {
VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
"at node = "
<< graph.node(back).DebugString();
for (int i = 0; i < graph_view.num_nodes(); ++i) {
if (num_ready_inputs[i] != graph_view.GetFanin(i).size()) {
VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
}
}
}
return errors::InvalidArgument(
"The graph couldn't be sorted in topological order.");
}
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/grappler/utils_test.cc
Expand Up @@ -352,14 +352,17 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
NodeMap node_map(&graph);

const NodeDef* add_node = node_map.GetNode("add");
const NodeDef* mul_node = node_map.GetNode("mul");
ASSERT_NE(add_node, nullptr);

// [a, b] are only non-control inputs
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
EXPECT_EQ(NumControlInputs(*add_node), 1);
// [sqrt, shape] are non control outputs
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
// sqrt is the only data output
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
EXPECT_EQ(NumControlInputs(*mul_node), 0);

EXPECT_TRUE(HasControlInputs(*add_node));
EXPECT_TRUE(HasRegularInputs(*add_node));
Expand Down

0 comments on commit 74229d4

Please sign in to comment.