Skip to content

Commit

Permalink
[AArch64][SVE] Move instcombine like transforms out of SVEIntrinsicOpts
Browse files Browse the repository at this point in the history
Instead move them to the instcombine that happens in AArch64TargetTransformInfo.

Differential Revision: https://reviews.llvm.org/D106144
  • Loading branch information
brads55 committed Jul 20, 2021
1 parent 82834a6 commit 191f9fa
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 252 deletions.
118 changes: 118 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,115 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
: None;
}

static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
IntrinsicInst &II) {
IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));

if (Op1 && Op2 &&
Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {

IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);

Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
Type *Tys[] = {Op1->getArgOperand(0)->getType()};

auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);

PTest->takeName(&II);
return IC.replaceInstUsesWith(II, PTest);
}

return None;
}

static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpPredicate = II.getOperand(0);
auto *OpMultiplicand = II.getOperand(1);
auto *OpMultiplier = II.getOperand(2);

IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);

// Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
// with a unit splat value, false otherwise.
auto IsUnitDupX = [](auto *I) {
auto *IntrI = dyn_cast<IntrinsicInst>(I);
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
return false;

auto *SplatValue = IntrI->getOperand(0);
return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
};

// Return true if a given instruction is an aarch64_sve_dup intrinsic call
// with a unit splat value, false otherwise.
auto IsUnitDup = [](auto *I) {
auto *IntrI = dyn_cast<IntrinsicInst>(I);
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
return false;

auto *SplatValue = IntrI->getOperand(2);
return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
};

// The OpMultiplier variable should always point to the dup (if any), so
// swap if necessary.
if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
std::swap(OpMultiplier, OpMultiplicand);

if (IsUnitDupX(OpMultiplier)) {
// [f]mul pg (dupx 1) %n => %n
OpMultiplicand->takeName(&II);
return IC.replaceInstUsesWith(II, OpMultiplicand);
} else if (IsUnitDup(OpMultiplier)) {
// [f]mul pg (dup pg 1) %n => %n
auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
auto *DupPg = DupInst->getOperand(1);
// TODO: this is naive. The optimization is still valid if DupPg
// 'encompasses' OpPredicate, not only if they're the same predicate.
if (OpPredicate == DupPg) {
OpMultiplicand->takeName(&II);
return IC.replaceInstUsesWith(II, OpMultiplicand);
}
}

return None;
}

static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpVal = II.getOperand(0);
auto *OpIndices = II.getOperand(1);
VectorType *VTy = cast<VectorType>(II.getType());

// Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
// constant splat value < minimal element count of result.
auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
return None;

auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
if (!SplatValue ||
SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
return None;

// Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
// splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);
auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
auto *VectorSplat =
Builder.CreateVectorSplat(VTy->getElementCount(), Extract);

VectorSplat->takeName(&II);
return IC.replaceInstUsesWith(II, VectorSplat);
}

Optional<Instruction *>
AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
IntrinsicInst &II) const {
Expand Down Expand Up @@ -713,6 +822,15 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
return instCombineSVECntElts(IC, II, 8);
case Intrinsic::aarch64_sve_cntb:
return instCombineSVECntElts(IC, II, 16);
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
return instCombineSVEPTest(IC, II);
case Intrinsic::aarch64_sve_mul:
case Intrinsic::aarch64_sve_fmul:
return instCombineSVEVectorMul(IC, II);
case Intrinsic::aarch64_sve_tbl:
return instCombineSVETBL(IC, II);
}

return None;
Expand Down
189 changes: 0 additions & 189 deletions llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,9 @@ struct SVEIntrinsicOpts : public ModulePass {
SmallSetVector<IntrinsicInst *, 4> &PTrues);
bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);

/// Operates at the instruction-scope. I.e., optimizations are applied local
/// to individual instructions.
static bool optimizeIntrinsic(Instruction *I);
bool optimizeIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);

/// Operates at the function-scope. I.e., optimizations are applied local to
/// the functions themselves.
bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);

static bool optimizePTest(IntrinsicInst *I);
static bool optimizeVectorMul(IntrinsicInst *I);
static bool optimizeTBL(IntrinsicInst *I);
};
} // end anonymous namespace

Expand Down Expand Up @@ -285,185 +276,11 @@ bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
return Changed;
}

bool SVEIntrinsicOpts::optimizePTest(IntrinsicInst *I) {
IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(I->getArgOperand(0));
IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(I->getArgOperand(1));

if (Op1 && Op2 &&
Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {

Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
Type *Tys[] = {Op1->getArgOperand(0)->getType()};
Module *M = I->getParent()->getParent()->getParent();

auto Fn = Intrinsic::getDeclaration(M, I->getIntrinsicID(), Tys);
auto CI = CallInst::Create(Fn, Ops, I->getName(), I);

I->replaceAllUsesWith(CI);
I->eraseFromParent();
if (Op1->use_empty())
Op1->eraseFromParent();
if (Op1 != Op2 && Op2->use_empty())
Op2->eraseFromParent();

return true;
}

return false;
}

bool SVEIntrinsicOpts::optimizeVectorMul(IntrinsicInst *I) {
assert((I->getIntrinsicID() == Intrinsic::aarch64_sve_mul ||
I->getIntrinsicID() == Intrinsic::aarch64_sve_fmul) &&
"Unexpected opcode");

auto *OpPredicate = I->getOperand(0);
auto *OpMultiplicand = I->getOperand(1);
auto *OpMultiplier = I->getOperand(2);

// Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
// with a unit splat value, false otherwise.
auto IsUnitDupX = [](auto *I) {
auto *IntrI = dyn_cast<IntrinsicInst>(I);
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
return false;

auto *SplatValue = IntrI->getOperand(0);
return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
};

// Return true if a given instruction is an aarch64_sve_dup intrinsic call
// with a unit splat value, false otherwise.
auto IsUnitDup = [](auto *I) {
auto *IntrI = dyn_cast<IntrinsicInst>(I);
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
return false;

auto *SplatValue = IntrI->getOperand(2);
return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
};

bool Changed = true;

// The OpMultiplier variable should always point to the dup (if any), so
// swap if necessary.
if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
std::swap(OpMultiplier, OpMultiplicand);

if (IsUnitDupX(OpMultiplier)) {
// [f]mul pg (dupx 1) %n => %n
I->replaceAllUsesWith(OpMultiplicand);
I->eraseFromParent();
Changed = true;
} else if (IsUnitDup(OpMultiplier)) {
// [f]mul pg (dup pg 1) %n => %n
auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
auto *DupPg = DupInst->getOperand(1);
// TODO: this is naive. The optimization is still valid if DupPg
// 'encompasses' OpPredicate, not only if they're the same predicate.
if (OpPredicate == DupPg) {
I->replaceAllUsesWith(OpMultiplicand);
I->eraseFromParent();
Changed = true;
}
}

// If an instruction was optimized out then it is possible that some dangling
// instructions are left.
if (Changed) {
auto *OpPredicateInst = dyn_cast<Instruction>(OpPredicate);
auto *OpMultiplierInst = dyn_cast<Instruction>(OpMultiplier);
if (OpMultiplierInst && OpMultiplierInst->use_empty())
OpMultiplierInst->eraseFromParent();
if (OpPredicateInst && OpPredicateInst->use_empty())
OpPredicateInst->eraseFromParent();
}

return Changed;
}

bool SVEIntrinsicOpts::optimizeTBL(IntrinsicInst *I) {
assert(I->getIntrinsicID() == Intrinsic::aarch64_sve_tbl &&
"Unexpected opcode");

auto *OpVal = I->getOperand(0);
auto *OpIndices = I->getOperand(1);
VectorType *VTy = cast<VectorType>(I->getType());

// Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
// constant splat value < minimal element count of result.
auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
return false;

auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
if (!SplatValue ||
SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
return false;

// Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
// splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
LLVMContext &Ctx = I->getContext();
IRBuilder<> Builder(Ctx);
Builder.SetInsertPoint(I);
auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
auto *VectorSplat =
Builder.CreateVectorSplat(VTy->getElementCount(), Extract);

I->replaceAllUsesWith(VectorSplat);
I->eraseFromParent();
if (DupXIntrI->use_empty())
DupXIntrI->eraseFromParent();
return true;
}

bool SVEIntrinsicOpts::optimizeIntrinsic(Instruction *I) {
IntrinsicInst *IntrI = dyn_cast<IntrinsicInst>(I);
if (!IntrI)
return false;

switch (IntrI->getIntrinsicID()) {
case Intrinsic::aarch64_sve_fmul:
case Intrinsic::aarch64_sve_mul:
return optimizeVectorMul(IntrI);
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
return optimizePTest(IntrI);
case Intrinsic::aarch64_sve_tbl:
return optimizeTBL(IntrI);
default:
return false;
}

return true;
}

bool SVEIntrinsicOpts::optimizeIntrinsicCalls(
SmallSetVector<Function *, 4> &Functions) {
bool Changed = false;
for (auto *F : Functions) {
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();

// Traverse the DT with an rpo walk so we see defs before uses, allowing
// simplification to be done incrementally.
BasicBlock *Root = DT->getRoot();
ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
for (auto *BB : RPOT)
for (Instruction &I : make_early_inc_range(*BB))
Changed |= optimizeIntrinsic(&I);
}
return Changed;
}

bool SVEIntrinsicOpts::optimizeFunctions(
SmallSetVector<Function *, 4> &Functions) {
bool Changed = false;

Changed |= optimizePTrueIntrinsicCalls(Functions);
Changed |= optimizeIntrinsicCalls(Functions);

return Changed;
}
Expand All @@ -480,13 +297,7 @@ bool SVEIntrinsicOpts::runOnModule(Module &M) {
continue;

switch (F.getIntrinsicID()) {
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
case Intrinsic::aarch64_sve_ptrue:
case Intrinsic::aarch64_sve_mul:
case Intrinsic::aarch64_sve_fmul:
case Intrinsic::aarch64_sve_tbl:
for (User *U : F.users())
Functions.insert(cast<Instruction>(U)->getFunction());
break;
Expand Down
Loading

0 comments on commit 191f9fa

Please sign in to comment.