From 74229d4736accc47e02ac7d440d931f489260c0e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 Dec 2019 16:10:36 -0800 Subject: [PATCH] [Grappler] 1) Skip dead branch elimination for merge nodes with control inputs, since these can create cycles in the resulting optimized graph. 2) Optimize a few utility functions. 3) Add more verbose VLOGging when topological sorting fails. PiperOrigin-RevId: 284871268 Change-Id: I36435402d826e4737b709468d88641d7a7fa2a83 --- .../grappler/optimizers/loop_optimizer.cc | 76 +++++++++++-------- .../optimizers/loop_optimizer_test.cc | 26 +------ .../grappler/optimizers/meta_optimizer.cc | 2 +- tensorflow/core/grappler/utils.cc | 28 +++++-- tensorflow/core/grappler/utils.h | 3 + .../core/grappler/utils/topological_sort.cc | 10 +++ tensorflow/core/grappler/utils_test.cc | 3 + 7 files changed, 85 insertions(+), 63 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index cf56e00e148439..a00dcc0634b929 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -627,7 +627,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, return Status::OK(); } - VLOG(3) << "Try to find a zero iteration while loop:" + VLOG(4) << "Try to find a zero iteration while loop:" << " switch_node=" << switch_node.name(); // Find the boolean predicate from a LoopCond node (e.g. Greater). @@ -704,7 +704,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, &constant_switch_value)); if (constant_switch_value == false) { - VLOG(4) << "Remove 0 iteration while loop:" + VLOG(3) << "Remove 0 iteration while loop:" << " switch_node=" << switch_node.name(); *has_dead_fanout = true; *dead_fanout = 1; @@ -746,8 +746,6 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } if (options_.enable_dead_branch_removal) { - // TODO(srjoglekar): Figure out if we can optimize NodeMap creations across - // optimizer passes. NodeMap node_map(optimized_graph); absl::flat_hash_set feed_nodes; for (const auto& feed : item.feed) { @@ -890,43 +888,55 @@ Status LoopOptimizer::RemoveDeadBranches( // Names of the nodes that were removed from the graph. absl::flat_hash_set dead_node_names; dead_node_names.reserve(dead_nodes.size()); - for (const NodeDef* dead_node : dead_nodes) + for (const NodeDef* dead_node : dead_nodes) { dead_node_names.insert(dead_node->name()); + } - // Remove dead inputs from Merge nodes that were not pruned from the graph. + // Check that the merge nodes are valid. for (const auto& itr : dead_merge_inputs) { - NodeDef* dead_node = itr.first; - if (dead_nodes.find(dead_node) != dead_nodes.end()) { - // The node has been pruned since all its inputs are dead. + NodeDef* merge_node = itr.first; + if (dead_nodes.find(merge_node) != dead_nodes.end()) { + // The node will be pruned since all its inputs are dead. continue; } // Remove dead data input. const std::set& dead_inputs = itr.second; - CHECK_LE(dead_inputs.size(), 1); - // (This loop would delete >1 items possibly in the wrong order.) - for (int index : dead_inputs) { - dead_node->mutable_input()->DeleteSubrange(index, 1); - } - // Turn Merge into Identity only if we deleted the other data input. - if (!dead_inputs.empty()) { - const int num_data_inputs = dead_node->attr().at("N").i(); - CHECK_EQ(num_data_inputs, dead_inputs.size() + 1); - dead_node->set_op("Identity"); - dead_node->mutable_attr()->erase("N"); - } - // Remove control inputs from dead nodes. - int pos = 0; - while (pos < dead_node->input_size()) { - TensorId tensor = ParseTensorName(dead_node->input(pos)); - if (tensor.index() == Graph::kControlSlot && - dead_node_names.contains(tensor.node())) { - auto* inputs = dead_node->mutable_input(); - inputs->SwapElements(pos, dead_node->input_size() - 1); - inputs->RemoveLast(); - } else { - ++pos; - } + const int num_data_inputs = merge_node->attr().at("N").i(); + if (merge_node->input_size() != num_data_inputs) { + LOG(WARNING) + << "Skipping loop optimization for Merge node with control input: " + << merge_node->name(); + return Status::OK(); + } else if (dead_inputs.size() != 1 || num_data_inputs != 2) { + LOG(WARNING) << "Skipping loop optimization for Merge node (" + << merge_node->name() + << ") with unexpected dead_inputs.size() (" + << dead_inputs.size() << " or num_data_inputs" + << num_data_inputs; + return Status::OK(); + } + } + + // Remove dead inputs from Merge nodes that will not be not + // pruned from the graph. + for (const auto& itr : dead_merge_inputs) { + NodeDef* merge_node = itr.first; + if (dead_nodes.find(merge_node) != dead_nodes.end()) { + // The node will be pruned since all its inputs are dead. + continue; } + VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString(); + // Remove dead data input. + const std::set& dead_inputs = itr.second; + int index = *dead_inputs.begin(); + auto* inputs = merge_node->mutable_input(); + inputs->SwapElements(1, index); + inputs->SwapElements(1, merge_node->input_size() - 1); + inputs->RemoveLast(); + merge_node->set_op("Identity"); + merge_node->mutable_attr()->erase("N"); + + VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString(); } EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph); diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index a8bedeed663354..f48f5b01a796ba 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -777,10 +777,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) { ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1}); ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2}); ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1}); - ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2), - {v_in, square1}); - ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1), - {v_in, square1}); ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1); Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false); @@ -831,19 +827,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) { ASSERT_EQ(node.input_size(), 2); EXPECT_EQ(node.input(0), "square1"); EXPECT_EQ(node.input(1), "sqrt2"); - } else if (node.name() == "m6") { - // both inputs are alive and the control dependency can get triggered - EXPECT_EQ(node.op(), "Merge"); - ASSERT_EQ(node.input_size(), 3); - EXPECT_EQ(node.input(0), "v_in"); - EXPECT_EQ(node.input(1), "square1"); - EXPECT_EQ(node.input(2), "^sqrt2"); - } else if (node.name() == "m7") { - // removed control input from dead sqrt1 - EXPECT_EQ(node.op(), "Merge"); - ASSERT_EQ(node.input_size(), 2); - EXPECT_EQ(node.input(0), "v_in"); - EXPECT_EQ(node.input(1), "square1"); } else if (node.name() == "m8") { // The node is to be preserved because of a fetch EXPECT_EQ(node.op(), "Merge"); @@ -859,11 +842,11 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) { } } - auto tensors_expected = EvaluateNodes(item.graph, {"m7", "m8", "m9"}); - ASSERT_EQ(tensors_expected.size(), 3); + auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"}); + ASSERT_EQ(tensors_expected.size(), 2); - auto tensors = EvaluateNodes(output, {"m7", "m8", "m9"}); - ASSERT_EQ(tensors.size(), 3); + auto tensors = EvaluateNodes(output, {"m8", "m9"}); + ASSERT_EQ(tensors.size(), 2); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); test::ExpectTensorNear(tensors_expected[1], tensors[1], 1e-6); @@ -1098,7 +1081,6 @@ node { op: "Merge" input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1" input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency" - input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert" device: "/job:localhost/replica:0/task:0/device:CPU:0" attr { key: "N" diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c924f8b52b1552..9fc9be0e5af15e 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -221,7 +221,7 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.function_optimization() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique( cfg_.function_optimization(), - /*lower_contorl_flow=*/!IsSingleThreadedExecutor())); + /*lower_control_flow=*/!IsSingleThreadedExecutor())); } if (cfg_.debug_stripper() == RewriterConfig::ON) { optimizers->push_back(MakeUnique()); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 59307a218dbb28..9213b37ff2907b 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -277,10 +277,22 @@ bool HasRegularInputs(const NodeDef& node) { } int NumNonControlInputs(const NodeDef& node) { - int num_inputs = node.input_size(); - for (const string& input : node.input()) { + int num_inputs = 0; + for (; num_inputs < node.input_size(); ++num_inputs) { + const string& input = node.input(num_inputs); if (IsControlInput(input)) { - --num_inputs; + return num_inputs; + } + } + return num_inputs; +} + +int NumControlInputs(const NodeDef& node) { + int num_inputs = 0; + for (; num_inputs < node.input_size(); ++num_inputs) { + const string& input = node.input(node.input_size() - num_inputs - 1); + if (!IsControlInput(input)) { + return num_inputs; } } return num_inputs; @@ -302,8 +314,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) { bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) { for (const NodeDef* output : node_map.GetOutputs(node.name())) { - for (const string& node_as_input : output->input()) { - if (!IsControlInput(node_as_input)) continue; + for (int idx = output->input_size() - 1; idx >= 0; --idx) { + const string& node_as_input = output->input(idx); + if (!IsControlInput(node_as_input)) break; TensorId tensor = ParseTensorName(node_as_input); if (tensor.node() == node.name()) { @@ -317,8 +330,9 @@ bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) { int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) { int num_outputs = 0; for (const NodeDef* output : node_map.GetOutputs(node.name())) { - for (const string& node_as_input : output->input()) { - if (!IsControlInput(node_as_input)) continue; + for (int idx = output->input_size() - 1; idx >= 0; --idx) { + const string& node_as_input = output->input(idx); + if (!IsControlInput(node_as_input)) break; TensorId tensor = ParseTensorName(node_as_input); if (tensor.node() == node.name()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 7a1b65e1729f98..87835245762a73 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -221,6 +221,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map); // Returns true iff the node has at least one control output. bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map); +// Number of connected control inputs. +int NumControlInputs(const NodeDef& node); + // Number of connected non-control inputs. int NumNonControlInputs(const NodeDef& node); diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index a6d0f5037bb35c..e24a457593a035 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -90,6 +90,16 @@ Status ComputeTopologicalOrder( } if (back != graph_view.num_nodes()) { + if (VLOG_IS_ON(1)) { + VLOG(1) << "The graph couldn't be sorted in topological order. Stalled " + "at node = " + << graph.node(back).DebugString(); + for (int i = 0; i < graph_view.num_nodes(); ++i) { + if (num_ready_inputs[i] != graph_view.GetFanin(i).size()) { + VLOG(1) << "Node not ready: " << graph.node(i).DebugString(); + } + } + } return errors::InvalidArgument( "The graph couldn't be sorted in topological order."); } diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 777630ff98cd37..7e3d4d90dcd572 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -352,14 +352,17 @@ TEST_F(UtilsTest, NumNonControlOutputs) { NodeMap node_map(&graph); const NodeDef* add_node = node_map.GetNode("add"); + const NodeDef* mul_node = node_map.GetNode("mul"); ASSERT_NE(add_node, nullptr); // [a, b] are only non-control inputs EXPECT_EQ(NumNonControlInputs(*add_node), 2); + EXPECT_EQ(NumControlInputs(*add_node), 1); // [sqrt, shape] are non control outputs EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2); // sqrt is the only data output EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1); + EXPECT_EQ(NumControlInputs(*mul_node), 0); EXPECT_TRUE(HasControlInputs(*add_node)); EXPECT_TRUE(HasRegularInputs(*add_node));