diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 06051552b9d5e9..247b8a6c199165 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -39,6 +39,7 @@ class BlockFrequencyInfo; class DominatorTree; class BranchInst; class CallBase; +class ExtractElementInst; class Function; class GlobalValue; class IntrinsicInst; @@ -805,6 +806,36 @@ class TargetTransformInfo { ///< shuffle mask. }; + /// Kind of the reduction data. + enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_MinMax, /// Min/max reduction data. + RK_UnsignedMinMax, /// Unsigned min/max reduction data. + }; + + /// Contains opcode + LHS/RHS parts of the reduction operations. + struct ReductionData { + ReductionData() = delete; + ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + assert(Kind != RK_None && "expected binary or min/max reduction only."); + } + unsigned Opcode = 0; + Value *LHS = nullptr; + Value *RHS = nullptr; + ReductionKind Kind = RK_None; + bool hasSameData(ReductionData &RD) const { + return Kind == RD.Kind && Opcode == RD.Opcode; + } + }; + + static ReductionKind matchPairwiseReduction( + const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty); + + static ReductionKind matchVectorSplittingReduction( + const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty); + /// Additional information about an operand's possible values. enum OperandValueKind { OK_AnyValue, // Operand can have any value. diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index cedffe7da4bada..3d0af73d445c96 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -944,6 +944,54 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase { return TargetTTI->getShuffleCost(TTI::SK_PermuteTwoSrc, VecTy, 0, nullptr); } + case Instruction::ExtractElement: { + unsigned Idx = -1; + auto *EEI = cast(U); + auto *CI = dyn_cast(EEI->getOperand(1)); + if (CI) + Idx = CI->getZExtValue(); + + // Try to match a reduction sequence (series of shufflevector and + // vector adds followed by a extractelement). + unsigned ReduxOpCode; + VectorType *ReduxType; + + switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode, + ReduxType)) { + case TTI::RK_Arithmetic: + return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false, + CostKind); + case TTI::RK_MinMax: + return TargetTTI->getMinMaxReductionCost( + ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind); + case TTI::RK_UnsignedMinMax: + return TargetTTI->getMinMaxReductionCost( + ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind); + case TTI::RK_None: + break; + } + + switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + case TTI::RK_Arithmetic: + return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true, CostKind); + case TTI::RK_MinMax: + return TargetTTI->getMinMaxReductionCost( + ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind); + case TTI::RK_UnsignedMinMax: + return TargetTTI->getMinMaxReductionCost( + ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind); + case TTI::RK_None: + break; + } + return TargetTTI->getVectorInstrCost(Opcode, U->getOperand(0)->getType(), + Idx); + } } // By default, just classify everything as 'basic'. return TTI::TCC_Basic; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index a862397a592cd3..12c5c84e5b0bad 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -970,35 +970,10 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, return Mask == ActualMask; } -namespace { -/// Kind of the reduction data. -enum ReductionKind { - RK_None, /// Not a reduction. - RK_Arithmetic, /// Binary reduction data. - RK_MinMax, /// Min/max reduction data. - RK_UnsignedMinMax, /// Unsigned min/max reduction data. -}; -/// Contains opcode + LHS/RHS parts of the reduction operations. -struct ReductionData { - ReductionData() = delete; - ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) - : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { - assert(Kind != RK_None && "expected binary or min/max reduction only."); - } - unsigned Opcode = 0; - Value *LHS = nullptr; - Value *RHS = nullptr; - ReductionKind Kind = RK_None; - bool hasSameData(ReductionData &RD) const { - return Kind == RD.Kind && Opcode == RD.Opcode; - } -}; -} // namespace - -static Optional getReductionData(Instruction *I) { +static Optional getReductionData(Instruction *I) { Value *L, *R; if (m_BinOp(m_Value(L), m_Value(R)).match(I)) - return ReductionData(RK_Arithmetic, I->getOpcode(), L, R); + return TTI::ReductionData(TTI::RK_Arithmetic, I->getOpcode(), L, R); if (auto *SI = dyn_cast(I)) { if (m_SMin(m_Value(L), m_Value(R)).match(SI) || m_SMax(m_Value(L), m_Value(R)).match(SI) || @@ -1007,20 +982,20 @@ static Optional getReductionData(Instruction *I) { m_UnordFMin(m_Value(L), m_Value(R)).match(SI) || m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) { auto *CI = cast(SI->getCondition()); - return ReductionData(RK_MinMax, CI->getOpcode(), L, R); + return TTI::ReductionData(TTI::RK_MinMax, CI->getOpcode(), L, R); } if (m_UMin(m_Value(L), m_Value(R)).match(SI) || m_UMax(m_Value(L), m_Value(R)).match(SI)) { auto *CI = cast(SI->getCondition()); - return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R); + return TTI::ReductionData(TTI::RK_UnsignedMinMax, CI->getOpcode(), L, R); } } return llvm::None; } -static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, - unsigned Level, - unsigned NumLevels) { +static TTI::ReductionKind matchPairwiseReductionAtLevel(Instruction *I, + unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> @@ -1028,24 +1003,24 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, // <4 x i32> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 if (!I) - return RK_None; + return TTI::RK_None; assert(I->getType()->isVectorTy() && "Expecting a vector type"); - Optional RD = getReductionData(I); + Optional RD = getReductionData(I); if (!RD) - return RK_None; + return TTI::RK_None; ShuffleVectorInst *LS = dyn_cast(RD->LHS); if (!LS && Level) - return RK_None; + return TTI::RK_None; ShuffleVectorInst *RS = dyn_cast(RD->RHS); if (!RS && Level) - return RK_None; + return TTI::RK_None; // On level 0 we can omit one shufflevector instruction. if (!Level && !RS && !LS) - return RK_None; + return TTI::RK_None; // Shuffle inputs must match. Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; @@ -1054,7 +1029,7 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, if (NextLevelOpR && NextLevelOpL) { // If we have two shuffles their operands must match. if (NextLevelOpL != NextLevelOpR) - return RK_None; + return TTI::RK_None; NextLevelOp = NextLevelOpL; } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { @@ -1065,32 +1040,32 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R if (NextLevelOpL && NextLevelOpL != RD->RHS) - return RK_None; + return TTI::RK_None; else if (NextLevelOpR && NextLevelOpR != RD->LHS) - return RK_None; + return TTI::RK_None; NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; } else - return RK_None; + return TTI::RK_None; // Check that the next levels binary operation exists and matches with the // current one. if (Level + 1 != NumLevels) { - Optional NextLevelRD = + Optional NextLevelRD = getReductionData(cast(NextLevelOp)); if (!NextLevelRD || !RD->hasSameData(*NextLevelRD)) - return RK_None; + return TTI::RK_None; } // Shuffle mask for pairwise operation must match. if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) - return RK_None; + return TTI::RK_None; } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) - return RK_None; + return TTI::RK_None; } else { - return RK_None; + return TTI::RK_None; } if (++Level == NumLevels) @@ -1101,11 +1076,10 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, NumLevels); } -static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, - VectorType *&Ty) { +TTI::ReductionKind TTI::matchPairwiseReduction( + const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty) { if (!EnableReduxCost) - return RK_None; + return TTI::RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast(ReduxRoot->getOperand(1)); @@ -1113,19 +1087,19 @@ static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return RK_None; + return TTI::RK_None; auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) - return RK_None; - Optional RD = getReductionData(RdxStart); + return TTI::RK_None; + Optional RD = getReductionData(RdxStart); if (!RD) - return RK_None; + return TTI::RK_None; auto *VecTy = cast(RdxStart->getType()); unsigned NumVecElems = VecTy->getNumElements(); if (!isPowerOf2_32(NumVecElems)) - return RK_None; + return TTI::RK_None; // We look for a sequence of shuffle,shuffle,add triples like the following // that builds a pairwise reduction tree. @@ -1146,8 +1120,8 @@ static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 // %r = extractelement <4 x float> %bin.rdx8, i32 0 if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) == - RK_None) - return RK_None; + TTI::RK_None) + return TTI::RK_None; Opcode = RD->Opcode; Ty = VecTy; @@ -1166,11 +1140,11 @@ getShuffleAndOtherOprd(Value *L, Value *R) { return std::make_pair(L, S); } -static ReductionKind -matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, VectorType *&Ty) { +TTI::ReductionKind TTI::matchVectorSplittingReduction( + const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty) { + if (!EnableReduxCost) - return RK_None; + return TTI::RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast(ReduxRoot->getOperand(1)); @@ -1178,19 +1152,19 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return RK_None; + return TTI::RK_None; auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) - return RK_None; - Optional RD = getReductionData(RdxStart); + return TTI::RK_None; + Optional RD = getReductionData(RdxStart); if (!RD) - return RK_None; + return TTI::RK_None; auto *VecTy = cast(ReduxRoot->getOperand(0)->getType()); unsigned NumVecElems = VecTy->getNumElements(); if (!isPowerOf2_32(NumVecElems)) - return RK_None; + return TTI::RK_None; // We look for a sequence of shuffles and adds like the following matching one // fadd, shuffle vector pair at a time. @@ -1210,10 +1184,10 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, while (NumVecElemsRemain - 1) { // Check for the right reduction operation. if (!RdxOp) - return RK_None; - Optional RDLevel = getReductionData(RdxOp); + return TTI::RK_None; + Optional RDLevel = getReductionData(RdxOp); if (!RDLevel || !RDLevel->hasSameData(*RD)) - return RK_None; + return TTI::RK_None; Value *NextRdxOp; ShuffleVectorInst *Shuffle; @@ -1222,9 +1196,9 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) - return RK_None; + return TTI::RK_None; if (Shuffle->getOperand(0) != NextRdxOp) - return RK_None; + return TTI::RK_None; // Check that shuffle masks matches. for (unsigned j = 0; j != MaskStart; ++j) @@ -1234,7 +1208,7 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, ArrayRef Mask = Shuffle->getShuffleMask(); if (ShuffleMask != Mask) - return RK_None; + return TTI::RK_None; RdxOp = dyn_cast(NextRdxOp); NumVecElemsRemain /= 2; @@ -1291,10 +1265,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { case Instruction::Select: case Instruction::ICmp: case Instruction::FCmp: - return getUserCost(I, CostKind); case Instruction::Store: case Instruction::Load: - return getUserCost(I, CostKind); case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -1308,59 +1280,10 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { case Instruction::FPTrunc: case Instruction::BitCast: case Instruction::AddrSpaceCast: - return getUserCost(I, CostKind); - case Instruction::ExtractElement: { - const ExtractElementInst *EEI = cast(I); - ConstantInt *CI = dyn_cast(I->getOperand(1)); - unsigned Idx = -1; - if (CI) - Idx = CI->getZExtValue(); - - // Try to match a reduction sequence (series of shufflevector and vector - // adds followed by a extractelement). - unsigned ReduxOpCode; - VectorType *ReduxType; - - switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { - case RK_Arithmetic: - return getArithmeticReductionCost(ReduxOpCode, ReduxType, - /*IsPairwiseForm=*/false, - CostKind); - case RK_MinMax: - return getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); - case RK_UnsignedMinMax: - return getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); - case RK_None: - break; - } - - switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { - case RK_Arithmetic: - return getArithmeticReductionCost(ReduxOpCode, ReduxType, - /*IsPairwiseForm=*/true, CostKind); - case RK_MinMax: - return getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/true, /*IsUnsigned=*/false); - case RK_UnsignedMinMax: - return getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/true, /*IsUnsigned=*/true); - case RK_None: - break; - } - - return getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(), - Idx); - } + case Instruction::ExtractElement: case Instruction::InsertElement: case Instruction::ExtractValue: case Instruction::ShuffleVector: - return getUserCost(I, CostKind); case Instruction::Call: return getUserCost(I, CostKind); default: diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp index 38ee6d1e7aafe0..2c42b7a31aa2c8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -992,13 +992,6 @@ GCNTTIImpl::getUserCost(const User *U, ArrayRef Operands, // Estimate different operations to be optimized out switch (I->getOpcode()) { - case Instruction::ExtractElement: { - ConstantInt *CI = dyn_cast(I->getOperand(1)); - unsigned Idx = -1; - if (CI) - Idx = CI->getZExtValue(); - return getVectorInstrCost(I->getOpcode(), I->getOperand(0)->getType(), Idx); - } case Instruction::FNeg: return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind, TTI::OK_AnyValue, TTI::OK_AnyValue,