diff --git a/llo/opt/derive.hpp b/llo/opt/derive.hpp index 83b41e3..f8f8199 100644 --- a/llo/opt/derive.hpp +++ b/llo/opt/derive.hpp @@ -1,4 +1,4 @@ -#include "llo/opt/plugin_opt.hpp" +#include "llo/opt/multi_opt.hpp" #include "llo/generated/grader.hpp" diff --git a/llo/opt/multi_opt.hpp b/llo/opt/multi_opt.hpp new file mode 100644 index 0000000..c75fe70 --- /dev/null +++ b/llo/opt/multi_opt.hpp @@ -0,0 +1,26 @@ +#include "opt/graph_edit.hpp" + +#include "llo/opt/const_merge.hpp" +#include "llo/opt/ops_merge.hpp" +#include "llo/opt/one_prune.hpp" +#include "llo/opt/zero_prune.hpp" + +#include "llo/variable.hpp" + +#ifndef LLO_MULTI_OPT_HPP +#define LLO_MULTI_OPT_HPP + +namespace llo +{ + +ade::TensptrT multi_optimize (ade::TensptrT root, + std::vector edits = { + const_merge_edit, + zero_prune_edit, + one_prune_edit, + ops_merge_edit, + }); + +} + +#endif // LLO_MULTI_OPT_HPP diff --git a/llo/opt/plugin_opt.hpp b/llo/opt/plugin_opt.hpp deleted file mode 100644 index 2fec9eb..0000000 --- a/llo/opt/plugin_opt.hpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "llo/variable.hpp" - -#ifndef LLO_PLUGIN_OPT_HPP -#define LLO_PLUGIN_OPT_HPP - -namespace llo -{ - -ade::TensptrT plugin_optimize (ade::TensptrT root); - -} - -#endif // LLO_PLUGIN_OPT_HPP diff --git a/llo/python/llo.cpp b/llo/python/llo.cpp index fc76122..368220e 100644 --- a/llo/python/llo.cpp +++ b/llo/python/llo.cpp @@ -6,9 +6,11 @@ #include "ade/ade.hpp" +#include "llo/opt/derive.hpp" + +#include "llo/constant.hpp" #include "llo/variable.hpp" #include "llo/eval.hpp" -#include "llo/opt/derive.hpp" namespace py = pybind11; diff --git a/llo/src/derive.cpp b/llo/src/derive.cpp index 4dced88..8154dc0 100644 --- a/llo/src/derive.cpp +++ b/llo/src/derive.cpp @@ -11,7 +11,7 @@ ade::TensptrT derive (ade::TensptrT root, ade::iTensor* target) root->accept(grader); auto it = grader.derivatives_.find(root.get()); assert(grader.derivatives_.end() != it); - return plugin_optimize(it->second); + return multi_optimize(it->second); } } diff --git a/llo/src/multi_opt.cpp b/llo/src/multi_opt.cpp new file mode 100644 index 0000000..5548512 --- /dev/null +++ b/llo/src/multi_opt.cpp @@ -0,0 +1,26 @@ +#include "llo/opt/multi_opt.hpp" + +#ifdef LLO_MULTI_OPT_HPP + +namespace llo +{ + +ade::TensptrT multi_optimize (ade::TensptrT root, + std::vector edits) +{ + return opt::graph_edit(root, + [&edits](ade::Opcode opcode, ade::ArgsT args) + { + ade::TensptrT out; + for (auto it = edits.begin(), et = edits.end(); + it != et && nullptr == out; ++it) + { + out = (*it)(opcode, args); + } + return out; + }); +} + +} + +#endif diff --git a/llo/src/plugin_opt.cpp b/llo/src/plugin_opt.cpp deleted file mode 100644 index d28f8db..0000000 --- a/llo/src/plugin_opt.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "opt/graph_edit.hpp" - -#include "llo/opt/const_merge.hpp" -#include "llo/opt/ops_merge.hpp" -#include "llo/opt/one_prune.hpp" -#include "llo/opt/zero_prune.hpp" -#include "llo/opt/plugin_opt.hpp" - -#ifdef LLO_PLUGIN_OPT_HPP - -namespace llo -{ - -static const std::vector edits = -{ - const_merge_edit, - zero_prune_edit, - one_prune_edit, - ops_merge_edit, -}; - -ade::TensptrT plugin_edit (ade::Opcode opcode, ade::ArgsT args) -{ - ade::TensptrT out; - for (auto it = edits.begin(), et = edits.end(); - it != et && nullptr == out; ++it) - { - out = (*it)(opcode, args); - } - return out; -} - -ade::TensptrT plugin_optimize (ade::TensptrT root) -{ - return opt::graph_edit(root, plugin_edit); -} - -} - -#endif diff --git a/llo/test/common.hpp b/llo/test/common.hpp index a12e11d..1965c66 100644 --- a/llo/test/common.hpp +++ b/llo/test/common.hpp @@ -14,4 +14,5 @@ #define EXPECT_FATAL(EVENT, MSG) try { EVENT; FAIL() << \ "did not expect " << #EVENT << " to succeed"; } \ - catch (std::runtime_error& e) { EXPECT_STREQ(MSG, e.what()); } + catch (std::runtime_error& e) { EXPECT_STREQ(MSG, e.what()); }\ + catch (std::exception& e) { FAIL() << "unexpected throw " << e.what(); }