Skip to content

Commit

Permalink
[AggressiveInstCombine] convert a chain of 'or-shift' bits into maske…
Browse files Browse the repository at this point in the history
…d compare

and (or (lshr X, C), ...), 1 --> (X & C') != 0

I initially thought about implementing the minimal pattern in instcombine as mentioned here:
https://bugs.llvm.org/show_bug.cgi?id=37098#c6

...but we need to do better to catch the more general sequence from the motivating test 
(more than 2 bits in the compare). And a test-suite run with statistics showed that this 
pattern only happened 2 times currently. It would potentially happen more often if 
reassociation worked better (D45842), but it's probably still not too frequent?

This is small enough that I didn't see a need to create a whole new class/file within 
AggressiveInstCombine. There are likely other relatively small matchers like what was 
discussed in D44266 that would slide under foldUnusualPatterns() (name suggestions welcome). 
We could potentially also consolidate matchers for ctpop, bswap, etc under here.

Differential Revision: https://reviews.llvm.org/D45986

llvm-svn: 331311
  • Loading branch information
rotateright committed May 1, 2018
1 parent 52fd169 commit d2025a2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 57 deletions.
115 changes: 94 additions & 21 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Expand Up @@ -19,11 +19,15 @@
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/Utils/Local.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "aggressive-instcombine"

Expand Down Expand Up @@ -53,6 +57,91 @@ class AggressiveInstCombinerLegacyPass : public FunctionPass {
};
} // namespace

/// This is a recursive helper for 'and X, 1' that walks through a chain of 'or'
/// instructions looking for shift ops of a common source value (first member of
/// the pair). The second member of the pair is a mask constant for all of the
/// bits that are being compared. So this:
/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
/// returns {X, 0x129} and those are the operands of an 'and' that is compared
/// to zero.
static bool matchMaskedCmpOp(Value *V, std::pair<Value *, APInt> &Result) {
// Recurse through a chain of 'or' operands.
Value *Op0, *Op1;
if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
return matchMaskedCmpOp(Op0, Result) && matchMaskedCmpOp(Op1, Result);

// We need a shift-right or a bare value representing a compare of bit 0 of
// the original source operand.
Value *Candidate;
uint64_t BitIndex = 0;
if (!match(V, m_LShr(m_Value(Candidate), m_ConstantInt(BitIndex))))
Candidate = V;

// Initialize result source operand.
if (!Result.first)
Result.first = Candidate;

// Fill in the mask bit derived from the shift constant.
Result.second |= (1 << BitIndex);
return Result.first == Candidate;
}

/// Match an 'and' of a chain of or-shifted bits from a common source value into
/// a masked compare:
/// and (or (lshr X, C), ...), 1 --> (X & C') != 0
static bool foldToMaskedCmp(Instruction &I) {
// TODO: This is only looking for 'any-bits-set' and 'all-bits-clear'.
// We should also match 'all-bits-set' and 'any-bits-clear' by looking for a
// a chain of 'and'.
if (!match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
return false;

std::pair<Value *, APInt>
MaskOps(nullptr, APInt::getNullValue(I.getType()->getScalarSizeInBits()));
if (!matchMaskedCmpOp(cast<BinaryOperator>(&I)->getOperand(0), MaskOps))
return false;

IRBuilder<> Builder(&I);
Value *Mask = Builder.CreateAnd(MaskOps.first, MaskOps.second);
Value *CmpZero = Builder.CreateIsNotNull(Mask);
Value *Zext = Builder.CreateZExt(CmpZero, I.getType());
I.replaceAllUsesWith(Zext);
return true;
}

/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
static bool foldUnusualPatterns(Function &F, DominatorTree &DT) {
bool MadeChange = false;
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
if (!DT.isReachableFromEntry(&BB))
continue;
// Do not delete instructions under here and invalidate the iterator.
for (Instruction &I : BB)
MadeChange |= foldToMaskedCmp(I);
}

// We're done with transforms, so remove dead instructions.
if (MadeChange)
for (BasicBlock &BB : F)
SimplifyInstructionsInBlock(&BB);

return MadeChange;
}

/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) {
bool MadeChange = false;
const DataLayout &DL = F.getParent()->getDataLayout();
TruncInstCombine TIC(TLI, DL, DT);
MadeChange |= TIC.run(F);
MadeChange |= foldUnusualPatterns(F, DT);
return MadeChange;
}

void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.setPreservesCFG();
Expand All @@ -65,35 +154,19 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
}

bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) {
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
auto &DL = F.getParent()->getDataLayout();

bool MadeIRChange = false;

// Handle TruncInst patterns
TruncInstCombine TIC(TLI, DL, DT);
MadeIRChange |= TIC.run(F);

// TODO: add more patterns to handle...

return MadeIRChange;
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
return runImpl(F, TLI, DT);
}

PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &DL = F.getParent()->getDataLayout();
bool MadeIRChange = false;

// Handle TruncInst patterns
TruncInstCombine TIC(TLI, DL, DT);
MadeIRChange |= TIC.run(F);
if (!MadeIRChange)
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
if (!runImpl(F, TLI, DT)) {
// No changes, all analyses are preserved.
return PreservedAnalyses::all();

}
// Mark all the analyses that instcombine updates as preserved.
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
Expand Down
31 changes: 12 additions & 19 deletions llvm/test/Transforms/AggressiveInstCombine/masked-cmp.ll
Expand Up @@ -5,10 +5,10 @@

define i32 @anyset_two_bit_mask(i32 %x) {
; CHECK-LABEL: @anyset_two_bit_mask(
; CHECK-NEXT: [[S:%.*]] = lshr i32 [[X:%.*]], 3
; CHECK-NEXT: [[O:%.*]] = or i32 [[S]], [[X]]
; CHECK-NEXT: [[R:%.*]] = and i32 [[O]], 1
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 9
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[TMP3]]
;
%s = lshr i32 %x, 3
%o = or i32 %s, %x
Expand All @@ -18,14 +18,10 @@ define i32 @anyset_two_bit_mask(i32 %x) {

define i32 @anyset_four_bit_mask(i32 %x) {
; CHECK-LABEL: @anyset_four_bit_mask(
; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], 3
; CHECK-NEXT: [[T2:%.*]] = lshr i32 [[X]], 5
; CHECK-NEXT: [[T3:%.*]] = lshr i32 [[X]], 8
; CHECK-NEXT: [[O1:%.*]] = or i32 [[T1]], [[X]]
; CHECK-NEXT: [[O2:%.*]] = or i32 [[T2]], [[T3]]
; CHECK-NEXT: [[O3:%.*]] = or i32 [[O1]], [[O2]]
; CHECK-NEXT: [[R:%.*]] = and i32 [[O3]], 1
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 297
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[TMP3]]
;
%t1 = lshr i32 %x, 3
%t2 = lshr i32 %x, 5
Expand All @@ -41,13 +37,10 @@ define i32 @anyset_four_bit_mask(i32 %x) {

define i32 @anyset_three_bit_mask_all_shifted_bits(i32 %x) {
; CHECK-LABEL: @anyset_three_bit_mask_all_shifted_bits(
; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], 3
; CHECK-NEXT: [[T2:%.*]] = lshr i32 [[X]], 5
; CHECK-NEXT: [[T3:%.*]] = lshr i32 [[X]], 8
; CHECK-NEXT: [[O2:%.*]] = or i32 [[T2]], [[T3]]
; CHECK-NEXT: [[O3:%.*]] = or i32 [[T1]], [[O2]]
; CHECK-NEXT: [[R:%.*]] = and i32 [[O3]], 1
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 296
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[TMP3]]
;
%t1 = lshr i32 %x, 3
%t2 = lshr i32 %x, 5
Expand Down
25 changes: 8 additions & 17 deletions llvm/test/Transforms/PhaseOrdering/bitfield-bittests.ll
Expand Up @@ -18,15 +18,10 @@ target datalayout = "n32"

define i32 @allclear(i32 %a) {
; CHECK-LABEL: @allclear(
; CHECK-NEXT: [[BF_LSHR:%.*]] = lshr i32 [[A:%.*]], 1
; CHECK-NEXT: [[BF_CLEAR1:%.*]] = or i32 [[BF_LSHR]], [[A]]
; CHECK-NEXT: [[BF_LSHR5:%.*]] = lshr i32 [[A]], 2
; CHECK-NEXT: [[OR2:%.*]] = or i32 [[BF_CLEAR1]], [[BF_LSHR5]]
; CHECK-NEXT: [[BF_LSHR10:%.*]] = lshr i32 [[A]], 3
; CHECK-NEXT: [[OR83:%.*]] = or i32 [[OR2]], [[BF_LSHR10]]
; CHECK-NEXT: [[OR13:%.*]] = and i32 [[OR83]], 1
; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[OR13]], 1
; CHECK-NEXT: ret i32 [[TMP1]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[TMP3]]
;
%a.sroa.0.0.trunc = trunc i32 %a to i8
%a.sroa.5.0.shift = lshr i32 %a, 8
Expand All @@ -51,14 +46,10 @@ define i32 @allclear(i32 %a) {

define i32 @anyset(i32 %a) {
; CHECK-LABEL: @anyset(
; CHECK-NEXT: [[BF_LSHR:%.*]] = lshr i32 [[A:%.*]], 1
; CHECK-NEXT: [[BF_CLEAR1:%.*]] = or i32 [[BF_LSHR]], [[A]]
; CHECK-NEXT: [[BF_LSHR5:%.*]] = lshr i32 [[A]], 2
; CHECK-NEXT: [[OR2:%.*]] = or i32 [[BF_CLEAR1]], [[BF_LSHR5]]
; CHECK-NEXT: [[BF_LSHR10:%.*]] = lshr i32 [[A]], 3
; CHECK-NEXT: [[OR83:%.*]] = or i32 [[OR2]], [[BF_LSHR10]]
; CHECK-NEXT: [[OR13:%.*]] = and i32 [[OR83]], 1
; CHECK-NEXT: ret i32 [[OR13]]
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
; CHECK-NEXT: ret i32 [[TMP3]]
;
%a.sroa.0.0.trunc = trunc i32 %a to i8
%a.sroa.5.0.shift = lshr i32 %a, 8
Expand Down

0 comments on commit d2025a2

Please sign in to comment.