Skip to content

Commit

Permalink
Revert "Revert "[CodeGen] Extend reduction support in ComplexDeinterl…
Browse files Browse the repository at this point in the history
…eaving pass to support predication""

Adds the capability to recognize SelectInst that appear in the IR.
These instructions are generated during scalable vectorization for reduction
and when the code contains conditions inside the loop body or when
"-prefer-predicate-over-epilogue=predicate-dont-vectorize" is set.

Differential Revision: https://reviews.llvm.org/D152558

This reverts commit ab09654.

Reason: Reapplying after removing unnecessary default case in switch expression.
  • Loading branch information
igogo-x86 committed Jun 23, 2023
1 parent 2273741 commit 04a8070
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 56 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum class ComplexDeinterleavingOperation {
Symmetric,
ReductionPHI,
ReductionOperation,
ReductionSelect,
};

enum class ComplexDeinterleavingRotation {
Expand Down
59 changes: 59 additions & 0 deletions llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ class ComplexDeinterleavingGraph {

NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);

/// Identifies SelectInsts in a loop that has reduction with predication masks
/// and/or predicated tail folding
NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);

Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);

/// Complete IR modifications after producing new reduction operation:
Expand Down Expand Up @@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
if (NodePtr CN = identifyPHINode(Real, Imag))
return CN;

if (NodePtr CN = identifySelectNode(Real, Imag))
return CN;

auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);

Expand Down Expand Up @@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
Instruction *Imag) {
auto *SelectReal = dyn_cast<SelectInst>(Real);
auto *SelectImag = dyn_cast<SelectInst>(Imag);
if (!SelectReal || !SelectImag)
return nullptr;

Instruction *MaskA, *MaskB;
Instruction *AR, *AI, *RA, *BI;
if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
m_Instruction(RA))) ||
!match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
m_Instruction(BI))))
return nullptr;

if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
return nullptr;

if (!MaskA->getType()->isVectorTy())
return nullptr;

auto NodeA = identifyNode(AR, AI);
if (!NodeA)
return nullptr;

auto NodeB = identifyNode(RA, BI);
if (!NodeB)
return nullptr;

NodePtr PlaceholderNode = prepareCompositeNode(
ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
PlaceholderNode->addOperand(NodeA);
PlaceholderNode->addOperand(NodeB);
FinalInstructions.insert(MaskA);
FinalInstructions.insert(MaskB);
return submitCompositeNode(PlaceholderNode);
}

static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
FastMathFlags Flags, Value *InputA,
Value *InputB) {
Expand Down Expand Up @@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionOperation(ReplacementNode, Node);
break;
case ComplexDeinterleavingOperation::ReductionSelect: {
auto *MaskReal = Node->Real->getOperand(0);
auto *MaskImag = Node->Imag->getOperand(0);
auto *A = replaceNode(Builder, Node->Operands[0]);
auto *B = replaceNode(Builder, Node->Operands[1]);
auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
cast<VectorType>(MaskReal->getType()));
auto *NewMask =
Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
NewMaskTy, {MaskReal, MaskImag});
ReplacementNode = Builder.CreateSelect(NewMask, A, B);
break;
}
}

assert(ReplacementNode && "Target failed to create Intrinsic call.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,37 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
; CHECK-NEXT: mov x11, x10
; CHECK-NEXT: mov z1.d, #0 // =0x0
; CHECK-NEXT: rdvl x12, #2
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: whilelo p1.d, xzr, x9
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: add x13, x0, x8
; CHECK-NEXT: add x14, x1, x8
; CHECK-NEXT: zip1 p2.d, p1.d, p1.d
; CHECK-NEXT: zip2 p3.d, p1.d, p1.d
; CHECK-NEXT: add x8, x8, x12
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl]
; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13]
; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl]
; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14]
; CHECK-NEXT: uzp2 z6.d, z3.d, z2.d
; CHECK-NEXT: uzp1 z2.d, z3.d, z2.d
; CHECK-NEXT: uzp1 z7.d, z5.d, z4.d
; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d
; CHECK-NEXT: movprfx z3, z1
; CHECK-NEXT: fmla z3.d, p0/m, z7.d, z6.d
; CHECK-NEXT: fmad z7.d, p0/m, z2.d, z0.d
; CHECK-NEXT: fmad z2.d, p0/m, z4.d, z3.d
; CHECK-NEXT: movprfx z3, z7
; CHECK-NEXT: fmls z3.d, p0/m, z4.d, z6.d
; CHECK-NEXT: mov z1.d, p1/m, z2.d
; CHECK-NEXT: mov z0.d, p1/m, z3.d
; CHECK-NEXT: whilelo p1.d, x11, x9
; CHECK-NEXT: add x8, x8, x12
; CHECK-NEXT: add x11, x11, x10
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: mov z0.d, p3/m, z7.d
; CHECK-NEXT: mov z1.d, p2/m, z6.d
; CHECK-NEXT: b.mi .LBB0_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d
; CHECK-NEXT: faddv d0, p0, z0.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: faddv d1, p0, z2.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1
; CHECK-NEXT: ret
Expand Down Expand Up @@ -122,39 +121,38 @@ define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %
; CHECK-NEXT: and x11, x11, x12
; CHECK-NEXT: mov z1.d, #0 // =0x0
; CHECK-NEXT: rdvl x12, #2
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: .LBB1_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ld1w { z2.d }, p0/z, [x2, x9, lsl #2]
; CHECK-NEXT: add x13, x0, x8
; CHECK-NEXT: add x14, x1, x8
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: add x9, x9, x10
; CHECK-NEXT: add x8, x8, x12
; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, #0
; CHECK-NEXT: zip1 p2.d, p1.d, p1.d
; CHECK-NEXT: zip2 p3.d, p1.d, p1.d
; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl]
; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13]
; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl]
; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14]
; CHECK-NEXT: cmpne p2.d, p0/z, z2.d, #0
; CHECK-NEXT: zip1 p1.d, p2.d, p2.d
; CHECK-NEXT: zip2 p2.d, p2.d, p2.d
; CHECK-NEXT: ld1d { z2.d }, p2/z, [x13, #1, mul vl]
; CHECK-NEXT: ld1d { z3.d }, p1/z, [x13]
; CHECK-NEXT: ld1d { z4.d }, p2/z, [x14, #1, mul vl]
; CHECK-NEXT: ld1d { z5.d }, p1/z, [x14]
; CHECK-NEXT: cmp x11, x9
; CHECK-NEXT: uzp2 z6.d, z3.d, z2.d
; CHECK-NEXT: uzp1 z2.d, z3.d, z2.d
; CHECK-NEXT: uzp1 z3.d, z5.d, z4.d
; CHECK-NEXT: movprfx z7, z0
; CHECK-NEXT: fmla z7.d, p0/m, z3.d, z2.d
; CHECK-NEXT: fmad z3.d, p0/m, z6.d, z1.d
; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d
; CHECK-NEXT: fmad z2.d, p0/m, z4.d, z3.d
; CHECK-NEXT: movprfx z5, z7
; CHECK-NEXT: fmls z5.d, p0/m, z4.d, z6.d
; CHECK-NEXT: mov z0.d, p1/m, z5.d
; CHECK-NEXT: mov z1.d, p1/m, z2.d
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: mov z0.d, p2/m, z7.d
; CHECK-NEXT: mov z1.d, p1/m, z6.d
; CHECK-NEXT: b.ne .LBB1_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d
; CHECK-NEXT: faddv d0, p0, z0.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: faddv d1, p0, z2.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1
; CHECK-NEXT: ret
Expand Down Expand Up @@ -223,42 +221,41 @@ define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, pt
; CHECK-NEXT: mov x8, xzr
; CHECK-NEXT: mov x9, xzr
; CHECK-NEXT: mov z1.d, #0 // =0x0
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: cntd x11
; CHECK-NEXT: whilelo p1.d, xzr, x10
; CHECK-NEXT: rdvl x12, #2
; CHECK-NEXT: whilelo p1.d, xzr, x10
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: .LBB2_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ld1w { z2.d }, p1/z, [x2, x9, lsl #2]
; CHECK-NEXT: add x13, x0, x8
; CHECK-NEXT: add x14, x1, x8
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: add x9, x9, x11
; CHECK-NEXT: add x8, x8, x12
; CHECK-NEXT: cmpne p2.d, p1/z, z2.d, #0
; CHECK-NEXT: zip1 p1.d, p2.d, p2.d
; CHECK-NEXT: zip2 p3.d, p2.d, p2.d
; CHECK-NEXT: cmpne p1.d, p1/z, z2.d, #0
; CHECK-NEXT: zip1 p2.d, p1.d, p1.d
; CHECK-NEXT: zip2 p3.d, p1.d, p1.d
; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl]
; CHECK-NEXT: ld1d { z3.d }, p1/z, [x13]
; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13]
; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl]
; CHECK-NEXT: ld1d { z5.d }, p1/z, [x14]
; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14]
; CHECK-NEXT: whilelo p1.d, x9, x10
; CHECK-NEXT: uzp1 z6.d, z3.d, z2.d
; CHECK-NEXT: uzp2 z2.d, z3.d, z2.d
; CHECK-NEXT: uzp1 z7.d, z5.d, z4.d
; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d
; CHECK-NEXT: movprfx z3, z0
; CHECK-NEXT: fmla z3.d, p0/m, z7.d, z6.d
; CHECK-NEXT: fmad z7.d, p0/m, z2.d, z1.d
; CHECK-NEXT: fmsb z2.d, p0/m, z4.d, z3.d
; CHECK-NEXT: movprfx z3, z7
; CHECK-NEXT: fmla z3.d, p0/m, z4.d, z6.d
; CHECK-NEXT: mov z1.d, p2/m, z3.d
; CHECK-NEXT: mov z0.d, p2/m, z2.d
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: mov z0.d, p3/m, z7.d
; CHECK-NEXT: mov z1.d, p2/m, z6.d
; CHECK-NEXT: b.mi .LBB2_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d
; CHECK-NEXT: faddv d0, p0, z0.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: faddv d1, p0, z2.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1
; CHECK-NEXT: ret
Expand Down

0 comments on commit 04a8070

Please sign in to comment.