- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[LLVM][InstCombine] Preserve vector types when shrinking FP constants. #163598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][InstCombine] Preserve vector types when shrinking FP constants. #163598
Conversation
While my original objective was to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues for ConstantVector based splats: 1. PreferBFloat is not set for bfloat vectors. 2. getMinimumFPType() returns a scalar type for vector constants where getSplatValue() is successful. Please let me know if you rather me upstream those fixes seperate from the use-constant-fp-for-fixed-length-splat support.
| 
          
 @llvm/pr-subscribers-llvm-transforms Author: Paul Walker (paulwalker-arm) ChangesWhile my objective is to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues also affect ConstantVector based splats: 
 Please let me know if you rather I upstream those fixes separate from the use-constant-fp-for-fixed-length-splat support, which is mainly just refactoring shrinkFPConstant. Full diff: https://github.com/llvm/llvm-project/pull/163598.diff 2 Files Affected: 
 diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4c9b10a094981..46857b7f66bfe 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1643,33 +1643,43 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
 
 /// Return a Constant* for the specified floating-point constant if it fits
 /// in the specified FP type without changing its value.
-static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
+static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
   bool losesInfo;
-  APFloat F = CFP->getValueAPF();
   (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
   return !losesInfo;
 }
 
-static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
-  if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
-    return nullptr;  // No constant folding of this.
+static Type *shrinkFPConstant(LLVMContext &Ctx, APFloat F, bool PreferBFloat) {
   // See if the value can be truncated to bfloat and then reextended.
-  if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
-    return Type::getBFloatTy(CFP->getContext());
+  if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
+    return Type::getBFloatTy(Ctx);
   // See if the value can be truncated to half and then reextended.
-  if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
-    return Type::getHalfTy(CFP->getContext());
+  if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
+    return Type::getHalfTy(Ctx);
   // See if the value can be truncated to float and then reextended.
-  if (fitsInFPType(CFP, APFloat::IEEEsingle()))
-    return Type::getFloatTy(CFP->getContext());
-  if (CFP->getType()->isDoubleTy())
-    return nullptr;  // Won't shrink.
-  if (fitsInFPType(CFP, APFloat::IEEEdouble()))
-    return Type::getDoubleTy(CFP->getContext());
-  // Don't try to shrink to various long double types.
+  if (fitsInFPType(F, APFloat::IEEEsingle()))
+    return Type::getFloatTy(Ctx);
+  // See if the value can be truncated to double and then reextended.
+  if (fitsInFPType(F, APFloat::IEEEdouble()))
+    return Type::getDoubleTy(Ctx);
+  // Does not shrink.
   return nullptr;
 }
 
+static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
+  Type *Ty = CFP->getType();
+  if (Ty->getScalarType() == Type::getPPC_FP128Ty(CFP->getContext()))
+    return nullptr; // No constant folding of this.
+
+  Type *ShrinkTy =
+      shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
+  if (auto *VecTy = dyn_cast<VectorType>(Ty))
+    ShrinkTy = VectorType::get(ShrinkTy, VecTy);
+
+  // Does it shrink?
+  return ShrinkTy != Ty ? ShrinkTy : nullptr;
+}
+
 // Determine if this is a vector of ConstantFPs and if so, return the minimal
 // type we can safely truncate all elements to.
 static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
@@ -1720,10 +1730,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
 
   // Try to shrink scalable and fixed splat vectors.
   if (auto *FPC = dyn_cast<Constant>(V))
-    if (isa<VectorType>(V->getType()))
+    if (auto *VTy = dyn_cast<VectorType>(V->getType()))
       if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
         if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
-          return T;
+          return VectorType::get(T, VTy);
 
   // Try to shrink a vector of FP constants. This returns nullptr on scalable
   // vectors
@@ -1796,10 +1806,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
   Type *Ty = FPT.getType();
   auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
   if (BO && BO->hasOneUse()) {
-    Type *LHSMinType =
-        getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
-    Type *RHSMinType =
-        getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
+    bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
+    Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
+    Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
     unsigned OpWidth = BO->getType()->getFPMantissaWidth();
     unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
     unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index 9125339c00ecf..076a5280eaf63 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine --use-constant-fp-for-fixed-length-splat -S | FileCheck %s
 
 define float @test(float %x) nounwind  {
 ; CHECK-LABEL: @test(
@@ -449,6 +450,28 @@ define bfloat @bf16_frem(bfloat %x) {
   ret bfloat %t3
 }
 
+define <4 x bfloat> @v4bf16_frem_x_const(<4 x bfloat> %x) {
+; CHECK-LABEL: @v4bf16_frem_x_const(
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <4 x bfloat> [[X:%.*]], splat (bfloat 0xR40C9)
+; CHECK-NEXT:    ret <4 x bfloat> [[TMP1]]
+;
+  %t1 = fpext <4 x bfloat> %x to <4 x float>
+  %t2 = frem <4 x float> %t1, splat(float 6.281250e+00)
+  %t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
+  ret <4 x bfloat> %t3
+}
+
+define <4 x bfloat> @v4bf16_frem_const_x(<4 x bfloat> %x) {
+; CHECK-LABEL: @v4bf16_frem_const_x(
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <4 x bfloat> splat (bfloat 0xR40C9), [[X:%.*]]
+; CHECK-NEXT:    ret <4 x bfloat> [[TMP1]]
+;
+  %t1 = fpext <4 x bfloat> %x to <4 x float>
+  %t2 = frem <4 x float> splat(float 6.281250e+00), %t1
+  %t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
+  ret <4 x bfloat> %t3
+}
+
 define <4 x float> @v4f32_fadd(<4 x float> %a) {
 ; CHECK-LABEL: @v4f32_fadd(
 ; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
 | 
    
| if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue())) | ||
| if (Type *T = shrinkFPConstant(Splat, PreferBFloat)) | ||
| return T; | ||
| return VectorType::get(T, VTy); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks right to me, but I don't really get how this worked previously. E.g. v4f32_fadd has a vector splat, why did that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the Instruction::FRem handling uses getMinimumFPType()'s return type and even then the bug only occurs when LHS is the vector splat of which there were no tests.
| Type *ShrinkTy = | ||
| shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat); | ||
| if (auto *VecTy = dyn_cast<VectorType>(Ty)) | ||
| ShrinkTy = VectorType::get(ShrinkTy, VecTy); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShrinkTy may be null.
… for unshrinkable constants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
llvm#163598) While my objective is to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues also affect ConstantVector based splats: 1. PreferBFloat is not set for bfloat vectors. 2. getMinimumFPType() returns a scalar type for vector constants where getSplatValue() is successful.
llvm#163598) While my objective is to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues also affect ConstantVector based splats: 1. PreferBFloat is not set for bfloat vectors. 2. getMinimumFPType() returns a scalar type for vector constants where getSplatValue() is successful.
While my objective is to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues also affect ConstantVector based splats:
Please let me know if you rather I upstream those fixes separate from the use-constant-fp-for-fixed-length-splat support, which is mainly just refactoring shrinkFPConstant.