diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 99525607f744a..9d9d1d3338a43 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1430,6 +1430,7 @@ class TargetTransformInfo { /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2. LLVM_ABI static OperandValueInfo getOperandInfo(const Value *V); + LLVM_ABI static OperandValueInfo mergeInfo(const Value *X, const Value *Y); /// This is an approximation of reciprocal throughput of a math/logic op. /// A higher cost indicates less expected throughput. diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index c529d87502acd..634f58517e4a0 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -956,6 +956,27 @@ TargetTransformInfo::getOperandInfo(const Value *V) { return {OpInfo, OpProps}; } +TargetTransformInfo::OperandValueInfo +TargetTransformInfo::mergeInfo(const Value *X, const Value *Y) { + auto [OpInfoX, OpPropsX] = TargetTransformInfo::getOperandInfo(X); + auto [OpInfoY, OpPropsY] = TargetTransformInfo::getOperandInfo(Y); + + OperandValueKind MergeInfo = OK_AnyValue; + OperandValueProperties MergeProp = OP_None; + + if (OpInfoX == OK_AnyValue || OpInfoY == OK_AnyValue || + OpInfoX == OK_UniformValue || OpInfoY == OK_UniformValue) + MergeInfo = OK_AnyValue; + else if (OpInfoX == OK_NonUniformConstantValue || + OpInfoY == OK_NonUniformConstantValue) + MergeInfo = OK_NonUniformConstantValue; + else + MergeInfo = X == Y ? OK_UniformConstantValue : OK_NonUniformConstantValue; + + MergeProp = OpPropsX == OpPropsY ? OpPropsX : OP_None; + return {MergeInfo, MergeProp}; +} + InstructionCost TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, OperandValueInfo Op1Info, OperandValueInfo Op2Info, diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 243f685cf25e2..9801cb3255568 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -2437,6 +2437,10 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { M -= NumSrcElts; }; + TTI::OperandValueInfo Op0Info, Op1Info; + Op0Info = TTI.mergeInfo(X, Z); + Op1Info = TTI.mergeInfo(Y, W); + SmallVector NewMask0(OldMask); TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc; if (X == Z) { @@ -2500,11 +2504,12 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { nullptr, {Y, W}); if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) { - NewCost += - TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind); + NewCost += TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, + CostKind, Op0Info, Op1Info); } else { - NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, - ShuffleDstTy, PredLHS, CostKind); + NewCost += + TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, ShuffleDstTy, + PredLHS, CostKind, Op0Info, Op1Info); } LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I diff --git a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll index 77b44d0e40e14..3e628d7865193 100644 --- a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll +++ b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll @@ -20,6 +20,19 @@ define <4 x float> @shuf_fdiv_v4f32_yy(<4 x float> %x, <4 x float> %y, <4 x floa ret <4 x float> %r } +define <16 x i16> @shuf_uniform_shift_v16i16_v8i16(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: define <16 x i16> @shuf_uniform_shift_v16i16_v8i16( +; CHECK-SAME: <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0]], <8 x i16> [[A1]], <16 x i32> +; CHECK-NEXT: [[RES:%.*]] = shl <16 x i16> [[TMP1]], splat (i16 7) +; CHECK-NEXT: ret <16 x i16> [[RES]] +; + %v0 = shl <8 x i16> %a0, splat (i16 7) + %v1 = shl <8 x i16> %a1, splat (i16 7) + %res = shufflevector <8 x i16> %v0, <8 x i16> %v1, <16 x i32> + ret <16 x i16> %res +} + ; Common operand is op0 of the binops. define <4 x i32> @shuf_add_v4i32_xx(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {