diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index f709a5ac52a41..e40d08633facc 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1456,6 +1456,36 @@ struct ThreeOps_match { } }; +/// Matches instructions with Opcode and any number of operands +template struct AnyOps_match { + std::tuple Operands; + + AnyOps_match(const OperandTypes &...Ops) : Operands(Ops...) {} + + // Operand matching works by recursively calling match_operands, matching the + // operands left to right. The first version is called for each operand but + // the last, for which the second version is called. The second version of + // match_operands is also used to match each individual operand. + template + std::enable_if_t match_operands(const Instruction *I) { + return match_operands(I) && match_operands(I); + } + + template + std::enable_if_t match_operands(const Instruction *I) { + return std::get(Operands).match(I->getOperand(Idx)); + } + + template bool match(OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opcode) { + auto *I = cast(V); + return I->getNumOperands() == sizeof...(OperandTypes) && + match_operands<0, sizeof...(OperandTypes) - 1>(I); + } + return false; + } +}; + /// Matches SelectInst. template inline ThreeOps_match @@ -1572,6 +1602,12 @@ m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) { PointerOp); } +/// Matches GetElementPtrInst. +template +inline auto m_GEP(const OperandTypes &...Ops) { + return AnyOps_match(Ops...); +} + //===----------------------------------------------------------------------===// // Matchers for CastInst classes // diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index b1add3c42976f..2972116736660 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -207,6 +207,12 @@ struct FlattenInfo { match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)), m_Value(MatchedItCount))); + // Matches the pattern ptr+i*M+j, with the two additions being done via GEP. + bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)), + m_Specific(InnerInductionPHI))) && + match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI), + m_Value(MatchedItCount))); + if (!MatchedItCount) return false; @@ -224,7 +230,7 @@ struct FlattenInfo { // Look through extends if the IV has been widened. Don't look through // extends if we already looked through a trunc. - if (Widened && IsAdd && + if (Widened && (IsAdd || IsGEP) && (isa(MatchedItCount) || isa(MatchedItCount))) { assert(MatchedItCount->getType() == InnerInductionPHI->getType() && "Unexpected type mismatch in types after widening"); @@ -236,7 +242,7 @@ struct FlattenInfo { LLVM_DEBUG(dbgs() << "Looking for inner trip count: "; InnerTripCount->dump()); - if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { + if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); LinearIVUses.insert(U); @@ -647,33 +653,40 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, if (OR != OverflowResult::MayOverflow) return OR; - for (Value *V : FI.LinearIVUses) { - for (Value *U : V->users()) { - if (auto *GEP = dyn_cast(U)) { - for (Value *GEPUser : U->users()) { - auto *GEPUserInst = cast(GEPUser); - if (!isa(GEPUserInst) && - !(isa(GEPUserInst) && - GEP == GEPUserInst->getOperand(1))) - continue; - if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, - FI.InnerLoop)) - continue; - // The IV is used as the operand of a GEP which dominates the loop - // latch, and the IV is at least as wide as the address space of the - // GEP. In this case, the GEP would wrap around the address space - // before the IV increment wraps, which would be UB. - if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { - LLVM_DEBUG( - dbgs() << "use of linear IV would be UB if overflow occurred: "; - GEP->dump()); - return OverflowResult::NeverOverflows; - } - } + auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) { + for (Value *GEPUser : GEP->users()) { + auto *GEPUserInst = cast(GEPUser); + if (!isa(GEPUserInst) && + !(isa(GEPUserInst) && GEP == GEPUserInst->getOperand(1))) + continue; + if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop)) + continue; + // The IV is used as the operand of a GEP which dominates the loop + // latch, and the IV is at least as wide as the address space of the + // GEP. In this case, the GEP would wrap around the address space + // before the IV increment wraps, which would be UB. + if (GEP->isInBounds() && + GEPOperand->getType()->getIntegerBitWidth() >= + DL.getPointerTypeSizeInBits(GEP->getType())) { + LLVM_DEBUG( + dbgs() << "use of linear IV would be UB if overflow occurred: "; + GEP->dump()); + return true; } } + return false; + }; + + // Check if any IV user is, or is used by, a GEP that would cause UB if the + // multiply overflows. + for (Value *V : FI.LinearIVUses) { + if (auto *GEP = dyn_cast(V)) + if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1))) + return OverflowResult::NeverOverflows; + for (Value *U : V->users()) + if (auto *GEP = dyn_cast(U)) + if (CheckGEP(GEP, V)) + return OverflowResult::NeverOverflows; } return OverflowResult::MayOverflow; @@ -779,6 +792,18 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(), "flatten.trunciv"); + if (auto *GEP = dyn_cast(V)) { + // Replace the GEP with one that uses OuterValue as the offset. + auto *InnerGEP = cast(GEP->getOperand(0)); + Value *Base = InnerGEP->getOperand(0); + // When the base of the GEP doesn't dominate the outer induction phi then + // we need to insert the new GEP where the old GEP was. + if (!DT->dominates(Base, &*Builder.GetInsertPoint())) + Builder.SetInsertPoint(cast(V)); + OuterValue = Builder.CreateGEP(GEP->getSourceElementType(), Base, + OuterValue, "flatten." + V->getName()); + } + LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: "; OuterValue->dump()); V->replaceAllUsesWith(OuterValue); diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll new file mode 100644 index 0000000000000..f4b8ea97237fe --- /dev/null +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll @@ -0,0 +1,137 @@ +; RUN: opt < %s -S -passes='loop(loop-flatten),verify' -verify-loop-info -verify-dom-info -verify-scev | FileCheck %s + +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" + +; We should be able to flatten the loops and turn the two geps into one. +; CHECK-LABEL: test1 +define void @test1(i32 %N, ptr %A) { +entry: + %cmp3 = icmp ult i32 0, %N + br i1 %cmp3, label %for.outer.preheader, label %for.end + +; CHECK-LABEL: for.outer.preheader: +; CHECK: %flatten.tripcount = mul i32 %N, %N +for.outer.preheader: + br label %for.inner.preheader + +; CHECK-LABEL: for.inner.preheader: +; CHECK: %flatten.arrayidx = getelementptr i32, ptr %A, i32 %i +for.inner.preheader: + %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ] + br label %for.inner + +; CHECK-LABEL: for.inner: +; CHECK: store i32 0, ptr %flatten.arrayidx, align 4 +; CHECK: br label %for.outer +for.inner: + %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ] + %mul = mul i32 %i, %N + %gep = getelementptr inbounds i32, ptr %A, i32 %mul + %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j + store i32 0, ptr %arrayidx, align 4 + %inc1 = add nuw i32 %j, 1 + %cmp2 = icmp ult i32 %inc1, %N + br i1 %cmp2, label %for.inner, label %for.outer + +; CHECK-LABEL: for.outer: +; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount +for.outer: + %inc2 = add i32 %i, 1 + %cmp1 = icmp ult i32 %inc2, %N + br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; We can flatten, but the flattened gep has to be inserted after the load it +; depends on. +; CHECK-LABEL: test2 +define void @test2(i32 %N, ptr %A) { +entry: + %cmp3 = icmp ult i32 0, %N + br i1 %cmp3, label %for.outer.preheader, label %for.end + +; CHECK-LABEL: for.outer.preheader: +; CHECK: %flatten.tripcount = mul i32 %N, %N +for.outer.preheader: + br label %for.inner.preheader + +; CHECK-LABEL: for.inner.preheader: +; CHECK-NOT: getelementptr i32, ptr %ptr, i32 %i +for.inner.preheader: + %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ] + br label %for.inner + +; CHECK-LABEL: for.inner: +; CHECK: %flatten.arrayidx = getelementptr i32, ptr %ptr, i32 %i +; CHECK: store i32 0, ptr %flatten.arrayidx, align 4 +; CHECK: br label %for.outer +for.inner: + %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ] + %ptr = load volatile ptr, ptr %A, align 4 + %mul = mul i32 %i, %N + %gep = getelementptr inbounds i32, ptr %ptr, i32 %mul + %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j + store i32 0, ptr %arrayidx, align 4 + %inc1 = add nuw i32 %j, 1 + %cmp2 = icmp ult i32 %inc1, %N + br i1 %cmp2, label %for.inner, label %for.outer + +; CHECK-LABEL: for.outer: +; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount +for.outer: + %inc2 = add i32 %i, 1 + %cmp1 = icmp ult i32 %inc2, %N + br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; We can't flatten if the gep offset is smaller than the pointer size. +; CHECK-LABEL: test3 +define void @test3(i16 %N, ptr %A) { +entry: + %cmp3 = icmp ult i16 0, %N + br i1 %cmp3, label %for.outer.preheader, label %for.end + +for.outer.preheader: + br label %for.inner.preheader + +; CHECK-LABEL: for.inner.preheader: +; CHECK-NOT: getelementptr i32, ptr %A, i16 %i +for.inner.preheader: + %i = phi i16 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ] + br label %for.inner + +; CHECK-LABEL: for.inner: +; CHECK-NOT: getelementptr i32, ptr %A, i16 %i +; CHECK: br i1 %cmp2, label %for.inner, label %for.outer +for.inner: + %j = phi i16 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ] + %mul = mul i16 %i, %N + %gep = getelementptr inbounds i32, ptr %A, i16 %mul + %arrayidx = getelementptr inbounds i32, ptr %gep, i16 %j + store i32 0, ptr %arrayidx, align 4 + %inc1 = add nuw i16 %j, 1 + %cmp2 = icmp ult i16 %inc1, %N + br i1 %cmp2, label %for.inner, label %for.outer + +for.outer: + %inc2 = add i16 %i, 1 + %cmp1 = icmp ult i16 %inc2, %N + br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit + +for.end.loopexit: + br label %for.end + +for.end: + ret void +}