Skip to content

Commit

Permalink
[AArch64][GlobalISel] Add G_VECREDUCE fewerElements support for full …
Browse files Browse the repository at this point in the history
…scalarization.

For some reductions like G_VECREDUCE_OR on AArch64, we need to scalarize
completely if the source is <= 64b. This change adds support for that in
the legalizer. If the source has a pow-2 num elements, then we can do
a tree reduction using the scalar operation in the individual elements.
Otherwise, we just create a sequential chain of operations.

For AArch64, we only need to scalarize if the input is <64b. If it's great than
64b then we can first do a fewElements step to 64b, taking advantage of vector
instructions until we reach the point of scalarization.

I also had to relax the verifier checks for reductions because the intrinsics
support <1 x EltTy> types, which we lower to scalars for GlobalISel.

Differential Revision: https://reviews.llvm.org/D108276
  • Loading branch information
aemerson committed Aug 19, 2021
1 parent fbb8e77 commit 95ac3d1
Show file tree
Hide file tree
Showing 8 changed files with 1,137 additions and 37 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
Expand Up @@ -403,6 +403,7 @@ class LegalizerHelper {
LegalizeResult lowerAbsToAddXor(MachineInstr &MI);
LegalizeResult lowerAbsToMaxNeg(MachineInstr &MI);
LegalizeResult lowerIsNaN(MachineInstr &MI);
LegalizeResult lowerVectorReduction(MachineInstr &MI);
};

/// Helper function that creates a libcall to the given \p Name using the given
Expand Down
124 changes: 94 additions & 30 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Expand Up @@ -3489,6 +3489,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
return lowerRotate(MI);
case G_ISNAN:
return lowerIsNaN(MI);
GISEL_VECREDUCE_CASES_NONSEQ
return lowerVectorReduction(MI);
}
}

Expand Down Expand Up @@ -4637,35 +4639,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
return Legalized;
}

LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
unsigned Opc = MI.getOpcode();
assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
"Sequential reductions not expected");

if (TypeIdx != 1)
return UnableToLegalize;

// The semantics of the normal non-sequential reductions allow us to freely
// re-associate the operation.
Register SrcReg = MI.getOperand(1).getReg();
LLT SrcTy = MRI.getType(SrcReg);
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
return UnableToLegalize;

SmallVector<Register> SplitSrcs;
const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
SmallVector<Register> PartialReductions;
for (unsigned Part = 0; Part < NumParts; ++Part) {
PartialReductions.push_back(
MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
}

static unsigned getScalarOpcForReduction(unsigned Opc) {
unsigned ScalarOpc;
switch (Opc) {
case TargetOpcode::G_VECREDUCE_FADD:
Expand Down Expand Up @@ -4708,10 +4682,81 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
ScalarOpc = TargetOpcode::G_UMIN;
break;
default:
LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
llvm_unreachable("Unhandled reduction");
}
return ScalarOpc;
}

LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
unsigned Opc = MI.getOpcode();
assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
"Sequential reductions not expected");

if (TypeIdx != 1)
return UnableToLegalize;

// The semantics of the normal non-sequential reductions allow us to freely
// re-associate the operation.
Register SrcReg = MI.getOperand(1).getReg();
LLT SrcTy = MRI.getType(SrcReg);
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);

if (NarrowTy.isVector() &&
(SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
return UnableToLegalize;

unsigned ScalarOpc = getScalarOpcForReduction(Opc);
SmallVector<Register> SplitSrcs;
// If NarrowTy is a scalar then we're being asked to scalarize.
const unsigned NumParts =
NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
: SrcTy.getNumElements();

extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
if (NarrowTy.isScalar()) {
if (DstTy != NarrowTy)
return UnableToLegalize; // FIXME: handle implicit extensions.

if (isPowerOf2_32(NumParts)) {
// Generate a tree of scalar operations to reduce the critical path.
SmallVector<Register> PartialResults;
unsigned NumPartsLeft = NumParts;
while (NumPartsLeft > 1) {
for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
PartialResults.emplace_back(
MIRBuilder
.buildInstr(ScalarOpc, {NarrowTy},
{SplitSrcs[Idx], SplitSrcs[Idx + 1]})
.getReg(0));
}
SplitSrcs = PartialResults;
PartialResults.clear();
NumPartsLeft = SplitSrcs.size();
}
assert(SplitSrcs.size() == 1);
MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
MI.eraseFromParent();
return Legalized;
}
// If we can't generate a tree, then just do sequential operations.
Register Acc = SplitSrcs[0];
for (unsigned Idx = 1; Idx < NumParts; ++Idx)
Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
.getReg(0);
MIRBuilder.buildCopy(DstReg, Acc);
MI.eraseFromParent();
return Legalized;
}
SmallVector<Register> PartialReductions;
for (unsigned Part = 0; Part < NumParts; ++Part) {
PartialReductions.push_back(
MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
}


// If the types involved are powers of 2, we can generate intermediate vector
// ops, before generating a final reduction operation.
if (isPowerOf2_32(SrcTy.getNumElements()) &&
Expand Down Expand Up @@ -7389,3 +7434,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerIsNaN(MachineInstr &MI) {
MI.eraseFromParent();
return Legalized;
}

LegalizerHelper::LegalizeResult
LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
Register SrcReg = MI.getOperand(1).getReg();
LLT SrcTy = MRI.getType(SrcReg);
LLT DstTy = MRI.getType(SrcReg);

// The source could be a scalar if the IR type was <1 x sN>.
if (SrcTy.isScalar()) {
if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
return UnableToLegalize; // FIXME: handle extension.
// This can be just a plain copy.
Observer.changingInstr(MI);
MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
Observer.changedInstr(MI);
return Legalized;
}
return UnableToLegalize;;
}
3 changes: 0 additions & 3 deletions llvm/lib/CodeGen/MachineVerifier.cpp
Expand Up @@ -1589,11 +1589,8 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
case TargetOpcode::G_VECREDUCE_UMAX:
case TargetOpcode::G_VECREDUCE_UMIN: {
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
if (!DstTy.isScalar())
report("Vector reduction requires a scalar destination type", MI);
if (!SrcTy.isVector())
report("Vector reduction requires vector source=", MI);
break;
}

Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Expand Up @@ -691,6 +691,27 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampMaxNumElements(1, s32, 4)
.lower();

getActionDefinitionsBuilder(G_VECREDUCE_OR)
// Try to break down into smaller vectors as long as they're at least 64
// bits. This lets us use vector operations for some parts of the
// reduction.
.fewerElementsIf(
[=](const LegalityQuery &Q) {
LLT SrcTy = Q.Types[1];
if (SrcTy.isScalar())
return false;
if (!isPowerOf2_32(SrcTy.getNumElements()))
return false;
// We can usually perform 64b vector operations.
return SrcTy.getSizeInBits() > 64;
},
[=](const LegalityQuery &Q) {
LLT SrcTy = Q.Types[1];
return std::make_pair(1, SrcTy.divide(2));
})
.scalarize(1)
.lower();

getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT})
.lowerIf([=](const LegalityQuery &Q) { return Q.Types[0].isScalar(); });

Expand Down

0 comments on commit 95ac3d1

Please sign in to comment.