Skip to content

Commit

Permalink
Use WTO for clinit analysis
Browse files Browse the repository at this point in the history
Summary:
As title.

Collect dependencies and then run WeakTopologicalOrdering to find components and ordering. All SCCs are filtered out (and warned on). Dependendants of SCCs are also not included.

This should deal better with programs with clinit cycles, as the pass will not totally self-disable.

Reviewed By: thezhangwei

Differential Revision: D45259803

fbshipit-source-id: fcf22856d6966f6cee3923d0f3bdb63e95d76dae
  • Loading branch information
agampe authored and facebook-github-bot committed Apr 25, 2023
1 parent dff2f1f commit f471873
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 55 deletions.
187 changes: 146 additions & 41 deletions opt/final_inline/FinalInlineV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

#include <boost/variant.hpp>
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <vector>

#include "CFGMutation.h"
#include "ConcurrentContainers.h"
#include "ConfigFiles.h"
#include "Debug.h"
#include "DexAccess.h"
Expand All @@ -31,6 +33,7 @@
#include "TypeSystem.h"
#include "TypeUtil.h"
#include "Walkers.h"
#include "WeakTopologicalOrdering.h"

/*
* dx-generated class initializers often use verbose bytecode sequences to
Expand All @@ -50,6 +53,20 @@ using namespace sparta;

namespace {

std::ostream& operator<<(std::ostream& o,
const sparta::WtoComponent<DexClass*>& c) {
if (c.is_scc()) {
o << "(" << show(c.head_node());
for (const auto& sub : c) {
o << " " << sub;
}
o << ")";
} else {
o << show(c.head_node());
}
return o;
}

/*
* Foo.<clinit> may read some static fields from class Bar, in which case
* Bar.<clinit> will be executed first by the VM to determine the values of
Expand All @@ -64,41 +81,130 @@ namespace {
* (JLS SE7 12.4.1 indicates that cycles are indeed allowed.) In that case,
* this pass cannot safely optimize the static final constants.
*/
Scope reverse_tsort_by_clinit_deps(const Scope& scope) {
Scope reverse_tsort_by_clinit_deps(const Scope& scope, size_t& init_cycles) {
Timer timer{"reverse_tsort_by_clinit_deps"};

std::unordered_set<const DexClass*> scope_set(scope.begin(), scope.end());
Scope result;
std::unordered_set<const DexClass*> visiting;
std::unordered_set<const DexClass*> visited;
std::function<void(DexClass*)> visit = [&](DexClass* cls) {
if (visited.count(cls) != 0 || scope_set.count(cls) == 0) {
return;

// Collect data for WTO.
// NOTE: Doing this already also as reverse so we don't have to do that later.

auto [deps, reverse_deps, roots, all_cnt] = [&]() {
ConcurrentMap<DexClass*, std::vector<DexClass*>> deps_parallel;
ConcurrentMap<DexClass*, std::vector<DexClass*>> reverse_deps_parallel;
ConcurrentSet<DexClass*> is_target;
ConcurrentSet<DexClass*> maybe_roots;
ConcurrentSet<DexClass*> all;
walk::parallel::classes(scope, [&](DexClass* cls) {
// TODO: Consider superclass relationship.
auto clinit = cls->get_clinit();
bool has_deps = false;
if (clinit != nullptr && clinit->get_code() != nullptr) {
std::vector<DexClass*> deps_vec;
editable_cfg_adapter::iterate_with_iterator(
clinit->get_code(), [&](const IRList::iterator& it) {
auto insn = it->insn;
if (opcode::is_an_sget(insn->opcode())) {
auto dependee_cls = type_class(insn->get_field()->get_class());
if (dependee_cls == nullptr || dependee_cls == cls ||
scope_set.count(dependee_cls) == 0) {
return editable_cfg_adapter::LOOP_CONTINUE;
}
reverse_deps_parallel.update(
dependee_cls,
[&](auto&, auto& v, auto) { v.push_back(cls); });
maybe_roots.insert(dependee_cls);
deps_vec.push_back(dependee_cls);
}
// TODO: Consider static methods.
return editable_cfg_adapter::LOOP_CONTINUE;
});
if (!deps_vec.empty()) {
has_deps = true;
is_target.insert(cls);
deps_parallel.emplace(cls, std::move(deps_vec));
}
}
// Something with no deps - make it a root so it gets visited.
if (!has_deps) {
maybe_roots.insert(cls);
}
all.insert(cls);
});
std::unordered_map<DexClass*, std::vector<DexClass*>> deps;
for (auto& kv : deps_parallel) {
deps[kv.first] = std::move(kv.second);
}
if (visiting.count(cls)) {
throw final_inline::class_initialization_cycle(cls);
std::unordered_map<DexClass*, std::vector<DexClass*>> reverse_deps;
for (auto& kv : reverse_deps_parallel) {
reverse_deps[kv.first] = std::move(kv.second);
}
visiting.emplace(cls);
auto clinit = cls->get_clinit();
if (clinit != nullptr && clinit->get_code() != nullptr) {
editable_cfg_adapter::iterate_with_iterator(
clinit->get_code(), [&](const IRList::iterator& it) {
auto insn = it->insn;
if (opcode::is_an_sget(insn->opcode())) {
auto dependee_cls = type_class(insn->get_field()->get_class());
if (dependee_cls == nullptr || dependee_cls == cls) {
return editable_cfg_adapter::LOOP_CONTINUE;
}
visit(dependee_cls);
}
return editable_cfg_adapter::LOOP_CONTINUE;
});

std::vector<DexClass*> roots;
std::copy_if(maybe_roots.begin(), maybe_roots.end(),
std::back_inserter(roots),
[&](auto* cls) { return is_target.count_unsafe(cls) == 0; });
return std::make_tuple(std::move(deps), std::move(reverse_deps),
std::move(roots), all.size());
}();

// NOTE: Using nullptr for root node.

auto wto = sparta::WeakTopologicalOrdering<DexClass*>(
nullptr,
[&roots = roots, &reverse_deps = reverse_deps](DexClass* const& cls) {
if (cls == nullptr) {
return roots;
}

auto it = reverse_deps.find(cls);
if (it == reverse_deps.end()) {
return std::vector<DexClass*>();
}

return it->second;
});

auto it = wto.begin();
auto it_end = wto.end();

redex_assert(it != it_end);
redex_assert(it->is_vertex());
redex_assert(it->head_node() == nullptr);
++it;

Scope result;
std::unordered_set<DexClass*> taken;

for (; it != it_end; ++it) {
if (it->is_scc()) {
// Cycle...
++init_cycles;

TRACE(FINALINLINE, 0, "Init cycle detected in %s",
[&]() {
std::ostringstream oss;
oss << *it;
return oss.str();
}()
.c_str());

continue;
}
visiting.erase(cls);

auto* cls = it->head_node();
auto deps_it = deps.find(cls);
if (deps_it != deps.end() &&
!std::all_of(deps_it->second.begin(), deps_it->second.end(),
[&](auto* cls) { return taken.count(cls) != 0; })) {
TRACE(FINALINLINE, 1, "Skipping %s because of missing deps", SHOW(cls));
continue;
}

result.emplace_back(cls);
visited.emplace(cls);
};
for (DexClass* cls : scope) {
visit(cls);
taken.insert(cls);
}

return result;
}

Expand Down Expand Up @@ -430,7 +536,8 @@ cp::WholeProgramState analyze_and_simplify_clinits(
init_classes_with_side_effects,
const XStoreRefs* xstores,
const std::unordered_set<const DexType*>& blocklist_types,
const std::unordered_set<std::string>& allowed_opaque_callee_names) {
const std::unordered_set<std::string>& allowed_opaque_callee_names,
size_t& init_cycles) {
const std::unordered_set<DexMethodRef*> pure_methods = get_pure_methods();
cp::WholeProgramState wps(blocklist_types);

Expand All @@ -441,7 +548,7 @@ cp::WholeProgramState analyze_and_simplify_clinits(

cp::Transform::RuntimeCache runtime_cache{};

for (DexClass* cls : reverse_tsort_by_clinit_deps(scope)) {
for (DexClass* cls : reverse_tsort_by_clinit_deps(scope, init_cycles)) {
ConstantEnvironment env;
cp::set_encoded_values(cls, &env);
auto clinit = cls->get_clinit();
Expand Down Expand Up @@ -1054,16 +1161,14 @@ FinalInlinePassV2::Stats FinalInlinePassV2::run(
const XStoreRefs* xstores,
const Config& config,
std::optional<DexStoresVector*> stores) {
try {
auto wps = final_inline::analyze_and_simplify_clinits(
scope, init_classes_with_side_effects, xstores, config.blocklist_types);
return inline_final_gets(stores, scope, min_sdk,
init_classes_with_side_effects, xstores, wps,
config.blocklist_types, cp::FieldType::STATIC);
} catch (final_inline::class_initialization_cycle& e) {
std::cerr << e.what();
return {0, 0, 1};
}
size_t clinit_cycles = 0;
auto wps = final_inline::analyze_and_simplify_clinits(
scope, init_classes_with_side_effects, xstores, config.blocklist_types,
{}, clinit_cycles);
auto res = inline_final_gets(stores, scope, min_sdk,
init_classes_with_side_effects, xstores, wps,
config.blocklist_types, cp::FieldType::STATIC);
return {res.inlined_count, res.init_classes, clinit_cycles};
}

FinalInlinePassV2::Stats FinalInlinePassV2::run_inline_ifields(
Expand Down
17 changes: 3 additions & 14 deletions opt/final_inline/FinalInlineV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,14 @@ class FinalInlinePassV2 : public Pass {

namespace final_inline {

class class_initialization_cycle : public std::exception {
public:
explicit class_initialization_cycle(const DexClass* cls) {
m_msg = "Found a class initialization cycle involving " + show(cls);
}

const char* what() const noexcept override { return m_msg.c_str(); }

private:
std::string m_msg;
};

constant_propagation::WholeProgramState analyze_and_simplify_clinits(
const Scope& scope,
const init_classes::InitClassesWithSideEffects&
init_classes_with_side_effects,
const XStoreRefs* xstores,
const std::unordered_set<const DexType*>& blocklist_types = {},
const std::unordered_set<std::string>& allowed_opaque_callee_names = {});
const std::unordered_set<const DexType*>& blocklist_types,
const std::unordered_set<std::string>& allowed_opaque_callee_names,
size_t& clinit_cycles);

class StaticFieldReadAnalysis {
public:
Expand Down

0 comments on commit f471873

Please sign in to comment.