diff --git a/opt/final_inline/FinalInlineV2.cpp b/opt/final_inline/FinalInlineV2.cpp index 8da49d7b3b..a7d3e477b6 100644 --- a/opt/final_inline/FinalInlineV2.cpp +++ b/opt/final_inline/FinalInlineV2.cpp @@ -9,10 +9,12 @@ #include #include +#include #include #include #include "CFGMutation.h" +#include "ConcurrentContainers.h" #include "ConfigFiles.h" #include "Debug.h" #include "DexAccess.h" @@ -31,6 +33,7 @@ #include "TypeSystem.h" #include "TypeUtil.h" #include "Walkers.h" +#include "WeakTopologicalOrdering.h" /* * dx-generated class initializers often use verbose bytecode sequences to @@ -50,6 +53,20 @@ using namespace sparta; namespace { +std::ostream& operator<<(std::ostream& o, + const sparta::WtoComponent& c) { + if (c.is_scc()) { + o << "(" << show(c.head_node()); + for (const auto& sub : c) { + o << " " << sub; + } + o << ")"; + } else { + o << show(c.head_node()); + } + return o; +} + /* * Foo. may read some static fields from class Bar, in which case * Bar. will be executed first by the VM to determine the values of @@ -64,41 +81,130 @@ namespace { * (JLS SE7 12.4.1 indicates that cycles are indeed allowed.) In that case, * this pass cannot safely optimize the static final constants. */ -Scope reverse_tsort_by_clinit_deps(const Scope& scope) { +Scope reverse_tsort_by_clinit_deps(const Scope& scope, size_t& init_cycles) { + Timer timer{"reverse_tsort_by_clinit_deps"}; + std::unordered_set scope_set(scope.begin(), scope.end()); - Scope result; - std::unordered_set visiting; - std::unordered_set visited; - std::function visit = [&](DexClass* cls) { - if (visited.count(cls) != 0 || scope_set.count(cls) == 0) { - return; + + // Collect data for WTO. + // NOTE: Doing this already also as reverse so we don't have to do that later. + + auto [deps, reverse_deps, roots, all_cnt] = [&]() { + ConcurrentMap> deps_parallel; + ConcurrentMap> reverse_deps_parallel; + ConcurrentSet is_target; + ConcurrentSet maybe_roots; + ConcurrentSet all; + walk::parallel::classes(scope, [&](DexClass* cls) { + // TODO: Consider superclass relationship. + auto clinit = cls->get_clinit(); + bool has_deps = false; + if (clinit != nullptr && clinit->get_code() != nullptr) { + std::vector deps_vec; + editable_cfg_adapter::iterate_with_iterator( + clinit->get_code(), [&](const IRList::iterator& it) { + auto insn = it->insn; + if (opcode::is_an_sget(insn->opcode())) { + auto dependee_cls = type_class(insn->get_field()->get_class()); + if (dependee_cls == nullptr || dependee_cls == cls || + scope_set.count(dependee_cls) == 0) { + return editable_cfg_adapter::LOOP_CONTINUE; + } + reverse_deps_parallel.update( + dependee_cls, + [&](auto&, auto& v, auto) { v.push_back(cls); }); + maybe_roots.insert(dependee_cls); + deps_vec.push_back(dependee_cls); + } + // TODO: Consider static methods. + return editable_cfg_adapter::LOOP_CONTINUE; + }); + if (!deps_vec.empty()) { + has_deps = true; + is_target.insert(cls); + deps_parallel.emplace(cls, std::move(deps_vec)); + } + } + // Something with no deps - make it a root so it gets visited. + if (!has_deps) { + maybe_roots.insert(cls); + } + all.insert(cls); + }); + std::unordered_map> deps; + for (auto& kv : deps_parallel) { + deps[kv.first] = std::move(kv.second); } - if (visiting.count(cls)) { - throw final_inline::class_initialization_cycle(cls); + std::unordered_map> reverse_deps; + for (auto& kv : reverse_deps_parallel) { + reverse_deps[kv.first] = std::move(kv.second); } - visiting.emplace(cls); - auto clinit = cls->get_clinit(); - if (clinit != nullptr && clinit->get_code() != nullptr) { - editable_cfg_adapter::iterate_with_iterator( - clinit->get_code(), [&](const IRList::iterator& it) { - auto insn = it->insn; - if (opcode::is_an_sget(insn->opcode())) { - auto dependee_cls = type_class(insn->get_field()->get_class()); - if (dependee_cls == nullptr || dependee_cls == cls) { - return editable_cfg_adapter::LOOP_CONTINUE; - } - visit(dependee_cls); - } - return editable_cfg_adapter::LOOP_CONTINUE; - }); + + std::vector roots; + std::copy_if(maybe_roots.begin(), maybe_roots.end(), + std::back_inserter(roots), + [&](auto* cls) { return is_target.count_unsafe(cls) == 0; }); + return std::make_tuple(std::move(deps), std::move(reverse_deps), + std::move(roots), all.size()); + }(); + + // NOTE: Using nullptr for root node. + + auto wto = sparta::WeakTopologicalOrdering( + nullptr, + [&roots = roots, &reverse_deps = reverse_deps](DexClass* const& cls) { + if (cls == nullptr) { + return roots; + } + + auto it = reverse_deps.find(cls); + if (it == reverse_deps.end()) { + return std::vector(); + } + + return it->second; + }); + + auto it = wto.begin(); + auto it_end = wto.end(); + + redex_assert(it != it_end); + redex_assert(it->is_vertex()); + redex_assert(it->head_node() == nullptr); + ++it; + + Scope result; + std::unordered_set taken; + + for (; it != it_end; ++it) { + if (it->is_scc()) { + // Cycle... + ++init_cycles; + + TRACE(FINALINLINE, 0, "Init cycle detected in %s", + [&]() { + std::ostringstream oss; + oss << *it; + return oss.str(); + }() + .c_str()); + + continue; } - visiting.erase(cls); + + auto* cls = it->head_node(); + auto deps_it = deps.find(cls); + if (deps_it != deps.end() && + !std::all_of(deps_it->second.begin(), deps_it->second.end(), + [&](auto* cls) { return taken.count(cls) != 0; })) { + TRACE(FINALINLINE, 1, "Skipping %s because of missing deps", SHOW(cls)); + continue; + } + result.emplace_back(cls); - visited.emplace(cls); - }; - for (DexClass* cls : scope) { - visit(cls); + taken.insert(cls); } + return result; } @@ -430,7 +536,8 @@ cp::WholeProgramState analyze_and_simplify_clinits( init_classes_with_side_effects, const XStoreRefs* xstores, const std::unordered_set& blocklist_types, - const std::unordered_set& allowed_opaque_callee_names) { + const std::unordered_set& allowed_opaque_callee_names, + size_t& init_cycles) { const std::unordered_set pure_methods = get_pure_methods(); cp::WholeProgramState wps(blocklist_types); @@ -441,7 +548,7 @@ cp::WholeProgramState analyze_and_simplify_clinits( cp::Transform::RuntimeCache runtime_cache{}; - for (DexClass* cls : reverse_tsort_by_clinit_deps(scope)) { + for (DexClass* cls : reverse_tsort_by_clinit_deps(scope, init_cycles)) { ConstantEnvironment env; cp::set_encoded_values(cls, &env); auto clinit = cls->get_clinit(); @@ -1054,16 +1161,14 @@ FinalInlinePassV2::Stats FinalInlinePassV2::run( const XStoreRefs* xstores, const Config& config, std::optional stores) { - try { - auto wps = final_inline::analyze_and_simplify_clinits( - scope, init_classes_with_side_effects, xstores, config.blocklist_types); - return inline_final_gets(stores, scope, min_sdk, - init_classes_with_side_effects, xstores, wps, - config.blocklist_types, cp::FieldType::STATIC); - } catch (final_inline::class_initialization_cycle& e) { - std::cerr << e.what(); - return {0, 0, 1}; - } + size_t clinit_cycles = 0; + auto wps = final_inline::analyze_and_simplify_clinits( + scope, init_classes_with_side_effects, xstores, config.blocklist_types, + {}, clinit_cycles); + auto res = inline_final_gets(stores, scope, min_sdk, + init_classes_with_side_effects, xstores, wps, + config.blocklist_types, cp::FieldType::STATIC); + return {res.inlined_count, res.init_classes, clinit_cycles}; } FinalInlinePassV2::Stats FinalInlinePassV2::run_inline_ifields( diff --git a/opt/final_inline/FinalInlineV2.h b/opt/final_inline/FinalInlineV2.h index 8489502a03..fb3999f809 100644 --- a/opt/final_inline/FinalInlineV2.h +++ b/opt/final_inline/FinalInlineV2.h @@ -80,25 +80,14 @@ class FinalInlinePassV2 : public Pass { namespace final_inline { -class class_initialization_cycle : public std::exception { - public: - explicit class_initialization_cycle(const DexClass* cls) { - m_msg = "Found a class initialization cycle involving " + show(cls); - } - - const char* what() const noexcept override { return m_msg.c_str(); } - - private: - std::string m_msg; -}; - constant_propagation::WholeProgramState analyze_and_simplify_clinits( const Scope& scope, const init_classes::InitClassesWithSideEffects& init_classes_with_side_effects, const XStoreRefs* xstores, - const std::unordered_set& blocklist_types = {}, - const std::unordered_set& allowed_opaque_callee_names = {}); + const std::unordered_set& blocklist_types, + const std::unordered_set& allowed_opaque_callee_names, + size_t& clinit_cycles); class StaticFieldReadAnalysis { public: