Skip to content
This repository has been archived by the owner on Feb 29, 2020. It is now read-only.

Commit

Permalink
fix losing optimizations in multi_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jan 30, 2019
1 parent f9e7691 commit cfc698f
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 73 deletions.
34 changes: 17 additions & 17 deletions llo/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ struct Constant final : public ade::iLeaf
return new Constant(data, dtype, shape);
}

template <typename T>
static Constant* get (T scalar, ade::Shape shape)
{
size_t n = shape.n_elems();
T buffer[n];
std::fill(buffer, buffer + n, scalar);
return new Constant((char*) buffer, age::get_type<T>(), shape);
}
template <typename T>
static Constant* get (T scalar, ade::Shape shape)
{
size_t n = shape.n_elems();
T buffer[n];
std::fill(buffer, buffer + n, scalar);
return new Constant((char*) buffer, age::get_type<T>(), shape);
}

Constant (const Constant& other) = delete;

Expand Down Expand Up @@ -73,19 +73,19 @@ struct Constant final : public ade::iLeaf
return dtype_;
}

template <typename T>
T at (size_t i) const
{
std::vector<T> out;
age::type_convert(out, (void*) &data_[i * age::type_size(dtype_)],
dtype_, 1);
return out[0];
}
template <typename T>
T at (size_t i) const
{
std::vector<T> out;
age::type_convert(out, (void*) &data_[i * age::type_size(dtype_)],
dtype_, 1);
return out[0];
}

private:
Constant (const char* data, age::_GENERATED_DTYPE dtype, ade::Shape shape) :
data_(data, shape.n_elems() * age::type_size(dtype)),
shape_(shape), dtype_(dtype) {}
shape_(shape), dtype_(dtype) {}

/// Smartpointer to a block of untyped data
std::string data_;
Expand Down
3 changes: 2 additions & 1 deletion llo/opt/const_merge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
namespace llo
{

ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args);
ade::TensptrT const_merge_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args);

ade::TensptrT const_merge (ade::TensptrT root);

Expand Down
12 changes: 6 additions & 6 deletions llo/opt/multi_opt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ namespace llo
{

ade::TensptrT multi_optimize (ade::TensptrT root,
std::vector<opt::EditFuncT> edits = {
const_merge_edit,
zero_prune_edit,
one_prune_edit,
ops_merge_edit,
});
std::vector<opt::EditFuncT> edits = {
const_merge_edit,
zero_prune_edit,
one_prune_edit,
ops_merge_edit,
});

}

Expand Down
3 changes: 2 additions & 1 deletion llo/opt/one_prune.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
namespace llo
{

ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args);
ade::TensptrT one_prune_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args);

/// Return tree that prunes one branches in input according to OPCODE
/// For example, mul(x, 1) is converted to simply x, while abs(1) is 1
Expand Down
3 changes: 2 additions & 1 deletion llo/opt/ops_merge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ static const std::set<age::_GENERATED_OPCODE> nnary_codes = {
age::MAX,
};

ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args);
ade::TensptrT ops_merge_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args);

ade::TensptrT ops_merge (ade::TensptrT root);

Expand Down
3 changes: 2 additions & 1 deletion llo/opt/zero_prune.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
namespace llo
{

ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args);
ade::TensptrT zero_prune_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args);

/// Return tree that prunes zero branches in input according to OPCODE
/// For example, add(x, 0) is converted to simply x, while mul(x, 0) is 0
Expand Down
8 changes: 6 additions & 2 deletions llo/src/const_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace llo
{

ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args)
ade::TensptrT const_merge_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args)
{
ade::ArgsT cargs;
std::copy_if(args.begin(), args.end(), std::back_inserter(cargs),
Expand All @@ -33,6 +34,7 @@ ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args)
ade::TensptrT carg(Constant::get(
(char*) tens->data(), age::DOUBLE, temp->shape()));

// assert nnary functions are independent of order
ade::ArgsT vargs;
std::copy_if(args.begin(), args.end(), std::back_inserter(vargs),
[](ade::MappedTensor& arg)
Expand All @@ -41,7 +43,9 @@ ade::TensptrT const_merge_edit (ade::Opcode opcode, ade::ArgsT args)
arg.get_tensor().get());
});
vargs.push_back(ade::identity_map(carg));
return ade::TensptrT(ade::Functor::get(opcode, vargs));
is_optimized = true;
args = vargs;
return nullptr;
}
return nullptr;
}
Expand Down
24 changes: 13 additions & 11 deletions llo/src/multi_opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ namespace llo
{

ade::TensptrT multi_optimize (ade::TensptrT root,
std::vector<opt::EditFuncT> edits)
std::vector<opt::EditFuncT> edits)
{
return opt::graph_edit(root,
[&edits](ade::Opcode opcode, ade::ArgsT args)
{
ade::TensptrT out;
for (auto it = edits.begin(), et = edits.end();
it != et && nullptr == out; ++it)
{
out = (*it)(opcode, args);
}
return out;
});
[&edits](bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args) -> ade::TensptrT
{
for (auto edit : edits)
{
if (auto out = edit(is_optimized, opcode, args))
{
return out;
}
}
return nullptr;
});
}

}
Expand Down
22 changes: 14 additions & 8 deletions llo/src/one_prune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
namespace llo
{

// todo: change this to target fixed value instead of looking at label
ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args)
ade::TensptrT one_prune_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args)
{
size_t n = args.size();
bool has_one = false;
Expand Down Expand Up @@ -44,8 +44,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args)
{
return args[0].get_tensor();
}
return ade::TensptrT(ade::Functor::get(
ade::Opcode{"SUM", age::SUM}, {args[0]}));
is_optimized = true;
opcode = ade::Opcode{"SUM", age::SUM};
args = {args[0]};
return nullptr;
case age::PROD:
{
ade::ArgsT filtered;
Expand All @@ -60,8 +62,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args)
{
return ade::TensptrT(llo::Constant::get(1, args[0].shape()));
}
return ade::TensptrT(ade::Functor::get(
ade::Opcode{"PROD", age::PROD}, filtered));
is_optimized = true;
opcode = ade::Opcode{"PROD", age::PROD};
args = filtered;
return nullptr;
}
case age::DIV:
if (is_one[1])
Expand All @@ -70,8 +74,10 @@ ade::TensptrT one_prune_edit (ade::Opcode opcode, ade::ArgsT args)
{
return args[0].get_tensor();
}
return ade::TensptrT(ade::Functor::get(
ade::Opcode{"SUM", age::SUM}, {args[0]}));
is_optimized = true;
opcode = ade::Opcode{"SUM", age::SUM};
args = {args[0]};
return nullptr;
}
// else if is_one[0]
break;
Expand Down
9 changes: 6 additions & 3 deletions llo/src/ops_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ static bool is_bijective (ade::CoordptrT coorder)
coorder->is_bijective();
}

ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args)
ade::TensptrT ops_merge_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args)
{
if (nnary_codes.end() != nnary_codes.find((age::_GENERATED_OPCODE) opcode.code_))
if (nnary_codes.end() != nnary_codes.find(
(age::_GENERATED_OPCODE) opcode.code_))
{
bool merged = false;
ade::ArgsT newchildren;
Expand Down Expand Up @@ -129,7 +131,8 @@ ade::TensptrT ops_merge_edit (ade::Opcode opcode, ade::ArgsT args)
}
else if (merged)
{
return ade::TensptrT(ade::Functor::get(opcode, newchildren));
is_optimized = true;
args = newchildren;
}
}
return nullptr;
Expand Down
13 changes: 10 additions & 3 deletions llo/src/zero_prune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace llo
{

// todo: change this to target fixed value instead of looking at label
ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args)
ade::TensptrT zero_prune_edit (bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args)
{
size_t n = args.size();
bool has_zero = false;
Expand Down Expand Up @@ -62,7 +63,10 @@ ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args)
{
return ade::TensptrT(llo::Constant::get(0, args[0].shape()));
}
return ade::TensptrT(ade::Functor::get(ade::Opcode{"SUM", age::SUM}, filtered));
is_optimized = true;
opcode = ade::Opcode{"SUM", age::SUM};
args = filtered;
return nullptr;
}
case age::SUB:
if (is_zero[0] && is_zero[1])
Expand All @@ -71,7 +75,10 @@ ade::TensptrT zero_prune_edit (ade::Opcode opcode, ade::ArgsT args)
}
else if (is_zero[0])
{
return ade::TensptrT(ade::Functor::get(ade::Opcode{"NEG", age::NEG}, {args[1]}));
is_optimized = true;
opcode = ade::Opcode{"NEG", age::NEG};
args = {args[1]};
return nullptr;
}
// else if is_zero[1]
return args[0].get_tensor();
Expand Down
2 changes: 1 addition & 1 deletion opt/graph_edit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace opt
{

/// Edit functor type
using EditFuncT = std::function<ade::TensptrT(ade::Opcode,ade::ArgsT)>;
using EditFuncT = std::function<ade::TensptrT(bool&,ade::Opcode&,ade::ArgsT&)>;

/// For some target extractable from iLeaf, prune graph such that reduces the
/// length of branches to target from root
Expand Down
5 changes: 3 additions & 2 deletions opt/src/graph_edit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ ade::TensptrT graph_edit (ade::TensptrT root, EditFuncT edit)
}
}
auto opcode = func->get_opcode();
auto optimized = edit(opcode, children);
bool is_optimized = false;
auto optimized = edit(is_optimized, opcode, children);
// only record optimization if changed
if (nullptr != optimized)
{
opt_graph.emplace(func, optimized);
}
else if (changed)
else if (changed || is_optimized)
{
opt_graph.emplace(func, ade::TensptrT(
ade::Functor::get(opcode, children)));
Expand Down
35 changes: 19 additions & 16 deletions opt/test/test_editor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,30 @@ TEST(EDITOR, Prune)
ade::identity_map(binar2),
}));

opt::EditFuncT pruner = [&](ade::Opcode opcode, ade::ArgsT args)
{
if (opcode.code_ < 2) // killable
opt::EditFuncT pruner =
[&](bool& is_optimized,
ade::Opcode& opcode, ade::ArgsT& args) -> ade::TensptrT
{
ade::ArgsT filtered;
for (auto arg : args)
if (opcode.code_ < 2) // killable
{
ade::iTensor* tens = arg.get_tensor().get();
if (tens != leaf.get() && tens != leaf2.get())
ade::ArgsT filtered;
for (auto arg : args)
{
filtered.push_back(arg);
ade::iTensor* tens = arg.get_tensor().get();
if (tens != leaf.get() && tens != leaf2.get())
{
filtered.push_back(arg);
}
}
args = filtered;
}
args = filtered;
}
if (args.size() > 0)
{
return ade::TensptrT(ade::Functor::get(opcode, args));
}
return leaf;
};
if (args.size() > 0)
{
is_optimized = true;
return nullptr;
}
return leaf;
};

auto root = opt::graph_edit(repl_binar, pruner);

Expand Down

0 comments on commit cfc698f

Please sign in to comment.