Skip to content

Commit

Permalink
AMDGPU: Don't store current instruction in AMDGPULibCalls member
Browse files Browse the repository at this point in the history
This was adding confusing global state which was shadowed most of the
time.

https://reviews.llvm.org/D156680
  • Loading branch information
arsenm committed Jul 31, 2023
1 parent d517117 commit c2c22c6
Showing 1 changed file with 63 additions and 59 deletions.
122 changes: 63 additions & 59 deletions llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ class AMDGPULibCalls {
FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);

protected:
CallInst *CI;

bool isUnsafeMath(const FPMathOperator *FPOp) const;

bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;

void replaceCall(Value *With) {
CI->replaceAllUsesWith(With);
CI->eraseFromParent();
static void replaceCall(Instruction *I, Value *With) {
I->replaceAllUsesWith(With);
I->eraseFromParent();
}

static void replaceCall(FPMathOperator *I, Value *With) {
replaceCall(cast<Instruction>(I), With);
}

public:
Expand Down Expand Up @@ -501,15 +503,14 @@ bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
<< " with native version of sin/cos");

replaceCall(sinval);
replaceCall(aCI, sinval);
return true;
}
}
return false;
}

bool AMDGPULibCalls::useNative(CallInst *aCI) {
CI = aCI;
Function *Callee = aCI->getCalledFunction();
if (!Callee || aCI->isNoBuiltin())
return false;
Expand Down Expand Up @@ -601,7 +602,6 @@ bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,

// This function returns false if no change; return true otherwise.
bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
this->CI = CI;
Function *Callee = CI->getCalledFunction();
// Ignore indirect calls.
if (!Callee || CI->isNoBuiltin())
Expand Down Expand Up @@ -733,7 +733,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
nval = ConstantDataVector::get(context, tmp);
}
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
} else {
Expand All @@ -743,7 +743,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
if (CF->isExactlyValue(tr[i].input)) {
Value *nval = ConstantFP::get(CF->getType(), tr[i].result);
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
}
Expand All @@ -765,7 +765,7 @@ bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
opr0,
"recip2div");
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
return false;
Expand All @@ -786,7 +786,7 @@ bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0),
opr1, "__div2recip");
Value *nval = B.CreateFMul(opr0, nval1, "__div2mul");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
return false;
Expand All @@ -813,8 +813,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
ConstantFP *CF;
ConstantInt *CINT;
Type *eltType;
Value *opr0 = CI->getArgOperand(0);
Value *opr1 = CI->getArgOperand(1);
Value *opr0 = FPOp->getOperand(0);
Value *opr1 = FPOp->getOperand(1);
ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);

if (getVecSize(FInfo) == 1) {
Expand All @@ -841,37 +841,37 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,

if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
// pow/powr/pown(x, 0) == 1
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n");
Constant *cnval = ConstantFP::get(eltType, 1.0);
if (getVecSize(FInfo) > 1) {
cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
}
replaceCall(cnval);
replaceCall(FPOp, cnval);
return true;
}
if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
// pow/powr/pown(x, 1.0) = x
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
replaceCall(opr0);
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
replaceCall(FPOp, opr0);
return true;
}
if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
// pow/powr/pown(x, 2.0) = x*x
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0
<< "\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
<< *opr0 << "\n");
Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
// pow/powr/pown(x, -1.0) = 1.0/x
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1 / " << *opr0 << "\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n");
Constant *cnval = ConstantFP::get(eltType, 1.0);
if (getVecSize(FInfo) > 1) {
cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
}
Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}

Expand All @@ -882,11 +882,11 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
: AMDGPULibFunc::EI_RSQRT,
FInfo))) {
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << FInfo.getName()
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName()
<< '(' << *opr0 << ")\n");
Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
: "__pow2rsqrt");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
}
Expand Down Expand Up @@ -939,10 +939,10 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
}
nval = B.CreateFDiv(cnval, nval, "__1powprod");
}
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
<< ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
<< ")\n");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}

Expand Down Expand Up @@ -1066,7 +1066,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
nTy = FixedVectorType::get(nTyS, vTy);
unsigned size = nTy->getScalarSizeInBits();
opr_n = CI->getArgOperand(1);
opr_n = FPOp->getOperand(1);
if (opr_n->getType()->isIntegerTy())
opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
else
Expand All @@ -1078,9 +1078,9 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
nval = B.CreateBitCast(nval, opr0->getType());
}

LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
<< "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
replaceCall(nval);
replaceCall(FPOp, nval);

return true;
}
Expand All @@ -1100,43 +1100,44 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
}
int ci_opr1 = (int)CINT->getSExtValue();
if (ci_opr1 == 1) { // rootn(x, 1) = x
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
replaceCall(opr0);
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
replaceCall(FPOp, opr0);
return true;
}
if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
Module *M = CI->getModule();

Module *M = B.GetInsertBlock()->getModule();
if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
if (FunctionCallee FPExpr =
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
<< ")\n");
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
} else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
Module *M = CI->getModule();
if (FunctionCallee FPExpr =
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
<< ")\n");
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
} else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << "\n");
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n");
Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
opr0,
"__rootn2div");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
} else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
Module *M = CI->getModule();
} else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
if (FunctionCallee FPExpr =
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
<< ")\n");
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
}
Expand All @@ -1154,23 +1155,23 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) {
// fma/mad(a, b, c) = c if a=0 || b=0
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n");
replaceCall(opr2);
replaceCall(CI, opr2);
return true;
}
if (CF0 && CF0->isExactlyValue(1.0f)) {
// fma/mad(a, b, c) = b+c if a=1
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2
<< "\n");
Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
if (CF1 && CF1->isExactlyValue(1.0f)) {
// fma/mad(a, b, c) = a+c if b=1
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2
<< "\n");
Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
Expand All @@ -1179,7 +1180,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * "
<< *opr1 << "\n");
Value *nval = B.CreateFMul(opr0, opr1, "fmamul");
replaceCall(nval);
replaceCall(CI, nval);
return true;
}
}
Expand All @@ -1205,13 +1206,15 @@ bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,

if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
(FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
Module *M = B.GetInsertBlock()->getModule();

if (FunctionCallee FPExpr = getNativeFunction(
CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
Value *opr0 = CI->getArgOperand(0);
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
Value *opr0 = FPOp->getOperand(0);
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
<< "sqrt(" << *opr0 << ")\n");
Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
replaceCall(nval);
replaceCall(FPOp, nval);
return true;
}
}
Expand All @@ -1231,7 +1234,8 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,

bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;

Value *CArgVal = CI->getArgOperand(0);
Value *CArgVal = FPOp->getOperand(0);
CallInst *CI = cast<CallInst>(FPOp);
BasicBlock * const CBB = CI->getParent();

int const MaxScan = 30;
Expand All @@ -1247,7 +1251,7 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
CArgVal->replaceAllUsesWith(AvailableVal);
if (CArgVal->getNumUses() == 0)
LI->eraseFromParent();
CArgVal = CI->getArgOperand(0);
CArgVal = FPOp->getOperand(0);
}
}
}
Expand Down Expand Up @@ -1617,12 +1621,12 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
}
}

LLVMContext &context = CI->getParent()->getParent()->getContext();
LLVMContext &context = aCI->getContext();
Constant *nval0, *nval1;
if (FuncVecSize == 1) {
nval0 = ConstantFP::get(CI->getType(), DVal0[0]);
nval0 = ConstantFP::get(aCI->getType(), DVal0[0]);
if (hasTwoResults)
nval1 = ConstantFP::get(CI->getType(), DVal1[0]);
nval1 = ConstantFP::get(aCI->getType(), DVal1[0]);
} else {
if (getArgType(FInfo) == AMDGPULibFunc::F32) {
SmallVector <float, 0> FVal0, FVal1;
Expand Down Expand Up @@ -1653,7 +1657,7 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
new StoreInst(nval1, aCI->getArgOperand(1), aCI);
}

replaceCall(nval0);
replaceCall(aCI, nval0);
return true;
}

Expand Down

0 comments on commit c2c22c6

Please sign in to comment.