diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index ff0c5d530747b..3cfe935e2cca3 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -137,19 +137,12 @@ struct ComplexDeinterleavingCompositeNode { Instruction *Real; Instruction *Imag; - // Instructions that should only exist within this node, there should be no - // users of these instructions outside the node. An example of these would be - // the multiply instructions of a partial multiply operation. - SmallVector InternalInstructions; ComplexDeinterleavingRotation Rotation; SmallVector Operands; Value *ReplacementNode = nullptr; - void addInstruction(Instruction *I) { InternalInstructions.push_back(I); } void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } - bool hasAllInternalUses(SmallPtrSet &AllInstructions); - void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { auto PrintValue = [&](Value *V) { @@ -181,12 +174,6 @@ struct ComplexDeinterleavingCompositeNode { OS << " - "; PrintNodeRef(Op); } - OS << " InternalInstructions:\n"; - for (const auto &I : InternalInstructions) { - OS << " - \""; - I->print(OS, true); - OS << "\"\n"; - } } }; @@ -194,14 +181,22 @@ class ComplexDeinterleavingGraph { public: using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; - explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {} + explicit ComplexDeinterleavingGraph(const TargetLowering *TL, + const TargetLibraryInfo *TLI) + : TL(TL), TLI(TLI) {} private: const TargetLowering *TL = nullptr; - Instruction *RootValue = nullptr; - NodePtr RootNode; + const TargetLibraryInfo *TLI = nullptr; SmallVector CompositeNodes; - SmallPtrSet AllInstructions; + + SmallPtrSet FinalInstructions; + + /// Root instructions are instructions from which complex computation starts + std::map RootToNode; + + /// Topologically sorted root instructions + SmallVector OrderedRoots; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Instruction *R, Instruction *I) { @@ -211,10 +206,6 @@ class ComplexDeinterleavingGraph { NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); - AllInstructions.insert(Node->Real); - AllInstructions.insert(Node->Imag); - for (auto *I : Node->InternalInstructions) - AllInstructions.insert(I); return Node; } @@ -271,6 +262,10 @@ class ComplexDeinterleavingGraph { /// current graph. bool identifyNodes(Instruction *RootI); + /// Check that every instruction, from the roots to the leaves, has internal + /// uses. + bool checkNodes(); + /// Perform the actual replacement of the underlying instruction graph. void replaceNodes(); }; @@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef Mask) { } bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { - bool Changed = false; - - SmallVector DeadInstrRoots; + ComplexDeinterleavingGraph Graph(TL, TLI); for (auto &I : *B) { auto *SVI = dyn_cast(&I); @@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { if (!isInterleavingMask(SVI->getShuffleMask())) continue; - ComplexDeinterleavingGraph Graph(TL); - if (!Graph.identifyNodes(SVI)) - continue; - - Graph.replaceNodes(); - DeadInstrRoots.push_back(SVI); - Changed = true; + Graph.identifyNodes(SVI); } - for (const auto &I : DeadInstrRoots) { - if (!I || I->getParent() == nullptr) - continue; - llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); + if (Graph.checkNodes()) { + Graph.replaceNodes(); + return true; } - return Changed; + return false; } ComplexDeinterleavingGraph::NodePtr @@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Node->Rotation = Rotation; Node->addOperand(CommonNode); Node->addOperand(UncommonNode); - Node->InternalInstructions.append(FNegs); return submitCompositeNode(Node); } @@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); - Node->addInstruction(RealMulI); - Node->addInstruction(ImagMulI); Node->Rotation = Rotation; Node->addOperand(CommonRes); Node->addOperand(UncommonRes); @@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle, RealShuffle, ImagShuffle); PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); + FinalInstructions.insert(RealShuffle); + FinalInstructions.insert(ImagShuffle); return submitCompositeNode(PlaceholderNode); } if (RealShuffle || ImagShuffle) { @@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) return false; - RootValue = RootI; - AllInstructions.insert(RootI); - RootNode = identifyNode(Real, Imag); + auto RootNode = identifyNode(Real, Imag); LLVM_DEBUG({ Function *F = RootI->getFunction(); @@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { dbgs() << "\n"; }); - // Check all instructions have internal uses - for (const auto &Node : CompositeNodes) { - if (!Node->hasAllInternalUses(AllInstructions)) { - LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); - return false; + if (RootNode) { + RootToNode[RootI] = RootNode; + OrderedRoots.push_back(RootI); + return true; + } + + return false; +} + +bool ComplexDeinterleavingGraph::checkNodes() { + // Collect all instructions from roots to leaves + SmallPtrSet AllInstructions; + SmallVector Worklist; + for (auto *I : OrderedRoots) + Worklist.push_back(I); + + // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG + // chains + while (!Worklist.empty()) { + auto *I = Worklist.back(); + Worklist.pop_back(); + + if (!AllInstructions.insert(I).second) + continue; + + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast(Op)) { + if (!FinalInstructions.count(I)) + Worklist.emplace_back(OpI); + } } } - return RootNode != nullptr; + + // Find instructions that have users outside of chain + SmallVector OuterInstructions; + for (auto *I : AllInstructions) { + // Skip root nodes + if (RootToNode.count(I)) + continue; + + for (User *U : I->users()) { + if (AllInstructions.count(cast(U))) + continue; + + // Found an instruction that is not used by XCMLA/XCADD chain + Worklist.emplace_back(I); + break; + } + } + + // If any instructions are found to be used outside, find and remove roots + // that somehow connect to those instructions. + SmallPtrSet Visited; + while (!Worklist.empty()) { + auto *I = Worklist.back(); + Worklist.pop_back(); + if (!Visited.insert(I).second) + continue; + + // Found an impacted root node. Removing it from the nodes to be + // deinterleaved + if (RootToNode.count(I)) { + LLVM_DEBUG(dbgs() << "Instruction " << *I + << " could be deinterleaved but its chain of complex " + "operations have an outside user\n"); + RootToNode.erase(I); + } + + if (!AllInstructions.count(I) || FinalInstructions.count(I)) + continue; + + for (User *U : I->users()) + Worklist.emplace_back(cast(U)); + + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast(Op)) + Worklist.emplace_back(OpI); + } + } + return !RootToNode.empty(); } static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node, @@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode( } void ComplexDeinterleavingGraph::replaceNodes() { - Value *R = replaceNode(RootNode.get()); - assert(R && "Unable to find replacement for RootValue"); - RootValue->replaceAllUsesWith(R); -} - -bool ComplexDeinterleavingCompositeNode::hasAllInternalUses( - SmallPtrSet &AllInstructions) { - if (Operation == ComplexDeinterleavingOperation::Shuffle) - return true; + SmallVector DeadInstrRoots; + for (auto *RootInstruction : OrderedRoots) { + // Check if this potential root went through check process and we can + // deinterleave it + if (!RootToNode.count(RootInstruction)) + continue; - for (auto *User : Real->users()) { - if (!AllInstructions.contains(cast(User))) - return false; + IRBuilder<> Builder(RootInstruction); + auto RootNode = RootToNode[RootInstruction]; + Value *R = replaceNode(RootNode.get()); + assert(R && "Unable to find replacement for RootInstruction"); + DeadInstrRoots.push_back(RootInstruction); + RootInstruction->replaceAllUsesWith(R); } - for (auto *User : Imag->users()) { - if (!AllInstructions.contains(cast(User))) - return false; - } - for (auto *I : InternalInstructions) { - for (auto *User : I->users()) { - if (!AllInstructions.contains(cast(User))) - return false; - } - } - return true; + + for (auto *I : DeadInstrRoots) + RecursivelyDeleteTriviallyDeadInstructions(I, TLI); } diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll index fe3d30677f084..4d84636e92ca2 100644 --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll @@ -2,30 +2,20 @@ ; RUN: llc < %s --mattr=+complxnum,+neon -o - | FileCheck %s target triple = "aarch64-arm-none-eabi" -; Expected to not transform +; Expected to transform ; *p = (a * b); ; return (a * b) * a; define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) { ; CHECK-LABEL: mul_triangle: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: zip2 v4.2s, v0.2s, v2.2s -; CHECK-NEXT: zip1 v0.2s, v0.2s, v2.2s -; CHECK-NEXT: zip2 v5.2s, v1.2s, v3.2s -; CHECK-NEXT: zip1 v1.2s, v1.2s, v3.2s -; CHECK-NEXT: fmul v6.2s, v5.2s, v4.2s -; CHECK-NEXT: fneg v2.2s, v6.2s -; CHECK-NEXT: fmla v2.2s, v0.2s, v1.2s -; CHECK-NEXT: fmul v3.2s, v4.2s, v1.2s -; CHECK-NEXT: fmla v3.2s, v0.2s, v5.2s -; CHECK-NEXT: fmul v1.2s, v3.2s, v4.2s -; CHECK-NEXT: fmul v5.2s, v3.2s, v0.2s -; CHECK-NEXT: st2 { v2.2s, v3.2s }, [x0] -; CHECK-NEXT: fneg v1.2s, v1.2s -; CHECK-NEXT: fmla v5.2s, v4.2s, v2.2s -; CHECK-NEXT: fmla v1.2s, v0.2s, v2.2s -; CHECK-NEXT: zip1 v0.4s, v1.4s, v5.4s +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #0 +; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #90 +; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #0 +; CHECK-NEXT: str q3, [x0] +; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #90 +; CHECK-NEXT: mov v0.16b, v2.16b ; CHECK-NEXT: ret entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32>