Skip to content

Commit

Permalink
[InstCombine] Add folds for (fp_binop ({s|u}itofp x), ({s|u}itofp y))
Browse files Browse the repository at this point in the history
The full fold is one of the following:
1) `(fp_binop ({s|u}itofp x), ({s|u}itofp y))`
    -> `({s|u}itofp (int_binop x, y))`
2) `(fp_binop ({s|u}itofp x), FpC)`
    -> `({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))`

And support the following binops:
    `fmul` -> `mul`
    `fadd` -> `add`
    `fsub` -> `sub`

Proofs: https://alive2.llvm.org/ce/z/zuacA8

The proofs timeout, so they must be reproduced locally.

Closes #82555
  • Loading branch information
goldsteinn committed Mar 6, 2024
1 parent 0f5849e commit 946ea4e
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 150 deletions.
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2793,6 +2793,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;

if (Instruction *R = foldFBinOpOfIntCasts(I))
return R;

Value *X, *Y;
Constant *C;

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Instruction *R = foldFPSignBitOps(I))
return R;

if (Instruction *R = foldFBinOpOfIntCasts(I))
return R;

// X * -1.0 --> -X
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (match(Op1, m_SpecificFP(-1.0)))
Expand Down
223 changes: 164 additions & 59 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,71 +1402,176 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
}

// Try to fold:
// 1) (add (sitofp x), (sitofp y))
// -> (sitofp (add x, y))
// 2) (add (sitofp x), FpC)
// -> (sitofp (add x, (fptosi FpC)))
// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
// -> ({s|u}itofp (int_binop x, y))
// 2) (fp_binop ({s|u}itofp x), FpC)
// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
Value *LHS = BO.getOperand(0), *RHS = BO.getOperand(1);
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
Value *LHSIntVal = LHSConv->getOperand(0);
Type *FPType = LHSConv->getType();

// TODO: This check is overly conservative. In many cases known bits
// analysis can tell us that the result of the addition has less significant
// bits than the integer type can hold.
auto IsValidPromotion = [](Type *FTy, Type *ITy) {
Type *FScalarTy = FTy->getScalarType();
Type *IScalarTy = ITy->getScalarType();

// Do we have enough bits in the significand to represent the result of
// the integer addition?
unsigned MaxRepresentableBits =
APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
};
Value *IntOps[2] = {nullptr, nullptr};
Constant *Op1FpC = nullptr;

// Check for:
// 1) (binop ({s|u}itofp x), ({s|u}itofp y))
// 2) (binop ({s|u}itofp x), FpC)
if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
!match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
return nullptr;

// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
// ... if the constant fits in the integer value. This is useful for things
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
// requires a constant pool load, and generally allows the add to be better
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
LHSIntVal->getType(), DL);
if (LHSConv->hasOneUse() &&
ConstantFoldCastOperand(Instruction::SIToFP, CI, BO.getType(),
DL) == CFP &&
willNotOverflowSignedAdd(LHSIntVal, CI, BO)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI);
return new SIToFPInst(NewAdd, BO.getType());
}
}
if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
!match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
!match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
return nullptr;

// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
Value *RHSIntVal = RHSConv->getOperand(0);
// It's enough to check LHS types only because we require int types to
// be the same for this transform.
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
// Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of int->fp conversions),
// and if the integer add will not overflow.
if (LHSIntVal->getType() == RHSIntVal->getType() &&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, BO)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal);
return new SIToFPInst(NewAdd, BO.getType());
}
Type *FPTy = BO.getType();
Type *IntTy = IntOps[0]->getType();

// Do we have signed casts?
bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));

unsigned IntSz = IntTy->getScalarSizeInBits();
// This is the maximum number of inuse bits by the integer where the int -> fp
// casts are exact.
unsigned MaxRepresentableBits =
APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());

// Cache KnownBits a bit to potentially save some analysis.
WithCache<const Value *> OpsKnown[2] = {IntOps[0], IntOps[1]};

// Preserve known number of leading bits. This can allow us to trivial nsw/nuw
// checks later on.
unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};

auto IsNonZero = [&](unsigned OpNo) -> bool {
if (OpsKnown[OpNo].hasKnownBits() &&
OpsKnown[OpNo].getKnownBits(SQ).isNonZero())
return true;
return isKnownNonZero(IntOps[OpNo], SQ.DL);
};

auto IsNonNeg = [&](unsigned OpNo) -> bool {
if (OpsKnown[OpNo].hasKnownBits() &&
OpsKnown[OpNo].getKnownBits(SQ).isNonNegative())
return true;
return isKnownNonNegative(IntOps[OpNo], SQ);
};

// Check if we know for certain that ({s|u}itofp op) is exact.
auto IsValidPromotion = [&](unsigned OpNo) -> bool {
// If fp precision >= bitwidth(op) then its exact.
// NB: This is slightly conservative for `sitofp`. For signed conversion, we
// can handle `MaxRepresentableBits == IntSz - 1` as the sign bit will be
// handled specially. We can't, however, increase the bound arbitrarily for
// `sitofp` as for larger sizes, it won't sign extend.
if (MaxRepresentableBits < IntSz) {
// Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
// numSignBits(op).
// TODO: If we add support for `WithCache` in `ComputeNumSignBits`, change
// `IntOps[OpNo]` arguments to `KnownOps[OpNo]`.
if (OpsFromSigned)
NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
// Finally for unsigned check that fp precision >= bitwidth(op) -
// numLeadingZeros(op).
else {
NumUsedLeadingBits[OpNo] =
IntSz - OpsKnown[OpNo].getKnownBits(SQ).countMinLeadingZeros();
}
}
// NB: We could also check if op is known to be a power of 2 or zero (which
// will always be representable). Its unlikely, however, that is we are
// unable to bound op in any way we will be able to pass the overflow checks
// later on.

if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
return false;
// Signed + Mul also requires that op is non-zero to avoid -0 cases.
return !OpsFromSigned || BO.getOpcode() != Instruction::FMul ||
IsNonZero(OpNo);
};

// If we have a constant rhs, see if we can losslessly convert it to an int.
if (Op1FpC != nullptr) {
Constant *Op1IntC = ConstantFoldCastOperand(
OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
IntTy, DL);
if (Op1IntC == nullptr)
return nullptr;
if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
: Instruction::UIToFP,
Op1IntC, FPTy, DL) != Op1FpC)
return nullptr;

// First try to keep sign of cast the same.
IntOps[1] = Op1IntC;
}
return nullptr;

// Ensure lhs/rhs integer types match.
if (IntTy != IntOps[1]->getType())
return nullptr;

if (Op1FpC == nullptr) {
if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
// If we have a signed + unsigned, see if we can treat both as signed
// (uitofp nneg x) == (sitofp nneg x).
if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
return nullptr;
OpsFromSigned = true;
}
if (!IsValidPromotion(1))
return nullptr;
}
if (!IsValidPromotion(0))
return nullptr;

// Final we check if the integer version of the binop will not overflow.
BinaryOperator::BinaryOps IntOpc;
// Because of the precision check, we can often rule out overflows.
bool NeedsOverflowCheck = true;
// Try to conservatively rule out overflow based on the already done precision
// checks.
unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1;
unsigned OverflowMaxCurBits =
std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]);
bool OutputSigned = OpsFromSigned;
switch (BO.getOpcode()) {
case Instruction::FAdd:
IntOpc = Instruction::Add;
OverflowMaxOutputBits += OverflowMaxCurBits;
break;
case Instruction::FSub:
IntOpc = Instruction::Sub;
OverflowMaxOutputBits += OverflowMaxCurBits;
break;
case Instruction::FMul:
IntOpc = Instruction::Mul;
OverflowMaxOutputBits += OverflowMaxCurBits * 2;
break;
default:
llvm_unreachable("Unsupported binop");
}
// The precision check may have already ruled out overflow.
if (OverflowMaxOutputBits < IntSz) {
NeedsOverflowCheck = false;
// We can bound unsigned overflow from sub to in range signed value (this is
// what allows us to avoid the overflow check for sub).
if (IntOpc == Instruction::Sub)
OutputSigned = true;
}

// Precision check did not rule out overflow, so need to check.
// TODO: If we add support for `WithCache` in `willNotOverflow`, change
// `IntOps[...]` arguments to `KnownOps[...]`.
if (NeedsOverflowCheck &&
!willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned))
return nullptr;

Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]);
if (auto *IntBO = dyn_cast<BinaryOperator>(IntBinOp)) {
IntBO->setHasNoSignedWrap(OutputSigned);
IntBO->setHasNoUnsignedWrap(!OutputSigned);
}
if (OutputSigned)
return new SIToFPInst(IntBinOp, FPTy);
return new UIToFPInst(IntBinOp, FPTy);
}

/// A binop with a constant operand and a sign-extended boolean operand may be
Expand Down
8 changes: 3 additions & 5 deletions llvm/test/Transforms/InstCombine/add-sitofp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,13 @@ define float @test_2_neg(i32 %a, i32 %b) {
ret float %res
}

; This test demonstrates overly conservative legality check. The float addition
; can be replaced with the integer addition because the result of the operation
; can be represented in float, but we don't do that now.
; can be represented in float.
define float @test_3(i32 %a, i32 %b) {
; CHECK-LABEL: @test_3(
; CHECK-NEXT: [[M:%.*]] = lshr i32 [[A:%.*]], 24
; CHECK-NEXT: [[N:%.*]] = and i32 [[M]], [[B:%.*]]
; CHECK-NEXT: [[O:%.*]] = sitofp i32 [[N]] to float
; CHECK-NEXT: [[P:%.*]] = fadd float [[O]], 1.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i32 [[N]], 1
; CHECK-NEXT: [[P:%.*]] = sitofp i32 [[TMP1]] to float
; CHECK-NEXT: ret float [[P]]
;
%m = lshr i32 %a, 24
Expand Down
Loading

0 comments on commit 946ea4e

Please sign in to comment.