diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 02881109f17d2..fc17d9e96356b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2706,17 +2706,18 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, return LastInst; } -/// Match UB-safe variants of the funnel shift intrinsic. -static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, - const DominatorTree &DT) { +std::optional>> +InstCombinerImpl::convertShlOrLShrToFShlOrFShr(Instruction &Or) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? + assert(Or.getOpcode() == BinaryOperator::Or && "Expecting or instruction"); + unsigned Width = Or.getType()->getScalarSizeInBits(); Instruction *Or0, *Or1; if (!match(Or.getOperand(0), m_Instruction(Or0)) || !match(Or.getOperand(1), m_Instruction(Or1))) - return nullptr; + return std::nullopt; bool IsFshl = true; // Sub on LSHR. SmallVector FShiftArgs; @@ -2730,7 +2731,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || Or0->getOpcode() == Or1->getOpcode()) - return nullptr; + return std::nullopt; // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). if (Or0->getOpcode() == BinaryOperator::LShr) { @@ -2766,7 +2767,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, // might remove it after this fold). This still doesn't guarantee that the // final codegen will match this original pattern. if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { - KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + KnownBits KnownL = computeKnownBits(L, /*Depth*/ 0, &Or); return KnownL.getMaxValue().ult(Width) ? L : nullptr; } @@ -2810,7 +2811,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, IsFshl = false; // Sub on SHL. } if (!ShAmt) - return nullptr; + return std::nullopt; FShiftArgs = {ShVal0, ShVal1, ShAmt}; } else if (isa(Or0) || isa(Or1)) { @@ -2832,18 +2833,18 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, const APInt *ZextHighShlAmt; if (!match(Or0, m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt))))) - return nullptr; + return std::nullopt; if (!match(Or1, m_ZExt(m_Value(Low))) || !match(ZextHigh, m_ZExt(m_Value(High)))) - return nullptr; + return std::nullopt; unsigned HighSize = High->getType()->getScalarSizeInBits(); unsigned LowSize = Low->getType()->getScalarSizeInBits(); // Make sure High does not overlap with Low and most significant bits of // High aren't shifted out. if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize)) - return nullptr; + return std::nullopt; for (User *U : ZextHigh->users()) { Value *X, *Y; @@ -2874,11 +2875,21 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, } if (FShiftArgs.empty()) - return nullptr; + return std::nullopt; Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; - Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return CallInst::Create(F, FShiftArgs); + return std::make_tuple(IID, FShiftArgs); +} + +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { + if (auto Opt = IC.convertShlOrLShrToFShlOrFShr(Or)) { + auto [IID, FShiftArgs] = *Opt; + Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); + return CallInst::Create(F, FShiftArgs); + } + + return nullptr; } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -3372,7 +3383,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*MatchBitReversals*/ true)) return BitOp; - if (Instruction *Funnel = matchFunnelShift(I, *this, DT)) + if (Instruction *Funnel = matchFunnelShift(I, *this)) return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 0bbb22be71569..303d02cc24fc9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -236,6 +236,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final return getLosslessTrunc(C, TruncTy, Instruction::SExt); } + std::optional>> + convertShlOrLShrToFShlOrFShr(Instruction &Or); + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index fa076098d63cd..79e83163a6b06 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -610,6 +610,19 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask, Known)) return R; + // Do not simplify if shl is part of funnel-shift pattern + if (I->hasOneUse()) { + auto *Inst = dyn_cast(I->user_back()); + if (Inst && Inst->getOpcode() == BinaryOperator::Or) { + if (auto Opt = convertShlOrLShrToFShlOrFShr(*Inst)) { + auto [IID, FShiftArgs] = *Opt; + if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + FShiftArgs[0] == FShiftArgs[1]) + return nullptr; + } + } + } + // TODO: If we only want bits that already match the signbit then we don't // need to shift. @@ -670,6 +683,19 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + // Do not simplify if lshr is part of funnel-shift pattern + if (I->hasOneUse()) { + auto *Inst = dyn_cast(I->user_back()); + if (Inst && Inst->getOpcode() == BinaryOperator::Or) { + if (auto Opt = convertShlOrLShrToFShlOrFShr(*Inst)) { + auto [IID, FShiftArgs] = *Opt; + if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + FShiftArgs[0] == FShiftArgs[1]) + return nullptr; + } + } + } + // If we are just demanding the shifted sign bit and below, then this can // be treated as an ASHR in disguise. if (DemandedMask.countl_zero() >= ShiftAmt) { diff --git a/llvm/test/Transforms/InstCombine/fsh.ll b/llvm/test/Transforms/InstCombine/fsh.ll index 48bf296993f6a..9fc19fa5f6828 100644 --- a/llvm/test/Transforms/InstCombine/fsh.ll +++ b/llvm/test/Transforms/InstCombine/fsh.ll @@ -722,6 +722,134 @@ define i32 @fsh_orconst_rotate(i32 %a) { ret i32 %t2 } +define i32 @fsh_rotate_5(i8 %x, i32 %y) { +; CHECK-LABEL: @fsh_rotate_5( +; CHECK-NEXT: [[T1:%.*]] = zext i8 [[X:%.*]] to i32 +; CHECK-NEXT: [[OR1:%.*]] = or i32 [[T1]], [[Y:%.*]] +; CHECK-NEXT: [[OR2:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR1]], i32 [[OR1]], i32 5) +; CHECK-NEXT: ret i32 [[OR2]] +; + + %t1 = zext i8 %x to i32 + %or1 = or i32 %t1, %y + %shr = lshr i32 %or1, 27 + %shl = shl i32 %or1, 5 + %or2 = or i32 %shr, %shl + ret i32 %or2 +} + +define i32 @fsh_rotate_18(i8 %x, i32 %y) { +; CHECK-LABEL: @fsh_rotate_18( +; CHECK-NEXT: [[T1:%.*]] = zext i8 [[X:%.*]] to i32 +; CHECK-NEXT: [[OR1:%.*]] = or i32 [[T1]], [[Y:%.*]] +; CHECK-NEXT: [[OR2:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR1]], i32 [[OR1]], i32 18) +; CHECK-NEXT: ret i32 [[OR2]] +; + + %t1 = zext i8 %x to i32 + %or1 = or i32 %t1, %y + %shr = lshr i32 %or1, 14 + %shl = shl i32 %or1, 18 + %or2 = or i32 %shr, %shl + ret i32 %or2 +} + +define i32 @fsh_load_rotate_12(ptr %data) { +; CHECK-LABEL: @fsh_load_rotate_12( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[DATA:%.*]], align 1 +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[TMP0]] to i32 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[CONV]], 24 +; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1 +; CHECK-NEXT: [[CONV2:%.*]] = zext i8 [[TMP1]] to i32 +; CHECK-NEXT: [[SHL3:%.*]] = shl nuw nsw i32 [[CONV2]], 16 +; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHL3]], [[SHL]] +; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 2 +; CHECK-NEXT: [[TMP2:%.*]] = load i8, ptr [[ARRAYIDX4]], align 1 +; CHECK-NEXT: [[CONV5:%.*]] = zext i8 [[TMP2]] to i32 +; CHECK-NEXT: [[SHL6:%.*]] = shl nuw nsw i32 [[CONV5]], 8 +; CHECK-NEXT: [[OR7:%.*]] = or i32 [[OR]], [[SHL6]] +; CHECK-NEXT: [[ARRAYIDX8:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 3 +; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX8]], align 1 +; CHECK-NEXT: [[CONV9:%.*]] = zext i8 [[TMP3]] to i32 +; CHECK-NEXT: [[OR10:%.*]] = or i32 [[OR7]], [[CONV9]] +; CHECK-NEXT: [[OR15:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR10]], i32 [[OR10]], i32 12) +; CHECK-NEXT: ret i32 [[OR15]] +; + +entry: + %0 = load i8, ptr %data + %conv = zext i8 %0 to i32 + %shl = shl nuw i32 %conv, 24 + %arrayidx1 = getelementptr inbounds i8, ptr %data, i64 1 + %1 = load i8, ptr %arrayidx1 + %conv2 = zext i8 %1 to i32 + %shl3 = shl nuw nsw i32 %conv2, 16 + %or = or i32 %shl3, %shl + %arrayidx4 = getelementptr inbounds i8, ptr %data, i64 2 + %2 = load i8, ptr %arrayidx4 + %conv5 = zext i8 %2 to i32 + %shl6 = shl nuw nsw i32 %conv5, 8 + %or7 = or i32 %or, %shl6 + %arrayidx8 = getelementptr inbounds i8, ptr %data, i64 3 + %3 = load i8, ptr %arrayidx8 + %conv9 = zext i8 %3 to i32 + %or10 = or i32 %or7, %conv9 + %shr = lshr i32 %or10, 20 + %shl7 = shl i32 %or10, 12 + %or15 = or i32 %shr, %shl7 + ret i32 %or15 +} + +define i32 @fsh_load_rotate_25(ptr %data) { +; CHECK-LABEL: @fsh_load_rotate_25( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[DATA:%.*]], align 1 +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[TMP0]] to i32 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[CONV]], 24 +; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1 +; CHECK-NEXT: [[CONV2:%.*]] = zext i8 [[TMP1]] to i32 +; CHECK-NEXT: [[SHL3:%.*]] = shl nuw nsw i32 [[CONV2]], 16 +; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHL3]], [[SHL]] +; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 2 +; CHECK-NEXT: [[TMP2:%.*]] = load i8, ptr [[ARRAYIDX4]], align 1 +; CHECK-NEXT: [[CONV5:%.*]] = zext i8 [[TMP2]] to i32 +; CHECK-NEXT: [[SHL6:%.*]] = shl nuw nsw i32 [[CONV5]], 8 +; CHECK-NEXT: [[OR7:%.*]] = or i32 [[OR]], [[SHL6]] +; CHECK-NEXT: [[ARRAYIDX8:%.*]] = getelementptr inbounds i8, ptr [[DATA]], i64 3 +; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX8]], align 1 +; CHECK-NEXT: [[CONV9:%.*]] = zext i8 [[TMP3]] to i32 +; CHECK-NEXT: [[OR10:%.*]] = or i32 [[OR7]], [[CONV9]] +; CHECK-NEXT: [[OR15:%.*]] = call i32 @llvm.fshl.i32(i32 [[OR10]], i32 [[OR10]], i32 25) +; CHECK-NEXT: ret i32 [[OR15]] +; + +entry: + %0 = load i8, ptr %data + %conv = zext i8 %0 to i32 + %shl = shl nuw i32 %conv, 24 + %arrayidx1 = getelementptr inbounds i8, ptr %data, i64 1 + %1 = load i8, ptr %arrayidx1 + %conv2 = zext i8 %1 to i32 + %shl3 = shl nuw nsw i32 %conv2, 16 + %or = or i32 %shl3, %shl + %arrayidx4 = getelementptr inbounds i8, ptr %data, i64 2 + %2 = load i8, ptr %arrayidx4 + %conv5 = zext i8 %2 to i32 + %shl6 = shl nuw nsw i32 %conv5, 8 + %or7 = or i32 %or, %shl6 + %arrayidx8 = getelementptr inbounds i8, ptr %data, i64 3 + %3 = load i8, ptr %arrayidx8 + %conv9 = zext i8 %3 to i32 + %or10 = or i32 %or7, %conv9 + %shr = lshr i32 %or10, 7 + %shl7 = shl i32 %or10, 25 + %or15 = or i32 %shr, %shl7 + ret i32 %or15 +} + define <2 x i31> @fshr_mask_args_same_vector(<2 x i31> %a) { ; CHECK-LABEL: @fshr_mask_args_same_vector( ; CHECK-NEXT: [[T3:%.*]] = shl <2 x i31> [[A:%.*]],