Skip to content

Commit

Permalink
AMDGPU: Refactor AMDGPUCodeGenPrepare fdiv handling
Browse files Browse the repository at this point in the history
NFC-ish. Does trigger some reordering of the fdiv scalarization. Also
skips scalarizing in more cases where nothing was going to happen. We
can still scalarize in some no-op edge cases.

https://reviews.llvm.org/D155740
  • Loading branch information
arsenm committed Jul 21, 2023
1 parent 7c5e4ef commit 6699c37
Show file tree
Hide file tree
Showing 3 changed files with 1,287 additions and 1,328 deletions.
213 changes: 108 additions & 105 deletions llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class AMDGPUCodeGenPrepareImpl
DT);
}

bool canIgnoreDenormalInput(const Value *V, const Instruction *CtxI) const {
return HasFP32DenormalFlush ||
computeKnownFPClass(V, fcSubnormal, CtxI).isKnownNeverSubnormal();
}

/// Promotes uniform binary operation \p I to equivalent 32 bit binary
/// operation.
///
Expand Down Expand Up @@ -247,13 +252,22 @@ class AMDGPUCodeGenPrepareImpl
Value *matchFractPat(IntrinsicInst &I);
Value *applyFractPat(IRBuilder<> &Builder, Value *FractArg);

bool canOptimizeWithRsq(const FPMathOperator *SqrtOp, FastMathFlags DivFMF,
FastMathFlags SqrtFMF) const;

Value *optimizeWithRsq(IRBuilder<> &Builder, Value *Num, Value *Den,
FastMathFlags DivFMF, FastMathFlags SqrtFMF,
const Instruction *CtxI, bool AllowApproxRsq) const;
const Instruction *CtxI) const;

Value *optimizeWithRcp(IRBuilder<> &Builder, Value *Num, Value *Den,
FastMathFlags FMF, const Instruction *CtxI,
bool AllowInaccurateRcp, bool RcpIsAccurate) const;
FastMathFlags FMF, const Instruction *CtxI) const;
Value *optimizeWithFDivFast(IRBuilder<> &Builder, Value *Num, Value *Den,
float ReqdAccuracy) const;

Value *visitFDivElement(IRBuilder<> &Builder, Value *Num, Value *Den,
FastMathFlags DivFMF, FastMathFlags SqrtFMF,
Value *RsqOp, const Instruction *FDiv,
float ReqdAccuracy) const;

public:
bool visitFDiv(BinaryOperator &I);
Expand Down Expand Up @@ -815,13 +829,27 @@ static Value *emitRsqIEEE1ULP(IRBuilder<> &Builder, Value *Src,
return Builder.CreateFMul(Rsq, OutputScaleFactor);
}

bool AMDGPUCodeGenPrepareImpl::canOptimizeWithRsq(const FPMathOperator *SqrtOp,
FastMathFlags DivFMF,
FastMathFlags SqrtFMF) const {
// The rsqrt contraction increases accuracy from ~2ulp to ~1ulp.
if (!DivFMF.allowContract() || !SqrtFMF.allowContract())
return false;

// v_rsq_f32 gives 1ulp
return SqrtFMF.approxFunc() || HasUnsafeFPMath ||
SqrtOp->getFPAccuracy() >= 1.0f;
}

Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq(
IRBuilder<> &Builder, Value *Num, Value *Den, FastMathFlags DivFMF,
FastMathFlags SqrtFMF, const Instruction *CtxI, bool AllowApproxRsq) const {
FastMathFlags SqrtFMF, const Instruction *CtxI) const {
// The rsqrt contraction increases accuracy from ~2ulp to ~1ulp.
if (!DivFMF.allowContract() || !SqrtFMF.allowContract())
return nullptr;
assert(DivFMF.allowContract() && SqrtFMF.allowContract());

// rsq_f16 is accurate to 0.51 ulp.
// rsq_f32 is accurate for !fpmath >= 1.0ulp and denormals are flushed.
// rsq_f64 is never accurate.
const ConstantFP *CLHS = dyn_cast<ConstantFP>(Num);
if (!CLHS)
return nullptr;
Expand All @@ -830,14 +858,16 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq(
assert(Ty->isFloatTy());

bool IsNegative = false;

// TODO: Handle other numerator values with arcp.
if (CLHS->isExactlyValue(1.0) || (IsNegative = CLHS->isExactlyValue(-1.0))) {
// Add in the sqrt flags.
IRBuilder<>::FastMathFlagGuard Guard(Builder);
DivFMF |= SqrtFMF;
Builder.setFastMathFlags(DivFMF);

if (HasFP32DenormalFlush || AllowApproxRsq ||
computeKnownFPClass(Den, fcSubnormal, CtxI).isKnownNeverSubnormal()) {
if ((DivFMF.approxFunc() && SqrtFMF.approxFunc()) ||
canIgnoreDenormalInput(Den, CtxI)) {
Value *Result = Builder.CreateUnaryIntrinsic(Intrinsic::amdgcn_rsq, Den);
// -1.0 / sqrt(x) -> fneg(rsq(x))
return IsNegative ? Builder.CreateFNeg(Result) : Result;
Expand All @@ -855,14 +885,13 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRsq(
// allowed with unsafe-fp-math or afn.
//
// a/b -> a*rcp(b) when arcp is allowed, and we only need provide ULP 1.0
Value *AMDGPUCodeGenPrepareImpl::optimizeWithRcp(IRBuilder<> &Builder,
Value *Num, Value *Den,
FastMathFlags FMF,
const Instruction *CtxI,
bool AllowInaccurateRcp,
bool RcpIsAccurate) const {
assert(AllowInaccurateRcp || RcpIsAccurate);

Value *
AMDGPUCodeGenPrepareImpl::optimizeWithRcp(IRBuilder<> &Builder, Value *Num,
Value *Den, FastMathFlags FMF,
const Instruction *CtxI) const {
// rcp_f16 is accurate to 0.51 ulp.
// rcp_f32 is accurate for !fpmath >= 1.0ulp and denormals are flushed.
// rcp_f64 is never accurate.
Type *Ty = Den->getType();
assert(Ty->isFloatTy());

Expand All @@ -872,7 +901,7 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRcp(IRBuilder<> &Builder,
(IsNegative = CLHS->isExactlyValue(-1.0))) {
Value *Src = Den;

if (HasFP32DenormalFlush || AllowInaccurateRcp) {
if (HasFP32DenormalFlush || FMF.approxFunc()) {
// -1.0 / x -> 1.0 / fneg(x)
if (IsNegative)
Src = Builder.CreateFNeg(Src);
Expand Down Expand Up @@ -902,7 +931,7 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRcp(IRBuilder<> &Builder,

// TODO: Could avoid denormal scaling and use raw rcp if we knew the output
// will never underflow.
if (AllowInaccurateRcp || HasFP32DenormalFlush) {
if (HasFP32DenormalFlush || FMF.approxFunc()) {
Value *Recip = Builder.CreateUnaryIntrinsic(Intrinsic::amdgcn_rcp, Den);
return Builder.CreateFMul(Num, Recip);
}
Expand All @@ -921,9 +950,8 @@ Value *AMDGPUCodeGenPrepareImpl::optimizeWithRcp(IRBuilder<> &Builder,
// 1/x -> fdiv.fast(1,x) when !fpmath >= 2.5ulp.
//
// NOTE: optimizeWithRcp should be tried first because rcp is the preference.
static Value *optimizeWithFDivFast(Value *Num, Value *Den, float ReqdAccuracy,
bool HasFP32DenormalFlush,
IRBuilder<> &Builder, Module *Mod) {
Value *AMDGPUCodeGenPrepareImpl::optimizeWithFDivFast(
IRBuilder<> &Builder, Value *Num, Value *Den, float ReqdAccuracy) const {
// fdiv.fast can achieve 2.5 ULP accuracy.
if (ReqdAccuracy < 2.5f)
return nullptr;
Expand All @@ -943,8 +971,25 @@ static Value *optimizeWithFDivFast(Value *Num, Value *Den, float ReqdAccuracy,
if (!HasFP32DenormalFlush && !NumIsOne)
return nullptr;

Function *Decl = Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_fdiv_fast);
return Builder.CreateCall(Decl, { Num, Den });
return Builder.CreateIntrinsic(Intrinsic::amdgcn_fdiv_fast, {}, {Num, Den});
}

Value *AMDGPUCodeGenPrepareImpl::visitFDivElement(
IRBuilder<> &Builder, Value *Num, Value *Den, FastMathFlags DivFMF,
FastMathFlags SqrtFMF, Value *RsqOp, const Instruction *FDivInst,
float ReqdDivAccuracy) const {
if (RsqOp) {
Value *Rsq =
optimizeWithRsq(Builder, Num, RsqOp, DivFMF, SqrtFMF, FDivInst);
if (Rsq)
return Rsq;
}

Value *Rcp = optimizeWithRcp(Builder, Num, Den, DivFMF, FDivInst);
if (Rcp)
return Rcp;

return optimizeWithFDivFast(Builder, Num, Den, ReqdDivAccuracy);
}

// Optimizations is performed based on fpmath, fast math flags as well as
Expand Down Expand Up @@ -975,8 +1020,7 @@ bool AMDGPUCodeGenPrepareImpl::visitFDiv(BinaryOperator &FDiv) {

const FPMathOperator *FPOp = cast<const FPMathOperator>(&FDiv);
const FastMathFlags DivFMF = FPOp->getFastMathFlags();

const float ReqdAccuracy = FPOp->getFPAccuracy();
const float ReqdAccuracy = FPOp->getFPAccuracy();

// Inaccurate rcp is allowed with unsafe-fp-math or afn.
//
Expand All @@ -991,110 +1035,69 @@ bool AMDGPUCodeGenPrepareImpl::visitFDiv(BinaryOperator &FDiv) {
if (AllowInaccurateRcp)
return false;

bool AllowApproxRsq = false;
// Defer the correct implementations to codegen.
if (ReqdAccuracy < 1.0f)
return false;

FastMathFlags SqrtFMF;

// rcp_f16 is accurate to 0.51 ulp.
// rcp_f32 is accurate for !fpmath >= 1.0ulp and denormals are flushed.
// rcp_f64 is never accurate.
const bool RcpIsAccurate = ReqdAccuracy >= 1.0f;
Value *Num = FDiv.getOperand(0);
Value *Den = FDiv.getOperand(1);

Value *RsqOp = nullptr;
auto *DenII = dyn_cast<IntrinsicInst>(Den);
if (DenII && DenII->getIntrinsicID() == Intrinsic::sqrt &&
DenII->hasOneUse() && (RcpIsAccurate || AllowInaccurateRcp)) {
DenII->hasOneUse()) {
const auto *SqrtOp = cast<FPMathOperator>(DenII);
AllowApproxRsq = HasUnsafeFPMath || SqrtOp->hasApproxFunc();

if (AllowApproxRsq || SqrtOp->getFPAccuracy() >= 1.0f) {
SqrtFMF = SqrtOp->getFastMathFlags();
SqrtFMF = SqrtOp->getFastMathFlags();
if (canOptimizeWithRsq(SqrtOp, DivFMF, SqrtFMF))
RsqOp = SqrtOp->getOperand(0);
}
}

IRBuilder<> Builder(FDiv.getParent(), std::next(FDiv.getIterator()));
Builder.setFastMathFlags(DivFMF);
Builder.SetCurrentDebugLocation(FDiv.getDebugLoc());

Value *NewFDiv = nullptr;
if (auto *VT = dyn_cast<FixedVectorType>(FDiv.getType())) {
NewFDiv = PoisonValue::get(VT);

// FIXME: Doesn't do the right thing for cases where the vector is partially
// constant. This works when the scalarizer pass is run first.
for (unsigned I = 0, E = VT->getNumElements(); I != E; ++I) {
Value *NumEltI = Builder.CreateExtractElement(Num, I);

Value *NewElt = nullptr;
if (RsqOp) {
Value *DenEltI = Builder.CreateExtractElement(RsqOp, I);
NewElt = optimizeWithRsq(Builder, NumEltI, DenEltI, DivFMF, SqrtFMF,
&FDiv, AllowApproxRsq);
if (!NewElt) {
// TODO: Avoid inserting dead extract in the first place
if (Instruction *Extract = dyn_cast<Instruction>(DenEltI))
Extract->eraseFromParent();
}
}

Value *DenEltI = nullptr;

if (!NewElt && (RcpIsAccurate || AllowInaccurateRcp)) {
DenEltI = Builder.CreateExtractElement(Den, I);

// Try rcp first.
NewElt = optimizeWithRcp(Builder, NumEltI, DenEltI, DivFMF,
cast<Instruction>(FPOp), AllowInaccurateRcp,
RcpIsAccurate);
if (!NewElt) // Try fdiv.fast.
NewElt = optimizeWithFDivFast(NumEltI, DenEltI, ReqdAccuracy,
HasFP32DenormalFlush, Builder, Mod);
}

if (!NewElt) {
if (!DenEltI)
DenEltI = Builder.CreateExtractElement(Den, I);

// Keep the original, but scalarized.
Value *ScalarDiv = Builder.CreateFDiv(NumEltI, DenEltI);
if (auto *ScalarDivInst = dyn_cast<Instruction>(ScalarDiv))
ScalarDivInst->copyMetadata(FDiv);
NewElt = ScalarDiv;
}

NewFDiv = Builder.CreateInsertElement(NewFDiv, NewElt, I);
}
} else { // Scalar FDiv.
if (RsqOp) {
NewFDiv = optimizeWithRsq(Builder, Num, RsqOp, DivFMF, SqrtFMF,
cast<Instruction>(FPOp), AllowApproxRsq);
SmallVector<Value *, 4> NumVals;
SmallVector<Value *, 4> DenVals;
SmallVector<Value *, 4> RsqDenVals;
extractValues(Builder, NumVals, Num);
extractValues(Builder, DenVals, Den);

if (RsqOp)
extractValues(Builder, RsqDenVals, RsqOp);

SmallVector<Value *, 4> ResultVals(NumVals.size());
for (int I = 0, E = NumVals.size(); I != E; ++I) {
Value *NumElt = NumVals[I];
Value *DenElt = DenVals[I];
Value *RsqDenElt = RsqOp ? RsqDenVals[I] : nullptr;

Value *NewElt =
visitFDivElement(Builder, NumElt, DenElt, DivFMF, SqrtFMF, RsqDenElt,
cast<Instruction>(FPOp), ReqdAccuracy);
if (!NewElt) {
// Keep the original, but scalarized.

// This has the unfortunate side effect of sometimes scalarizing when
// we're not going to do anything.
NewElt = Builder.CreateFDiv(NumElt, DenElt);
if (auto *NewEltInst = dyn_cast<Instruction>(NewElt))
NewEltInst->copyMetadata(FDiv);
}

if (!NewFDiv) {
// Try rcp first.
if (RcpIsAccurate || AllowInaccurateRcp) {
NewFDiv =
optimizeWithRcp(Builder, Num, Den, DivFMF, cast<Instruction>(FPOp),
AllowInaccurateRcp, RcpIsAccurate);
}

if (!NewFDiv) { // Try fdiv.fast.
NewFDiv = optimizeWithFDivFast(Num, Den, ReqdAccuracy,
HasFP32DenormalFlush, Builder, Mod);
}
}
ResultVals[I] = NewElt;
}

if (NewFDiv) {
FDiv.replaceAllUsesWith(NewFDiv);
NewFDiv->takeName(&FDiv);
Value *NewVal = insertValues(Builder, FDiv.getType(), ResultVals);

if (NewVal) {
FDiv.replaceAllUsesWith(NewVal);
NewVal->takeName(&FDiv);
RecursivelyDeleteTriviallyDeadInstructions(&FDiv, TLInfo);
}

return !!NewFDiv;
return true;
}

static bool hasUnsafeFPMath(const Function &F) {
Expand Down
Loading

0 comments on commit 6699c37

Please sign in to comment.