diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index c7e25c9f3d2c9..f4d9fd2837388 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -243,6 +243,10 @@ class LoopIdiomRecognize { bool recognizeShiftUntilBitTest(); bool recognizeShiftUntilZero(); + bool recognizeAndInsertCtz(); + void transformLoopToCtz(BasicBlock *PreCondBB, Instruction *CntInst, + PHINode *CntPhi, Value *Var); + /// @} }; } // end anonymous namespace @@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() { << CurLoop->getHeader()->getName() << "\n"); return recognizePopcount() || recognizeAndInsertFFS() || - recognizeShiftUntilBitTest() || recognizeShiftUntilZero(); + recognizeShiftUntilBitTest() || recognizeShiftUntilZero() || + recognizeAndInsertCtz(); } /// Check if the given conditional branch is based on the comparison between @@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() { ++NumShiftUntilZero; return MadeChange; } + +// This function recognizes a loop that counts the number of trailing zeros +// loop: +// %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ] +// %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ] +// %add = add nuw nsw i32 %count.010, 1 +// %shr = ashr exact i32 %n.addr.09, 1 +// %0 = and i32 %n.addr.09, 2 +// %cmp1 = icmp eq i32 %0, 0 +// br i1 %cmp1, label %while.body, label %if.end.loopexit +static bool detectShiftUntilZeroAndOneIdiom(Loop *CurLoop, Value *&InitX, + Instruction *&CntInst, + PHINode *&CntPhi) { + BasicBlock *LoopEntry; + Value *VarX; + Instruction *DefX; + + CntInst = nullptr; + CntPhi = nullptr; + LoopEntry = *(CurLoop->block_begin()); + + // Check if the loop-back branch is in desirable form. + // "if (x == 0) goto loop-entry" + if (Value *T = matchCondition( + dyn_cast(LoopEntry->getTerminator()), LoopEntry, true)) { + DefX = dyn_cast(T); + } else { + LLVM_DEBUG(dbgs() << "Bad condition for branch instruction\n"); + return false; + } + + // operand compares with 2, because we are looking for "x & 2" + // which was optimized by previous passes from "(x >> 1) & 1" + + if (!match(DefX, m_c_And(PatternMatch::m_Value(VarX), + PatternMatch::m_SpecificInt(2)))) + return false; + + // check if VarX is a phi node + + auto *PhiX = dyn_cast(VarX); + + if (!PhiX || PhiX->getParent() != LoopEntry) + return false; + + Instruction *DefXRShift = nullptr; + + // check if PhiX has a shift instruction as a operand, which is a "x >> 1" + + for (int i = 0; i < 2; ++i) { + if (auto *Inst = dyn_cast(PhiX->getOperand(i))) { + if (Inst->getOpcode() == Instruction::AShr || + Inst->getOpcode() == Instruction::LShr) { + DefXRShift = Inst; + break; + } + } + } + + if (DefXRShift == nullptr) + return false; + + // check if the shift instruction is a "x >> 1" + auto *Shft = dyn_cast(DefXRShift->getOperand(1)); + if (!Shft || !Shft->isOne()) + return false; + + if (DefXRShift->getOperand(0) != VarX) + return false; + + InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader()); + + // Find the instruction which counts the trailing zeros: cnt.next = cnt + 1. + for (Instruction &Inst : llvm::make_range( + LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) { + if (Inst.getOpcode() != Instruction::Add) + continue; + + ConstantInt *Inc = dyn_cast(Inst.getOperand(1)); + if (!Inc || !Inc->isOne()) + continue; + + PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry); + if (!Phi) + continue; + + CntInst = &Inst; + CntPhi = Phi; + break; + } + if (!CntInst) + return false; + + return true; +} + +/// Recognize CTTZ idiom in a non-countable loop and convert it to countable +/// with CTTZ of variable as a trip count. If CTTZ was inserted, returns true; +/// otherwise, returns false. +/// +// int count_trailing_zeroes(uint32_t n) { +// int count = 0; +// if (n == 0){ +// return 32; +// } +// while ((n & 1) == 0) { +// count += 1; +// n >>= 1; +// } +// +// +// return count; +// } +bool LoopIdiomRecognize::recognizeAndInsertCtz() { + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + Value *InitX; + PHINode *CntPhi = nullptr; + Instruction *CntInst = nullptr; + // For counting trailing zeros with uncountable loop idiom, transformation is + // always profitable if IdiomCanonicalSize is 7. + const size_t IdiomCanonicalSize = 7; + + if (!detectShiftUntilZeroAndOneIdiom(CurLoop, InitX, CntInst, CntPhi)) + return false; + + BasicBlock *PH = CurLoop->getLoopPreheader(); + + auto *PreCondBB = PH->getSinglePredecessor(); + if (!PreCondBB) + return false; + auto *PreCondBI = dyn_cast(PreCondBB->getTerminator()); + if (!PreCondBI) + return false; + + // check that initial value is not zero and "(init & 1) == 0" + // initial value must not be zero, because it will cause infinite loop + // without this check, after replacing the loop with cttz, the counter will be + // size of int, while before the replacement the loop would have executed + // indefinitely + + // match that case, where n is initial value + // entry: + // %cmp.not = icmp eq i32 %n, 0 + // br i1 %cmp.not, label %cleanup, label %while.cond.preheader + // + // while.cond.preheader: + // %and5 = and i32 %n, 1 + // %cmp16 = icmp eq i32 %and5, 0 + // br i1 %cmp16, label %while.body.preheader, label %cleanup + + Value *PreCond = matchCondition(PreCondBI, PH, true); + + if (!PreCond) + return false; + + Value *InitPredX = nullptr; + if (!match(PreCond, m_c_And(PatternMatch::m_Value(InitPredX), + PatternMatch::m_One())) || + InitPredX != InitX) + return false; + auto *PrePreCondBB = PreCondBB->getSinglePredecessor(); + if (!PrePreCondBB) + return false; + auto *PrePreCondBI = dyn_cast(PrePreCondBB->getTerminator()); + if (!PrePreCondBI) + return false; + if (matchCondition(PrePreCondBI, PreCondBB) != InitX) + return false; + + // CTTZ intrinsic always profitable after deleting the loop. + // the loop has only 7 instructions: + + // @llvm.dbg doesn't count as they have no semantic effect. + auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug(); + uint32_t HeaderSize = + std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); + if (HeaderSize != IdiomCanonicalSize) + return false; + + transformLoopToCtz(PH, CntInst, CntPhi, InitX); + return true; +} + +void LoopIdiomRecognize::transformLoopToCtz(BasicBlock *Preheader, + Instruction *CntInst, + PHINode *CntPhi, Value *InitX) { + BranchInst *PreheaderBr = cast(Preheader->getTerminator()); + const DebugLoc &DL = CntInst->getDebugLoc(); + + // Insert the CTTZ instruction at the end of the preheader block + IRBuilder<> Builder(PreheaderBr); + Builder.SetCurrentDebugLocation(DL); + Value *Count = createFFSIntrinsic(Builder, InitX, DL, + /* is zero poison */ true, Intrinsic::cttz); + + Value *NewCount = Count; + + NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType()); + + Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader); + // If the counter was being incremented in the loop, add NewCount to the + // counter's initial value, but only if the initial value is not zero. + ConstantInt *InitConst = dyn_cast(CntInitVal); + if (!InitConst || !InitConst->isZero()) + NewCount = Builder.CreateAdd(NewCount, CntInitVal); + + BasicBlock *Body = *(CurLoop->block_begin()); + + // All the references to the original counter outside + // the loop are replaced with the NewCount + CntInst->replaceUsesOutsideBlock(NewCount, Body); + SE->forgetLoop(CurLoop); +} \ No newline at end of file diff --git a/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll new file mode 100644 index 0000000000000..5c32d49782934 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll @@ -0,0 +1,147 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes=loop-idiom -mtriple=riscv32 -S < %s | FileCheck %s +; RUN: opt -passes=loop-idiom -mtriple=riscv64 -S < %s | FileCheck %s + +; Copied from popcnt test. + +;To recognize this pattern: +;int ctz(uint32_t n) +;{ +; int count = 0; +; if (n == 0) +; { +; return 32; +; } +; while ((n & 1) == 0) +; { +; count += 1; +; n >>= 1; +; } +; return count; +;} + +define signext i32 @count_trailing_zeroes(i32 noundef signext %n) local_unnamed_addr #0 { +; CHECK-LABEL: @count_trailing_zeroes( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[N:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]] +; CHECK: while.cond.preheader: +; CHECK-NEXT: [[AND4:%.*]] = and i32 [[N]], 1 +; CHECK-NEXT: [[CMP15:%.*]] = icmp eq i32 [[AND4]], 0 +; CHECK-NEXT: br i1 [[CMP15]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]] +; CHECK: while.body.preheader: +; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[N]], i1 true) +; CHECK-NEXT: br label [[WHILE_BODY:%.*]] +; CHECK: while.body: +; CHECK-NEXT: [[COUNT_07:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ] +; CHECK-NEXT: [[N_ADDR_06:%.*]] = phi i32 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ] +; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_07]], 1 +; CHECK-NEXT: [[SHR]] = lshr i32 [[N_ADDR_06]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[N_ADDR_06]], 2 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[TMP1]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]] +; CHECK: cleanup.loopexit: +; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP0]], [[WHILE_BODY]] ] +; CHECK-NEXT: br label [[CLEANUP]] +; CHECK: cleanup: +; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 32, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %cmp = icmp eq i32 %n, 0 + br i1 %cmp, label %cleanup, label %while.cond.preheader + +while.cond.preheader: ; preds = %entry + %and4 = and i32 %n, 1 + %cmp15 = icmp eq i32 %and4, 0 + br i1 %cmp15, label %while.body, label %cleanup + +while.body: ; preds = %while.cond.preheader, %while.body + %count.07 = phi i32 [ %add, %while.body ], [ 0, %while.cond.preheader ] + %n.addr.06 = phi i32 [ %shr, %while.body ], [ %n, %while.cond.preheader ] + %add = add nuw nsw i32 %count.07, 1 + %shr = lshr i32 %n.addr.06, 1 + %0 = and i32 %n.addr.06, 2 + %cmp1 = icmp eq i32 %0, 0 + br i1 %cmp1, label %while.body, label %cleanup + +cleanup: ; preds = %while.body, %while.cond.preheader, %entry + %retval.0 = phi i32 [ 32, %entry ], [ 0, %while.cond.preheader ], [ %add, %while.body ] + ret i32 %retval.0 +} + +;int ctz(uint64_t n) +;{ +; int count = 0; +; if (n != 0) +; { +; while ((n & 1) == 0) +; { +; n >>= 1; +; count += 1; +; } +; } +; else +; { +; return 64; +; } +; return count; +;} + +define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr { +; CHECK-LABEL: @ctz( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[N:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP_NOT]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]] +; CHECK: while.cond.preheader: +; CHECK-NEXT: [[AND5:%.*]] = and i64 [[N]], 1 +; CHECK-NEXT: [[CMP16:%.*]] = icmp eq i64 [[AND5]], 0 +; CHECK-NEXT: br i1 [[CMP16]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]] +; CHECK: while.body.preheader: +; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.cttz.i64(i64 [[N]], i1 true) +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: br label [[WHILE_BODY:%.*]] +; CHECK: while.body: +; CHECK-NEXT: [[COUNT_08:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ] +; CHECK-NEXT: [[N_ADDR_07:%.*]] = phi i64 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ] +; CHECK-NEXT: [[SHR]] = lshr i64 [[N_ADDR_07]], 1 +; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_08]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[N_ADDR_07]], 2 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[TMP2]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]] +; CHECK: cleanup.loopexit: +; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP1]], [[WHILE_BODY]] ] +; CHECK-NEXT: br label [[CLEANUP]] +; CHECK: cleanup: +; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 64, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %cmp.not = icmp eq i64 %n, 0 + br i1 %cmp.not, label %cleanup, label %while.cond.preheader + +while.cond.preheader: ; preds = %entry + %and5 = and i64 %n, 1 + %cmp16 = icmp eq i64 %and5, 0 + br i1 %cmp16, label %while.body.preheader, label %cleanup + +while.body.preheader: ; preds = %while.cond.preheader + br label %while.body + +while.body: ; preds = %while.body.preheader, %while.body + %count.08 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ] + %n.addr.07 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ] + %shr = lshr i64 %n.addr.07, 1 + %add = add nuw nsw i32 %count.08, 1 + %0 = and i64 %n.addr.07, 2 + %cmp1 = icmp eq i64 %0, 0 + br i1 %cmp1, label %while.body, label %cleanup.loopexit + +cleanup.loopexit: ; preds = %while.body + %add.lcssa = phi i32 [ %add, %while.body ] + br label %cleanup + +cleanup: ; preds = %cleanup.loopexit, %while.cond.preheader, %entry + %retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ] + ret i32 %retval.0 +}