Skip to content

Commit

Permalink
enhance ConstantPropagationPass to support removing dupicated null check
Browse files Browse the repository at this point in the history
Summary: As title

Reviewed By: NTillmann

Differential Revision: D44231589

fbshipit-source-id: a3374aefd3a5cd31ecddab2abb1d32e480cc13a4
  • Loading branch information
beicy authored and facebook-github-bot committed Apr 11, 2023
1 parent 64d77be commit b8834dd
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 4 deletions.
6 changes: 6 additions & 0 deletions libredex/MethodUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,10 @@ DexMethod* kotlin_jvm_internal_Intrinsics_checkNotNullExpressionValue() {
"Lkotlin/jvm/internal/Intrinsics;.checkNotNullExpressionValue:(Ljava/"
"lang/Object;Ljava/lang/String;)V"));
}

DexMethod* redex_internal_checkObjectNotNull() {
return static_cast<DexMethod*>(DexMethod::get_method(
"Lredex/$NullCheck;.null_check:(Ljava/lang/Object;)V"));
}

}; // namespace method
2 changes: 2 additions & 0 deletions libredex/MethodUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ DexMethod* kotlin_jvm_internal_Intrinsics_checExpressionValueIsNotNull();

DexMethod* kotlin_jvm_internal_Intrinsics_checkNotNullExpressionValue();

DexMethod* redex_internal_checkObjectNotNull();

inline unsigned count_opcode_of_types(
const cfg::ControlFlowGraph& cfg,
const std::unordered_set<IROpcode>& opcodes) {
Expand Down
9 changes: 8 additions & 1 deletion service/constant-propagation/ConstantPropagationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,7 @@ FixpointIterator::FixpointIterator(
: BaseEdgeAwareIRAnalyzer(cfg),
m_insn_analyzer(std::move(insn_analyzer)),
m_kotlin_null_check_assertions(get_kotlin_null_assertions()),
m_redex_null_check_assertion(method::redex_internal_checkObjectNotNull()),
m_imprecise_switches(imprecise_switches) {}

void FixpointIterator::analyze_instruction_normal(
Expand All @@ -1250,8 +1251,14 @@ void FixpointIterator::analyze_no_throw(const IRInstruction* insn,
get_null_check_object_index(insn, m_kotlin_null_check_assertions);
}
if (!src_index) {
return;
// Check if it is redex null check.
if (insn->opcode() != OPCODE_INVOKE_STATIC ||
insn->get_method() != m_redex_null_check_assertion) {
return;
}
src_index = 0;
}

if (insn->has_dest()) {
auto dest = insn->dest();
if ((dest == *src_index) ||
Expand Down
2 changes: 2 additions & 0 deletions service/constant-propagation/ConstantPropagationAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class FixpointIterator final
mutable std::unordered_map<cfg::Block*, SwitchSuccs> m_switch_succs;
InstructionAnalyzer<ConstantEnvironment> m_insn_analyzer;
const std::unordered_set<DexMethodRef*>& m_kotlin_null_check_assertions;
const DexMethodRef* m_redex_null_check_assertion;

const bool m_imprecise_switches;

const SwitchSuccs& find_switch_succs(cfg::Block* block) const {
Expand Down
11 changes: 11 additions & 0 deletions service/constant-propagation/ConstantPropagationTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ bool Transform::eliminate_redundant_null_check(
auto* insn = cfg_it->insn;
switch (insn->opcode()) {
case OPCODE_INVOKE_STATIC: {
// Kotlin null check.
if (auto index = get_null_check_object_index(
insn, m_runtime_cache.kotlin_null_check_assertions)) {
++m_stats.null_checks_method_calls;
Expand All @@ -84,6 +85,16 @@ bool Transform::eliminate_redundant_null_check(
return true;
}
}
// Redex null check.
if (insn->get_method() == m_runtime_cache.redex_null_check_assertion) {
++m_stats.null_checks_method_calls;
auto val = env.get(insn->src(0)).maybe_get<SignedConstantDomain>();
if (val && val->interval() == sign_domain::Interval::NEZ) {
m_mutation->remove(cfg_it);
++m_stats.null_checks;
return true;
}
}
break;
}
default:
Expand Down
2 changes: 2 additions & 0 deletions service/constant-propagation/ConstantPropagationTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class Transform final {
struct RuntimeCache {
const std::unordered_set<DexMethodRef*> kotlin_null_check_assertions{
kotlin_nullcheck_wrapper::get_kotlin_null_assertions()};
const DexMethodRef* redex_null_check_assertion{
method::redex_internal_checkObjectNotNull()};
};

struct Stats {
Expand Down
3 changes: 3 additions & 0 deletions test/instr/IntrinsifyNullchecks.config
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"passes" : [
"RemoveUnreachablePass",
"IntrinsifyNullChecksPass",
"ConstantPropagationPass",
"CommonSubexpressionEliminationPass",
"InterproceduralConstantPropagationPass",
"RegAllocPass"
]
},
Expand Down
7 changes: 5 additions & 2 deletions test/instr/IntrinsifyNullchecksVerify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ TEST_F(PreVerify, TestNullCheck) {

auto* meth_init = find_dmethod_named(*test_obj_cls, "<init>");
ASSERT_NE(nullptr, meth_init);
// Before opt, there is a invoke-virtual Object;.getClass();
// Before opt, there two 2 invoke-virtual Object;.getClass();
EXPECT_NE(nullptr,
find_invoke(meth_init, DOPCODE_INVOKE_VIRTUAL, "getClass"));
EXPECT_EQ(2, find_num_invoke(meth_init, DOPCODE_INVOKE_VIRTUAL, "getClass"));
}

TEST_F(PostVerify, TestNullCheck) {
Expand All @@ -34,9 +35,11 @@ TEST_F(PostVerify, TestNullCheck) {

auto* meth_init = find_dmethod_named(*test_obj_cls, "<init>");
ASSERT_NE(nullptr, meth_init);
// After opt, getClass() should be replaced with a null_check.
// After opt, getClass() should be replaced with a null_check. And there is
// only 1 left, since the duplicated one is removed.
EXPECT_EQ(nullptr,
find_invoke(meth_init, DOPCODE_INVOKE_VIRTUAL, "getClass"));
EXPECT_NE(nullptr,
find_invoke(meth_init, DOPCODE_INVOKE_STATIC, "null_check"));
EXPECT_EQ(1, find_num_invoke(meth_init, DOPCODE_INVOKE_STATIC, "null_check"));
}
4 changes: 3 additions & 1 deletion test/instr/NullCheckCoversionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class NullCheckConversionTest {
public NullCheckConversionTest(String ss, SampleObj so, String so2) {
Objects.requireNonNull(ss);
this.ss = ss;
Objects.requireNonNull(this.ss);
Objects.requireNonNull(so, "new bar must not be null");
this.so = so;
this.so2 = Preconditions.checkNotNull(so2);
Expand All @@ -49,7 +50,8 @@ public boolean Test2(String ss) {
@DoNotStrip
@Test
public void Test(String[] args) {
NullCheckConversionTest my = new NullCheckConversionTest(args[2], null, args[1]);
String s1 = "abc";
NullCheckConversionTest my = new NullCheckConversionTest(s1, null, args[1]);
System.out.println(my.Test2(args[1]));
}
}
21 changes: 21 additions & 0 deletions test/instr/VerifyUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,27 @@ DexOpcodeMethod* find_invoke(std::vector<DexInstruction*>::iterator begin,
return it == end ? nullptr : static_cast<DexOpcodeMethod*>(*it);
}

size_t find_num_invoke(const DexMethod* m,
DexOpcode opcode,
const char* target_mname,
DexType* receiver) {
size_t num = 0;
for (const auto& insn : m->get_dex_code()->get_instructions()) {
if (insn->opcode() != opcode) {
continue;
}
auto meth = static_cast<DexOpcodeMethod*>(insn)->get_method();
if (receiver && meth->get_class() != receiver) {
continue;
}
auto mname = static_cast<DexOpcodeMethod*>(insn)->get_method()->get_name();
if (mname == DexString::get_string(target_mname)) {
num++;
}
}
return num;
}

// Given a semicolon delimited list of extracted files from the APK, return a
// map of the original APK's file path to its path on disk.
ResourceFiles decode_resource_paths(const char* location, const char* suffix) {
Expand Down
6 changes: 6 additions & 0 deletions test/instr/VerifyUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ DexOpcodeMethod* find_invoke(std::vector<DexInstruction*>::iterator begin,
DexType* receiver = nullptr);
DexInstruction* find_instruction(DexMethod* m, DexOpcode opcode);

/* Find the number of invoke instructions that calls a particular method name */
size_t find_num_invoke(const DexMethod* m,
DexOpcode opcode,
const char* target_mname,
DexType* receiver = nullptr);

void verify_class_merged(const DexClass* cls, size_t num_dmethods = 0);

// A quick helper to dump CFGs before/after verify
Expand Down

0 comments on commit b8834dd

Please sign in to comment.