Skip to content

Commit

Permalink
[CodeGen] Enable processing of interconnected complex number operations
Browse files Browse the repository at this point in the history
With this patch, ComplexDeinterleavingPass now has the ability to handle
any number of interconnected operations involving complex numbers.
For example, the patch enables the processing of code like the following:

for (int i = 0; i < 1000; ++i) {
    a[i] =  w[i] * v[i];
    b[i] =  w[i] * u[i];
}

This code has multiple arrays containing complex numbers and a common
subexpression `w` that appears in two expressions.

Differential Revision: https://reviews.llvm.org/D146988
  • Loading branch information
igogo-x86 committed Apr 18, 2023
1 parent dc86900 commit c692e87
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 90 deletions.
189 changes: 118 additions & 71 deletions llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,12 @@ struct ComplexDeinterleavingCompositeNode {
Instruction *Real;
Instruction *Imag;

// Instructions that should only exist within this node, there should be no
// users of these instructions outside the node. An example of these would be
// the multiply instructions of a partial multiply operation.
SmallVector<Instruction *> InternalInstructions;
ComplexDeinterleavingRotation Rotation;
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;

void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }

bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);

void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
auto PrintValue = [&](Value *V) {
Expand Down Expand Up @@ -181,27 +174,29 @@ struct ComplexDeinterleavingCompositeNode {
OS << " - ";
PrintNodeRef(Op);
}
OS << " InternalInstructions:\n";
for (const auto &I : InternalInstructions) {
OS << " - \"";
I->print(OS, true);
OS << "\"\n";
}
}
};

class ComplexDeinterleavingGraph {
public:
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
const TargetLibraryInfo *TLI)
: TL(TL), TLI(TLI) {}

private:
const TargetLowering *TL = nullptr;
Instruction *RootValue = nullptr;
NodePtr RootNode;
const TargetLibraryInfo *TLI = nullptr;
SmallVector<NodePtr> CompositeNodes;
SmallPtrSet<Instruction *, 16> AllInstructions;

SmallPtrSet<Instruction *, 16> FinalInstructions;

/// Root instructions are instructions from which complex computation starts
std::map<Instruction *, NodePtr> RootToNode;

/// Topologically sorted root instructions
SmallVector<Instruction *, 1> OrderedRoots;

NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
Instruction *R, Instruction *I) {
Expand All @@ -211,10 +206,6 @@ class ComplexDeinterleavingGraph {

NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
AllInstructions.insert(Node->Real);
AllInstructions.insert(Node->Imag);
for (auto *I : Node->InternalInstructions)
AllInstructions.insert(I);
return Node;
}

Expand Down Expand Up @@ -271,6 +262,10 @@ class ComplexDeinterleavingGraph {
/// current graph.
bool identifyNodes(Instruction *RootI);

/// Check that every instruction, from the roots to the leaves, has internal
/// uses.
bool checkNodes();

/// Perform the actual replacement of the underlying instruction graph.
void replaceNodes();
};
Expand Down Expand Up @@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
bool Changed = false;

SmallVector<Instruction *> DeadInstrRoots;
ComplexDeinterleavingGraph Graph(TL, TLI);

for (auto &I : *B) {
auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
Expand All @@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
if (!isInterleavingMask(SVI->getShuffleMask()))
continue;

ComplexDeinterleavingGraph Graph(TL);
if (!Graph.identifyNodes(SVI))
continue;

Graph.replaceNodes();
DeadInstrRoots.push_back(SVI);
Changed = true;
Graph.identifyNodes(SVI);
}

for (const auto &I : DeadInstrRoots) {
if (!I || I->getParent() == nullptr)
continue;
llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
if (Graph.checkNodes()) {
Graph.replaceNodes();
return true;
}

return Changed;
return false;
}

ComplexDeinterleavingGraph::NodePtr
Expand Down Expand Up @@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
Node->Rotation = Rotation;
Node->addOperand(CommonNode);
Node->addOperand(UncommonNode);
Node->InternalInstructions.append(FNegs);
return submitCompositeNode(Node);
}

Expand Down Expand Up @@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,

NodePtr Node = prepareCompositeNode(
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
Node->addInstruction(RealMulI);
Node->addInstruction(ImagMulI);
Node->Rotation = Rotation;
Node->addOperand(CommonRes);
Node->addOperand(UncommonRes);
Expand Down Expand Up @@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
RealShuffle, ImagShuffle);
PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
FinalInstructions.insert(RealShuffle);
FinalInstructions.insert(ImagShuffle);
return submitCompositeNode(PlaceholderNode);
}
if (RealShuffle || ImagShuffle) {
Expand Down Expand Up @@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
return false;

RootValue = RootI;
AllInstructions.insert(RootI);
RootNode = identifyNode(Real, Imag);
auto RootNode = identifyNode(Real, Imag);

LLVM_DEBUG({
Function *F = RootI->getFunction();
Expand All @@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
dbgs() << "\n";
});

// Check all instructions have internal uses
for (const auto &Node : CompositeNodes) {
if (!Node->hasAllInternalUses(AllInstructions)) {
LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
return false;
if (RootNode) {
RootToNode[RootI] = RootNode;
OrderedRoots.push_back(RootI);
return true;
}

return false;
}

bool ComplexDeinterleavingGraph::checkNodes() {
// Collect all instructions from roots to leaves
SmallPtrSet<Instruction *, 16> AllInstructions;
SmallVector<Instruction *, 8> Worklist;
for (auto *I : OrderedRoots)
Worklist.push_back(I);

// Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
// chains
while (!Worklist.empty()) {
auto *I = Worklist.back();
Worklist.pop_back();

if (!AllInstructions.insert(I).second)
continue;

for (Value *Op : I->operands()) {
if (auto *OpI = dyn_cast<Instruction>(Op)) {
if (!FinalInstructions.count(I))
Worklist.emplace_back(OpI);
}
}
}
return RootNode != nullptr;

// Find instructions that have users outside of chain
SmallVector<Instruction *, 2> OuterInstructions;
for (auto *I : AllInstructions) {
// Skip root nodes
if (RootToNode.count(I))
continue;

for (User *U : I->users()) {
if (AllInstructions.count(cast<Instruction>(U)))
continue;

// Found an instruction that is not used by XCMLA/XCADD chain
Worklist.emplace_back(I);
break;
}
}

// If any instructions are found to be used outside, find and remove roots
// that somehow connect to those instructions.
SmallPtrSet<Instruction *, 16> Visited;
while (!Worklist.empty()) {
auto *I = Worklist.back();
Worklist.pop_back();
if (!Visited.insert(I).second)
continue;

// Found an impacted root node. Removing it from the nodes to be
// deinterleaved
if (RootToNode.count(I)) {
LLVM_DEBUG(dbgs() << "Instruction " << *I
<< " could be deinterleaved but its chain of complex "
"operations have an outside user\n");
RootToNode.erase(I);
}

if (!AllInstructions.count(I) || FinalInstructions.count(I))
continue;

for (User *U : I->users())
Worklist.emplace_back(cast<Instruction>(U));

for (Value *Op : I->operands()) {
if (auto *OpI = dyn_cast<Instruction>(Op))
Worklist.emplace_back(OpI);
}
}
return !RootToNode.empty();
}

static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
Expand Down Expand Up @@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode(
}

void ComplexDeinterleavingGraph::replaceNodes() {
Value *R = replaceNode(RootNode.get());
assert(R && "Unable to find replacement for RootValue");
RootValue->replaceAllUsesWith(R);
}

bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
SmallPtrSet<Instruction *, 16> &AllInstructions) {
if (Operation == ComplexDeinterleavingOperation::Shuffle)
return true;
SmallVector<Instruction *, 16> DeadInstrRoots;
for (auto *RootInstruction : OrderedRoots) {
// Check if this potential root went through check process and we can
// deinterleave it
if (!RootToNode.count(RootInstruction))
continue;

for (auto *User : Real->users()) {
if (!AllInstructions.contains(cast<Instruction>(User)))
return false;
IRBuilder<> Builder(RootInstruction);
auto RootNode = RootToNode[RootInstruction];
Value *R = replaceNode(RootNode.get());
assert(R && "Unable to find replacement for RootInstruction");
DeadInstrRoots.push_back(RootInstruction);
RootInstruction->replaceAllUsesWith(R);
}
for (auto *User : Imag->users()) {
if (!AllInstructions.contains(cast<Instruction>(User)))
return false;
}
for (auto *I : InternalInstructions) {
for (auto *User : I->users()) {
if (!AllInstructions.contains(cast<Instruction>(User)))
return false;
}
}
return true;

for (auto *I : DeadInstrRoots)
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}
28 changes: 9 additions & 19 deletions llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,20 @@
; RUN: llc < %s --mattr=+complxnum,+neon -o - | FileCheck %s

target triple = "aarch64-arm-none-eabi"
; Expected to not transform
; Expected to transform
; *p = (a * b);
; return (a * b) * a;
define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) {
; CHECK-LABEL: mul_triangle:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: zip2 v4.2s, v0.2s, v2.2s
; CHECK-NEXT: zip1 v0.2s, v0.2s, v2.2s
; CHECK-NEXT: zip2 v5.2s, v1.2s, v3.2s
; CHECK-NEXT: zip1 v1.2s, v1.2s, v3.2s
; CHECK-NEXT: fmul v6.2s, v5.2s, v4.2s
; CHECK-NEXT: fneg v2.2s, v6.2s
; CHECK-NEXT: fmla v2.2s, v0.2s, v1.2s
; CHECK-NEXT: fmul v3.2s, v4.2s, v1.2s
; CHECK-NEXT: fmla v3.2s, v0.2s, v5.2s
; CHECK-NEXT: fmul v1.2s, v3.2s, v4.2s
; CHECK-NEXT: fmul v5.2s, v3.2s, v0.2s
; CHECK-NEXT: st2 { v2.2s, v3.2s }, [x0]
; CHECK-NEXT: fneg v1.2s, v1.2s
; CHECK-NEXT: fmla v5.2s, v4.2s, v2.2s
; CHECK-NEXT: fmla v1.2s, v0.2s, v2.2s
; CHECK-NEXT: zip1 v0.4s, v1.4s, v5.4s
; CHECK-NEXT: movi v3.2d, #0000000000000000
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #0
; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #90
; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #0
; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #90
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
entry:
%strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>
Expand Down

0 comments on commit c692e87

Please sign in to comment.