Skip to content

Commit

Permalink
[SLP]Alternate vectorization for cmp instructions.
Browse files Browse the repository at this point in the history
Added support for alternate ops vectorization of the cmp instructions.
It allows to vectorize either cmp instructions with same/swapped
predicate but different (swapped) operands kinds or cmp instructions
with different predicates and compatible operands kinds.

Differential Revision: https://reviews.llvm.org/D115955
  • Loading branch information
alexey-bataev committed Feb 2, 2022
1 parent 287ce6b commit 842a236
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 251 deletions.
175 changes: 169 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Expand Up @@ -471,17 +471,36 @@ static bool isValidForAlternation(unsigned Opcode) {
return true;
}

static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
unsigned BaseIndex = 0);

/// Checks if the provided operands of 2 cmp instructions are compatible, i.e.
/// compatible instructions or constants, or just some other regular values.
static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0,
Value *Op1) {
return (isConstant(BaseOp0) && isConstant(Op0)) ||
(isConstant(BaseOp1) && isConstant(Op1)) ||
(!isa<Instruction>(BaseOp0) && !isa<Instruction>(Op0) &&
!isa<Instruction>(BaseOp1) && !isa<Instruction>(Op1)) ||
getSameOpcode({BaseOp0, Op0}).getOpcode() ||
getSameOpcode({BaseOp1, Op1}).getOpcode();
}

/// \returns analysis of the Instructions in \p VL described in
/// InstructionsState, the Opcode that we suppose the whole list
/// could be vectorized even if its structure is diverse.
static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
unsigned BaseIndex = 0) {
unsigned BaseIndex) {
// Make sure these are all Instructions.
if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); }))
return InstructionsState(VL[BaseIndex], nullptr, nullptr);

bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
CmpInst::Predicate BasePred =
IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
: CmpInst::BAD_ICMP_PREDICATE;
unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
unsigned AltOpcode = Opcode;
unsigned AltIndex = BaseIndex;
Expand Down Expand Up @@ -514,6 +533,57 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
continue;
}
}
} else if (IsCmpOp && isa<CmpInst>(VL[Cnt])) {
auto *BaseInst = cast<Instruction>(VL[BaseIndex]);
auto *Inst = cast<Instruction>(VL[Cnt]);
Type *Ty0 = BaseInst->getOperand(0)->getType();
Type *Ty1 = Inst->getOperand(0)->getType();
if (Ty0 == Ty1) {
Value *BaseOp0 = BaseInst->getOperand(0);
Value *BaseOp1 = BaseInst->getOperand(1);
Value *Op0 = Inst->getOperand(0);
Value *Op1 = Inst->getOperand(1);
CmpInst::Predicate CurrentPred =
cast<CmpInst>(VL[Cnt])->getPredicate();
CmpInst::Predicate SwappedCurrentPred =
CmpInst::getSwappedPredicate(CurrentPred);
// Check for compatible operands. If the corresponding operands are not
// compatible - need to perform alternate vectorization.
if (InstOpcode == Opcode) {
if (BasePred == CurrentPred &&
areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1))
continue;
if (BasePred == SwappedCurrentPred &&
areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0))
continue;
if (E == 2 &&
(BasePred == CurrentPred || BasePred == SwappedCurrentPred))
continue;
auto *AltInst = cast<CmpInst>(VL[AltIndex]);
CmpInst::Predicate AltPred = AltInst->getPredicate();
Value *AltOp0 = AltInst->getOperand(0);
Value *AltOp1 = AltInst->getOperand(1);
// Check if operands are compatible with alternate operands.
if (AltPred == CurrentPred &&
areCompatibleCmpOps(AltOp0, AltOp1, Op0, Op1))
continue;
if (AltPred == SwappedCurrentPred &&
areCompatibleCmpOps(AltOp0, AltOp1, Op1, Op0))
continue;
}
if (BaseIndex == AltIndex && BasePred != CurrentPred) {
assert(isValidForAlternation(Opcode) &&
isValidForAlternation(InstOpcode) &&
"Cast isn't safe for alternation, logic needs to be updated!");
AltIndex = Cnt;
continue;
}
auto *AltInst = cast<CmpInst>(VL[AltIndex]);
CmpInst::Predicate AltPred = AltInst->getPredicate();
if (BasePred == CurrentPred || BasePred == SwappedCurrentPred ||
AltPred == CurrentPred || AltPred == SwappedCurrentPred)
continue;
}
} else if (InstOpcode == Opcode || InstOpcode == AltOpcode)
continue;
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
Expand Down Expand Up @@ -4354,9 +4424,41 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");

// Reorder operands if reordering would enable vectorization.
if (isa<BinaryOperator>(VL0)) {
auto *CI = dyn_cast<CmpInst>(VL0);
if (isa<BinaryOperator>(VL0) || CI) {
ValueList Left, Right;
reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this);
if (!CI || all_of(VL, [](Value *V) {
return cast<CmpInst>(V)->isCommutative();
})) {
reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this);
} else {
CmpInst::Predicate P0 = CI->getPredicate();
CmpInst::Predicate AltP0 = cast<CmpInst>(S.AltOp)->getPredicate();
assert(P0 != AltP0 &&
"Expected different main/alternate predicates.");
CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0);
Value *BaseOp0 = VL0->getOperand(0);
Value *BaseOp1 = VL0->getOperand(1);
// Collect operands - commute if it uses the swapped predicate or
// alternate operation.
for (Value *V : VL) {
auto *Cmp = cast<CmpInst>(V);
Value *LHS = Cmp->getOperand(0);
Value *RHS = Cmp->getOperand(1);
CmpInst::Predicate CurrentPred = CI->getPredicate();
if (P0 == AltP0Swapped) {
if ((P0 == CurrentPred &&
!areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) ||
(AltP0 == CurrentPred &&
areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)))
std::swap(LHS, RHS);
} else if (P0 != CurrentPred && AltP0 != CurrentPred) {
std::swap(LHS, RHS);
}
Left.push_back(LHS);
Right.push_back(RHS);
}
}
TE->setOperand(0, Left);
TE->setOperand(1, Right);
buildTree_rec(Left, Depth + 1, {TE, 0});
Expand Down Expand Up @@ -5288,7 +5390,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
((Instruction::isBinaryOp(E->getOpcode()) &&
Instruction::isBinaryOp(E->getAltOpcode())) ||
(Instruction::isCast(E->getOpcode()) &&
Instruction::isCast(E->getAltOpcode()))) &&
Instruction::isCast(E->getAltOpcode())) ||
(isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
"Invalid Shuffle Vector Operand");
InstructionCost ScalarCost = 0;
if (NeedToShuffleReuses) {
Expand Down Expand Up @@ -5336,6 +5439,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, CostKind);
VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy,
CostKind);
} else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy,
Builder.getInt1Ty(),
CI0->getPredicate(), CostKind, VL0);
VecCost += TTI->getCmpSelInstrCost(
E->getOpcode(), ScalarTy, Builder.getInt1Ty(),
cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
E->getAltOp());
} else {
Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
Expand All @@ -5352,6 +5463,27 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
[E](Instruction *I) {
assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
if (auto *CI0 = dyn_cast<CmpInst>(E->getMainOp())) {
auto *AltCI0 = cast<CmpInst>(E->getAltOp());
auto *CI = cast<CmpInst>(I);
CmpInst::Predicate P0 = CI0->getPredicate();
CmpInst::Predicate AltP0 = AltCI0->getPredicate();
assert(P0 != AltP0 &&
"Expected different main/alternate predicates.");
CmpInst::Predicate AltP0Swapped =
CmpInst::getSwappedPredicate(AltP0);
CmpInst::Predicate CurrentPred = CI->getPredicate();
if (P0 == AltP0Swapped)
return (P0 == CurrentPred &&
!areCompatibleCmpOps(
CI0->getOperand(0), CI0->getOperand(1),
CI->getOperand(0), CI->getOperand(1))) ||
(AltP0 == CurrentPred &&
!areCompatibleCmpOps(
CI0->getOperand(0), CI0->getOperand(1),
CI->getOperand(1), CI->getOperand(0)));
return AltP0 == CurrentPred || AltP0Swapped == CurrentPred;
}
return I->getOpcode() == E->getAltOpcode();
},
Mask);
Expand Down Expand Up @@ -6834,11 +6966,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
((Instruction::isBinaryOp(E->getOpcode()) &&
Instruction::isBinaryOp(E->getAltOpcode())) ||
(Instruction::isCast(E->getOpcode()) &&
Instruction::isCast(E->getAltOpcode()))) &&
Instruction::isCast(E->getAltOpcode())) ||
(isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
"Invalid Shuffle Vector Operand");

Value *LHS = nullptr, *RHS = nullptr;
if (Instruction::isBinaryOp(E->getOpcode())) {
if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) {
setInsertPointAfterBundle(E);
LHS = vectorizeTree(E->getOperand(0));
RHS = vectorizeTree(E->getOperand(1));
Expand All @@ -6858,6 +6991,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS);
V1 = Builder.CreateBinOp(
static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS);
} else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
V0 = Builder.CreateCmp(CI0->getPredicate(), LHS, RHS);
auto *AltCI = cast<CmpInst>(E->getAltOp());
CmpInst::Predicate AltPred = AltCI->getPredicate();
unsigned AltIdx =
std::distance(E->Scalars.begin(), find(E->Scalars, AltCI));
if (AltCI->getOperand(0) != E->getOperand(0)[AltIdx])
AltPred = CmpInst::getSwappedPredicate(AltPred);
V1 = Builder.CreateCmp(AltPred, LHS, RHS);
} else {
V0 = Builder.CreateCast(
static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
Expand All @@ -6882,6 +7024,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
[E](Instruction *I) {
assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
if (auto *CI0 = dyn_cast<CmpInst>(E->getMainOp())) {
auto *AltCI0 = cast<CmpInst>(E->getAltOp());
auto *CI = cast<CmpInst>(I);
CmpInst::Predicate P0 = CI0->getPredicate();
CmpInst::Predicate AltP0 = AltCI0->getPredicate();
assert(P0 != AltP0 &&
"Expected different main/alternate predicates.");
CmpInst::Predicate AltP0Swapped =
CmpInst::getSwappedPredicate(AltP0);
CmpInst::Predicate CurrentPred = CI->getPredicate();
if (P0 == AltP0Swapped)
return (P0 == CurrentPred &&
!areCompatibleCmpOps(
CI0->getOperand(0), CI0->getOperand(1),
CI->getOperand(0), CI->getOperand(1))) ||
(AltP0 == CurrentPred &&
!areCompatibleCmpOps(
CI0->getOperand(0), CI0->getOperand(1),
CI->getOperand(1), CI->getOperand(0)));
return AltP0 == CurrentPred || AltP0Swapped == CurrentPred;
}
return I->getOpcode() == E->getAltOpcode();
},
Mask, &OpScalars, &AltScalars);
Expand Down
57 changes: 22 additions & 35 deletions llvm/test/Transforms/PhaseOrdering/X86/vector-reductions-logical.ll
Expand Up @@ -90,24 +90,17 @@ return:
define float @test_merge_anyof_v4sf(<4 x float> %t) {
; CHECK-LABEL: @test_merge_anyof_v4sf(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[T:%.*]], i64 3
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[T]], i64 2
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[T]], i64 1
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[T]], i64 0
; CHECK-NEXT: [[T_FR:%.*]] = freeze <4 x float> [[T]]
; CHECK-NEXT: [[TMP4:%.*]] = fcmp olt <4 x float> [[T_FR]], zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i1> [[TMP4]] to i4
; CHECK-NEXT: [[TMP6:%.*]] = icmp ne i4 [[TMP5]], 0
; CHECK-NEXT: [[CMP19:%.*]] = fcmp ogt float [[TMP3]], 1.000000e+00
; CHECK-NEXT: [[OR_COND3:%.*]] = select i1 [[TMP6]], i1 true, i1 [[CMP19]]
; CHECK-NEXT: [[CMP24:%.*]] = fcmp ogt float [[TMP2]], 1.000000e+00
; CHECK-NEXT: [[OR_COND4:%.*]] = select i1 [[OR_COND3]], i1 true, i1 [[CMP24]]
; CHECK-NEXT: [[CMP29:%.*]] = fcmp ogt float [[TMP1]], 1.000000e+00
; CHECK-NEXT: [[OR_COND5:%.*]] = select i1 [[OR_COND4]], i1 true, i1 [[CMP29]]
; CHECK-NEXT: [[CMP34:%.*]] = fcmp ogt float [[TMP0]], 1.000000e+00
; CHECK-NEXT: [[OR_COND6:%.*]] = select i1 [[OR_COND5]], i1 true, i1 [[CMP34]]
; CHECK-NEXT: [[ADD:%.*]] = fadd float [[TMP3]], [[TMP2]]
; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[OR_COND6]], float 0.000000e+00, float [[ADD]]
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x float> [[T:%.*]], <4 x float> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP0:%.*]] = fcmp olt <8 x float> [[SHUFFLE]], <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
; CHECK-NEXT: [[TMP1:%.*]] = fcmp ogt <8 x float> [[SHUFFLE]], <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i1> [[TMP0]], <8 x i1> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP3:%.*]] = freeze <8 x i1> [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i1> [[TMP3]] to i8
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP4]], 0
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x float> [[T]], <4 x float> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP5:%.*]] = fadd <4 x float> [[SHIFT]], [[T]]
; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x float> [[TMP5]], i64 0
; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[DOTNOT]], float [[ADD]], float 0.000000e+00
; CHECK-NEXT: ret float [[RETVAL_0]]
;
entry:
Expand Down Expand Up @@ -420,24 +413,18 @@ return:
define float @test_merge_anyof_v4si(<4 x i32> %t) {
; CHECK-LABEL: @test_merge_anyof_v4si(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[T_FR:%.*]] = freeze <4 x i32> [[T:%.*]]
; CHECK-NEXT: [[TMP0:%.*]] = icmp slt <4 x i32> [[T_FR]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[TMP0]] to i4
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i4 [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[T_FR]], <i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i1> [[TMP3]], i64 0
; CHECK-NEXT: [[OR_COND3:%.*]] = or i1 [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i1> [[TMP3]], i64 1
; CHECK-NEXT: [[OR_COND4:%.*]] = or i1 [[OR_COND3]], [[TMP5]]
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i1> [[TMP3]], i64 2
; CHECK-NEXT: [[OR_COND5:%.*]] = or i1 [[OR_COND4]], [[TMP6]]
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x i1> [[TMP3]], i64 3
; CHECK-NEXT: [[OR_COND6:%.*]] = or i1 [[OR_COND5]], [[TMP7]]
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[T_FR]], <4 x i32> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP8:%.*]] = add nsw <4 x i32> [[SHIFT]], [[T_FR]]
; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x i32> [[TMP8]], i64 0
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[T:%.*]], <4 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP0:%.*]] = icmp slt <8 x i32> [[SHUFFLE]], <i32 1, i32 1, i32 1, i32 1, i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <8 x i32> [[SHUFFLE]], <i32 1, i32 1, i32 1, i32 1, i32 255, i32 255, i32 255, i32 255>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i1> [[TMP0]], <8 x i1> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP3:%.*]] = freeze <8 x i1> [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i1> [[TMP3]] to i8
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP4]], 0
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[T]], <4 x i32> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP5:%.*]] = add nsw <4 x i32> [[SHIFT]], [[T]]
; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x i32> [[TMP5]], i64 0
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[ADD]] to float
; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[OR_COND6]], float 0.000000e+00, float [[CONV]]
; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[DOTNOT]], float [[CONV]], float 0.000000e+00
; CHECK-NEXT: ret float [[RETVAL_0]]
;
entry:
Expand Down

0 comments on commit 842a236

Please sign in to comment.