From cfc698f056846c03f65062a978e7855db1c9bddb Mon Sep 17 00:00:00 2001 From: RaggleDodo Date: Tue, 29 Jan 2019 20:30:58 -0800 Subject: [PATCH] fix losing optimizations in multi_opt --- llo/constant.hpp | 34 +++++++++++++++++----------------- llo/opt/const_merge.hpp | 3 ++- llo/opt/multi_opt.hpp | 12 ++++++------ llo/opt/one_prune.hpp | 3 ++- llo/opt/ops_merge.hpp | 3 ++- llo/opt/zero_prune.hpp | 3 ++- llo/src/const_merge.cpp | 8 ++++++-- llo/src/multi_opt.cpp | 24 +++++++++++++----------- llo/src/one_prune.cpp | 22 ++++++++++++++-------- llo/src/ops_merge.cpp | 9 ++++++--- llo/src/zero_prune.cpp | 13 ++++++++++--- opt/graph_edit.hpp | 2 +- opt/src/graph_edit.cpp | 5 +++-- opt/test/test_editor.cpp | 35 +++++++++++++++++++---------------- 14 files changed, 103 insertions(+), 73 deletions(-) diff --git a/llo/constant.hpp b/llo/constant.hpp index a022461..b92b77c 100644 --- a/llo/constant.hpp +++ b/llo/constant.hpp @@ -25,14 +25,14 @@ struct Constant final : public ade::iLeaf return new Constant(data, dtype, shape); } - template - static Constant* get (T scalar, ade::Shape shape) - { - size_t n = shape.n_elems(); - T buffer[n]; - std::fill(buffer, buffer + n, scalar); - return new Constant((char*) buffer, age::get_type(), shape); - } + template + static Constant* get (T scalar, ade::Shape shape) + { + size_t n = shape.n_elems(); + T buffer[n]; + std::fill(buffer, buffer + n, scalar); + return new Constant((char*) buffer, age::get_type(), shape); + } Constant (const Constant& other) = delete; @@ -73,19 +73,19 @@ struct Constant final : public ade::iLeaf return dtype_; } - template - T at (size_t i) const - { - std::vector out; - age::type_convert(out, (void*) &data_[i * age::type_size(dtype_)], - dtype_, 1); - return out[0]; - } + template + T at (size_t i) const + { + std::vector out; + age::type_convert(out, (void*) &data_[i * age::type_size(dtype_)], + dtype_, 1); + return out[0]; + } private: Constant (const char* data, age::_GENERATED_DTYPE dtype, ade::Shape shape) : data_(data, shape.n_elems() * age::type_size(dtype)), - shape_(shape), dtype_(dtype) {} + shape_(shape), dtype_(dtype) {} /// Smartpointer to a block of untyped data std::string data_; diff --git a/llo/opt/const_merge.hpp b/llo/opt/const_merge.hpp index e3e3037..66b79b0 100644 --- a/llo/opt/const_merge.hpp +++ b/llo/opt/const_merge.hpp @@ -8,7 +8,8 @@ namespace llo { -ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args); +ade::TensptrT const_merge_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args); ade::TensptrT const_merge (ade::TensptrT root); diff --git a/llo/opt/multi_opt.hpp b/llo/opt/multi_opt.hpp index c75fe70..a9bddab 100644 --- a/llo/opt/multi_opt.hpp +++ b/llo/opt/multi_opt.hpp @@ -14,12 +14,12 @@ namespace llo { ade::TensptrT multi_optimize (ade::TensptrT root, - std::vector edits = { - const_merge_edit, - zero_prune_edit, - one_prune_edit, - ops_merge_edit, - }); + std::vector edits = { + const_merge_edit, + zero_prune_edit, + one_prune_edit, + ops_merge_edit, + }); } diff --git a/llo/opt/one_prune.hpp b/llo/opt/one_prune.hpp index 8c0a8f5..d4d3980 100644 --- a/llo/opt/one_prune.hpp +++ b/llo/opt/one_prune.hpp @@ -16,7 +16,8 @@ namespace llo { -ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args); +ade::TensptrT one_prune_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args); /// Return tree that prunes one branches in input according to OPCODE /// For example, mul(x, 1) is converted to simply x, while abs(1) is 1 diff --git a/llo/opt/ops_merge.hpp b/llo/opt/ops_merge.hpp index a2c365f..9c2d37d 100644 --- a/llo/opt/ops_merge.hpp +++ b/llo/opt/ops_merge.hpp @@ -17,7 +17,8 @@ static const std::set nnary_codes = { age::MAX, }; -ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args); +ade::TensptrT ops_merge_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args); ade::TensptrT ops_merge (ade::TensptrT root); diff --git a/llo/opt/zero_prune.hpp b/llo/opt/zero_prune.hpp index 8e74220..28c9973 100644 --- a/llo/opt/zero_prune.hpp +++ b/llo/opt/zero_prune.hpp @@ -16,7 +16,8 @@ namespace llo { -ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args); +ade::TensptrT zero_prune_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args); /// Return tree that prunes zero branches in input according to OPCODE /// For example, add(x, 0) is converted to simply x, while mul(x, 0) is 0 diff --git a/llo/src/const_merge.cpp b/llo/src/const_merge.cpp index 5aa0af4..1ee0fe2 100644 --- a/llo/src/const_merge.cpp +++ b/llo/src/const_merge.cpp @@ -9,7 +9,8 @@ namespace llo { -ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args) +ade::TensptrT const_merge_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) { ade::ArgsT cargs; std::copy_if(args.begin(), args.end(), std::back_inserter(cargs), @@ -33,6 +34,7 @@ ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args) ade::TensptrT carg(Constant::get( (char*) tens->data(), age::DOUBLE, temp->shape())); + // assert nnary functions are independent of order ade::ArgsT vargs; std::copy_if(args.begin(), args.end(), std::back_inserter(vargs), [](ade::MappedTensor& arg) @@ -41,7 +43,9 @@ ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args) arg.get_tensor().get()); }); vargs.push_back(ade::identity_map(carg)); - return ade::TensptrT(ade::Functor::get(opcode, vargs)); + is_optimized = true; + args = vargs; + return nullptr; } return nullptr; } diff --git a/llo/src/multi_opt.cpp b/llo/src/multi_opt.cpp index 5548512..eb1e0ed 100644 --- a/llo/src/multi_opt.cpp +++ b/llo/src/multi_opt.cpp @@ -6,19 +6,21 @@ namespace llo { ade::TensptrT multi_optimize (ade::TensptrT root, - std::vector edits) + std::vector edits) { return opt::graph_edit(root, - [&edits](ade::Opcode opcode, ade::ArgsT args) - { - ade::TensptrT out; - for (auto it = edits.begin(), et = edits.end(); - it != et && nullptr == out; ++it) - { - out = (*it)(opcode, args); - } - return out; - }); + [&edits](bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) -> ade::TensptrT + { + for (auto edit : edits) + { + if (auto out = edit(is_optimized, opcode, args)) + { + return out; + } + } + return nullptr; + }); } } diff --git a/llo/src/one_prune.cpp b/llo/src/one_prune.cpp index bdb8646..3d89989 100644 --- a/llo/src/one_prune.cpp +++ b/llo/src/one_prune.cpp @@ -12,8 +12,8 @@ namespace llo { -// todo: change this to target fixed value instead of looking at label -ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args) +ade::TensptrT one_prune_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) { size_t n = args.size(); bool has_one = false; @@ -44,8 +44,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args) { return args[0].get_tensor(); } - return ade::TensptrT(ade::Functor::get( - ade::Opcode{"SUM", age::SUM}, {args[0]})); + is_optimized = true; + opcode = ade::Opcode{"SUM", age::SUM}; + args = {args[0]}; + return nullptr; case age::PROD: { ade::ArgsT filtered; @@ -60,8 +62,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args) { return ade::TensptrT(llo::Constant::get(1, args[0].shape())); } - return ade::TensptrT(ade::Functor::get( - ade::Opcode{"PROD", age::PROD}, filtered)); + is_optimized = true; + opcode = ade::Opcode{"PROD", age::PROD}; + args = filtered; + return nullptr; } case age::DIV: if (is_one[1]) @@ -70,8 +74,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args) { return args[0].get_tensor(); } - return ade::TensptrT(ade::Functor::get( - ade::Opcode{"SUM", age::SUM}, {args[0]})); + is_optimized = true; + opcode = ade::Opcode{"SUM", age::SUM}; + args = {args[0]}; + return nullptr; } // else if is_one[0] break; diff --git a/llo/src/ops_merge.cpp b/llo/src/ops_merge.cpp index 4132f18..df70fe2 100644 --- a/llo/src/ops_merge.cpp +++ b/llo/src/ops_merge.cpp @@ -31,9 +31,11 @@ static bool is_bijective (ade::CoordptrT coorder) coorder->is_bijective(); } -ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args) +ade::TensptrT ops_merge_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) { - if (nnary_codes.end() != nnary_codes.find((age::_GENERATED_OPCODE) opcode.code_)) + if (nnary_codes.end() != nnary_codes.find( + (age::_GENERATED_OPCODE) opcode.code_)) { bool merged = false; ade::ArgsT newchildren; @@ -129,7 +131,8 @@ ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args) } else if (merged) { - return ade::TensptrT(ade::Functor::get(opcode, newchildren)); + is_optimized = true; + args = newchildren; } } return nullptr; diff --git a/llo/src/zero_prune.cpp b/llo/src/zero_prune.cpp index 73fa0ce..7614135 100644 --- a/llo/src/zero_prune.cpp +++ b/llo/src/zero_prune.cpp @@ -13,7 +13,8 @@ namespace llo { // todo: change this to target fixed value instead of looking at label -ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args) +ade::TensptrT zero_prune_edit (bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) { size_t n = args.size(); bool has_zero = false; @@ -62,7 +63,10 @@ ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args) { return ade::TensptrT(llo::Constant::get(0, args[0].shape())); } - return ade::TensptrT(ade::Functor::get(ade::Opcode{"SUM", age::SUM}, filtered)); + is_optimized = true; + opcode = ade::Opcode{"SUM", age::SUM}; + args = filtered; + return nullptr; } case age::SUB: if (is_zero[0] && is_zero[1]) @@ -71,7 +75,10 @@ ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args) } else if (is_zero[0]) { - return ade::TensptrT(ade::Functor::get(ade::Opcode{"NEG", age::NEG}, {args[1]})); + is_optimized = true; + opcode = ade::Opcode{"NEG", age::NEG}; + args = {args[1]}; + return nullptr; } // else if is_zero[1] return args[0].get_tensor(); diff --git a/opt/graph_edit.hpp b/opt/graph_edit.hpp index 6478b36..f6f0034 100644 --- a/opt/graph_edit.hpp +++ b/opt/graph_edit.hpp @@ -15,7 +15,7 @@ namespace opt { /// Edit functor type -using EditFuncT = std::function; +using EditFuncT = std::function; /// For some target extractable from iLeaf, prune graph such that reduces the /// length of branches to target from root diff --git a/opt/src/graph_edit.cpp b/opt/src/graph_edit.cpp index c8c2d6e..0300f42 100644 --- a/opt/src/graph_edit.cpp +++ b/opt/src/graph_edit.cpp @@ -54,13 +54,14 @@ ade::TensptrT graph_edit (ade::TensptrT root, EditFuncT edit) } } auto opcode = func->get_opcode(); - auto optimized = edit(opcode, children); + bool is_optimized = false; + auto optimized = edit(is_optimized, opcode, children); // only record optimization if changed if (nullptr != optimized) { opt_graph.emplace(func, optimized); } - else if (changed) + else if (changed || is_optimized) { opt_graph.emplace(func, ade::TensptrT( ade::Functor::get(opcode, children))); diff --git a/opt/test/test_editor.cpp b/opt/test/test_editor.cpp index d4cfcd0..74f5668 100644 --- a/opt/test/test_editor.cpp +++ b/opt/test/test_editor.cpp @@ -77,27 +77,30 @@ TEST(EDITOR, Prune) ade::identity_map(binar2), })); - opt::EditFuncT pruner = [&](ade::Opcode opcode, ade::ArgsT args) - { - if (opcode.code_ < 2) // killable + opt::EditFuncT pruner = + [&](bool& is_optimized, + ade::Opcode& opcode, ade::ArgsT& args) -> ade::TensptrT { - ade::ArgsT filtered; - for (auto arg : args) + if (opcode.code_ < 2) // killable { - ade::iTensor* tens = arg.get_tensor().get(); - if (tens != leaf.get() && tens != leaf2.get()) + ade::ArgsT filtered; + for (auto arg : args) { - filtered.push_back(arg); + ade::iTensor* tens = arg.get_tensor().get(); + if (tens != leaf.get() && tens != leaf2.get()) + { + filtered.push_back(arg); + } } + args = filtered; } - args = filtered; - } - if (args.size() > 0) - { - return ade::TensptrT(ade::Functor::get(opcode, args)); - } - return leaf; - }; + if (args.size() > 0) + { + is_optimized = true; + return nullptr; + } + return leaf; + }; auto root = opt::graph_edit(repl_binar, pruner);