Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {

/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
bool losesInfo;
APFloat F = CFP->getValueAPF();
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}

static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
return nullptr; // No constant folding of this.
static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F,
bool PreferBFloat) {
// See if the value can be truncated to bfloat and then reextended.
if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
return Type::getBFloatTy(CFP->getContext());
if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
return Type::getBFloatTy(Ctx);
// See if the value can be truncated to half and then reextended.
if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
return Type::getHalfTy(CFP->getContext());
if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
return Type::getHalfTy(Ctx);
// See if the value can be truncated to float and then reextended.
if (fitsInFPType(CFP, APFloat::IEEEsingle()))
return Type::getFloatTy(CFP->getContext());
if (CFP->getType()->isDoubleTy())
return nullptr; // Won't shrink.
if (fitsInFPType(CFP, APFloat::IEEEdouble()))
return Type::getDoubleTy(CFP->getContext());
if (fitsInFPType(F, APFloat::IEEEsingle()))
return Type::getFloatTy(Ctx);
if (&F.getSemantics() == &APFloat::IEEEdouble())
return nullptr; // Won't shrink.
// See if the value can be truncated to double and then reextended.
if (fitsInFPType(F, APFloat::IEEEdouble()))
return Type::getDoubleTy(Ctx);
// Don't try to shrink to various long double types.
return nullptr;
}

static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
Type *Ty = CFP->getType();
if (Ty->getScalarType()->isPPC_FP128Ty())
return nullptr; // No constant folding of this.

Type *ShrinkTy =
shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
if (ShrinkTy)
if (auto *VecTy = dyn_cast<VectorType>(Ty))
ShrinkTy = VectorType::get(ShrinkTy, VecTy);

return ShrinkTy;
}

// Determine if this is a vector of ConstantFPs and if so, return the minimal
// type we can safely truncate all elements to.
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
Expand Down Expand Up @@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {

// Try to shrink scalable and fixed splat vectors.
if (auto *FPC = dyn_cast<Constant>(V))
if (isa<VectorType>(V->getType()))
if (auto *VTy = dyn_cast<VectorType>(V->getType()))
if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
return T;
return VectorType::get(T, VTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks right to me, but I don't really get how this worked previously. E.g. v4f32_fadd has a vector splat, why did that work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the Instruction::FRem handling uses getMinimumFPType()'s return type and even then the bug only occurs when LHS is the vector splat of which there were no tests.


// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
Expand Down Expand Up @@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
Type *Ty = FPT.getType();
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
if (BO && BO->hasOneUse()) {
Type *LHSMinType =
getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
Type *RHSMinType =
getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/Transforms/InstCombine/fpextend.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
; RUN: opt < %s -passes=instcombine -use-constant-fp-for-fixed-length-splat -S | FileCheck %s

define float @test(float %x) nounwind {
; CHECK-LABEL: @test(
Expand Down Expand Up @@ -449,6 +450,28 @@ define bfloat @bf16_frem(bfloat %x) {
ret bfloat %t3
}

define <4 x bfloat> @v4bf16_frem_x_const(<4 x bfloat> %x) {
; CHECK-LABEL: @v4bf16_frem_x_const(
; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> [[X:%.*]], splat (bfloat 0xR40C9)
; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
;
%t1 = fpext <4 x bfloat> %x to <4 x float>
%t2 = frem <4 x float> %t1, splat(float 6.281250e+00)
%t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
ret <4 x bfloat> %t3
}

define <4 x bfloat> @v4bf16_frem_const_x(<4 x bfloat> %x) {
; CHECK-LABEL: @v4bf16_frem_const_x(
; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> splat (bfloat 0xR40C9), [[X:%.*]]
; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
;
%t1 = fpext <4 x bfloat> %x to <4 x float>
%t2 = frem <4 x float> splat(float 6.281250e+00), %t1
%t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
ret <4 x bfloat> %t3
}

define <4 x float> @v4f32_fadd(<4 x float> %a) {
; CHECK-LABEL: @v4f32_fadd(
; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
Expand All @@ -459,3 +482,16 @@ define <4 x float> @v4f32_fadd(<4 x float> %a) {
%5 = fptrunc <4 x double> %4 to <4 x float>
ret <4 x float> %5
}

define <4 x float> @v4f32_fadd_const_not_shrinkable(<4 x float> %a) {
; CHECK-LABEL: @v4f32_fadd_const_not_shrinkable(
; CHECK-NEXT: [[TMP1:%.*]] = fpext <4 x float> [[A:%.*]] to <4 x double>
; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x double> [[TMP1]], splat (double -1.000000e+100)
; CHECK-NEXT: [[TMP3:%.*]] = fptrunc <4 x double> [[TMP2]] to <4 x float>
; CHECK-NEXT: ret <4 x float> [[TMP3]]
;
%2 = fpext <4 x float> %a to <4 x double>
%4 = fadd <4 x double> %2, splat (double -1.000000e+100)
%5 = fptrunc <4 x double> %4 to <4 x float>
ret <4 x float> %5
}
Loading