Skip to content

Commit

Permalink
Small improvements to the arithmetic optimizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 165760972
  • Loading branch information
benoitsteiner authored and tensorflower-gardener committed Aug 18, 2017
1 parent b640959 commit a271c37
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
Expand Up @@ -177,7 +177,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
if (rep == node) {
continue;
}
const std::set<NodeDef*> fanouts = map.GetOutputs(node->name());
const std::set<NodeDef*>& fanouts = map.GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (string& name : *fanout->mutable_input()) {
int position;
Expand All @@ -190,7 +190,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
} else {
name = strings::StrCat("^", rep->name());
}
map.UpdateOutput(nodename, fanout->name(), name);
map.AddOutput(rep->name(), fanout->name());
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/grappler/utils.cc
Expand Up @@ -61,8 +61,9 @@ void NodeMap::AddOutput(const string& node, const string& output) {

void NodeMap::UpdateOutput(const string& node, const string& old_output,
const string& new_output) {
outputs_[node].erase(nodes_[old_output]);
outputs_[node].insert(nodes_[new_output]);
std::set<NodeDef*>& outputs = outputs_[node];
outputs.erase(nodes_[old_output]);
outputs.insert(nodes_[new_output]);
}

bool IsSameInput(const string& name1, const string& name2) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/grappler/memory_optimizer_test.py
Expand Up @@ -125,6 +125,7 @@ def testRewritingDefaultGradientNames(self):
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
original_metagraph)
self.assertGreater(
Expand All @@ -146,6 +147,7 @@ def testRewritingNameScopedGradientNames(self):
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
original_metagraph)
Expand Down

0 comments on commit a271c37

Please sign in to comment.