Expand Up
@@ -15,11 +15,13 @@
#include " llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
#include " llvm/ADT/Twine.h"
#include " llvm/Analysis/DomTreeUpdater.h"
#include " llvm/Analysis/TargetTransformInfo.h"
#include " llvm/IR/BasicBlock.h"
#include " llvm/IR/Constant.h"
#include " llvm/IR/Constants.h"
#include " llvm/IR/DerivedTypes.h"
#include " llvm/IR/Dominators.h"
#include " llvm/IR/Function.h"
#include " llvm/IR/IRBuilder.h"
#include " llvm/IR/InstrTypes.h"
Expand All
@@ -33,6 +35,7 @@
#include " llvm/Pass.h"
#include " llvm/Support/Casting.h"
#include " llvm/Transforms/Scalar.h"
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
#include < algorithm>
#include < cassert>
Expand All
@@ -59,23 +62,26 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
void getAnalysisUsage (AnalysisUsage &AU) const override {
AU.addRequired <TargetTransformInfoWrapperPass>();
AU.addPreserved <DominatorTreeWrapperPass>();
}
};
} // end anonymous namespace
static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI, const DataLayout &DL);
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU);
static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL);
const DataLayout &DL, DomTreeUpdater *DTU );
char ScalarizeMaskedMemIntrinLegacyPass::ID = 0 ;
INITIALIZE_PASS_BEGIN (ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
" Scalarize unsupported masked memory intrinsics" , false ,
false )
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
" Scalarize unsupported masked memory intrinsics" , false ,
false )
Expand Down
Expand Up
@@ -131,7 +137,8 @@ static bool isConstantIntVector(Value *Mask) {
// %10 = extractelement <16 x i1> %mask, i32 2
// br i1 %10, label %cond.load4, label %else5
//
static void scalarizeMaskedLoad (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedLoad (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand (0 );
Value *Alignment = CI->getArgOperand (1 );
Value *Mask = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -213,25 +220,26 @@ static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
// %Elt = load i32* %EltAddr
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
//
BasicBlock *CondBlock = IfBlock-> splitBasicBlock (InsertPt-> getIterator (),
" cond.load " );
Builder. SetInsertPoint (InsertPt );
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable= */ false ,
/* BranchWeights= */ nullptr , DTU );
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.load" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
Value *Gep = Builder.CreateConstInBoundsGEP1_32 (EltTy, FirstEltPtr, Idx);
LoadInst *Load = Builder.CreateAlignedLoad (EltTy, Gep, AdjustedAlignVal);
Value *NewVResult = Builder.CreateInsertElement (VResult, Load, Idx);
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock =
CondBlock->splitBasicBlock (InsertPt->getIterator (), " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
BasicBlock *PrevIfBlock = IfBlock;
IfBlock = NewIfBlock;
// Create the phi to join the new and previous value.
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
PHINode *Phi = Builder.CreatePHI (VecType, 2 , " res.phi.else" );
Phi->addIncoming (NewVResult, CondBlock);
Phi->addIncoming (VResult, PrevIfBlock);
Expand Down
Expand Up
@@ -270,7 +278,8 @@ static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
// store i32 %6, i32* %7
// br label %else2
// . . .
static void scalarizeMaskedStore (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedStore (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand (0 );
Value *Ptr = CI->getArgOperand (1 );
Value *Alignment = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -345,22 +354,24 @@ static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
// %EltAddr = getelementptr i32* %1, i32 0
// %store i32 %OneElt, i32* %EltAddr
//
BasicBlock *CondBlock =
IfBlock->splitBasicBlock (InsertPt->getIterator (), " cond.store" );
Builder.SetInsertPoint (InsertPt);
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable=*/ false ,
/* BranchWeights=*/ nullptr , DTU);
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.store" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
Value *OneElt = Builder.CreateExtractElement (Src, Idx);
Value *Gep = Builder.CreateConstInBoundsGEP1_32 (EltTy, FirstEltPtr, Idx);
Builder.CreateAlignedStore (OneElt, Gep, AdjustedAlignVal);
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock =
CondBlock->splitBasicBlock (InsertPt->getIterator (), " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
IfBlock = NewIfBlock;
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
}
CI->eraseFromParent ();
Expand Down
Expand Up
@@ -396,7 +407,8 @@ static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
// . . .
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
// ret <16 x i32> %Result
static void scalarizeMaskedGather (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedGather (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Ptrs = CI->getArgOperand (0 );
Value *Alignment = CI->getArgOperand (1 );
Value *Mask = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -464,24 +476,28 @@ static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
// %Elt = load i32* %EltAddr
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
//
BasicBlock *CondBlock = IfBlock->splitBasicBlock (InsertPt, " cond.load" );
Builder.SetInsertPoint (InsertPt);
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable=*/ false ,
/* BranchWeights=*/ nullptr , DTU);
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.load" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
Value *Ptr = Builder.CreateExtractElement (Ptrs, Idx, " Ptr" + Twine (Idx));
LoadInst *Load =
Builder.CreateAlignedLoad (EltTy, Ptr , AlignVal, " Load" + Twine (Idx));
Value *NewVResult =
Builder.CreateInsertElement (VResult, Load, Idx, " Res" + Twine (Idx));
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock (InsertPt, " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
BasicBlock *PrevIfBlock = IfBlock;
IfBlock = NewIfBlock;
// Create the phi to join the new and previous value.
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
PHINode *Phi = Builder.CreatePHI (VecType, 2 , " res.phi.else" );
Phi->addIncoming (NewVResult, CondBlock);
Phi->addIncoming (VResult, PrevIfBlock);
Expand Down
Expand Up
@@ -520,7 +536,8 @@ static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
// store i32 %Elt1, i32* %Ptr1, align 4
// br label %else2
// . . .
static void scalarizeMaskedScatter (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedScatter (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand (0 );
Value *Ptrs = CI->getArgOperand (1 );
Value *Alignment = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -586,27 +603,32 @@ static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
// %store i32 %Elt1, i32* %Ptr1
//
BasicBlock *CondBlock = IfBlock->splitBasicBlock (InsertPt, " cond.store" );
Builder.SetInsertPoint (InsertPt);
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable=*/ false ,
/* BranchWeights=*/ nullptr , DTU);
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.store" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
Value *OneElt = Builder.CreateExtractElement (Src, Idx, " Elt" + Twine (Idx));
Value *Ptr = Builder.CreateExtractElement (Ptrs, Idx, " Ptr" + Twine (Idx));
Builder.CreateAlignedStore (OneElt, Ptr , AlignVal);
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock (InsertPt, " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
IfBlock = NewIfBlock;
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
}
CI->eraseFromParent ();
ModifiedDT = true ;
}
static void scalarizeMaskedExpandLoad (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedExpandLoad (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand (0 );
Value *Mask = CI->getArgOperand (1 );
Value *PassThru = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -687,10 +709,14 @@ static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
// %Elt = load i32* %EltAddr
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
//
BasicBlock *CondBlock = IfBlock->splitBasicBlock (InsertPt->getIterator (),
" cond.load" );
Builder.SetInsertPoint (InsertPt);
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable=*/ false ,
/* BranchWeights=*/ nullptr , DTU);
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.load" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
LoadInst *Load = Builder.CreateAlignedLoad (EltTy, Ptr , Align (1 ));
Value *NewVResult = Builder.CreateInsertElement (VResult, Load, Idx);
Expand All
@@ -700,16 +726,13 @@ static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
NewPtr = Builder.CreateConstInBoundsGEP1_32 (EltTy, Ptr , 1 );
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock =
CondBlock->splitBasicBlock (InsertPt->getIterator (), " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
BasicBlock *PrevIfBlock = IfBlock;
IfBlock = NewIfBlock;
// Create the phi to join the new and previous value.
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
PHINode *ResultPhi = Builder.CreatePHI (VecType, 2 , " res.phi.else" );
ResultPhi->addIncoming (NewVResult, CondBlock);
ResultPhi->addIncoming (VResult, PrevIfBlock);
Expand All
@@ -730,7 +753,8 @@ static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
ModifiedDT = true ;
}
static void scalarizeMaskedCompressStore (CallInst *CI, bool &ModifiedDT) {
static void scalarizeMaskedCompressStore (CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand (0 );
Value *Ptr = CI->getArgOperand (1 );
Value *Mask = CI->getArgOperand (2 );
Expand Down
Expand Up
@@ -793,10 +817,14 @@ static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
// %EltAddr = getelementptr i32* %1, i32 0
// %store i32 %OneElt, i32* %EltAddr
//
BasicBlock *CondBlock =
IfBlock->splitBasicBlock (InsertPt->getIterator (), " cond.store" );
Builder.SetInsertPoint (InsertPt);
Instruction *ThenTerm =
SplitBlockAndInsertIfThen (Predicate, InsertPt, /* Unreachable=*/ false ,
/* BranchWeights=*/ nullptr , DTU);
BasicBlock *CondBlock = ThenTerm->getParent ();
CondBlock->setName (" cond.store" );
Builder.SetInsertPoint (CondBlock->getTerminator ());
Value *OneElt = Builder.CreateExtractElement (Src, Idx);
Builder.CreateAlignedStore (OneElt, Ptr , Align (1 ));
Expand All
@@ -806,15 +834,13 @@ static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
NewPtr = Builder.CreateConstInBoundsGEP1_32 (EltTy, Ptr , 1 );
// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock =
CondBlock->splitBasicBlock (InsertPt->getIterator (), " else" );
Builder.SetInsertPoint (InsertPt);
Instruction *OldBr = IfBlock->getTerminator ();
BranchInst::Create (CondBlock, NewIfBlock, Predicate, OldBr);
OldBr->eraseFromParent ();
BasicBlock *NewIfBlock = ThenTerm->getSuccessor (0 );
NewIfBlock->setName (" else" );
BasicBlock *PrevIfBlock = IfBlock;
IfBlock = NewIfBlock;
Builder.SetInsertPoint (NewIfBlock, NewIfBlock->begin ());
// Add a PHI for the pointer if this isn't the last iteration.
if ((Idx + 1 ) != VectorWidth) {
PHINode *PtrPhi = Builder.CreatePHI (Ptr ->getType (), 2 , " ptr.phi.else" );
Expand All
@@ -828,7 +854,12 @@ static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
ModifiedDT = true ;
}
static bool runImpl (Function &F, const TargetTransformInfo &TTI) {
static bool runImpl (Function &F, const TargetTransformInfo &TTI,
DominatorTree *DT) {
Optional<DomTreeUpdater> DTU;
if (DT)
DTU.emplace (DT, DomTreeUpdater::UpdateStrategy::Lazy);
bool EverMadeChange = false ;
bool MadeChange = true ;
auto &DL = F.getParent ()->getDataLayout ();
Expand All
@@ -837,7 +868,9 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI) {
for (Function::iterator I = F.begin (); I != F.end ();) {
BasicBlock *BB = &*I++;
bool ModifiedDTOnIteration = false ;
MadeChange |= optimizeBlock (*BB, ModifiedDTOnIteration, TTI, DL);
MadeChange |= optimizeBlock (*BB, ModifiedDTOnIteration, TTI, DL,
DTU.hasValue () ? DTU.getPointer () : nullptr );
// Restart BB iteration if the dominator tree of the Function was changed
if (ModifiedDTOnIteration)
Expand All
@@ -851,28 +884,33 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI) {
bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction (Function &F) {
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
return runImpl (F, TTI);
DominatorTree *DT = nullptr ;
if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
DT = &DTWP->getDomTree ();
return runImpl (F, TTI, DT);
}
PreservedAnalyses
ScalarizeMaskedMemIntrinPass::run (Function &F, FunctionAnalysisManager &AM) {
auto &TTI = AM.getResult <TargetIRAnalysis>(F);
if (!runImpl (F, TTI))
auto *DT = AM.getCachedResult <DominatorTreeAnalysis>(F);
if (!runImpl (F, TTI, DT))
return PreservedAnalyses::all ();
PreservedAnalyses PA;
PA.preserve <TargetIRAnalysis>();
PA.preserve <DominatorTreeAnalysis>();
return PA;
}
static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL ) {
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU ) {
bool MadeChange = false ;
BasicBlock::iterator CurInstIterator = BB.begin ();
while (CurInstIterator != BB.end ()) {
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
MadeChange |= optimizeCallInst (CI, ModifiedDT, TTI, DL);
MadeChange |= optimizeCallInst (CI, ModifiedDT, TTI, DL, DTU );
if (ModifiedDT)
return true ;
}
Expand All
@@ -882,7 +920,7 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL) {
const DataLayout &DL, DomTreeUpdater *DTU ) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
if (II) {
// The scalarization code below does not work for scalable vectors.
Expand All
@@ -900,14 +938,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getType (),
cast<ConstantInt>(CI->getArgOperand (1 ))->getAlignValue ()))
return false ;
scalarizeMaskedLoad (CI, ModifiedDT);
scalarizeMaskedLoad (CI, DTU, ModifiedDT);
return true ;
case Intrinsic::masked_store:
if (TTI.isLegalMaskedStore (
CI->getArgOperand (0 )->getType (),
cast<ConstantInt>(CI->getArgOperand (2 ))->getAlignValue ()))
return false ;
scalarizeMaskedStore (CI, ModifiedDT);
scalarizeMaskedStore (CI, DTU, ModifiedDT);
return true ;
case Intrinsic::masked_gather: {
unsigned AlignmentInt =
Expand All
@@ -917,7 +955,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
DL.getValueOrABITypeAlignment (MaybeAlign (AlignmentInt), LoadTy);
if (TTI.isLegalMaskedGather (LoadTy, Alignment))
return false ;
scalarizeMaskedGather (CI, ModifiedDT);
scalarizeMaskedGather (CI, DTU, ModifiedDT);
return true ;
}
case Intrinsic::masked_scatter: {
Expand All
@@ -928,18 +966,18 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
DL.getValueOrABITypeAlignment (MaybeAlign (AlignmentInt), StoreTy);
if (TTI.isLegalMaskedScatter (StoreTy, Alignment))
return false ;
scalarizeMaskedScatter (CI, ModifiedDT);
scalarizeMaskedScatter (CI, DTU, ModifiedDT);
return true ;
}
case Intrinsic::masked_expandload:
if (TTI.isLegalMaskedExpandLoad (CI->getType ()))
return false ;
scalarizeMaskedExpandLoad (CI, ModifiedDT);
scalarizeMaskedExpandLoad (CI, DTU, ModifiedDT);
return true ;
case Intrinsic::masked_compressstore:
if (TTI.isLegalMaskedCompressStore (CI->getArgOperand (0 )->getType ()))
return false ;
scalarizeMaskedCompressStore (CI, ModifiedDT);
scalarizeMaskedCompressStore (CI, DTU, ModifiedDT);
return true ;
}
}
Expand Down