From 012173680f368bff9b4e3db21e1381360422cdc6 Mon Sep 17 00:00:00 2001 From: eopXD Date: Sat, 14 Aug 2021 15:58:05 +0800 Subject: [PATCH] [LoopIdiom] let the pass deal with runtime memset size The current LIR does not deal with runtime-determined memset-size. This patch utilizes SCEV and check if the PointerStrideSCEV and the MemsetSizeSCEV are equal. Before comparison the pass would try to fold the expression that is already protected by the loop guard. Testcase file `memset-runtime.ll`, `memset-runtime-debug.ll` added. This patch deals with proper loop-idiom. Proceeding patch wants to deal with SCEV-s that are inequal after folding with the loop guards. Reviewed By: lebedev.ri, Whitney Differential Revision: https://reviews.llvm.org/D107353 --- .../Transforms/Scalar/LoopIdiomRecognize.cpp | 96 +++++-- .../LoopIdiom/memset-runtime-debug.ll | 270 ++++++++++++++++++ .../Transforms/LoopIdiom/memset-runtime.ll | 110 +++++++ 3 files changed, 445 insertions(+), 31 deletions(-) create mode 100644 llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll create mode 100644 llvm/test/Transforms/LoopIdiom/memset-runtime.ll diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index f1dcb10b01bf1..6cf8f5a0b0d96 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -896,8 +896,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, /// processLoopMemSet - See if this memset can be promoted to a large memset. bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { - // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + // We can only handle non-volatile memsets. + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -910,23 +910,72 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // loop, which indicates a strided store. If we have something else, it's a // random store we can't handle. const SCEVAddRecExpr *Ev = dyn_cast(SE->getSCEV(Pointer)); - if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) + if (!Ev || Ev->getLoop() != CurLoop) return false; - - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + if (!Ev->isAffine()) { + LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n"); return false; + } - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) + const SCEV *PointerStrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!PointerStrideSCEV || !MemsetSizeSCEV) return false; - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + bool IsNegStride = false; + const bool IsConstantSize = isa(MSI->getLength()); + + if (IsConstantSize) { + // Memset size is constant. + // Check if the pointer stride matches the memset size. If so, then + // we know that every byte is touched in the loop. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + IsNegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant. + // Check if the pointer stride matches the memset size. + // To be conservative, the pass would not promote pointers that aren't in + // address space zero. Also, the pass only handles memset length and stride + // that are invariant for the top level loop. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (Pointer->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << " pointer is not in address space zero, " + << "abort\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, " + << "abort\n"); + return false; + } + + // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV + IsNegStride = PointerStrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV) + : PointerStrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: folding can be done to the SCEVs + // The folding is to fold expressions that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed + // optimization if equal. + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -936,7 +985,6 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, SmallPtrSet MSIs; MSIs.insert(MSI); - bool IsNegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, @@ -1028,20 +1076,6 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, /// /// This also maps the SCEV into the provided type and tries to handle the /// computation in a way that will fold cleanly. -static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, - unsigned StoreSize, Loop *CurLoop, - const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); - - // And scale it based on the store size. - if (StoreSize != 1) { - return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } - return TripCountSCEV; -} - -/// getNumBytes that takes StoreSize as a SCEV static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, const SCEV *StoreSizeSCEV, Loop *CurLoop, const DataLayout *DL, ScalarEvolution *SE) { @@ -1342,8 +1376,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // Okay, everything is safe, we can transform this! - const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + const SCEV *NumBytesS = getNumBytes( + BECount, IntIdxTy, SE->getConstant(IntIdxTy, StoreSize), CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll new file mode 100644 index 0000000000000..8ee554eb6d25e --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -0,0 +1,270 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; REQUIRES: asserts +; RUN: opt < %s -S -debug -passes=loop-idiom 2>&1 | FileCheck %s +; The C code to generate this testcase: +; void test(int *ar, int n, int m) +; { +; long i; +; for (i=0; i +; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: SCEV don't match, abort +; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader +; CHECK-NEXT: memset size is non-constant +; CHECK-NEXT: pointer is not in address space zero, abort +; CHECK: loop-idiom Scanning: F[NonAffinePointer] Countable Loop %for.body +; CHECK-NEXT: Pointer is not affine, abort + +define void @MemsetSize_LoopVariant(i32* %ar, i32 %n, i32 %m) { +; CHECK-LABEL: @MemsetSize_LoopVariant( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[N:%.*]] to i64 +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[CONV]] +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[M:%.*]] to i64 +; CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[MUL3:%.*]] = mul i64 [[CONV2]], 4 +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[I_02:%.*]] = phi i64 [ 0, [[FOR_BODY_LR_PH]] ], [ [[INC:%.*]], [[FOR_INC:%.*]] ] +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[I_02]], [[CONV1]] +; CHECK-NEXT: [[ADD_PTR:%.*]] = getelementptr inbounds i32, i32* [[AR:%.*]], i64 [[MUL]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[ADD_PTR]] to i8* +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[I_02]], [[MUL3]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[TMP0]], i8 0, i64 [[ADD]], i1 false) +; CHECK-NEXT: br label [[FOR_INC]] +; CHECK: for.inc: +; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_02]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INC]], [[CONV]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_FOR_END_CRIT_EDGE:%.*]] +; CHECK: for.cond.for.end_crit_edge: +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %conv = sext i32 %n to i64 + %cmp1 = icmp slt i64 0, %conv + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %conv1 = sext i32 %m to i64 + %conv2 = sext i32 %m to i64 + %mul3 = mul i64 %conv2, 4 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.inc + %i.02 = phi i64 [ 0, %for.body.lr.ph ], [ %inc, %for.inc ] + %mul = mul nsw i64 %i.02, %conv1 + %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %mul + %0 = bitcast i32* %add.ptr to i8* + %add = add nsw i64 %i.02, %mul3 + call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 %add, i1 false) + br label %for.inc + +for.inc: ; preds = %for.body + %inc = add nuw nsw i64 %i.02, 1 + %cmp = icmp slt i64 %inc, %conv + br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge + +for.cond.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond.for.end_crit_edge, %entry + ret void +} +; void test(int *ar, int n, int m) +; { +; long i; +; for (i=0; i=0; i--) { +; int *arr = ar + i * m; +; memset(arr, 0, m * sizeof(int)); +; } +; } +define void @For_NegativeStride(i32* %ar, i32 %n, i32 %m) { +; CHECK-LABEL: @For_NegativeStride( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[N:%.*]], 1 +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SUB]] to i64 +; CHECK-NEXT: [[CMP1:%.*]] = icmp sge i64 [[CONV]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[M:%.*]] to i64 +; CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[MUL3:%.*]] = mul i64 [[CONV2]], 4 +; CHECK-NEXT: [[TMP0:%.*]] = sub i64 [[CONV]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[CONV1]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP2]], i1 false) +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %sub = sub nsw i32 %n, 1 + %conv = sext i32 %sub to i64 + %cmp1 = icmp sge i64 %conv, 0 + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %conv1 = sext i32 %m to i64 + %conv2 = sext i32 %m to i64 + %mul3 = mul i64 %conv2, 4 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.inc + %i.02 = phi i64 [ %conv, %for.body.lr.ph ], [ %dec, %for.inc ] + %mul = mul nsw i64 %i.02, %conv1 + %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %mul + %0 = bitcast i32* %add.ptr to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 %mul3, i1 false) + br label %for.inc + +for.inc: ; preds = %for.body + %dec = add nsw i64 %i.02, -1 + %cmp = icmp sge i64 %dec, 0 + br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge + +for.cond.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond.for.end_crit_edge, %entry + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)