diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 4c9b10a094981..aa73175ab9325 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(APFloat F, const fltSemantics &Sem) { bool losesInfo; - APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } -static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { - if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) - return nullptr; // No constant folding of this. +static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F, + bool PreferBFloat) { // See if the value can be truncated to bfloat and then reextended. - if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat())) - return Type::getBFloatTy(CFP->getContext()); + if (PreferBFloat && fitsInFPType(F, APFloat::BFloat())) + return Type::getBFloatTy(Ctx); // See if the value can be truncated to half and then reextended. - if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf())) - return Type::getHalfTy(CFP->getContext()); + if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf())) + return Type::getHalfTy(Ctx); // See if the value can be truncated to float and then reextended. - if (fitsInFPType(CFP, APFloat::IEEEsingle())) - return Type::getFloatTy(CFP->getContext()); - if (CFP->getType()->isDoubleTy()) - return nullptr; // Won't shrink. - if (fitsInFPType(CFP, APFloat::IEEEdouble())) - return Type::getDoubleTy(CFP->getContext()); + if (fitsInFPType(F, APFloat::IEEEsingle())) + return Type::getFloatTy(Ctx); + if (&F.getSemantics() == &APFloat::IEEEdouble()) + return nullptr; // Won't shrink. + // See if the value can be truncated to double and then reextended. + if (fitsInFPType(F, APFloat::IEEEdouble())) + return Type::getDoubleTy(Ctx); // Don't try to shrink to various long double types. return nullptr; } +static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { + Type *Ty = CFP->getType(); + if (Ty->getScalarType()->isPPC_FP128Ty()) + return nullptr; // No constant folding of this. + + Type *ShrinkTy = + shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat); + if (ShrinkTy) + if (auto *VecTy = dyn_cast(Ty)) + ShrinkTy = VectorType::get(ShrinkTy, VecTy); + + return ShrinkTy; +} + // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { @@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) { // Try to shrink scalable and fixed splat vectors. if (auto *FPC = dyn_cast(V)) - if (isa(V->getType())) + if (auto *VTy = dyn_cast(V->getType())) if (auto *Splat = dyn_cast_or_null(FPC->getSplatValue())) if (Type *T = shrinkFPConstant(Splat, PreferBFloat)) - return T; + return VectorType::get(T, VTy); // Try to shrink a vector of FP constants. This returns nullptr on scalable // vectors @@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Type *Ty = FPT.getType(); auto *BO = dyn_cast(FPT.getOperand(0)); if (BO && BO->hasOneUse()) { - Type *LHSMinType = - getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy()); - Type *RHSMinType = - getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy()); + bool PreferBFloat = Ty->getScalarType()->isBFloatTy(); + Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat); unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll index 9125339c00ecf..a65b73b1ca75a 100644 --- a/llvm/test/Transforms/InstCombine/fpextend.ll +++ b/llvm/test/Transforms/InstCombine/fpextend.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-fp-for-fixed-length-splat -S | FileCheck %s define float @test(float %x) nounwind { ; CHECK-LABEL: @test( @@ -449,6 +450,28 @@ define bfloat @bf16_frem(bfloat %x) { ret bfloat %t3 } +define <4 x bfloat> @v4bf16_frem_x_const(<4 x bfloat> %x) { +; CHECK-LABEL: @v4bf16_frem_x_const( +; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> [[X:%.*]], splat (bfloat 0xR40C9) +; CHECK-NEXT: ret <4 x bfloat> [[TMP1]] +; + %t1 = fpext <4 x bfloat> %x to <4 x float> + %t2 = frem <4 x float> %t1, splat(float 6.281250e+00) + %t3 = fptrunc <4 x float> %t2 to <4 x bfloat> + ret <4 x bfloat> %t3 +} + +define <4 x bfloat> @v4bf16_frem_const_x(<4 x bfloat> %x) { +; CHECK-LABEL: @v4bf16_frem_const_x( +; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> splat (bfloat 0xR40C9), [[X:%.*]] +; CHECK-NEXT: ret <4 x bfloat> [[TMP1]] +; + %t1 = fpext <4 x bfloat> %x to <4 x float> + %t2 = frem <4 x float> splat(float 6.281250e+00), %t1 + %t3 = fptrunc <4 x float> %t2 to <4 x bfloat> + ret <4 x bfloat> %t3 +} + define <4 x float> @v4f32_fadd(<4 x float> %a) { ; CHECK-LABEL: @v4f32_fadd( ; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00) @@ -459,3 +482,16 @@ define <4 x float> @v4f32_fadd(<4 x float> %a) { %5 = fptrunc <4 x double> %4 to <4 x float> ret <4 x float> %5 } + +define <4 x float> @v4f32_fadd_const_not_shrinkable(<4 x float> %a) { +; CHECK-LABEL: @v4f32_fadd_const_not_shrinkable( +; CHECK-NEXT: [[TMP1:%.*]] = fpext <4 x float> [[A:%.*]] to <4 x double> +; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x double> [[TMP1]], splat (double -1.000000e+100) +; CHECK-NEXT: [[TMP3:%.*]] = fptrunc <4 x double> [[TMP2]] to <4 x float> +; CHECK-NEXT: ret <4 x float> [[TMP3]] +; + %2 = fpext <4 x float> %a to <4 x double> + %4 = fadd <4 x double> %2, splat (double -1.000000e+100) + %5 = fptrunc <4 x double> %4 to <4 x float> + ret <4 x float> %5 +}