Skip to content

Commit

Permalink
[CodeGen] Extend ComplexDeinterleaving pass to recognise patterns usi…
Browse files Browse the repository at this point in the history
…ng integer types

AArch64 introduced CMLA and CADD instructions as part of SVE2. This
change allows to generate such instructions when this architecture
feature is available.

Differential Revision: https://reviews.llvm.org/D153808
  • Loading branch information
igogo-x86 committed Jul 19, 2023
1 parent 98b0f13 commit c15557d
Show file tree
Hide file tree
Showing 9 changed files with 910 additions and 58 deletions.
174 changes: 121 additions & 53 deletions llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ static bool isInterleavingMask(ArrayRef<int> Mask);
/// <1, 3, 5, 7>).
static bool isDeinterleavingMask(ArrayRef<int> Mask);

/// Returns true if the operation is a negation of V, and it works for both
/// integers and floats.
static bool isNeg(Value *V);

/// Returns the operand for negation operation.
static Value *getNegOperand(Value *V);

namespace {

class ComplexDeinterleavingLegacyPass : public FunctionPass {
Expand Down Expand Up @@ -146,7 +153,7 @@ struct ComplexDeinterleavingCompositeNode {
// This two members are required exclusively for generating
// ComplexDeinterleavingOperation::Symmetric operations.
unsigned Opcode;
FastMathFlags Flags;
std::optional<FastMathFlags> Flags;

ComplexDeinterleavingRotation Rotation =
ComplexDeinterleavingRotation::Rotation_0;
Expand Down Expand Up @@ -333,7 +340,8 @@ class ComplexDeinterleavingGraph {
/// Return nullptr if it is not possible to construct a complex number.
/// \p Flags are needed to generate symmetric Add and Sub operations.
NodePtr identifyAdditions(std::list<Addend> &RealAddends,
std::list<Addend> &ImagAddends, FastMathFlags Flags,
std::list<Addend> &ImagAddends,
std::optional<FastMathFlags> Flags,
NodePtr Accumulator);

/// Extract one addend that have both real and imaginary parts positive.
Expand Down Expand Up @@ -512,6 +520,19 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
return true;
}

bool isNeg(Value *V) {
return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
}

Value *getNegOperand(Value *V) {
assert(isNeg(V));
auto *I = cast<Instruction>(V);
if (I->getOpcode() == Instruction::FNeg)
return I->getOperand(0);

return I->getOperand(1);
}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
ComplexDeinterleavingGraph Graph(TL, TLI);
if (Graph.collectPotentialReductions(B))
Expand Down Expand Up @@ -540,9 +561,12 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
return nullptr;
}

if (Real->getOpcode() != Instruction::FMul ||
Imag->getOpcode() != Instruction::FMul) {
LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
if ((Real->getOpcode() != Instruction::FMul &&
Real->getOpcode() != Instruction::Mul) ||
(Imag->getOpcode() != Instruction::FMul &&
Imag->getOpcode() != Instruction::Mul)) {
LLVM_DEBUG(
dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
return nullptr;
}

Expand All @@ -563,7 +587,7 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
R1 = Op;
}

if (match(I0, m_Neg(m_Value(Op)))) {
if (isNeg(I0)) {
Negs |= 2;
Negs ^= 1;
I0 = Op;
Expand Down Expand Up @@ -634,26 +658,29 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
<< "\n");
// Determine rotation
auto IsAdd = [](unsigned Op) {
return Op == Instruction::FAdd || Op == Instruction::Add;
};
auto IsSub = [](unsigned Op) {
return Op == Instruction::FSub || Op == Instruction::Sub;
};
ComplexDeinterleavingRotation Rotation;
if (Real->getOpcode() == Instruction::FAdd &&
Imag->getOpcode() == Instruction::FAdd)
if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
Rotation = ComplexDeinterleavingRotation::Rotation_0;
else if (Real->getOpcode() == Instruction::FSub &&
Imag->getOpcode() == Instruction::FAdd)
else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
Rotation = ComplexDeinterleavingRotation::Rotation_90;
else if (Real->getOpcode() == Instruction::FSub &&
Imag->getOpcode() == Instruction::FSub)
else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
Rotation = ComplexDeinterleavingRotation::Rotation_180;
else if (Real->getOpcode() == Instruction::FAdd &&
Imag->getOpcode() == Instruction::FSub)
else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
Rotation = ComplexDeinterleavingRotation::Rotation_270;
else {
LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
return nullptr;
}

if (!Real->getFastMathFlags().allowContract() ||
!Imag->getFastMathFlags().allowContract()) {
if (isa<FPMathOperator>(Real) &&
(!Real->getFastMathFlags().allowContract() ||
!Imag->getFastMathFlags().allowContract())) {
LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
return nullptr;
}
Expand Down Expand Up @@ -816,6 +843,9 @@ static bool isInstructionPotentiallySymmetric(Instruction *I) {
case Instruction::FSub:
case Instruction::FMul:
case Instruction::FNeg:
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
return true;
default:
return false;
Expand Down Expand Up @@ -925,27 +955,31 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
Instruction *Imag) {
auto IsOperationSupported = [](unsigned Opcode) -> bool {
return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
Opcode == Instruction::Sub;
};

if ((Real->getOpcode() != Instruction::FAdd &&
Real->getOpcode() != Instruction::FSub &&
Real->getOpcode() != Instruction::FNeg) ||
(Imag->getOpcode() != Instruction::FAdd &&
Imag->getOpcode() != Instruction::FSub &&
Imag->getOpcode() != Instruction::FNeg))
if (!IsOperationSupported(Real->getOpcode()) ||
!IsOperationSupported(Imag->getOpcode()))
return nullptr;

if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
LLVM_DEBUG(
dbgs()
<< "The flags in Real and Imaginary instructions are not identical\n");
return nullptr;
}
std::optional<FastMathFlags> Flags;
if (isa<FPMathOperator>(Real)) {
if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
"not identical\n");
return nullptr;
}

FastMathFlags Flags = Real->getFastMathFlags();
if (!Flags.allowReassoc()) {
LLVM_DEBUG(
dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
return nullptr;
Flags = Real->getFastMathFlags();
if (!Flags->allowReassoc()) {
LLVM_DEBUG(
dbgs()
<< "the 'Reassoc' attribute is missing in the FastMath flags\n");
return nullptr;
}
}

// Collect multiplications and addend instructions from the given instruction
Expand Down Expand Up @@ -978,35 +1012,52 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
Addends.emplace_back(I, IsPositive);
continue;
}

if (I->getOpcode() == Instruction::FAdd) {
switch (I->getOpcode()) {
case Instruction::FAdd:
case Instruction::Add:
Worklist.emplace_back(I->getOperand(1), IsPositive);
Worklist.emplace_back(I->getOperand(0), IsPositive);
} else if (I->getOpcode() == Instruction::FSub) {
break;
case Instruction::FSub:
Worklist.emplace_back(I->getOperand(1), !IsPositive);
Worklist.emplace_back(I->getOperand(0), IsPositive);
} else if (I->getOpcode() == Instruction::FMul) {
break;
case Instruction::Sub:
if (isNeg(I)) {
Worklist.emplace_back(getNegOperand(I), !IsPositive);
} else {
Worklist.emplace_back(I->getOperand(1), !IsPositive);
Worklist.emplace_back(I->getOperand(0), IsPositive);
}
break;
case Instruction::FMul:
case Instruction::Mul: {
Value *A, *B;
if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
if (isNeg(I->getOperand(0))) {
A = getNegOperand(I->getOperand(0));
IsPositive = !IsPositive;
} else {
A = I->getOperand(0);
}

if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
if (isNeg(I->getOperand(1))) {
B = getNegOperand(I->getOperand(1));
IsPositive = !IsPositive;
} else {
B = I->getOperand(1);
}
Muls.push_back(Product{A, B, IsPositive});
} else if (I->getOpcode() == Instruction::FNeg) {
break;
}
case Instruction::FNeg:
Worklist.emplace_back(I->getOperand(0), !IsPositive);
} else {
break;
default:
Addends.emplace_back(I, IsPositive);
continue;
}

if (I->getFastMathFlags() != Flags) {
if (Flags && I->getFastMathFlags() != *Flags) {
LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
"inconsistent with the root instructions' flags: "
<< *I << "\n");
Expand Down Expand Up @@ -1258,10 +1309,9 @@ ComplexDeinterleavingGraph::identifyMultiplications(
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
std::list<Addend> &ImagAddends,
FastMathFlags Flags,
NodePtr Accumulator = nullptr) {
ComplexDeinterleavingGraph::identifyAdditions(
std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
if (RealAddends.size() != ImagAddends.size())
return nullptr;

Expand Down Expand Up @@ -1312,14 +1362,22 @@ ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
TmpNode = prepareCompositeNode(
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
TmpNode->Opcode = Instruction::FAdd;
TmpNode->Flags = Flags;
if (Flags) {
TmpNode->Opcode = Instruction::FAdd;
TmpNode->Flags = *Flags;
} else {
TmpNode->Opcode = Instruction::Add;
}
} else if (Rotation ==
llvm::ComplexDeinterleavingRotation::Rotation_180) {
TmpNode = prepareCompositeNode(
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
TmpNode->Opcode = Instruction::FSub;
TmpNode->Flags = Flags;
if (Flags) {
TmpNode->Opcode = Instruction::FSub;
TmpNode->Flags = *Flags;
} else {
TmpNode->Opcode = Instruction::Sub;
}
} else {
TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
nullptr, nullptr);
Expand Down Expand Up @@ -1815,8 +1873,8 @@ ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
}

static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
FastMathFlags Flags, Value *InputA,
Value *InputB) {
std::optional<FastMathFlags> Flags,
Value *InputA, Value *InputB) {
Value *I;
switch (Opcode) {
case Instruction::FNeg:
Expand All @@ -1825,16 +1883,26 @@ static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
case Instruction::FAdd:
I = B.CreateFAdd(InputA, InputB);
break;
case Instruction::Add:
I = B.CreateAdd(InputA, InputB);
break;
case Instruction::FSub:
I = B.CreateFSub(InputA, InputB);
break;
case Instruction::Sub:
I = B.CreateSub(InputA, InputB);
break;
case Instruction::FMul:
I = B.CreateFMul(InputA, InputB);
break;
case Instruction::Mul:
I = B.CreateMul(InputA, InputB);
break;
default:
llvm_unreachable("Incorrect symmetric opcode");
}
cast<Instruction>(I)->setFastMathFlags(Flags);
if (Flags)
cast<Instruction>(I)->setFastMathFlags(*Flags);
return I;
}

Expand Down
28 changes: 23 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25858,7 +25858,8 @@ bool AArch64TargetLowering::isConstantUnsignedBitfieldExtractLegal(
}

bool AArch64TargetLowering::isComplexDeinterleavingSupported() const {
return Subtarget->hasSVE() || Subtarget->hasComplxNum();
return Subtarget->hasSVE() || Subtarget->hasSVE2() ||
Subtarget->hasComplxNum();
}

bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
Expand All @@ -25884,6 +25885,11 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
!llvm::isPowerOf2_32(VTyWidth))
return false;

if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2()) {
unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
return 8 <= ScalarWidth && ScalarWidth <= 64;
}

return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
}
Expand All @@ -25894,6 +25900,7 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
Value *Accumulator) const {
VectorType *Ty = cast<VectorType>(InputA->getType());
bool IsScalable = Ty->isScalableTy();
bool IsInt = Ty->getElementType()->isIntegerTy();

unsigned TyWidth =
Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue();
Expand Down Expand Up @@ -25929,10 +25936,15 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(

if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
if (Accumulator == nullptr)
Accumulator = ConstantFP::get(Ty, 0);
Accumulator = Constant::getNullValue(Ty);

if (IsScalable) {
auto *Mask = B.CreateVectorSplat(Ty->getElementCount(), B.getInt1(true));
if (IsInt)
return B.CreateIntrinsic(
Intrinsic::aarch64_sve_cmla_x, Ty,
{Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});

auto *Mask = B.getAllOnesMask(Ty->getElementCount());
return B.CreateIntrinsic(
Intrinsic::aarch64_sve_fcmla, Ty,
{Mask, Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
Expand All @@ -25950,12 +25962,18 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(

if (OperationType == ComplexDeinterleavingOperation::CAdd) {
if (IsScalable) {
auto *Mask = B.CreateVectorSplat(Ty->getElementCount(), B.getInt1(true));
if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
Rotation == ComplexDeinterleavingRotation::Rotation_270)
Rotation == ComplexDeinterleavingRotation::Rotation_270) {
if (IsInt)
return B.CreateIntrinsic(
Intrinsic::aarch64_sve_cadd_x, Ty,
{InputA, InputB, B.getInt32((int)Rotation * 90)});

auto *Mask = B.getAllOnesMask(Ty->getElementCount());
return B.CreateIntrinsic(
Intrinsic::aarch64_sve_fcadd, Ty,
{Mask, InputA, InputB, B.getInt32((int)Rotation * 90)});
}
return nullptr;
}

Expand Down

0 comments on commit c15557d

Please sign in to comment.